Chapter 6: Residual Connections & Layer Normalization

Stabilizing Deep Networks

Learning Objectives

  • Understand residual connections & layer normalization fundamentals
  • Master the mathematical foundations
  • Learn practical implementation
  • Apply knowledge through examples
  • Recognize real-world applications

Residual Connections & Layer Normalization

Why These Components Are Critical

Residual connections and layer normalization are essential for training deep transformer networks. They solve two critical problems: vanishing gradients and training instability.

Think of them like safety features in a car:

  • Residual connections: Like a backup system - ensures information can always flow through, even if a layer doesn't learn well
  • Layer normalization: Like cruise control - keeps the network stable and prevents it from going too fast or too slow
  • Together: They enable training of very deep networks (50+ layers) that would otherwise be impossible

⚠️ The Problem Without These Components

Problem 1: Vanishing Gradients

In deep networks without residuals:

  • Gradients get smaller as they flow backward through layers
  • Early layers receive almost no gradient signal
  • Result: Early layers don't learn, only later layers learn
  • Like a message getting distorted as it passes through many people
Problem 2: Training Instability

Without layer normalization:

  • Activations can become very large or very small
  • Gradients can explode or vanish
  • Training becomes unstable and may not converge
  • Like a car without cruise control - speed fluctuates wildly

📚 Historical Context

These innovations came from different sources:

  • Residual Connections: Introduced in ResNet (2015) for image recognition
  • Layer Normalization: Introduced in 2016, specifically for sequence models
  • Transformer (2017): Combined both, enabling deep language models
  • Result: Enabled training of models with 100+ layers!

Key Concepts

🔑 Residual Connections (Skip Connections)

Residual connections add the input directly to the output of a layer:

How It Works

  • Without residual: output = Layer(input)
  • With residual: output = input + Layer(input)
  • Key insight: The layer learns the "residual" (difference) rather than the full transformation

Why This Helps

  • Identity mapping: If Layer(input) = 0, output = input (information preserved!)
  • Gradient flow: Gradients can flow directly through the addition
  • Easier learning: Layer can learn small refinements rather than complete transformations

Layer Normalization

Layer normalization normalizes activations across features for each sample:

What It Does

  • Normalizes each token's features independently
  • Mean = 0, Variance = 1 for each token
  • Stabilizes training by keeping activations in a reasonable range

Layer Norm vs Batch Norm

  • Batch Norm: Normalizes across batch dimension (problematic for sequences of different lengths)
  • Layer Norm: Normalizes across feature dimension (works for any sequence length)
  • Why Layer Norm for transformers: Sequences have variable lengths, batch norm doesn't work well

Pre-Norm vs Post-Norm

Two common architectures:

Post-Norm (Original Transformer)

Structure: Sublayer → LayerNorm → Residual

  • x → Attention → LayerNorm → x + attention_output
  • Used in original Transformer paper

Pre-Norm (Modern Standard)

Structure: LayerNorm → Sublayer → Residual

  • x → LayerNorm → Attention → x + attention_output
  • More stable, easier to train
  • Used in modern models (GPT, BERT variants)

Mathematical Formulations

Residual Connection Formula

\[\text{output} = x + \text{Sublayer}(x)\]
Breaking Down:
  • x: Input to the layer
  • Sublayer(x): Output of attention or FFN
  • x + Sublayer(x): Residual connection (element-wise addition)
  • Key: If Sublayer(x) = 0, output = x (identity mapping preserved)

Layer Normalization Formula

\[\text{LayerNorm}(x) = \gamma \odot \left(\frac{x - \mu}{\sqrt{\sigma^2 + \varepsilon}}\right) + \beta\]
Step-by-Step:
  1. μ = mean(x): Compute mean across features
  2. σ² = var(x): Compute variance across features
  3. (x - μ) / √(σ² + ε): Normalize (mean=0, std=1)
  4. γ ⊙ ... + β: Scale and shift (learnable parameters)
Notation:
  • μ: Mean vector (computed per token)
  • σ²: Variance vector (computed per token)
  • ε: Small constant (e.g., 1e-5) to prevent division by zero
  • γ: Learnable scale parameter
  • β: Learnable shift parameter
  • ⊙: Element-wise multiplication

Complete Transformer Sublayer (Pre-Norm)

\[\text{output} = x + \text{Sublayer}(\text{LayerNorm}(x))\]
For Attention Sublayer:
  • x → LayerNorm(x) → MultiHeadAttention(...) → x + attention_output
For FFN Sublayer:
  • x → LayerNorm(x) → FFN(...) → x + ffn_output

Detailed Examples

Example: Residual Connection in Action

Let's see how residual connections preserve information:

Scenario: Deep Network (10 layers)

Without Residual:

  • Input: x = [1.0, 2.0, 3.0]
  • After Layer 1: [0.9, 1.8, 2.7] (slight change)
  • After Layer 2: [0.8, 1.6, 2.4] (more change)
  • ... (information degrades through layers)
  • After Layer 10: [0.1, 0.2, 0.3] (most information lost!)

With Residual:

  • Input: x = [1.0, 2.0, 3.0]
  • After Layer 1: x + Layer1(x) = [1.0, 2.0, 3.0] + [-0.1, -0.2, -0.3] = [0.9, 1.8, 2.7]
  • After Layer 2: previous + Layer2(...) = [0.9, 1.8, 2.7] + [-0.1, -0.2, -0.3] = [0.8, 1.6, 2.4]
  • ... (but original x is always accessible through the residual path!)
  • Even if layers learn nothing, output ≈ input (information preserved)

Example: Layer Normalization

Normalizing a token's features:

Input Token Features

Before normalization: x = [10.0, -5.0, 2.0, 8.0, -3.0]

Step 1: Compute mean

  • μ = (10.0 + (-5.0) + 2.0 + 8.0 + (-3.0)) / 5 = 2.4

Step 2: Compute variance

  • σ² = ((10.0-2.4)² + (-5.0-2.4)² + (2.0-2.4)² + (8.0-2.4)² + (-3.0-2.4)²) / 5
  • σ² = (57.76 + 54.76 + 0.16 + 31.36 + 29.16) / 5 = 34.64
  • σ = √34.64 ≈ 5.89

Step 3: Normalize

  • normalized = (x - μ) / σ
  • normalized = [(10.0-2.4)/5.89, (-5.0-2.4)/5.89, (2.0-2.4)/5.89, (8.0-2.4)/5.89, (-3.0-2.4)/5.89]
  • normalized ≈ [1.29, -1.26, -0.07, 0.95, -0.92]
  • Mean ≈ 0, Std ≈ 1 ✓

Step 4: Scale and shift (if learned)

  • If γ = [1.0, 1.0, 1.0, 1.0, 1.0] and β = [0.0, 0.0, 0.0, 0.0, 0.0]
  • Output = normalized (no change)
  • But γ and β are learned, so they can adjust the distribution

Implementation

Layer Normalization Implementation

import numpy as np

class LayerNormalization:
    """Layer Normalization for Transformers"""
    
    def __init__(self, d_model, eps=1e-5):
        """
        Initialize Layer Normalization
        
        Parameters:
        d_model: Model dimension
        eps: Small constant for numerical stability
        """
        self.d_model = d_model
        self.eps = eps
        
        # Learnable parameters
        self.gamma = np.ones(d_model)  # Scale parameter
        self.beta = np.zeros(d_model)  # Shift parameter
    
    def forward(self, x):
        """
        Forward pass through layer normalization
        
        Parameters:
        x: Input (batch, seq_len, d_model) or (seq_len, d_model)
        
        Returns:
        Normalized output (same shape as input)
        """
        # Handle different input shapes
        if len(x.shape) == 2:
            x = x.reshape(1, *x.shape)
            squeeze_output = True
        else:
            squeeze_output = False
        
        batch_size, seq_len, d_model = x.shape
        
        # Compute mean and variance across features (last dimension)
        # Shape: (batch, seq_len, 1)
        mean = np.mean(x, axis=-1, keepdims=True)
        variance = np.var(x, axis=-1, keepdims=True)
        
        # Normalize
        x_normalized = (x - mean) / np.sqrt(variance + self.eps)
        
        # Scale and shift
        output = self.gamma * x_normalized + self.beta
        
        if squeeze_output:
            output = output.squeeze(0)
        
        return output

# Example usage
d_model = 512
layer_norm = LayerNormalization(d_model)

# Input: (batch=2, seq_len=10, d_model=512)
x = np.random.randn(2, 10, 512) * 5  # Large values

# Normalize
output = layer_norm.forward(x)
print(f"Input mean: {np.mean(x):.2f}, std: {np.std(x):.2f}")
print(f"Output mean: {np.mean(output):.2f}, std: {np.std(output):.2f}")
# Output should have mean ≈ 0, std ≈ 1

Residual Connection Implementation

import numpy as np

def residual_connection(x, sublayer_output):
    """
    Apply residual connection
    
    Parameters:
    x: Input (batch, seq_len, d_model)
    sublayer_output: Output from attention or FFN (batch, seq_len, d_model)
    
    Returns:
    x + sublayer_output (element-wise addition)
    """
    return x + sublayer_output

# Complete transformer sublayer with pre-norm
def transformer_sublayer(x, sublayer_fn, layer_norm):
    """
    Transformer sublayer with pre-norm and residual
    
    Parameters:
    x: Input (batch, seq_len, d_model)
    sublayer_fn: Function for attention or FFN
    layer_norm: LayerNormalization instance
    
    Returns:
    Output with residual connection
    """
    # Pre-norm: normalize before sublayer
    x_norm = layer_norm.forward(x)
    
    # Apply sublayer (attention or FFN)
    sublayer_output = sublayer_fn(x_norm)
    
    # Residual connection
    output = residual_connection(x, sublayer_output)
    
    return output

# Example: Complete attention sublayer
def attention_sublayer(x, attention_fn, layer_norm):
    """Attention sublayer with pre-norm and residual"""
    x_norm = layer_norm.forward(x)
    attention_out = attention_fn(x_norm)
    return x + attention_out

# Example: Complete FFN sublayer
def ffn_sublayer(x, ffn_fn, layer_norm):
    """FFN sublayer with pre-norm and residual"""
    x_norm = layer_norm.forward(x)
    ffn_out = ffn_fn(x_norm)
    return x + ffn_out

Real-World Applications

Critical for All Deep Transformers

Residual connections and layer normalization are used in virtually every transformer model:

1. BERT (Bidirectional Encoder)

  • 12-24 layers, all use residuals and layer norm
  • Enables training deep bidirectional encoders
  • Without these, BERT couldn't train effectively

2. GPT Models

  • GPT-3: 96 layers! Impossible without residuals
  • Layer norm stabilizes training across all layers
  • Enables autoregressive generation at scale

3. Vision Transformers (ViT)

  • Apply transformers to images
  • Residuals and layer norm critical for deep vision models
  • Enable state-of-the-art image classification

Impact on Training

These components enable:

  • Deeper networks: 100+ layers vs 10-20 without residuals
  • Faster convergence: Layer norm stabilizes gradients
  • Better performance: Deeper = more capacity = better results
  • Stable training: Prevents gradient explosion/vanishing

Test Your Understanding

Question 1: What is the main purpose of residual connections?

A) To reduce parameters
B) To enable gradient flow and preserve information through deep networks
C) To add non-linearity
D) To normalize activations

Question 2: What does layer normalization normalize across?

A) Batch dimension
B) Feature dimension (for each token independently)
C) Sequence dimension
D) All dimensions

Question 3: In pre-norm architecture, where is layer normalization applied?

A) Before the sublayer (attention or FFN)
B) After the sublayer
C) Both before and after
D) Not used in pre-norm

Question 4: What is the formula for layer normalization?

A) \(LayerNorm(x) = \gamma \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta\) where μ and σ are mean and std across features, γ and β are learnable parameters
B) \(LayerNorm(x) = x\)
C) \(LayerNorm(x) = x - \mu\)
D) \(LayerNorm(x) = \mu\)

Question 5: Why are residual connections important in transformers?

A) Residual connections allow gradients to flow directly through layers, preventing vanishing gradients, enabling training of very deep networks, and allowing layers to learn residual mappings
B) They reduce parameters
C) They make it faster
D) They're not important

Question 6: What is the difference between pre-norm and post-norm in transformers?

A) Pre-norm applies layer norm before sublayers (more stable, common in modern transformers), post-norm applies after sublayers (original Transformer). Pre-norm generally trains better for deep networks
B) They're the same
C) Post-norm is always better
D) Pre-norm is never used

Question 7: How does layer normalization differ from batch normalization?

A) Layer norm normalizes across features for each sample independently, while batch norm normalizes across batch dimension. Layer norm works better for sequences and variable batch sizes
B) They're identical
C) Batch norm is always better
D) Layer norm normalizes batches

Question 8: What happens if you remove residual connections from transformers?

A) Gradients vanish in deep networks, training becomes unstable, and the model struggles to learn effectively, especially with many layers
B) Nothing changes
C) It improves
D) It becomes faster

Question 9: How would you implement residual connection and layer norm?

A) For pre-norm: x_norm = LayerNorm(x), output = x + Sublayer(x_norm). For post-norm: output = LayerNorm(x + Sublayer(x)). Compute mean and std, normalize, scale and shift with learnable parameters
B) Just add x
C) Just normalize
D) Random operations

Question 10: Why is epsilon added in layer normalization formula?

A) Epsilon (small constant like 1e-5) prevents division by zero when variance is very small, ensuring numerical stability during computation
B) To increase variance
C) To decrease variance
D) It's not needed

Question 11: How do residual connections help with gradient flow?

A) Residual connections create direct paths for gradients to flow backward, allowing gradients to bypass layers and reach earlier layers without being multiplied many times, preventing vanishing gradients
B) They block gradients
C) They don't affect gradients
D) They only help forward pass

Question 12: What is the complete transformer block structure with residual and norm?

A) Pre-norm: x = x + Attention(LayerNorm(x)), x = x + FFN(LayerNorm(x)). Post-norm: x = LayerNorm(x + Attention(x)), x = LayerNorm(x + FFN(x)). Both use residual connections
B) No residual connections
C) No layer norm
D) Sequential without connections