RNNs and LSTMs process sequences strictly left-to-right, forcing every piece of information to squeeze through a single hidden state bottleneck. Attention removes the bottleneck: each token can directly look at every other token in the sequence and weigh their relevance. This one idea — 'let me pay more attention to the words that matter' — is the foundation of every modern large language model.
Query · Key · Value — the Three Roles
Every token in the sequence projects itself into 3 vectors via learned matrices:
Q = X · W_Q (queries — 'what am I looking for?')
K = X · W_K (keys — 'what do I offer?')
V = X · W_V (values — 'what do I actually carry?')
Attention output for one query q:
scores = q · Kᵀ (how similar am I to each key?)
weights = softmax(scores / √dₖ) (normalise to sum to 1)
output = weights · V (weighted sum of values)
Full self-attention in matrix form:
Attention(Q, K, V) = softmax(Q·Kᵀ / √dₖ) · VEvery token is simultaneously a query (asking), a key (offering), and a value (carrying content). Self-attention = Q, K, V all come from the same sequence.
Multi-Head Attention: Many Lenses at Once
A single attention layer learns one relationship type. Multi-head attention runs H parallel heads with different W_Q, W_K, W_V matrices, then concatenates:
headᵢ = Attention(X·W_Q^i, X·W_K^i, X·W_V^i)
MultiHead(X) = Concat(head₁, ..., head_H) · W_O
Typical: H = 8 or 12 or 16 heads.
Each head ends up specialising in a different pattern — e.g.:
Head 1 learns subject ↔ verb agreement
Head 2 learns adjective ↔ noun binding
Head 3 learns coreference (pronoun → antecedent)
Head 4 attends to the previous token (local context)
Like multiple convolutional filters each detecting a different pattern.Not manually designed — each head discovers its specialty during training
Why Attention Beats RNNs
- Parallelisable: RNN must process xₜ before xₜ₊₁ (sequential). Attention computes all positions at once — massive GPU speedup on modern hardware.
- Direct long-range dependencies: path length from token 1 to token 1000 is 1 hop in attention vs 1000 hops in an RNN. No vanishing gradient.
- Interpretable: attention weights can be visualised, showing exactly which tokens the model consulted for each decision.
- Cost: O(N²·d) for sequence length N — the N² makes very long sequences expensive. For N ≤ ~10,000 this is fine; beyond that, linear-attention variants are needed.
Masking: Causal vs Bidirectional
Not all tasks should let every token see every other token:
Bidirectional attention (used in BERT, for understanding tasks):
Every token attends to every other token — full context.
Attention matrix is fully populated.
Causal (masked) attention (used in GPT, for generation):
Token at position t can only attend to positions ≤ t.
Upper-triangular entries of Q·Kᵀ are set to −∞ before softmax.
Prevents 'cheating' by looking at future tokens during training.
mask = [1 0 0 0] (token 1 sees only itself)
[1 1 0 0] (token 2 sees 1 & 2)
[1 1 1 0] (token 3 sees 1, 2, 3)
[1 1 1 1] (token 4 sees all)BERT = bidirectional (fill-in-blank) · GPT = causal (predict next token)