Transformers


4. Multi-Head Attention

An attention function maps a query and a set of key-value pairs to an output, where the query, keys, and values, and output are all vectors. The attention layers implement mechanisms that allow the model to weigh the importance of different tokens in a sequence when making predictions.

In particular, self-attention enables the transformer to consider relationships between all tokens simultaneously. Multi-head attention enhances this by splitting the focus across multiple perspectives, improving the model’s ability to capture complex dependencies.

The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key. Attention is all you need, 2017 - Vaswani et al.

4.1 Scale Dot-Product Attention

The attention mechanism implemented in the transformer is referred to as scaled dot-product attention where the input consists of queries and keys of dimension \(d_k\), and values of dimension \(d_v\). We compute the dot products of the query with all keys, divide each by \(\sqrt{d_k}\), and apply a softmax function to obtain the weights on the values. This process is illustrated in the figure below.

In practice, attention is computed on a set of queries, packed together into a matrix \(Q\), simultaneously. The keys and values are also packed together into matrices \(K\) and \(V\). The output is calculated as: \[ \begin{equation} \mathbf{Attention}(Q,K,V) = \text{softmax} \left(\frac{QK^T}{\sqrt{d_k}} \right)V \end{equation} \] The softmax operation yields the attention matrix \(\mathbf{A}\) where each row sums to unity. The elements of this attention matrix tell us how much each token influences (or depends on) the others. The larger the number, the larger the influence. Multiplying this matrix by the values matrix \(V\) i.e., \(\mathbf{A} \cdot V\), is akin to combining information from all tokens to create a new, context-aware representation of each token. In other words, by incorporating attention, we are finding a new representation for each token.

4.1.1 Implementation of Scaled Dot-Product Attention

Before being fed into the attention mechanism, each query, key, and value matrix has the shape (batch_size, seq_len, d_model) where batch_size is the number of sequences fed simultaneously into the mechanism, seq_len is the number of tokens in each sequence, and d_model is the size of each token in the embedding space. Here are is some sample code:

q = torch.rand(10, 5, 512) # shape: 10 sequences, each is 5 tokens long,
k = torch.rand(10, 5, 512) # each token's embedding is of size 512
v = torch.rand(10, 5, 512)

In the transformer, this scaled dot-product attention is implemented using h different attention heads. This means that each token’s embedding will be split into \(h\) sections and each is fed into a different attention head. Each of these sections is of size \(\frac{d_{model}}{h} = d_k\). Each attention head will now receive an input of shape (batch_size, seq_len, d_k). Overall, the input into the attention mechanism will be of shape (batch_size, h, seq_len, d_k).

Let’s now track the attention computation by watching the shapes of the matrices through the process. We shall denote batch_size as \(B\) and seq_len as \(L\).

\[ \begin{align} &Q \in \mathbb{R}^{B \times h \times L \times d_k}\;; \;\;\; K \in \mathbb{R}^{B \times h \times L \times d_k}\;; \;\;\; V \in \mathbb{R}^{B \times h \times L \times d_v} \end{align} \] Attention is done, fundamentally, at the level of individual sequences because the goal is to find how each token in the sequence is related to every other token. To transpose \(K\), we simply swap the last two dimensions. We swap only these because they tell us the shape of sequences in embedding space as they are fed into single attention heads i.e., (seq_len, d_k). \[ K^T \in \mathbb{R}^{B \times h \times d_k \times L} \\[0.5 cm] \] The result of the dot product of \(Q\) and \(K^T\) (followed by division by \(\sqrt{d_k}\) and a softmax operation) is a new matrix, called the attention matrix, having dimensions \(B \times L \times h \times h\). \[ \mathbf{softmax} \left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \in \mathbb{R}^{B \times h \times L \times L} \] Focusing on the last two dimensions alone, we see that the attention matrix for each sequence has the shape (seq_len, seq_len). It’s not difficult to see why this shape is so natural (or to be expected) for the attention matrix. It’s a square matrix whose elements tell us how each token in the sequence is related to (or influenced by) every other token. We reiterate that, due to the use of the softmax function, this matrix will be such that for each sequence the sum of elements in each row will be 1. This means that when we finally multiply this attention matrix with the value matrix \(V\), we are simply finding a weighted average (or sum) of the corresponding elements in \(V\).

Here is some code to illustrate this:

d_k = q.size(-1)  

scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
scores = scores.softmax(dim=-1)

# Attention matrix of the first sequence
print(scores[0, 0])
>>> tensor([[0.1701, 0.2121, 0.1963, 0.2013, 0.2202], 
            [0.2008, 0.1792, 0.2257, 0.2123, 0.1821], 
            [0.1632, 0.1958, 0.2279, 0.2090, 0.2040], 
            [0.1668, 0.2157, 0.1949, 0.1908, 0.2318], 
            [0.1851, 0.1832, 0.2083, 0.2026, 0.2208]])

scores[0, 0].sum(dim=-1)
>>> tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

Finally, we can find the shape of the new token representations resulting from computing the attention. \[ \mathbf{Attention(Q,K,V)} \in \mathbb{R}^{B \times h \times L \times d_k} \] From this tracking, we can see that we began with query, key, and value matrices of shape (batch_size, h, seq_len, d_k) and ended up with a matrix of the same shape. In other words, by incorporating attention, we are finding a new representation for each token that takes into account how each token relates to every other.

Let’s now implement a function to compute a single head’s attention from queries, keys, and values.

def attention(query, key, value, mask=None, dropout=None):
    '''
    Computes "scaled dot-product attention"
    '''
    
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)

    return torch.matmul(p_attn, value), p_attn


q = torch.rand(10, 8, 5, 64)
k = torch.rand(10, 8, 5, 64)
v = torch.rand(10, 8, 5, 64)  

d_k = q.size(-1)

a, p = attention(q, k, v)
a.shape, p.shape
>>> (torch.Size([10, 8, 5, 64]), torch.Size([10, 8, 5, 5]))

4.2 Multi-Head Attention

When reading a sentence, the transformer network needs to process different kinds of relationships — like word meaning, syntax, or context—to arrive at a richer understanding. The best way to achieve this to make the attention mechanism explicitly process these different aspects by breaking it up into multiple attention heads, each one processing a different relationship. For example, if the sentence is “The cat sat on the mat,” one attention head might focus on who is sitting (the cat), another on where it is sitting (on the mat), and another on the action (sat).

Instead of performing a single attention function with \(d_{model}\)-dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values \(h\) times with different, learned linear projections to \(d_k\), \(d_k\) and \(d_v\) dimensions, respectively. On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding \(d_v\)-dimensional output values. These are concatenated and once again projected, resulting in the final values Attention is all you need, 2017 - Vaswani et al.

In single-head attention, everything gets squeezed into just one representation. To make it work, the attention mechanism needs to average out all the different features and dependencies in the data. This averaging smooths out details. Instead of capturing distinct patterns separately, the model tends to blend and blur them together, losing the ability to recognize nuanced relationships. Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. \[ \begin{align}     \mathbf{MultiHead}(Q, K, V) &= \mathbf{Concat}(head_1, \cdots, head_h)W^O \\     \mathbf{where} &= \mathbf{Attention}(QW^Q_i, KW^K_i, VW^V_i) \end{align} \] where the projections are parameter matrices \(W^Q_i \in \mathbb{R}^{d_{model} \times d_k}\), \(W^K_i \in \mathbb{R}^{d_{model} \times d_k}\), \(W^V_i \in \mathbb{R}^{d_{model} \times d_v}\), and \(W^O_i \in \mathbb{R}^{hd_v \times d_{model}}\).

In the transformer, we employ \(h=8\) parallel attention layers/heads. For each of these, we use \(d_k = d_v = d_{model}/h=64\). This reduced dimensionality of each head leads to the total computational cost being similar to that of single-head attention with full dimensionality.

   

In the following implementation of multi-head attention:

Let’s focus on the multi-head attention figure. We make the following observations:

4.2.1 Implementation of Multi-Head Attention

From an implementation perspective and based on above observations, we’ll need two functions: one to clone layers that are used multiple times (like the linear layers) and another to create a mask to be applied in masked multi-head attention. Let’s see how to implement multi-head attention step by step, beginning with the function to clone layers.

import copy

def clones(module, N):
    '''
    Produce N identical layers
    '''
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

And now, attention.

n, seq_len, d_model = 10, 5, 512
h = 8

q = torch.rand(n, seq_len, d_model)
k = torch.rand(n, seq_len, d_model)
v = torch.rand(n, seq_len, d_model)

# 3 linear layers for projecting Q, K, and V matrices
# we implement h  linear layers of shape (h, d_model, d_k) as a single  # linear layer of shape (d_model, d_model)
linears = clones(nn.Linear(d_model, d_model), 3)

# 1. Linear projections: Q, K, and V are each passed through the single # layer at once
q, k, v = [
    linear(x).view(n, seq_len, h, d_k).transpose(1, 2) for linear, x in 
    zip(linears, (q, k, v))
]

# 2. Compute attention over projected vectors
m_att, p = attention(q, k, v)

# 3. 'Concat' using a view operation
m_att = m_att.transpose(1,2).contiguous().view(n, seq_len, d_model)
>>> torch.Size([10, 5, 512])

# 4. Linear operation after concat
m_att_output = nn.Linear(d_model, d_model)(m_att)
m_att_output.shape
>>> torch.Size([10, 5, 512])

When training the Transformer, we need to make sure that the decoder predicts each word without peeking at future words in the sequence. To enforce this, we apply a mask to the self-attention layer in the decoder. This mask ensures that when the model predicts the word at position \(i\), it only has access to the words at positions before \(i\). We achieve this by blocking attention to future positions, meaning each token only considers past words. Another safeguard comes from the fact that the target embeddings are shifted right by one position — which ensures the model doesn’t see the exact word it’s supposed to predict.

We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. Attention is all you need, 2017 - Vaswani et al.

Without a mask, the decoder would attend to all tokens in the output sequence, including future words. This would be cheating because the model would simply copy the next word instead of genuinely predicting it. To avoid this, we modify self-attention in the decoder so that predictions at each step depend only on previous tokens. This is like taking a square matrix and turning off its upper triangle, leaving only the diagonal and lower triangle. Each token still attends to itself (the diagonal remains), but crucially, it cannot access future tokens. Since decoder inputs are shifted right, it never sees the actual target word it’s supposed to predict. For example, in the sentence “The cat sat down”, when the model predicts “sat”, it only sees "<sos> The cat", ensuring that predictions are based only on past words.

Let’s now see an implementation,

def subsequent_mask(size):
    '''
    Mask out subsequent positions.
    '''
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(
        torch.ones(attn_shape), diagonal=1).type(torch.uint8
    )

    return subsequent_mask == 0

subsequent_mask(5)
>>> tensor([[[ True, False, False, False, False], 
             [ True, True,  False, False, False], 
             [ True, True,  True,  False, False], 
             [ True, True,  True,  True,  False], 
             [ True, True,  True,  True,  True]]])

And a visualisation of masking. In the output mask, the row represents the actual tokens the model is trying to predict while the columns represents the predicted mask. Consider the token predicted for position \(12\) for instance. We see that when predicting this token, the decoder will only attend to tokens at positions \(0\) to \(12\).

RUN_EXAMPLES = True

def show_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        return fn(*args)

def example_mask():
    LS_data = pd.concat([
        pd.DataFrame(
            {'Subsequent Mask': subsequent_mask(20)[0][x,y].flatten(), 
            'Window': y,  'Masking': x}) 
            for y in range(20) for x in range(20)
        ])  

    return (
        alt.Chart(LS_data)
        .mark_rect()
        .properties(height=250, width=250)
        .encode(
            alt.X('Window:O'),
            alt.Y('Masking:O'),
            alt.Color(
                'Subsequent Mask:Q', 
                scale=alt.Scale(scheme='viridis')),
        )
        .interactive()
    )  

show_example(example_mask)
mask.png

4.3 Tying Everything Together

As usual, all the code we have written so far can be encapsulated in a single class for reusability.

class MultiheadedAttention(nn.Module):
    '''
    Implements multi-headed attention
    '''

    def __init__(self, h, d_model, dropout=0.1):
        '''
        Take in the model size and number of heads
        '''
        super(MultiheadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h # We assume d_v always equals d_k
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4) # Four 
        # linear layers for: Q, K, V, and final output of multi-head 
        # attention (after concat)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)  

    def forward(self, query:torch.Tensor, key:torch.Tensor, 
    value:torch.Tensor, mask:torch.Tensor=None):  

        if mask is not None:
            mask = mask.unsqueeze(1) # same mask applied to all h heads
        n = query.size(0) 

        # 1) Do all the linear projections in batch from d_model
        query, key, value = [
            linear(x).view(n, -1, self.h, self.d_k).transpose(1, 2)
            for linear, x in zip(self.linears, (query, key, value))
        ] 

        #2) Apply attention on all the projected vectors in batch 
        # (yields the newly projected token representation X)
        x, self.attn = attention(
            query, key, value, mask=mask, dropout=self.dropout
        )  

        # 3) "Concat" using a view - to reshape the tensor to (n, 
        # seq_len, d_model) - and then apply a final linear
        x = (x.transpose(1, 2).contiguous().view(n, -1, self.h * 
        self.d_k))
        
        del query
        del key
        del value
        
        # 4) The top-most linear layer after concat
        return self.linears[-1](x)

Let’s test our class implementation.

q = torch.rand(n, seq_len, d_model)
k = torch.rand(n, seq_len, d_model)
v = torch.rand(n, seq_len, d_model)  

m_att = MultiheadedAttention(8, 512)
m_att_output = m_att(q, k, v)
m_att_output.shape
>>> torch.Size([10, 5, 512])

As we can see, our implementation works as expected.

4.4 Conclusion

In this tutorial, we explored scaled dot-product attention and multi-head attention, and understood how the transformer forms new representations of tokens based on how individual tokens in input sequences relate to each other. The most important part of the transformer architecture is now behind us. Next, we explore the other sub-layer of the encoder and decoder - position-wise feed-forward network!

Acknowledgements

This series of tutorials has benefited a lot from the Harvard NLP’s codebase for The Annotated Transformer