An attention function maps a query and a set of key-value pairs to an output, where the query, keys, and values, and output are all vectors. The attention layers implement mechanisms that allow the model to weigh the importance of different tokens in a sequence when making predictions.
In particular, self-attention enables the transformer to consider relationships between all tokens simultaneously. Multi-head attention enhances this by splitting the focus across multiple perspectives, improving the model’s ability to capture complex dependencies.
The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key. Attention is all you need, 2017 - Vaswani et al.
The attention mechanism implemented in the transformer is referred to as scaled dot-product attention where the input consists of queries and keys of dimension \(d_k\), and values of dimension \(d_v\). We compute the dot products of the query with all keys, divide each by \(\sqrt{d_k}\), and apply a softmax function to obtain the weights on the values. This process is illustrated in the figure below.
In practice, attention is computed on a set of queries, packed together into a matrix \(Q\), simultaneously. The keys and values are also packed together into matrices \(K\) and \(V\). The output is calculated as: \[ \begin{equation} \mathbf{Attention}(Q,K,V) = \text{softmax} \left(\frac{QK^T}{\sqrt{d_k}} \right)V \end{equation} \] The softmax operation yields the attention matrix \(\mathbf{A}\) where each row sums to unity. The elements of this attention matrix tell us how much each token influences (or depends on) the others. The larger the number, the larger the influence. Multiplying this matrix by the values matrix \(V\) i.e., \(\mathbf{A} \cdot V\), is akin to combining information from all tokens to create a new, context-aware representation of each token. In other words, by incorporating attention, we are finding a new representation for each token.
Before being fed into the attention mechanism, each query, key, and
value matrix has the shape (batch_size, seq_len, d_model)
where batch_size
is the number of sequences fed
simultaneously into the mechanism, seq_len
is the number of
tokens in each sequence, and d_model
is the size of each
token in the embedding space. Here are is some sample code:
= torch.rand(10, 5, 512) # shape: 10 sequences, each is 5 tokens long,
q = torch.rand(10, 5, 512) # each token's embedding is of size 512
k = torch.rand(10, 5, 512) v
In the transformer, this scaled dot-product attention is implemented
using h
different attention heads. This means that each
token’s embedding will be split into \(h\) sections and each is fed into a
different attention head. Each of these sections is of size \(\frac{d_{model}}{h} = d_k\). Each attention
head will now receive an input of shape
(batch_size, seq_len, d_k)
. Overall, the input into the
attention mechanism will be of shape
(batch_size, h, seq_len, d_k)
.
Let’s now track the attention computation by watching the shapes of
the matrices through the process. We shall denote
batch_size
as \(B\) and
seq_len
as \(L\).
\[
\begin{align}
&Q \in \mathbb{R}^{B \times h \times L \times d_k}\;; \;\;\; K \in
\mathbb{R}^{B \times h \times L \times d_k}\;; \;\;\; V \in
\mathbb{R}^{B \times h \times L \times d_v}
\end{align}
\] Attention is done, fundamentally, at the level of individual
sequences because the goal is to find how each token in the sequence is
related to every other token. To transpose \(K\), we simply swap the last two
dimensions. We swap only these because they tell us the shape of
sequences in embedding space as they are fed into single attention heads
i.e., (seq_len, d_k)
. \[ K^T \in
\mathbb{R}^{B \times h \times d_k \times L} \\[0.5 cm] \] The
result of the dot product of \(Q\) and
\(K^T\) (followed by division by \(\sqrt{d_k}\) and a softmax operation) is a
new matrix, called the attention matrix, having dimensions
\(B \times L \times h \times h\). \[
\mathbf{softmax} \left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \in
\mathbb{R}^{B \times h \times L \times L}
\] Focusing on the last two dimensions alone, we see that the
attention matrix for each sequence has the shape
(seq_len, seq_len)
. It’s not difficult to see why this
shape is so natural (or to be expected) for the attention
matrix. It’s a square matrix whose elements tell us how each token
in the sequence is related to (or influenced by) every other token. We
reiterate that, due to the use of the softmax function, this
matrix will be such that for each sequence the sum of elements in each
row will be 1. This means that when we finally multiply this attention
matrix with the value matrix \(V\), we
are simply finding a weighted average (or sum) of the corresponding
elements in \(V\).
Here is some code to illustrate this:
= q.size(-1)
d_k
= torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
scores = scores.softmax(dim=-1)
scores
# Attention matrix of the first sequence
print(scores[0, 0])
>>> tensor([[0.1701, 0.2121, 0.1963, 0.2013, 0.2202],
0.2008, 0.1792, 0.2257, 0.2123, 0.1821],
[0.1632, 0.1958, 0.2279, 0.2090, 0.2040],
[0.1668, 0.2157, 0.1949, 0.1908, 0.2318],
[0.1851, 0.1832, 0.2083, 0.2026, 0.2208]])
[
0, 0].sum(dim=-1)
scores[>>> tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
Finally, we can find the shape of the new token representations
resulting from computing the attention. \[
\mathbf{Attention(Q,K,V)} \in \mathbb{R}^{B \times h \times L \times
d_k}
\] From this tracking, we can see that we began with query, key,
and value matrices of shape (batch_size, h, seq_len, d_k)
and ended up with a matrix of the same shape. In other words, by
incorporating attention, we are finding a new representation
for each token that takes into account how each token relates to every
other.
Let’s now implement a function to compute a single head’s attention from queries, keys, and values.
def attention(query, key, value, mask=None, dropout=None):
'''
Computes "scaled dot-product attention"
'''
= query.size(-1)
d_k = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
scores if mask is not None:
= scores.masked_fill(mask == 0, -1e9)
scores
= scores.softmax(dim=-1)
p_attn if dropout is not None:
= dropout(p_attn)
p_attn
return torch.matmul(p_attn, value), p_attn
= torch.rand(10, 8, 5, 64)
q = torch.rand(10, 8, 5, 64)
k = torch.rand(10, 8, 5, 64)
v
= q.size(-1)
d_k
= attention(q, k, v)
a, p
a.shape, p.shape>>> (torch.Size([10, 8, 5, 64]), torch.Size([10, 8, 5, 5]))
When reading a sentence, the transformer network needs to process different kinds of relationships — like word meaning, syntax, or context—to arrive at a richer understanding. The best way to achieve this to make the attention mechanism explicitly process these different aspects by breaking it up into multiple attention heads, each one processing a different relationship. For example, if the sentence is “The cat sat on the mat,” one attention head might focus on who is sitting (the cat), another on where it is sitting (on the mat), and another on the action (sat).
Instead of performing a single attention function with \(d_{model}\)-dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values \(h\) times with different, learned linear projections to \(d_k\), \(d_k\) and \(d_v\) dimensions, respectively. On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding \(d_v\)-dimensional output values. These are concatenated and once again projected, resulting in the final values Attention is all you need, 2017 - Vaswani et al.
In single-head attention, everything gets squeezed into just one representation. To make it work, the attention mechanism needs to average out all the different features and dependencies in the data. This averaging smooths out details. Instead of capturing distinct patterns separately, the model tends to blend and blur them together, losing the ability to recognize nuanced relationships. Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. \[ \begin{align} \mathbf{MultiHead}(Q, K, V) &= \mathbf{Concat}(head_1, \cdots, head_h)W^O \\ \mathbf{where} &= \mathbf{Attention}(QW^Q_i, KW^K_i, VW^V_i) \end{align} \] where the projections are parameter matrices \(W^Q_i \in \mathbb{R}^{d_{model} \times d_k}\), \(W^K_i \in \mathbb{R}^{d_{model} \times d_k}\), \(W^V_i \in \mathbb{R}^{d_{model} \times d_v}\), and \(W^O_i \in \mathbb{R}^{hd_v \times d_{model}}\).
In the transformer, we employ \(h=8\) parallel attention layers/heads. For each of these, we use \(d_k = d_v = d_{model}/h=64\). This reduced dimensionality of each head leads to the total computational cost being similar to that of single-head attention with full dimensionality.
In the following implementation of multi-head attention:
Each of the \(h\) learned \(\mathbf{W_i^{Q/K/V}}\) matrices in the paper are implemented as a single linear layer of dimensions \(d_{model} \times d_{model}\) rather than \(h\) different matrices of shape \(d_{model} \times d_k\). The single matrix mirrors the concatenation of \(h\) different matrices. That is, \[ \begin{align} \mathbf{Concat}(W^{Q/K/V}_0, \cdots, W^{Q/K/V}_h) &\in \mathbb{R}^{d_{model}\times h \cdot d_k} \\[0.5 cm] &\in \mathbb{R}^{d_{model}\times d_{model}} \end{align} \]
Each linear layer implements the transformation \(x\cdot A^T + b\), which differs from just multiplying the query by the matrix \(W^Q_i\) because there is a bias vector added. However, the presence of the bias terms makes this implementation more robust during the training process.
Each query, key, and value matrix is of dimensions \((n \times \text{seq\_len} \times
d_{model})\) where \(n\) is the
number of sequences in the query (i.e., batch size),
seq_len
is the number of tokens in each sequence, and \(d_{model}\) is the token’s size in
embedding space.
The queries (and keys and values) are transformed at once using the \(d_{model} \times d_{model}\) linear layer before being reshaped into tensors of dimensions (\(n, h, \text{seq\_len}, d_k\)), a process which mimics having the queries be transformed by \(h\) different matrices. A query of shape \((n \times \text{seq\_len} \times d_{model})\) is transformed by the linear layer to a tensor of shape \((n \times \text{seq\_len} \times d_{model})\) i.e. it retains the same shape, but has been projected to a new subspace.
Let’s focus on the multi-head attention figure. We make the following observations:
For a single attention head, each of the \(Q\), \(K\), and \(V\) matrices is passed through \(1\) linear layer (to make a total of \(3\) linear layers). There is also an additional linear layer (represented by the learned matrix \(W^O\)) through which the concatenated outputs of all the attention heads is passed. That makes a total of \(4\) linear layers. As already mentioned, each of these layers will have dimensions \((d_{model}, d_{model})\).
There are \(h\) attention heads with identical implementation i.e., the implementation structure is the same although the weights associated with each linear layer will be different.
Although not captured in the above figure, a mask is applied in the decoder self-attention to prevent the it from “looking” at future tokens when predicting the next token. This is what’s referred to as masked multi-head attention in the paper and it prevents the decoder from “cheating”.
From an implementation perspective and based on above observations, we’ll need two functions: one to clone layers that are used multiple times (like the linear layers) and another to create a mask to be applied in masked multi-head attention. Let’s see how to implement multi-head attention step by step, beginning with the function to clone layers.
import copy
def clones(module, N):
'''
Produce N identical layers
'''
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
And now, attention.
= 10, 5, 512
n, seq_len, d_model = 8
h
= torch.rand(n, seq_len, d_model)
q = torch.rand(n, seq_len, d_model)
k = torch.rand(n, seq_len, d_model)
v
# 3 linear layers for projecting Q, K, and V matrices
# we implement h linear layers of shape (h, d_model, d_k) as a single # linear layer of shape (d_model, d_model)
= clones(nn.Linear(d_model, d_model), 3)
linears
# 1. Linear projections: Q, K, and V are each passed through the single # layer at once
= [
q, k, v 1, 2) for linear, x in
linear(x).view(n, seq_len, h, d_k).transpose(zip(linears, (q, k, v))
]
# 2. Compute attention over projected vectors
= attention(q, k, v)
m_att, p
# 3. 'Concat' using a view operation
= m_att.transpose(1,2).contiguous().view(n, seq_len, d_model)
m_att >>> torch.Size([10, 5, 512])
# 4. Linear operation after concat
= nn.Linear(d_model, d_model)(m_att)
m_att_output
m_att_output.shape>>> torch.Size([10, 5, 512])
When training the Transformer, we need to make sure that the decoder predicts each word without peeking at future words in the sequence. To enforce this, we apply a mask to the self-attention layer in the decoder. This mask ensures that when the model predicts the word at position \(i\), it only has access to the words at positions before \(i\). We achieve this by blocking attention to future positions, meaning each token only considers past words. Another safeguard comes from the fact that the target embeddings are shifted right by one position — which ensures the model doesn’t see the exact word it’s supposed to predict.
We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. Attention is all you need, 2017 - Vaswani et al.
Without a mask, the decoder would attend to all tokens in
the output sequence, including future words. This would be
cheating because the model would simply copy the next word
instead of genuinely predicting it. To avoid this, we modify
self-attention in the decoder so that predictions at each step depend
only on previous tokens. This is like taking a square
matrix and turning off its upper triangle, leaving only
the diagonal and lower triangle. Each token still attends to itself (the
diagonal remains), but crucially, it cannot access future
tokens. Since decoder inputs are shifted right, it never sees
the actual target word it’s supposed to predict. For example, in the
sentence “The cat sat down”, when the model predicts
“sat”, it only sees "<sos> The cat"
,
ensuring that predictions are based only on past words.
Let’s now see an implementation,
def subsequent_mask(size):
'''
Mask out subsequent positions.
'''
= (1, size, size)
attn_shape = torch.triu(
subsequent_mask =1).type(torch.uint8
torch.ones(attn_shape), diagonal
)
return subsequent_mask == 0
5)
subsequent_mask(>>> tensor([[[ True, False, False, False, False],
True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]]]) [
And a visualisation of masking. In the output mask, the row represents the actual tokens the model is trying to predict while the columns represents the predicted mask. Consider the token predicted for position \(12\) for instance. We see that when predicting this token, the decoder will only attend to tokens at positions \(0\) to \(12\).
= True
RUN_EXAMPLES
def show_example(fn, args=[]):
if __name__ == "__main__" and RUN_EXAMPLES:
return fn(*args)
def example_mask():
= pd.concat([
LS_data
pd.DataFrame('Subsequent Mask': subsequent_mask(20)[0][x,y].flatten(),
{'Window': y, 'Masking': x})
for y in range(20) for x in range(20)
])
return (
alt.Chart(LS_data)
.mark_rect()=250, width=250)
.properties(height
.encode('Window:O'),
alt.X('Masking:O'),
alt.Y(
alt.Color('Subsequent Mask:Q',
=alt.Scale(scheme='viridis')),
scale
)
.interactive()
)
show_example(example_mask)
As usual, all the code we have written so far can be encapsulated in a single class for reusability.
class MultiheadedAttention(nn.Module):
'''
Implements multi-headed attention
'''
def __init__(self, h, d_model, dropout=0.1):
'''
Take in the model size and number of heads
'''
super(MultiheadedAttention, self).__init__()
assert d_model % h == 0
self.d_k = d_model // h # We assume d_v always equals d_k
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4) # Four
# linear layers for: Q, K, V, and final output of multi-head
# attention (after concat)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward(self, query:torch.Tensor, key:torch.Tensor,
=None):
value:torch.Tensor, mask:torch.Tensor
if mask is not None:
= mask.unsqueeze(1) # same mask applied to all h heads
mask = query.size(0)
n
# 1) Do all the linear projections in batch from d_model
= [
query, key, value -1, self.h, self.d_k).transpose(1, 2)
linear(x).view(n, for linear, x in zip(self.linears, (query, key, value))
]
#2) Apply attention on all the projected vectors in batch
# (yields the newly projected token representation X)
self.attn = attention(
x, =mask, dropout=self.dropout
query, key, value, mask
)
# 3) "Concat" using a view - to reshape the tensor to (n,
# seq_len, d_model) - and then apply a final linear
= (x.transpose(1, 2).contiguous().view(n, -1, self.h *
x self.d_k))
del query
del key
del value
# 4) The top-most linear layer after concat
return self.linears[-1](x)
Let’s test our class implementation.
= torch.rand(n, seq_len, d_model)
q = torch.rand(n, seq_len, d_model)
k = torch.rand(n, seq_len, d_model)
v
= MultiheadedAttention(8, 512)
m_att = m_att(q, k, v)
m_att_output
m_att_output.shape>>> torch.Size([10, 5, 512])
As we can see, our implementation works as expected.
In this tutorial, we explored scaled dot-product attention and multi-head attention, and understood how the transformer forms new representations of tokens based on how individual tokens in input sequences relate to each other. The most important part of the transformer architecture is now behind us. Next, we explore the other sub-layer of the encoder and decoder - position-wise feed-forward network!
This series of tutorials has benefited a lot from the Harvard NLP’s codebase for The Annotated Transformer