Chapter 3: Multi-Head Attention
Learning Multiple Relationships
Learning Objectives
- Understand multi-head attention fundamentals
- Master the mathematical foundations
- Learn practical implementation
- Apply knowledge through examples
- Recognize real-world applications
Multi-Head Attention
Why Multiple Heads?
Multi-head attention allows the model to attend to information from different representation subspaces simultaneously. Instead of one attention mechanism, we use multiple parallel attention "heads", each learning different types of relationships.
Think of multi-head attention like having multiple experts:
- Single-head attention: Like one person trying to understand everything - they might miss some relationships
- Multi-head attention: Like a team of specialists - one focuses on syntax, another on semantics, another on long-range dependencies
- Result: Each head learns different patterns, then we combine their insights
π Real-World Analogy: The Research Team
Imagine analyzing a complex document:
- Head 1 (Syntax Expert): Focuses on grammatical relationships - "What is the subject? What is the verb?"
- Head 2 (Semantic Expert): Focuses on meaning - "What concepts are related? What is the topic?"
- Head 3 (Reference Expert): Focuses on references - "What does 'it' refer to? What does 'this' mean?"
- Head 4 (Temporal Expert): Focuses on time relationships - "What happened first? What's the sequence?"
Multi-head attention: All experts work in parallel, then we combine their findings for a complete understanding!
How Multi-Head Attention Works
The process:
- Split: Divide the embedding dimension into multiple heads
- Parallel Attention: Each head computes attention independently
- Specialization: Each head learns different relationships
- Concatenate: Combine all head outputs
- Project: Linear transformation to final dimension
Key Concepts
π Head Specialization
Different attention heads learn to focus on different aspects of the input:
Example: Sentence Analysis
Sentence: "The cat, which was very fluffy, sat on the mat."
- Head 1 (Subject-Verb): High attention from "sat" to "cat" (subject-verb relationship)
- Head 2 (Modifier): High attention from "fluffy" to "cat" (adjective-noun relationship)
- Head 3 (Relative Clause): High attention from "which" to "cat" (relative pronoun reference)
- Head 4 (Preposition): High attention from "on" to "mat" (preposition-object relationship)
Key insight: Each head specializes in a different linguistic relationship!
Dimension Splitting
How dimensions are divided:
- If embedding dimension = 512 and num_heads = 8
- Each head gets: 512 / 8 = 64 dimensions
- Each head operates in its own 64-dimensional subspace
- This allows parallel computation and specialization
Detailed Dimension Breakdown
Example with d_model = 512, num_heads = 8:
- Input: X shape (batch, seq_len, 512)
- Q projection: W_Q shape (512, 512) β Q shape (batch, seq_len, 512)
- Split Q: 8 heads, each (batch, seq_len, 64)
- Same for K and V: Each split into 8 heads of 64 dimensions
- Each head: Computes attention independently in 64-dim space
- After attention: Each head outputs (batch, seq_len, 64)
- Concatenate: 8 heads Γ 64 dims = 512 dims (batch, seq_len, 512)
- Output projection: W_O shape (512, 512) β Final (batch, seq_len, 512)
Why Split Dimensions?
Advantages of dimension splitting:
- Computational efficiency: Each head operates on smaller matrices (64Γ64 vs 512Γ512)
- Parallel processing: All heads can compute simultaneously
- Specialization: Each head learns different patterns in its subspace
- Parameter efficiency: Total parameters similar to single-head, but more expressive
Head Specialization in Practice
Research shows that different heads learn different patterns:
Observed Head Behaviors
- Syntactic heads: Focus on grammatical relationships (subject-verb, adjective-noun)
- Semantic heads: Focus on meaning and topic relationships
- Positional heads: Focus on relative positions and distances
- Long-range heads: Capture dependencies between distant words
- Local heads: Focus on immediate neighbors
Example: Multi-Head Analysis
Sentence: "The bank that the river flows by is closed."
- Head 1: "bank" attends to "closed" (subject-predicate)
- Head 2: "bank" attends to "river" (disambiguates "bank" = financial, not riverbank)
- Head 3: "that" attends to "bank" (relative pronoun)
- Head 4: "flows" attends to "river" (verb-subject)
- Result: Multiple perspectives combined for complete understanding
Mathematical Formulations
Multi-Head Attention Formula
Notation:
- h: Number of attention heads
- QW^Qα΅’: Query projection for head i
- KW^Kα΅’: Key projection for head i
- VW^Vα΅’: Value projection for head i
- Concat: Concatenate all head outputs
- W^O: Output projection matrix
Dimension Breakdown:
- Input: (batch, seq_len, d_model)
- Each head: (batch, seq_len, d_model/h)
- After concat: (batch, seq_len, d_model)
- After W^O: (batch, seq_len, d_model)
Step-by-Step Computation:
- Create Q, K, V: Q = XW_Q, K = XW_K, V = XW_V (all shape: batch Γ seq_len Γ d_model)
- Split into heads: Reshape to (batch Γ h, seq_len, d_model/h) for each of Q, K, V
- Compute attention per head: Each head computes scaled dot-product attention independently
- Concatenate: Combine all h heads β (batch, seq_len, d_model)
- Output projection: Multiply by W^O to get final output
Efficient Implementation Formula
Efficient Computation:
- All heads can be computed in parallel using batched matrix operations
- Instead of h separate computations, use one large batched computation
- Q, K, V are reshaped to include head dimension: (batch, h, seq_len, d_k)
- Attention computed for all heads simultaneously
- Much faster on GPUs than sequential head computation
Detailed Examples
Example: Multi-Head Attention on "The cat sat on the mat"
Input: Sequence of 6 tokens with embedding dimension 512, using 8 heads
Step 1: Create Q, K, V for each head
- Head 1: Qβ, Kβ, Vβ (each 64-dim, from 512/8)
- Head 2: Qβ, Kβ, Vβ (64-dim)
- ... Head 8: Qβ, Kβ, Vβ (64-dim)
Step 2: Compute attention for each head independently
- Head 1 might focus on subject-verb: "sat" β "cat"
- Head 2 might focus on preposition: "on" β "mat"
- Head 3 might focus on articles: "the" β "cat", "the" β "mat"
- Each head produces its own attention output (6Γ64)
Step 3: Concatenate all heads
- Concat[headβ, headβ, ..., headβ] β (6Γ512)
- All 8 heads' outputs combined
Step 4: Apply output projection
- Multiply by W^O (512Γ512) β Final output (6Γ512)
- This combines information from all heads
Result: Each token now has a representation that incorporates information from all tokens, with different heads capturing different relationships.
Example: Why Multiple Heads Help
Single-head attention limitation:
With one head, the model must learn all relationships in one attention pattern. This is like asking one person to be an expert in grammar, syntax, semantics, and discourse all at once.
Multi-head advantage:
With 8 heads, each head can specialize:
- Head 1: Subject-verb relationships
- Head 2: Adjective-noun relationships
- Head 3: Preposition-object relationships
- Head 4: Pronoun-antecedent relationships
- Head 5: Long-range dependencies
- Head 6: Negation scope
- Head 7: Temporal relationships
- Head 8: Causal relationships
This specialization allows the model to capture more nuanced relationships than a single head could.
Implementation
Multi-Head Attention Implementation
import numpy as np
def multi_head_attention(X, W_Q, W_K, W_V, W_O, num_heads):
"""
Multi-head attention implementation
Parameters:
X: Input (batch, seq_len, d_model)
W_Q: Query weight matrices for each head (num_heads, d_model, d_k)
W_K: Key weight matrices for each head (num_heads, d_model, d_k)
W_V: Value weight matrices for each head (num_heads, d_model, d_v)
W_O: Output projection (d_model, d_model)
num_heads: Number of attention heads
"""
batch_size, seq_len, d_model = X.shape
d_k = d_model // num_heads
# List to store outputs from each head
head_outputs = []
# Process each head
for h in range(num_heads):
# Project to Q, K, V for this head
Q_h = np.dot(X, W_Q[h]) # (batch, seq_len, d_k)
K_h = np.dot(X, W_K[h]) # (batch, seq_len, d_k)
V_h = np.dot(X, W_V[h]) # (batch, seq_len, d_v)
# Compute attention for this head
scores = np.dot(Q_h, K_h.transpose(0, 2, 1)) # (batch, seq_len, seq_len)
scores = scores / np.sqrt(d_k)
# Softmax
attention_weights = softmax(scores, axis=-1)
# Weighted sum
head_output = np.dot(attention_weights, V_h) # (batch, seq_len, d_v)
head_outputs.append(head_output)
# Concatenate all heads
multi_head_output = np.concatenate(head_outputs, axis=-1) # (batch, seq_len, d_model)
# Final output projection
output = np.dot(multi_head_output, W_O) # (batch, seq_len, d_model)
return output
def softmax(x, axis=-1):
"""Softmax function"""
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
# Example usage
batch_size, seq_len, d_model, num_heads = 2, 10, 512, 8
X = np.random.randn(batch_size, seq_len, d_model)
d_k = d_model // num_heads
# Initialize weight matrices for each head
W_Q = np.random.randn(num_heads, d_model, d_k) * 0.1
W_K = np.random.randn(num_heads, d_model, d_k) * 0.1
W_V = np.random.randn(num_heads, d_model, d_k) * 0.1
W_O = np.random.randn(d_model, d_model) * 0.1
output = multi_head_attention(X, W_Q, W_K, W_V, W_O, num_heads)
print(f"Output shape: {output.shape}") # (2, 10, 512)
Real-World Applications
Multi-Head Attention in Modern NLP
Multi-head attention is the core mechanism in:
- BERT: Uses 12 heads to understand bidirectional context, with different heads capturing different linguistic patterns
- GPT: Uses 12-96 heads depending on model size, with heads specializing in different aspects of language generation
- Translation models: Encoder heads focus on source language relationships, decoder heads on target language and cross-lingual alignment
- Question answering: Different heads attend to question tokens, context tokens, and relationships between them
Why Multi-Head is Essential
Single-head attention cannot capture the complexity of language:
- Language has multiple simultaneous relationships (syntax, semantics, discourse)
- Different relationships require different attention patterns
- Multi-head allows parallel processing of different relationship types
- Empirically, models with more heads perform better (up to a point)
Head Specialization in Practice
Research shows heads naturally specialize:
- Some heads focus on local patterns (adjacent words)
- Others focus on long-range dependencies (distant relationships)
- Some heads capture syntactic structure
- Others capture semantic relationships
- This specialization emerges during training, not by design