There is a residual connection around each of the two sub-layers in the encoder and decoder stacks, followed by layer normalisation.
We employ a residual connection around each of the two sub-layers, followed by layer normalization. That is, the output of each sub-layer is \(\mathbf{LayerNorm}(x + \mathbf{Sublayer}(x))\), where \(\mathbf{Sublayer}(x)\) is the function implemented by the sub-layer itself. Attention is all you need, 2017 - Vaswani et al.
These two functions are implemented in five different sections within the encoder and decoder stacks. See the figure below.
Given some input vector \(x=[x_1, x_2, \cdots, x_H]\), layer normalisation is implemented following these steps:
\[ \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \approx \frac{x_i - \mu}{\sigma + \epsilon} \]
Before proceeding to implementation, let’s get a primer on why normalisation of layer inputs is important.
Training deep learning models tends to be slow and unstable when the inputs to each layer are in different scales (some very large and others very small) and are uncentered (some very positive and others very negative). This problem is often addressed by batch normalisation commonly applied in CNNs, where each batch of inputs fed into the network is normalised (centered and scaled) using the its mean and variance. This steps accelerates the convergence of deep networks.
Batch normalisation has the following problems: - For small batch sizes, the mean and variance become unrepresentative of the entire dataset and thus unreliable - It’s difficult to parallelise a network where batch normalisation is applied - When working with sequence data where the input sizes vary, batch normalisation becomes very complicated
The solution to these problems is layer normalisation where, rather than normalising across batches, we do it across features focusing on just a single datapoint. To illustrate this difference, imagine that we had a batch of 4 datapoints being fed into the network with each having 3 features. We can represent this batch using the tensor below: \[ \mathbf{Batch} = \begin{bmatrix} 1 & 2 & 3 \\ 2 & 3 & 4 \\ 3 & 4 & 5 \\ 4 & 5 & 6 \end{bmatrix} \] Batch normalisation would compute the mean and variance along the columns (i.e., across the batch) whereas layer normalisation would do so across the rows (i.e., across the features). One can imagine that, for layer normalisation, each datapoint’s feature gets processed by a single neuron in the network, which means that the normalisation across the features is equivalent to normalising the output of the corresponding layer before passing on to the next layer., hence the name.
The pros of layer normalisation are: - Works well with small batch
sizes, even a batch size of 1.
- Great for recurrent and transformer models (like GPT or BERT).
- Doesn’t depend on other examples – just normalizes each input on its
own.
One thing to keep in mind when implementing layer normalisation is
that the scale and shift parameters are learned by the
model during training. We use the nn.Parameter
class to
tell PyTorch
that the tensor wrapped in it is a trainable
parameter to be updated during optimisation. Therefore, when we call
.parameters()
on the model, these tensors will be
included.
Now let’s implement a class for layer normalisation.
class LayerNorm(nn.Module):
'''
Construct a layer normalization module.
The forward pass implements the approach described in the original
layer norm paper: https://arxiv.org/abs/1911.07013
y = gamma * (x - mean) / (\sqrt{std^2 + eps}) + beta -----> where
gamma & beta are learnable parameters
'''
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x:torch.Tensor):
= x.mean(-1, keepdim=True)
mean = x.std(-1, keepdim=True)
std return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
= torch.randn(2, 3, 4)
x = LayerNorm(4)
norm
norm(x).shape>>> torch.Size([2, 3, 4])
Residual connections in a Transformer help preserve important information across layers by allowing the model to pass its inputs directly to deeper layers, bypassing transformations that might distort the original signal. This mechanism prevents the vanishing gradient problem, ensuring stable learning during training. Specifically, residual connections work by adding the original input of a layer to its output before applying normalisation, which maintains the integrity of information while enabling effective backpropagation. This approach significantly boosts the ability of Transformers to retain essential features while refining representations across multiple layers.
To facilitate residual connections, all sub-layers in the model, including embedding layers, produce outputs of dimension \(d_{model}=512\). These connections add the input of a layer to its output before applying layer normalization, ensuring stable gradients and preserving original information. Each sub-layer, such as self-attention or feed-forward networks, is wrapped with a residual connection. The sum of the input and transformed output is passed to layer normalization before continuing.
Let’s implement a class to implement residual connection followed by layer normalisation.
class SublayerConnection(nn.Module):
"""
Implements a residual connection followed by layer normalisation
"""
def __init__(self, size, dropout):
super(SublayerConnection, self).__init__()
self.norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
"""
Apply layer normalisation to X, pass it through some sublayer of
the same size, drop some connections (dropout), \\
and then apply a residual connection
"""
return x + self.dropout(sublayer(self.norm(x)))
= SublayerConnection(4, 0.1)
conn = nn.Linear(4, 4)
sublayer
conn(x, sublayer).shape>>> torch.Size([2, 3, 4])
In this tutorial, we explored the role of residual connections and layer normalisation in Transformers. We saw how residual connections help preserve information across layers, while layer normalisation stabilises training by standardising activations. Together, these techniques enhance deep learning efficiency and performance. Now we have all the building blocks of the encoder and decoder stacks. Next, we shall build out these stacks!
This series of tutorials has benefited a lot from the Harvard NLP’s codebase for The Annotated Transformer