Chapter 3: Multi-Head Attention

Learning Multiple Relationships

Learning Objectives

  • Understand multi-head attention fundamentals
  • Master the mathematical foundations
  • Learn practical implementation
  • Apply knowledge through examples
  • Recognize real-world applications

Multi-Head Attention

Why Multiple Heads?

Multi-head attention allows the model to attend to information from different representation subspaces simultaneously. Instead of one attention mechanism, we use multiple parallel attention "heads", each learning different types of relationships.

Think of multi-head attention like having multiple experts:

  • Single-head attention: Like one person trying to understand everything - they might miss some relationships
  • Multi-head attention: Like a team of specialists - one focuses on syntax, another on semantics, another on long-range dependencies
  • Result: Each head learns different patterns, then we combine their insights

πŸ“š Real-World Analogy: The Research Team

Imagine analyzing a complex document:

  • Head 1 (Syntax Expert): Focuses on grammatical relationships - "What is the subject? What is the verb?"
  • Head 2 (Semantic Expert): Focuses on meaning - "What concepts are related? What is the topic?"
  • Head 3 (Reference Expert): Focuses on references - "What does 'it' refer to? What does 'this' mean?"
  • Head 4 (Temporal Expert): Focuses on time relationships - "What happened first? What's the sequence?"

Multi-head attention: All experts work in parallel, then we combine their findings for a complete understanding!

How Multi-Head Attention Works

The process:

  1. Split: Divide the embedding dimension into multiple heads
  2. Parallel Attention: Each head computes attention independently
  3. Specialization: Each head learns different relationships
  4. Concatenate: Combine all head outputs
  5. Project: Linear transformation to final dimension

Key Concepts

πŸ”‘ Head Specialization

Different attention heads learn to focus on different aspects of the input:

Example: Sentence Analysis

Sentence: "The cat, which was very fluffy, sat on the mat."

  • Head 1 (Subject-Verb): High attention from "sat" to "cat" (subject-verb relationship)
  • Head 2 (Modifier): High attention from "fluffy" to "cat" (adjective-noun relationship)
  • Head 3 (Relative Clause): High attention from "which" to "cat" (relative pronoun reference)
  • Head 4 (Preposition): High attention from "on" to "mat" (preposition-object relationship)

Key insight: Each head specializes in a different linguistic relationship!

Dimension Splitting

How dimensions are divided:

  • If embedding dimension = 512 and num_heads = 8
  • Each head gets: 512 / 8 = 64 dimensions
  • Each head operates in its own 64-dimensional subspace
  • This allows parallel computation and specialization

Detailed Dimension Breakdown

Example with d_model = 512, num_heads = 8:

  • Input: X shape (batch, seq_len, 512)
  • Q projection: W_Q shape (512, 512) β†’ Q shape (batch, seq_len, 512)
  • Split Q: 8 heads, each (batch, seq_len, 64)
  • Same for K and V: Each split into 8 heads of 64 dimensions
  • Each head: Computes attention independently in 64-dim space
  • After attention: Each head outputs (batch, seq_len, 64)
  • Concatenate: 8 heads Γ— 64 dims = 512 dims (batch, seq_len, 512)
  • Output projection: W_O shape (512, 512) β†’ Final (batch, seq_len, 512)

Why Split Dimensions?

Advantages of dimension splitting:

  • Computational efficiency: Each head operates on smaller matrices (64Γ—64 vs 512Γ—512)
  • Parallel processing: All heads can compute simultaneously
  • Specialization: Each head learns different patterns in its subspace
  • Parameter efficiency: Total parameters similar to single-head, but more expressive

Head Specialization in Practice

Research shows that different heads learn different patterns:

Observed Head Behaviors

  • Syntactic heads: Focus on grammatical relationships (subject-verb, adjective-noun)
  • Semantic heads: Focus on meaning and topic relationships
  • Positional heads: Focus on relative positions and distances
  • Long-range heads: Capture dependencies between distant words
  • Local heads: Focus on immediate neighbors

Example: Multi-Head Analysis

Sentence: "The bank that the river flows by is closed."

  • Head 1: "bank" attends to "closed" (subject-predicate)
  • Head 2: "bank" attends to "river" (disambiguates "bank" = financial, not riverbank)
  • Head 3: "that" attends to "bank" (relative pronoun)
  • Head 4: "flows" attends to "river" (verb-subject)
  • Result: Multiple perspectives combined for complete understanding

Mathematical Formulations

Multi-Head Attention Formula

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O\]
\[\text{where } \text{head}_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i)\]
Notation:
  • h: Number of attention heads
  • QW^Qα΅’: Query projection for head i
  • KW^Kα΅’: Key projection for head i
  • VW^Vα΅’: Value projection for head i
  • Concat: Concatenate all head outputs
  • W^O: Output projection matrix
Dimension Breakdown:
  • Input: (batch, seq_len, d_model)
  • Each head: (batch, seq_len, d_model/h)
  • After concat: (batch, seq_len, d_model)
  • After W^O: (batch, seq_len, d_model)
Step-by-Step Computation:
  1. Create Q, K, V: Q = XW_Q, K = XW_K, V = XW_V (all shape: batch Γ— seq_len Γ— d_model)
  2. Split into heads: Reshape to (batch Γ— h, seq_len, d_model/h) for each of Q, K, V
  3. Compute attention per head: Each head computes scaled dot-product attention independently
  4. Concatenate: Combine all h heads β†’ (batch, seq_len, d_model)
  5. Output projection: Multiply by W^O to get final output

Efficient Implementation Formula

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O\]
\[\text{where } \text{head}_i = \text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right) V_i\]
Efficient Computation:
  • All heads can be computed in parallel using batched matrix operations
  • Instead of h separate computations, use one large batched computation
  • Q, K, V are reshaped to include head dimension: (batch, h, seq_len, d_k)
  • Attention computed for all heads simultaneously
  • Much faster on GPUs than sequential head computation

Detailed Examples

Example: Multi-Head Attention on "The cat sat on the mat"

Input: Sequence of 6 tokens with embedding dimension 512, using 8 heads

Step 1: Create Q, K, V for each head

  • Head 1: Q₁, K₁, V₁ (each 64-dim, from 512/8)
  • Head 2: Qβ‚‚, Kβ‚‚, Vβ‚‚ (64-dim)
  • ... Head 8: Qβ‚ˆ, Kβ‚ˆ, Vβ‚ˆ (64-dim)

Step 2: Compute attention for each head independently

  • Head 1 might focus on subject-verb: "sat" β†’ "cat"
  • Head 2 might focus on preposition: "on" β†’ "mat"
  • Head 3 might focus on articles: "the" β†’ "cat", "the" β†’ "mat"
  • Each head produces its own attention output (6Γ—64)

Step 3: Concatenate all heads

  • Concat[head₁, headβ‚‚, ..., headβ‚ˆ] β†’ (6Γ—512)
  • All 8 heads' outputs combined

Step 4: Apply output projection

  • Multiply by W^O (512Γ—512) β†’ Final output (6Γ—512)
  • This combines information from all heads

Result: Each token now has a representation that incorporates information from all tokens, with different heads capturing different relationships.

Example: Why Multiple Heads Help

Single-head attention limitation:

With one head, the model must learn all relationships in one attention pattern. This is like asking one person to be an expert in grammar, syntax, semantics, and discourse all at once.

Multi-head advantage:

With 8 heads, each head can specialize:

  • Head 1: Subject-verb relationships
  • Head 2: Adjective-noun relationships
  • Head 3: Preposition-object relationships
  • Head 4: Pronoun-antecedent relationships
  • Head 5: Long-range dependencies
  • Head 6: Negation scope
  • Head 7: Temporal relationships
  • Head 8: Causal relationships

This specialization allows the model to capture more nuanced relationships than a single head could.

Implementation

Multi-Head Attention Implementation

import numpy as np

def multi_head_attention(X, W_Q, W_K, W_V, W_O, num_heads):
    """
    Multi-head attention implementation
    
    Parameters:
    X: Input (batch, seq_len, d_model)
    W_Q: Query weight matrices for each head (num_heads, d_model, d_k)
    W_K: Key weight matrices for each head (num_heads, d_model, d_k)
    W_V: Value weight matrices for each head (num_heads, d_model, d_v)
    W_O: Output projection (d_model, d_model)
    num_heads: Number of attention heads
    """
    batch_size, seq_len, d_model = X.shape
    d_k = d_model // num_heads
    
    # List to store outputs from each head
    head_outputs = []
    
    # Process each head
    for h in range(num_heads):
        # Project to Q, K, V for this head
        Q_h = np.dot(X, W_Q[h])  # (batch, seq_len, d_k)
        K_h = np.dot(X, W_K[h])  # (batch, seq_len, d_k)
        V_h = np.dot(X, W_V[h])  # (batch, seq_len, d_v)
        
        # Compute attention for this head
        scores = np.dot(Q_h, K_h.transpose(0, 2, 1))  # (batch, seq_len, seq_len)
        scores = scores / np.sqrt(d_k)
        
        # Softmax
        attention_weights = softmax(scores, axis=-1)
        
        # Weighted sum
        head_output = np.dot(attention_weights, V_h)  # (batch, seq_len, d_v)
        head_outputs.append(head_output)
    
    # Concatenate all heads
    multi_head_output = np.concatenate(head_outputs, axis=-1)  # (batch, seq_len, d_model)
    
    # Final output projection
    output = np.dot(multi_head_output, W_O)  # (batch, seq_len, d_model)
    
    return output

def softmax(x, axis=-1):
    """Softmax function"""
    exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

# Example usage
batch_size, seq_len, d_model, num_heads = 2, 10, 512, 8
X = np.random.randn(batch_size, seq_len, d_model)
d_k = d_model // num_heads

# Initialize weight matrices for each head
W_Q = np.random.randn(num_heads, d_model, d_k) * 0.1
W_K = np.random.randn(num_heads, d_model, d_k) * 0.1
W_V = np.random.randn(num_heads, d_model, d_k) * 0.1
W_O = np.random.randn(d_model, d_model) * 0.1

output = multi_head_attention(X, W_Q, W_K, W_V, W_O, num_heads)
print(f"Output shape: {output.shape}")  # (2, 10, 512)

Real-World Applications

Multi-Head Attention in Modern NLP

Multi-head attention is the core mechanism in:

  • BERT: Uses 12 heads to understand bidirectional context, with different heads capturing different linguistic patterns
  • GPT: Uses 12-96 heads depending on model size, with heads specializing in different aspects of language generation
  • Translation models: Encoder heads focus on source language relationships, decoder heads on target language and cross-lingual alignment
  • Question answering: Different heads attend to question tokens, context tokens, and relationships between them

Why Multi-Head is Essential

Single-head attention cannot capture the complexity of language:

  • Language has multiple simultaneous relationships (syntax, semantics, discourse)
  • Different relationships require different attention patterns
  • Multi-head allows parallel processing of different relationship types
  • Empirically, models with more heads perform better (up to a point)

Head Specialization in Practice

Research shows heads naturally specialize:

  • Some heads focus on local patterns (adjacent words)
  • Others focus on long-range dependencies (distant relationships)
  • Some heads capture syntactic structure
  • Others capture semantic relationships
  • This specialization emerges during training, not by design

Test Your Understanding

Question 1: What is multi-head attention?

A) Running multiple attention mechanisms in parallel with different learned projections
B) Using multiple layers
C) Processing multiple sequences
D) Using larger matrices

Question 2: Why use multiple attention heads instead of one?

A) Different heads can learn to focus on different types of relationships (syntax, semantics, long-range dependencies), allowing the model to capture diverse patterns simultaneously
B) It's faster
C) It uses less memory
D) No reason

Question 3: How do you combine outputs from multiple attention heads?

A) Concatenate all head outputs, then apply a linear projection to get the final output with the desired dimension
B) Average them
C) Use only the first head
D) Multiply them

Question 4: What is the formula for multi-head attention?

A) \(MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W_O\) where each head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
B) \(MultiHead = Q + K + V\)
C) \(MultiHead = Q \times K\)
D) \(MultiHead = Attention\)

Question 5: How do you split dimensions for multiple heads?

A) If embedding dimension is d_model and you have h heads, each head gets d_model/h dimensions. For example, 768 dimensions with 12 heads gives 64 dimensions per head
B) Each head gets full dimension
C) Random split
D) Only first head gets dimensions

Question 6: What types of relationships can different attention heads learn?

A) Some heads focus on syntactic relationships (subject-verb), others on semantic (word meaning), positional (distance), or task-specific patterns, creating a rich representation
B) All heads learn the same
C) Only positional
D) Only semantic

Question 7: How does the number of heads affect model capacity?

A) More heads increase capacity by allowing the model to learn diverse attention patterns, but too many heads can lead to redundancy. Typical values are 8-16 heads
B) More heads always better
C) Fewer heads always better
D) Number doesn't matter

Question 8: What is the computational cost of multi-head attention compared to single-head?

A) Similar cost because dimensions are split across heads. With h heads, each head processes d/h dimensions, so total computation is roughly the same as single head with full dimension
B) h times more expensive
C) h times less expensive
D) No difference

Question 9: How would you visualize what different attention heads learn?

A) Create separate attention heatmaps for each head, showing which positions each head focuses on. Compare patterns across heads to see if they specialize in different relationships
B) Only visualize one head
C) Average all heads
D) Can't visualize

Question 10: What happens if you use too many attention heads?

A) Heads may become redundant, learning similar patterns, wasting parameters, and potentially hurting performance. There's a sweet spot based on model size and task
B) Always improves performance
C) No effect
D) Makes it faster

Question 11: How do you implement multi-head attention from scratch?

A) Split Q, K, V into h heads (reshape to add head dimension), compute attention for each head independently, concatenate all head outputs, apply output projection matrix W_O to combine heads
B) Just use single head
C) Stack heads sequentially
D) Random operations

Question 12: Why is multi-head attention a key innovation in transformers?

A) It allows the model to attend to information from different representation subspaces simultaneously, capturing multiple types of relationships in parallel, which is crucial for understanding complex language patterns
B) It's just faster
C) It uses less memory
D) No special reason