Masked Attentions in Transformer Architectures

 

Introduction

Masked attention is typically used in the Decoder part of (vanilla) transformer architecture to prevent the model from looking at future tokens. This left me with the impression that the models trained using Masked Language Modelling (MLM) objective use Masked Attention. However, masked language models like BERT (Bidirectional Encoded Representation from Transformers) DO NOT USE MASK in the attention layer. I will come to this part at the end. Despite the difference, it is essential to understand how masked attention works. For that, let’s create some toy embeddings for 8 tokens with the embedding size of 32 (that is,$X \in \mathbb{R}^{8 \times 32}$)

1
2
3
4
# import all necessary libraries
import torch
from matplotlib import pyplot as plt
x = torch.randn(8,32,generator=torch.random.manual_seed(42))

Self-Attention

Let’s recall the process involved in self-attention block. It could be written simply by

$$z = Softmax(\frac{QK^T}{\sqrt{d_k}})V$$

The query (Q), key (K) and value (V) matrices are obtained by multiplying $W_Q,W_K$ and $W_V$ matrices with the embeddings $X$.

Let’s fix the size of $W_Q= W_K=W_V = 32 \times 16$. Therefore, the size of $Q,K,V$ is $8 \times 16$ (Notice that the number of rows (tokens) remains the same throughout the entire computation). I am excluding batch dimension for a quicker grasp of the concept. In case of using a mask, we would have to rewrite the above equation as follows

$$z = Softmax(\frac{QK^T+Mask}{\sqrt{d_k}})V$$

1
2
3
4
5
6
7
8
9
10
11
def self_attention(q,k,v,has_mask=False,mask=None):

  if not has_mask:
    scores = torch.matmul(q,k.T)
  else:
    scores = torch.matmul(q,k.T)+mask

  
  alpha = torch.softmax(scores,dim=1)
  out = torch.matmul(alpha,v)
  return alpha,out

Let’s obtain $Q,K,V$ and compute attention score alpha.

1
2
3
4
5
6
Wq = torch.randn(32,16,generator=torch.random.manual_seed(10))
Wk = torch.randn(32,16,generator=torch.random.manual_seed(11))
Wv = torch.randn(32,16,generator=torch.random.manual_seed(12))
q = torch.matmul(x,Wq) # notice the order of multiplication
k = torch.matmul(x,Wk)
v = torch.matmul(x,Wv)

Let’s pass it through self-attention function and display the attention matrix without applying a mask.

When the attention matrix $\alpha$ multiplies the value matrix $V$, then it takes the liner (convex) combination of all the projected values of tokens. Therefore, the output $Z$ will be of size $ 8 \times 16$. Each row in $z$ is the new representation obtained by the convex combination of the value matrix. Just think of what happens at the position where we have value $1$ in the attention matrix (It simply passes the row of $v$ corresponding to the column with the element $1$ in the attention matrix)

Causal (Masked) Attention

How do we build a causal attention (also called left-to-right attention) mask? This is particularly used in the decoder of the transformer and also in generative models. The idea is to let the model gives the attention to the present token not to all the (future) tokens in the input sequence. This can be done by creating a mask matrix with $-\infty$ in the upper triangular elements, for example,

$$ Mask = \begin{bmatrix} 0 & -\infty & -\infty \\ 0 & 0 & -\infty \\ 0 & 0 & 0 \end{bmatrix}$$

1
2
3
mask = (torch.triu(torch.ones(8,8)) == 1).transpose(0,1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
print(mask)

Adding the mask to $QK^T$ and passing it through a softmax function produces the attention map shown below

As one can notice, the zeroth row contains an element with a value $1$ at the (0,0) position and zeros elsewhere. When this gets multiplied with a value matrix of size $8 \times 16$, this effectively returns the first row of the value matrix (that is the projected representation of the first token of the input sequence). The same idea stands true for the rest of the rows. That is, it causally pays attention to the input sequence from left to right. This type of attention is particularly used in decoder-based models like GPT (Generative Pretrained Transformers)

Masked Language Modelling

Unlike in the case of causal attention, in MLM some of the tokens in the input sequence need to be randomly masked. The masking could be done in two ways. Either we can use the mask in softmax of self-attention or we can use special tokens. If we use the former approach, the attention value for masked tokens should be zero. This could once again be achieved by setting $-\infty$ at the respective locations of the mask. This will be looking something like

$$ Mask = \begin{bmatrix} 0 & -\infty & 0 \\ 0 & -\infty & 0 \\ 0 & -\infty & 0 \end{bmatrix}$$

It is straightforward to convert it in to code

1
2
3
4
mask_indices = [1,3]
mask = torch.zeros(8,8)
mask[:,mask_indices] = float('-inf')
print(mask)

We can notice that the elements in $1$-st and $3$-rd columns are all zeros. Of course, 4th columns are also zero, but that is by chance! This means that the projected representations of masked tokens will not be used while giving attention to input tokens. This is what exactly we wanted!. However, there seems to be a catch with this approach (think about it). Therefore, the simple alternative approach is to use a special [mask] token as shown in the figure below.

That is, we let the model know explicitly that a set of words are blank (__) and ask it to fill in! Note, however, that the embeddings for the masked tokens will not be fine tuned during pretraining.

When we create an instance of the transformer encoder layer in pytorch using nn.TransformerEncoderLayer. While passing the input embeddings to the encoder we could optionally pask the respective masks, namely, src_mask, tar_mask and src_padding_mask. All these masks are meant for MASKING ATTENTION of encoder inputs (zero matrix of size $T \times T$), decoder inputs in the masked attention layer (upper triangular with negative infinites as discussed above) and to mask the padded tokens (because attention to embeddings of padded tokens is not required), respectively.

Question to Ponder

The attention weights $\alpha$ is being computed as \(\alpha = Softmax(QK^T)\)

In the case of masked attention, we used additive mask as follows (just modified the name of mask to “UTMask”)

\[\alpha = Softmax(QK^T+UTMask)\]

where $UTMask$ is a strict upper triangular matrix with elements value $-\infty$ . This serves the purpose. However, we could achieve the same with a multiplicative mask with the output of softmax as follows,

\(\alpha = Softmax(QK^T)*LTMask\) where $LTMask$ is the Lower triangular matrix with elements value $1$ (including the diagonal).

I am just wondering why Masked Multi-head Attention layer in transformer architecture prefers additive mask over multiplicative (Boolean) mask. Is it due to the behaviour of how back-prop for work additive and multiplicative operators? Are there any other reasons?

I posted the same question in stackoverflow