Transformers


5. Position-wise Feed-Forward Networks

5.1 Introduction

Each of the layers in our encoder and decoder also contains a fully-connected feedforward network which is applied to each position separately and identically. It comprises two linear transformations with a ReLU activation between them. \[FFN(x) = \mathbf{max} (0, xW_1 + b_1)W_2 + b_2\] The linear transformations are the same across different positions. That is, the same weights and biases are used for every token in an input sequence. However, as you move to the next layer, the parameters change and so do the transformations of the tokens.

The input to the network is the output of the multi-head attention which has the dimensionality \(d_{model}=512\). The output (second layer) of the FFN also has this dimensionality. The first (inner) layer of the FFN has dimensionality \(d_{ff}=2048\), effectively expanding the token’s representation to a vector of length 2048, which helps to capture more complex patterns.

The feedforward network is used once in each of the encoder and the decoder stacks as shown in the figure below.

ffn.png

5.2 Implementation

Implementing a position-wise FFN is a very straightforward exercise. It is effectively a multi-layer perceptron with two layers apart from the input.

class PositionwiseFeedForward(nn.Module):
    """
    Implements FFN equation.
    """  

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)  

    def forward(self, x):
        return self.w_2(self.dropout(self.w_1(x).relu()))

p = PositionwiseFeedForward(512, 2048)
x = torch.randn(10, 5, 512)
p(x).shape
>>> torch.Size([10, 5, 512])

# Check the shape of inner layer's output
p.w_1(x).shape
>>> torch.Size([10, 5, 2048])

5.3 Conclusion

This tutorial has been brief but just as important as those before. We have seen how the position-wise FFN receives the multi-head attention’s output, increases its dimensionality before reducing it again. Now we are left with only two building blocks of the encoder and decoder stacks - layer normalisation and residual connection. These two are what we tackle next!

Acknowledgements

This series of tutorials has benefited a lot from the Harvard NLP’s codebase for The Annotated Transformer