Many Eyes
Chapter 3 built one head — one clean mechanism for selective gathering. But a sentence is not one thing. “The lawyer who the jury believed told the truth.” To read that, something has to track syntax, something else has to track who believes whom, something else has to track the gap in the relative clause. One head cannot do all of that at once. The fix is embarrassingly direct: run several heads in parallel, let each look in its own direction, then pool the results.
4.1 The limit of one head
At the end of Chapter 3, SelfAttentionHead produced (B, T, head_size) from (B, T, C). That head computed exactly one pattern of relationships across the sequence: one set of Q/K/V projections, one attention matrix, one weighted average of values. It is, in one sense, one question the model asks of the sequence per forward pass.
Consider the sentence: “The cat sat on the mat because it was tired.” At the position "it", the model needs to resolve at least two things simultaneously — that it refers to the cat (coreference, backward-looking), and that tired is the predicate that governs it (grammatical dependency, forward-looking in meaning). A single head attending over the sequence generates one scalar score between "it" and every other token. It can learn to weight “cat” highly or weight “tired” highly — but a single weighted-average output conflates the two. There is no mechanism for the head to simultaneously “return” both a coreference signal and a dependency signal through one head_size-dimensional vector.
Language routinely requires several such simultaneous lookups. The single-head design is a bottleneck not because the math is wrong, but because one attention pattern is one degree of freedom, and a sequence has many independent things to say about itself at once.
4.2 The fix: parallel smaller heads
The solution in Vaswani et al. (“Attention Is All You Need,” 2017, §3.2.2) is to run h heads in parallel, each with its own Q, K, and V projection matrices, each projecting into a smaller subspace of size head_size = d_model / h. The outputs of all heads — each (B, T, head_size) — are then concatenated back into a (B, T, d_model) tensor, and a final learned linear layer mixes the concatenated result.
In the paper’s base model: d_model = 512, h = 8, so head_size = d_k = d_v = 512 / 8 = 64. Each head operates on a 64-dimensional slice of the representation space; together they reconstitute the full 512 dimensions. The total parameter count stays comparable to one large head, but now the model has eight independent subspaces to specialize in.
Formally (Vaswani et al. 2017, §3.2.2):
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O
where head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)
Each W_i^Q, W_i^K, W_i^V is a learned projection into a head_size-dimensional subspace. W^O is the final projection that takes the concatenated (B, T, h * head_size) = (B, T, d_model) back to (B, T, d_model). Everything in the projection matrices is learned by gradient descent — there is no hand-assignment of which head covers coreference and which covers syntax.
Why not just use one big head at fulld_model? You could — and for a single pass it would have more expressive capacity per head. But it would have one attention matrix and therefore one routing pattern. Multiple smaller heads each produce their own(T, T)attention matrix, which means h independent routing patterns operating in parallel. The cost is the same; the diversity is new.
4.3 The code
Here is MultiHeadAttention built directly on top of the SelfAttentionHead from Chapter 3. Read the shapes in every comment.
import torch
import torch.nn as nn
import torch.nn.functional as F
# ── Brought forward from Chapter 3 ────────────────────────────────────────────
class SelfAttentionHead(nn.Module):
"""One head of causal self-attention. Input (B,T,C) -> output (B,T,head_size)."""
def __init__(self, d_model, head_size, context_length, dropout=0.0):
super().__init__()
self.key = nn.Linear(d_model, head_size, bias=False)
self.query = nn.Linear(d_model, head_size, bias=False)
self.value = nn.Linear(d_model, head_size, bias=False)
self.register_buffer("tril", torch.tril(torch.ones(context_length, context_length)))
self.dropout = nn.Dropout(dropout)
def forward(self, x): # x: (B, T, C)
B, T, C = x.shape
k = self.key(x) # (B, T, head_size)
q = self.query(x) # (B, T, head_size)
wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 # (B, T, T), scaled
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1) # (B, T, T)
wei = self.dropout(wei)
v = self.value(x) # (B, T, head_size)
return wei @ v # (B, T, head_size)
# ── Chapter 4 addition ─────────────────────────────────────────────────────────
class MultiHeadAttention(nn.Module):
"""
h independent heads of causal self-attention, concatenated + projected.
Input: (B, T, d_model)
Output: (B, T, d_model) — same shape as input, so it can stack cleanly.
"""
def __init__(self, d_model, num_heads, context_length, dropout=0.0):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
head_size = d_model // num_heads # each head works in a subspace of this width
# num_heads independent heads, each (B,T,d_model) -> (B,T,head_size)
self.heads = nn.ModuleList([
SelfAttentionHead(d_model, head_size, context_length, dropout)
for _ in range(num_heads)
])
# after concat: (B, T, num_heads * head_size) = (B, T, d_model)
# W^O in the paper: mix the concatenated head outputs back into d_model
self.proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x): # x: (B, T, d_model)
# run all heads in parallel; each returns (B, T, head_size)
head_outputs = [h(x) for h in self.heads] # list of num_heads tensors
# concatenate along the last dim: (B, T, num_heads * head_size) = (B, T, d_model)
out = torch.cat(head_outputs, dim=-1) # (B, T, d_model)
# final linear mix (W^O): (B, T, d_model) -> (B, T, d_model)
out = self.proj(out) # (B, T, d_model)
out = self.dropout(out)
return out # (B, T, d_model)
# ── sanity check ──────────────────────────────────────────────────────────────
B, T, C = 4, 8, 384 # batch=4, sequence_len=8, d_model=384
x = torch.randn(B, T, C)
mha = MultiHeadAttention(d_model=384, num_heads=6, context_length=256)
out = mha(x)
print(out.shape) # torch.Size([4, 8, 384]) — same shape as input
Line-by-line walk
head_size = d_model // num_heads— the width of each head’s private subspace. Withd_model=384, num_heads=6that is64; with the paper’sd_model=512, num_heads=8it is also64.nn.ModuleList([SelfAttentionHead(...) for _ in range(num_heads)])— createsnum_headsindependent heads. Each has its ownkey,query,valueweight matrices and its owntrilbuffer. They share nothing. PyTorch sees all of them as sub-modules and will include their parameters inmodel.parameters().[h(x) for h in self.heads]— runs all heads on the same inputx. This is sequential Python here; in practice frameworks can run them in parallel on the GPU. Each head gets the full(B, T, d_model)input and returns(B, T, head_size).torch.cat(head_outputs, dim=-1)— stacks thosenum_headstensors along the channel axis:(B, T, head_size) × num_heads→(B, T, num_heads * head_size)=(B, T, d_model). This is the Concat in the paper’s formula.self.proj(out)— this is WO: a(d_model, d_model)linear layer that mixes across the concatenated heads. Without it, each head’shead_size-slice in the output would be independent of the other heads — no cross-head interaction. The projection is where the heads’ findings are combined into one view.- Output shape:
(B, T, d_model)— the same shape as the input. This is intentional. The transformer block (Chapter 5) feeds the output of multi-head attention directly into a residual connection that adds it back onto the input. Same shape is the contract.
4.4 What heads actually learn — and what we honestly know
It is tempting to say: head 1 learns syntax, head 2 learns coreference, head 3 learns distance. That story is cleaner than the evidence supports.
What we can say with reasonable confidence: heads empirically develop differentiated behavior. Researchers have visualized attention matrices in trained transformers and found that some heads consistently attend to adjacent tokens, some attend to sentence-initial positions, some attend to matching part-of-speech tags. These are observations from probing and visualization, not designed roles — nothing in the loss function says “head 4, you cover syntax.” Specialization, when it appears, is an emergent property of gradient descent on a prediction task.
One of the better-studied examples is the induction head, characterized in Olsson et al. (2022, “In-context Learning and Induction Heads,” Transformer Circuits Thread, transformer-circuits.pub). An induction head implements a specific copying pattern: if the model has seen the token sequence [A][B] earlier in the context, an induction head at a later occurrence of [A] attends strongly to the earlier [A] and thereby predicts [B] as the next token. Olsson et al. found that these heads emerge sharply at a specific point in training, a phase change visible as a bump in the training-loss curve — one of the cleaner mechanistic stories we have about what individual heads do.
But the honest picture is messier: most heads in a large trained model do not resolve into one neat linguistic role. Many heads appear to collaborate — their individual attention patterns are not obviously interpretable, but together they enable the model to route the right information to the right positions. The multi-head design gives the model the capacity for differentiated attention. Whether and how that capacity is used is determined by training, not by the architecture.
Heads can specialize. Some do. We can observe it in trained models. We did not build it in. That gap between design and behavior is the thing that makes transformer interpretability an open research problem.
4.5 An honest note on concatenation
Why concatenate, specifically? It is worth being direct: concatenation is an engineering choice that works, not a result derived from first principles.
The alternatives are obvious — you could average the head outputs, sum them, or feed them into a more complex combiner. Concatenation followed by a linear projection was the choice Vaswani et al. made, and it has held up across years of scaled models. The linear projection after the concat (WO) means the network can learn any linear combination of the heads — it can up-weight some heads and down-weight others for different positions, making it strictly more flexible than a fixed average. That flexibility, at the cost of one (d_model, d_model) weight matrix, is probably why it has stuck.
The WO layer also does something geometrically useful: it projects the concatenated output back into the same d_model-dimensional residual stream that will persist through the rest of the network. Keeping the residual stream’s shape constant at d_model across all sub-layers is a key architectural choice we will revisit in Chapter 5.
4.6 The thing to actually understand
- Many heads = many attention matrices. Each head produces its own
(B, T, T)routing pattern independently. The expressiveness gain is not in wider projections but in multiple independent views of the sequence. head_size = d_model / num_headskeeps the total computation comparable to one head at full width. You trade depth-per-head for breadth-of-coverage.- Concat + WO is the glue. Concatenation gathers the heads; the final linear projection mixes them and restores the
(B, T, d_model)shape the residual stream expects. - Specialization is observed, not designed. The architecture creates the capacity for heads to diverge; gradient descent decides whether they do and what they cover. Induction heads (Olsson et al. 2022) are one of the cleaner mechanistic examples — most heads are harder to read.
(B, T, d_model)in,(B, T, d_model)out. Multi-head attention is shape-preserving. Chapter 5 wraps it — along with a feed-forward net and two layer norms — into the transformer block, which is also shape-preserving. That’s what lets them stack.
4.7 Exercises
- Confirm the shape contract. Instantiate
MultiHeadAttention(d_model=256, num_heads=4, context_length=128)and pass a random(2, 16, 256)tensor through it. Assert thatoutput.shape == (2, 16, 256). Then changenum_headsto 8 and repeat. The output shape should not change. - Audit the parameter count. Call
sum(p.numel() for p in mha.parameters())for your multi-head module. Compare it to a singleSelfAttentionHead(d_model=256, head_size=256, context_length=128)without the final projection. Where does most of the parameter budget live? (Hint: the Q/K/V projections across all heads add up;projis anotherd_model × d_modelblock.) - Remove
WO. Comment outself.projand returntorch.cat(head_outputs, dim=-1)directly. Check the shape is still(B, T, d_model). Now train a tiny character-level model for 100 steps with and without the projection. Does loss differ? Why might the projection matter even early in training? - Watch the attention matrices diverge. After the module runs a forward pass, inspect
mha.heads[0].dropoutand rewriteSelfAttentionHead.forwardto also returnwei(alongside the value output). Run the same input through four heads and print each head’swei[0](the first example’s attention matrix). Are they identical? Should they be? - Break the divisibility constraint. Try
MultiHeadAttention(d_model=256, num_heads=3, ...). Theassertshould raise before any tensor is created. Then fix it: find the nearestnum_headsthat divides evenly and verify the module builds.
A 37th-Chamber original. Methods cited: Vaswani et al. (2017), “Attention Is All You Need,” arXiv:1706.03762, §3.2.2 (h=8, d_k=d_v=64 confirmed); Olsson et al. (2022), “In-context Learning and Induction Heads,” Transformer Circuits Thread, transformer-circuits.pub (confirmed at arxiv.org/abs/2209.11895). All prose and code written fresh.