The Opaque Box · Chapter 4

Many Eyes

Chapter 3 built one head — one clean mechanism for selective gathering. But a sentence is not one thing. “The lawyer who the jury believed told the truth.” To read that, something has to track syntax, something else has to track who believes whom, something else has to track the gap in the relative clause. One head cannot do all of that at once. The fix is embarrassingly direct: run several heads in parallel, let each look in its own direction, then pool the results.

4.1  The limit of one head

At the end of Chapter 3, SelfAttentionHead produced (B, T, head_size) from (B, T, C). That head computed exactly one pattern of relationships across the sequence: one set of Q/K/V projections, one attention matrix, one weighted average of values. It is, in one sense, one question the model asks of the sequence per forward pass.

Consider the sentence: “The cat sat on the mat because it was tired.” At the position "it", the model needs to resolve at least two things simultaneously — that it refers to the cat (coreference, backward-looking), and that tired is the predicate that governs it (grammatical dependency, forward-looking in meaning). A single head attending over the sequence generates one scalar score between "it" and every other token. It can learn to weight “cat” highly or weight “tired” highly — but a single weighted-average output conflates the two. There is no mechanism for the head to simultaneously “return” both a coreference signal and a dependency signal through one head_size-dimensional vector.

Language routinely requires several such simultaneous lookups. The single-head design is a bottleneck not because the math is wrong, but because one attention pattern is one degree of freedom, and a sequence has many independent things to say about itself at once.


4.2  The fix: parallel smaller heads

The solution in Vaswani et al. (“Attention Is All You Need,” 2017, §3.2.2) is to run h heads in parallel, each with its own Q, K, and V projection matrices, each projecting into a smaller subspace of size head_size = d_model / h. The outputs of all heads — each (B, T, head_size) — are then concatenated back into a (B, T, d_model) tensor, and a final learned linear layer mixes the concatenated result.

In the paper’s base model: d_model = 512, h = 8, so head_size = d_k = d_v = 512 / 8 = 64. Each head operates on a 64-dimensional slice of the representation space; together they reconstitute the full 512 dimensions. The total parameter count stays comparable to one large head, but now the model has eight independent subspaces to specialize in.

Formally (Vaswani et al. 2017, §3.2.2):

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O

where  head_i = Attention(Q W_i^Q,  K W_i^K,  V W_i^V)

Each W_i^Q, W_i^K, W_i^V is a learned projection into a head_size-dimensional subspace. W^O is the final projection that takes the concatenated (B, T, h * head_size) = (B, T, d_model) back to (B, T, d_model). Everything in the projection matrices is learned by gradient descent — there is no hand-assignment of which head covers coreference and which covers syntax.

Why not just use one big head at full d_model? You could — and for a single pass it would have more expressive capacity per head. But it would have one attention matrix and therefore one routing pattern. Multiple smaller heads each produce their own (T, T) attention matrix, which means h independent routing patterns operating in parallel. The cost is the same; the diversity is new.

4.3  The code

Here is MultiHeadAttention built directly on top of the SelfAttentionHead from Chapter 3. Read the shapes in every comment.

import torch
import torch.nn as nn
import torch.nn.functional as F

# ── Brought forward from Chapter 3 ────────────────────────────────────────────
class SelfAttentionHead(nn.Module):
    """One head of causal self-attention. Input (B,T,C) -> output (B,T,head_size)."""
    def __init__(self, d_model, head_size, context_length, dropout=0.0):
        super().__init__()
        self.key   = nn.Linear(d_model, head_size, bias=False)
        self.query = nn.Linear(d_model, head_size, bias=False)
        self.value = nn.Linear(d_model, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(context_length, context_length)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):                        # x: (B, T, C)
        B, T, C = x.shape
        k = self.key(x)                          # (B, T, head_size)
        q = self.query(x)                        # (B, T, head_size)
        wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5   # (B, T, T), scaled
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)             # (B, T, T)
        wei = self.dropout(wei)
        v = self.value(x)                        # (B, T, head_size)
        return wei @ v                           # (B, T, head_size)

# ── Chapter 4 addition ─────────────────────────────────────────────────────────
class MultiHeadAttention(nn.Module):
    """
    h independent heads of causal self-attention, concatenated + projected.
    Input:  (B, T, d_model)
    Output: (B, T, d_model)   — same shape as input, so it can stack cleanly.
    """
    def __init__(self, d_model, num_heads, context_length, dropout=0.0):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        head_size = d_model // num_heads         # each head works in a subspace of this width

        # num_heads independent heads, each (B,T,d_model) -> (B,T,head_size)
        self.heads = nn.ModuleList([
            SelfAttentionHead(d_model, head_size, context_length, dropout)
            for _ in range(num_heads)
        ])
        # after concat: (B, T, num_heads * head_size) = (B, T, d_model)
        # W^O in the paper: mix the concatenated head outputs back into d_model
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):                        # x: (B, T, d_model)
        # run all heads in parallel; each returns (B, T, head_size)
        head_outputs = [h(x) for h in self.heads]   # list of num_heads tensors

        # concatenate along the last dim: (B, T, num_heads * head_size) = (B, T, d_model)
        out = torch.cat(head_outputs, dim=-1)    # (B, T, d_model)

        # final linear mix (W^O): (B, T, d_model) -> (B, T, d_model)
        out = self.proj(out)                     # (B, T, d_model)
        out = self.dropout(out)
        return out                               # (B, T, d_model)

# ── sanity check ──────────────────────────────────────────────────────────────
B, T, C = 4, 8, 384            # batch=4, sequence_len=8, d_model=384
x = torch.randn(B, T, C)

mha = MultiHeadAttention(d_model=384, num_heads=6, context_length=256)
out = mha(x)
print(out.shape)               # torch.Size([4, 8, 384]) — same shape as input

Line-by-line walk


4.4  What heads actually learn — and what we honestly know

It is tempting to say: head 1 learns syntax, head 2 learns coreference, head 3 learns distance. That story is cleaner than the evidence supports.

What we can say with reasonable confidence: heads empirically develop differentiated behavior. Researchers have visualized attention matrices in trained transformers and found that some heads consistently attend to adjacent tokens, some attend to sentence-initial positions, some attend to matching part-of-speech tags. These are observations from probing and visualization, not designed roles — nothing in the loss function says “head 4, you cover syntax.” Specialization, when it appears, is an emergent property of gradient descent on a prediction task.

One of the better-studied examples is the induction head, characterized in Olsson et al. (2022, “In-context Learning and Induction Heads,” Transformer Circuits Thread, transformer-circuits.pub). An induction head implements a specific copying pattern: if the model has seen the token sequence [A][B] earlier in the context, an induction head at a later occurrence of [A] attends strongly to the earlier [A] and thereby predicts [B] as the next token. Olsson et al. found that these heads emerge sharply at a specific point in training, a phase change visible as a bump in the training-loss curve — one of the cleaner mechanistic stories we have about what individual heads do.

But the honest picture is messier: most heads in a large trained model do not resolve into one neat linguistic role. Many heads appear to collaborate — their individual attention patterns are not obviously interpretable, but together they enable the model to route the right information to the right positions. The multi-head design gives the model the capacity for differentiated attention. Whether and how that capacity is used is determined by training, not by the architecture.

Heads can specialize. Some do. We can observe it in trained models. We did not build it in. That gap between design and behavior is the thing that makes transformer interpretability an open research problem.

4.5  An honest note on concatenation

Why concatenate, specifically? It is worth being direct: concatenation is an engineering choice that works, not a result derived from first principles.

The alternatives are obvious — you could average the head outputs, sum them, or feed them into a more complex combiner. Concatenation followed by a linear projection was the choice Vaswani et al. made, and it has held up across years of scaled models. The linear projection after the concat (WO) means the network can learn any linear combination of the heads — it can up-weight some heads and down-weight others for different positions, making it strictly more flexible than a fixed average. That flexibility, at the cost of one (d_model, d_model) weight matrix, is probably why it has stuck.

The WO layer also does something geometrically useful: it projects the concatenated output back into the same d_model-dimensional residual stream that will persist through the rest of the network. Keeping the residual stream’s shape constant at d_model across all sub-layers is a key architectural choice we will revisit in Chapter 5.


4.6  The thing to actually understand


4.7  Exercises

  1. Confirm the shape contract. Instantiate MultiHeadAttention(d_model=256, num_heads=4, context_length=128) and pass a random (2, 16, 256) tensor through it. Assert that output.shape == (2, 16, 256). Then change num_heads to 8 and repeat. The output shape should not change.
  2. Audit the parameter count. Call sum(p.numel() for p in mha.parameters()) for your multi-head module. Compare it to a single SelfAttentionHead(d_model=256, head_size=256, context_length=128) without the final projection. Where does most of the parameter budget live? (Hint: the Q/K/V projections across all heads add up; proj is another d_model × d_model block.)
  3. Remove WO. Comment out self.proj and return torch.cat(head_outputs, dim=-1) directly. Check the shape is still (B, T, d_model). Now train a tiny character-level model for 100 steps with and without the projection. Does loss differ? Why might the projection matter even early in training?
  4. Watch the attention matrices diverge. After the module runs a forward pass, inspect mha.heads[0].dropout and rewrite SelfAttentionHead.forward to also return wei (alongside the value output). Run the same input through four heads and print each head’s wei[0] (the first example’s attention matrix). Are they identical? Should they be?
  5. Break the divisibility constraint. Try MultiHeadAttention(d_model=256, num_heads=3, ...). The assert should raise before any tensor is created. Then fix it: find the nearest num_heads that divides evenly and verify the module builds.
What’s next
Ch 5 — The transformer block
Read Ch 5 →

A 37th-Chamber original. Methods cited: Vaswani et al. (2017), “Attention Is All You Need,” arXiv:1706.03762, §3.2.2 (h=8, d_k=d_v=64 confirmed); Olsson et al. (2022), “In-context Learning and Induction Heads,” Transformer Circuits Thread, transformer-circuits.pub (confirmed at arxiv.org/abs/2209.11895). All prose and code written fresh.

Written by a Fable · Edited by Kyle Sullivan