The transformer uses learned embeddings to convert input tokens and target tokens to vectors of size \(d_{model}\). Learned linear transformations and the softmax function are also used to convert the decoder output to predicted next-token probabilities. For the model implemented in the paper, the same weight matrix is shared between the two embedding layers and the pre-softmax linear transformation. For the embedding layers, the weights are multiplied by \(\sqrt{d_{model}}\).
Similarly to other sequence transduction models, we use learned embeddings to convert the input tokens and output tokens to vectors of dimension. Attention is all you need, Vaswani et al., 2017
We are interested in the section of the transformer architecture shown on the figure below.
Suppose we are building a machine translation model to translate
English into French. Each language has a vocabulary size, which is the
total number of tokens (full words or parts of words) the model
can work with. Let’s call this vocab
. To process these
tokens, the model represents each one as a meaningful vector of a fixed
dimensionality, referred to as d_model
. These vector
representations, called embeddings, capture the semantic
meaning of the tokens, where similar ones are geometrically closer in
the embedding space.
During training, the model learns to represent each token in the
vocabulary as a vector of size d_model
. For example, if the
vocabulary size is \(5000\), tokens can
be indexed from \(0\) to \(4999\). Obviously, individual sentences
being translated contain only a small subset of the tokens in the
vocabulary. Each token is processed based on its index.
Now that we understand the basics, we can implement the transformer’s
embedding layers. To start, we use PyTorch
’s Embedding
module to generate preliminary embeddings for tokens. For illustration,
let’s assume our language has a vocabulary size of \(100\) tokens, each represented by an
embedding of size \(4\). The embeddings
can be created as follows:
import torch
import torch.nn as nn
= 4
d_model =100
vocab
= nn.Embedding(vocab, d_model) lut
We’ve just created a lookup table lut
that stores the
embeddings of all the \(100\) words
where the embedding of each word can be accessed using an integer index.
It is important to point out that these embeddings are simply
preliminary and that the final embeddings in this part of the
transformer are arrived at only after training is completed. Before
proceeding, let’s see what some of these embeddings look like by
accessing the weight
attribute of the lookup table.
This creates a lookup table lut
that stores the
embeddings for all \(100\) tokens,
where each token’s embedding can be accessed using its integer index.
These embeddings are preliminary - final embeddings are learned during
training. Before moving forward, we can inspect the embeddings by
accessing the weight
attribute of the lookup table.
= lut.weight
W
print(W.shape)
>>> torch.Size([100, 4])
print(W[:2, :])
# Example output
>>> tensor([[ 1.3037, -0.3994, -1.6429, 1.6953],
>>> [-0.1786, 0.4978, -1.4669, -1.3677]])
We have \(100\) vectors, each of
size \(4\), resulting in the shape
\((100, 4)\). Here, we use
PyTorch
tensors instead of NumPy arrays because they offer
benefits like automatic differentiation, which is crucial for training
the model through backpropagation.
When printing the embeddings, we use integer indices to access
specific tensors. Conveniently, PyTorch’s Embedding module allows us to
pass a tensor of token indices directly as input. Each element
in this tensor is an integer ranging from \(0\) to \(vocab-1\) (in this case, \(0\) to \(99\)). The tensor’s shape is
(batch_size, seq_len)
, where batch_size
is the
number of sequences to vectorize and seq_len
is the number
of tokens per sequence. Consequently, the output tensor has three
dimensions: the number of sequences, the number of tokens in each
sequence, and the dimensionality of each token’s representation.
To obtain the preliminary embeddings, we pass a tensor of token
indices, \(x\), into the lookup table.
Each row of \(x\) represents an input
sequence, and the elements correspond to token indices in the
vocabulary. The Embedding module uses these indices as keys to
retrieve the appropriate embeddings. Algorithmically, if the output
tensor is out
and W
is the Embedding module’s
weight
matrix, the mapping is as follows:
\[ out[i,j] = W[x[i, j]] \]
The resulting tensor will have the shape
(batch_size, seq_len, d_model)
. This means there are
“batch_size
sequences, each containing seq_len
tokens, with each token represented as a vector of dimensionality
d_model
”. Before proceeding with the implementation, recall
this …
In the embedding layers, we multiply those weights by \(\sqrt{d_{model}}\) Attention is all you need, 2017 - Vaswani et al.
Those embeddings will be multiplied by \(\sqrt{d_{model}}\). Let’s now implement this in code:
import math
= torch.tensor([
x 23, 37, 3, 45, 82], # sequence 1 --> 5 tokens long
[97, 61, 19, 73, 53] # sequence 2 --> 5 tokens long
[
])
= lut(x) * math.sqrt(d_model)
out
print(out.shape)
>>> torch.Size([2, 5, 4])
# Confirm mapping from indices to embeddings (out[i,j] == W[x[i, j]])
print(lut(x)[1, 3] == W[x[1, 3]])
>>> tensor([True, True, True, True])
# Mapping is consistent throught the tensor
all(lut(x) == W[x])
torch.>>> tensor(True)
In a real-world implementation, the tensor \(x\) would come from a text tokenizer (e.g., byte-pair encoding or word-piece tokenization) that takes input sentences and then converts them to integer indices based on the input vocabulary. More on this later when we are done with the architecture.
Now it’s time to tie everything together in a class.
We can encapsulate the code discussed earlier into a class that represents the transformer’s Embedding layer. This makes the implementation more modular and reusable:
class Embeddings(nn.Module):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab, d_model) # creates a lookup table
self.d_model = d_model
def forward(self, x):
return self.lut(x) * math.sqrt(self.d_model)
# Example Usage
= Embeddings(4, 100)
e = torch.randint(0, 100, (2, 5))
x
e(x).shape>>> torch.Size([2, 5, 4])
In this tutorial, we explored embedding layers, understanding how tokens are vectorized and scaled. Next, we delve into positional encodings!
This series of tutorials has benefited a lot from the Harvard NLP’s codebase for The Annotated Transformer