1. Why Transformers?
Before Transformers, sequence models (RNNs, LSTMs) processed tokens one at a time. This made training slow and made it hard to relate tokens that are far apart in a sequence. Transformers solve both problems by processing all tokens simultaneously through attention.
RNNs & LSTMs
Before Transformers
- Process tokens sequentially — slow to train
- Hidden state is a bottleneck for long sequences
- Gradient vanishing limits long-range memory
- Hard to parallelise across GPUs/TPUs
Transformers
Attention-based
- Process all tokens in parallel — fast training
- Attention directly connects any two positions
- Scale efficiently to billions of parameters
- Generalisable to vision, audio, and multimodal tasks
2. Key Concepts
Self-Attention
Each token attends to every other token in the sequence to compute a context-aware representation. Given queries Q, keys K, and values V:
Attention(Q, K, V) = softmax(QKᵀ / √dₖ) · V
Multi-Head Attention
Runs h attention heads in parallel, each learning different relationships (syntax, coreference, proximity). Their outputs are concatenated and projected.
- Each head sees a different linear projection of Q, K, V
- Captures diverse patterns simultaneously
Positional Encoding
Since attention has no built-in notion of order, positional information is injected by adding sinusoidal (or learned) encodings to the token embeddings before the first layer.
- Sinusoidal: fixed, generalises to longer sequences
- Learned: trained as embedding parameters
Feed-Forward Sub-layer
After each attention sub-layer, a two-layer fully connected network is applied to each position independently, expanding and compressing the representation.
FFN(x) = ReLU(xW₁ + b₁)W₂ + b₂
Residual Connections & Layer Norm
Each sub-layer (attention and FFN) wraps its output in a residual connection followed by layer normalisation, which stabilises training and enables very deep stacks.
output = LayerNorm(x + SubLayer(x))
Encoder & Decoder
The original Transformer has two stacks. Encoder-only models (BERT) are good for understanding; decoder-only models (GPT) for generation; encoder-decoder (T5, Whisper) for translation and summarisation.
- Encoder: bidirectional self-attention
- Decoder: masked (causal) self-attention + cross-attention
3. Architecture Overview
One Transformer encoder layer stacks these components in order:
| Step | Component | Purpose |
|---|---|---|
| 1 | Token Embedding + Positional Encoding | Convert token IDs to dense vectors with position information |
| 2 | Multi-Head Self-Attention | Let every token attend to every other token |
| 3 | Add & Norm (residual) | Stabilise gradients and preserve input signal |
| 4 | Position-wise Feed-Forward Network | Transform each position independently |
| 5 | Add & Norm (residual) | Stabilise gradients |
| 6 | Repeat N times | Stack N encoder layers for richer representations |
4. TensorFlow (Keras) Example
A Transformer encoder for text classification built from scratch with custom Keras layers.
Multi-Head Self-Attention Layer
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class MultiHeadSelfAttention(layers.Layer):
def __init__(self, embed_dim, num_heads):
super().__init__()
assert embed_dim % num_heads == 0
self.num_heads = num_heads
self.depth = embed_dim // num_heads
self.embed_dim = embed_dim
self.wq = layers.Dense(embed_dim)
self.wk = layers.Dense(embed_dim)
self.wv = layers.Dense(embed_dim)
self.out = layers.Dense(embed_dim)
def split_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, x):
batch_size = tf.shape(x)[0]
q = self.split_heads(self.wq(x), batch_size)
k = self.split_heads(self.wk(x), batch_size)
v = self.split_heads(self.wv(x), batch_size)
# Scaled dot-product attention
scores = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(
tf.cast(self.depth, tf.float32)
)
weights = tf.nn.softmax(scores, axis=-1)
context = tf.matmul(weights, v)
context = tf.transpose(context, perm=[0, 2, 1, 3])
context = tf.reshape(context, (batch_size, -1, self.embed_dim))
return self.out(context)
Transformer Encoder Block
class TransformerBlock(layers.Layer):
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.attention = MultiHeadSelfAttention(embed_dim, num_heads)
self.ffn = keras.Sequential([
layers.Dense(ff_dim, activation='relu'),
layers.Dense(embed_dim),
])
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
self.drop1 = layers.Dropout(dropout)
self.drop2 = layers.Dropout(dropout)
def call(self, x, training=False):
attn = self.drop1(self.attention(x), training=training)
x = self.norm1(x + attn) # residual + norm
ffn = self.drop2(self.ffn(x), training=training)
return self.norm2(x + ffn) # residual + norm
Token + Positional Embedding
class TokenAndPositionEmbedding(layers.Layer):
def __init__(self, maxlen, vocab_size, embed_dim):
super().__init__()
self.token_emb = layers.Embedding(vocab_size, embed_dim)
self.pos_emb = layers.Embedding(maxlen, embed_dim)
def call(self, x):
positions = tf.range(tf.shape(x)[-1])
return self.token_emb(x) + self.pos_emb(positions)
Full Classification Model
# Hyperparameters
VOCAB_SIZE = 20000
MAXLEN = 200
EMBED_DIM = 64
NUM_HEADS = 4
FF_DIM = 128
NUM_CLASSES = 2
inputs = layers.Input(shape=(MAXLEN,))
x = TokenAndPositionEmbedding(MAXLEN, VOCAB_SIZE, EMBED_DIM)(inputs)
x = TransformerBlock(EMBED_DIM, NUM_HEADS, FF_DIM)(x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dropout(0.1)(x)
x = layers.Dense(32, activation='relu')(x)
x = layers.Dropout(0.1)(x)
outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
model.summary()
# model.fit(x_train, y_train, batch_size=32, epochs=10, validation_split=0.2)
5. PyTorch Example
A Transformer encoder classifier using PyTorch's built-in nn.TransformerEncoderLayer
with sinusoidal positional encoding.
Sinusoidal Positional Encoding
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model)
def forward(self, x):
# x: (batch, seq_len, d_model)
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
Transformer Classifier Module
class TransformerClassifier(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_layers,
dim_ff, num_classes, max_len=512, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead,
dim_feedforward=dim_ff, dropout=dropout,
batch_first=True, norm_first=True # Pre-LN for stable training
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.classifier = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_model // 2, num_classes),
)
self.d_model = d_model
self._init_weights()
def _init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, x, padding_mask=None):
# Scale embeddings (standard Transformer practice)
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
x = self.encoder(x, src_key_padding_mask=padding_mask)
x = x.mean(dim=1) # global average pooling over sequence
return self.classifier(x)
Training Setup & Forward Pass
# Model hyperparameters
VOCAB_SIZE = 20000
D_MODEL = 128
NHEAD = 4
NUM_LAYERS = 2
DIM_FF = 256
NUM_CLASSES = 2
MAX_LEN = 200
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TransformerClassifier(
VOCAB_SIZE, D_MODEL, NHEAD, NUM_LAYERS, DIM_FF, NUM_CLASSES, MAX_LEN
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9
)
# Example forward pass
batch_size, seq_len = 8, 50
tokens = torch.randint(1, VOCAB_SIZE, (batch_size, seq_len)).to(device)
logits = model(tokens)
print(f"Output shape: {logits.shape}") # (8, 2)
# Training step
def train_step(tokens, labels):
model.train()
optimizer.zero_grad()
logits = model(tokens)
loss = criterion(logits, labels)
loss.backward()
# Gradient clipping is standard practice for Transformers
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
return loss.item()
6. Notable Pre-trained Transformer Models
In practice, training a Transformer from scratch requires large datasets and significant compute. Pre-trained models are almost always fine-tuned instead.
| Model | Type | Best For | Source |
|---|---|---|---|
| BERT | Encoder-only | Classification, NER, question answering | |
| GPT-4 | Decoder-only | Text generation, chat, reasoning | OpenAI |
| T5 | Encoder-Decoder | Translation, summarisation, generation | |
| RoBERTa | Encoder-only | Classification, improved BERT training | Meta AI |
| ViT | Encoder-only | Image classification (vision Transformer) | |
| Whisper | Encoder-Decoder | Speech recognition & translation | OpenAI |
| CLIP | Dual Encoder | Image–text matching, zero-shot vision | OpenAI |
7. Further Reading
transformers and datasets libraries for production work.