Chapter 5: Feed-Forward Networks

Processing After Attention

Learning Objectives

  • Understand feed-forward networks fundamentals
  • Master the mathematical foundations
  • Learn practical implementation
  • Apply knowledge through examples
  • Recognize real-world applications

Feed-Forward Networks

What is a Feed-Forward Network in Transformers?

Feed-Forward Networks (FFNs) are two-layer neural networks applied independently to each position after attention. They process the information gathered by attention, adding non-linearity and capacity to learn complex transformations.

Think of FFN like a post-processing step:

  • Attention: Gathers relevant information from other positions (like collecting ingredients)
  • FFN: Processes and transforms that information (like cooking the ingredients into a dish)
  • Key: Each position processes its information independently, but uses the same transformation

Why FFN After Attention?

Attention and FFN serve different purposes:

Attention's Role
  • Mixes information BETWEEN positions
  • Decides "what information to gather"
  • Like asking: "Which other words are relevant?"
  • Result: Each position has a weighted combination of all positions
FFN's Role
  • Processes information WITHIN each position
  • Decides "how to transform the gathered information"
  • Like asking: "What should I do with this information?"
  • Result: Each position gets a non-linear transformation

Together: Attention gathers context, FFN processes it!

📚 Real-World Analogy: The Research Process

Imagine writing a research paper:

  1. Attention (Gathering): You read multiple sources, identify relevant information from each
  2. FFN (Processing): You synthesize, analyze, and transform that information into your own understanding
  3. Result: You have both the gathered context (attention) and your processed understanding (FFN)

Key Concepts

🔑 FFN Architecture Components

FFN consists of three main components:

1. First Linear Layer (Expansion)

  • Input: d_model dimensions (e.g., 512)
  • Output: d_ff dimensions (e.g., 2048) - typically 4× expansion
  • Purpose: Expands the representation space
  • Why expand? More dimensions = more capacity to learn complex patterns

2. Activation Function (ReLU)

  • Function: ReLU(x) = max(0, x)
  • Purpose: Introduces non-linearity
  • Why ReLU? Simple, fast, helps with gradient flow
  • Effect: Allows the network to learn non-linear transformations

3. Second Linear Layer (Projection)

  • Input: d_ff dimensions (e.g., 2048)
  • Output: d_model dimensions (e.g., 512)
  • Purpose: Projects back to original dimension
  • Why project back? Maintains consistent dimensions for residual connection

The 4× Expansion Rule

Why do we expand by 4×?

Common Configurations

  • BERT-base: d_model=768, d_ff=3072 (4×)
  • GPT-3: d_model=12288, d_ff=49152 (4×)
  • Original Transformer: d_model=512, d_ff=2048 (4×)

Why 4×?

  • Empirically found to work well
  • Balance between capacity and efficiency
  • Too small (2×): Limited capacity
  • Too large (8×): More parameters, diminishing returns

Position-Wise Processing

FFN processes each position independently:

  • Same weights applied to all positions
  • Like using the same "processing function" for every word
  • Efficient: Can process all positions in parallel
  • Key difference from attention: No interaction between positions

Mathematical Formulations

Feed-Forward Network Formula

\[\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2\]
Breaking Down the Formula:
  1. xW₁ + b₁: First linear transformation (expansion)
    • x: Input vector (d_model dimensions)
    • W₁: Weight matrix (d_model × d_ff)
    • b₁: Bias vector (d_ff dimensions)
    • Result: Expanded representation (d_ff dimensions)
  2. ReLU(...): Activation function
    • Applies ReLU element-wise
    • ReLU(x) = max(0, x)
    • Introduces non-linearity
  3. ...W₂ + b₂: Second linear transformation (projection)
    • W₂: Weight matrix (d_ff × d_model)
    • b₂: Bias vector (d_model dimensions)
    • Result: Projected back to d_model dimensions

Dimension Flow

\[\text{Input: } (batch, seq\_len, d_{model})\] \[\rightarrow \text{Expand: } (batch, seq\_len, d_{ff})\] \[\rightarrow \text{ReLU: } (batch, seq\_len, d_{ff})\] \[\rightarrow \text{Project: } (batch, seq\_len, d_{model})\]
Example with Numbers:
  • Input: (2, 10, 512) - batch=2, seq_len=10, d_model=512
  • After W₁: (2, 10, 2048) - expanded to d_ff=2048
  • After ReLU: (2, 10, 2048) - same shape, non-linear
  • After W₂: (2, 10, 512) - projected back to d_model=512

Parameter Count

\[\text{Parameters} = (d_{model} \times d_{ff}) + d_{ff} + (d_{ff} \times d_{model}) + d_{model}\] \[= 2 \times d_{model} \times d_{ff} + d_{model} + d_{ff}\]
Example Calculation:

For d_model=512, d_ff=2048:

  • W₁: 512 × 2048 = 1,048,576 parameters
  • b₁: 2048 parameters
  • W₂: 2048 × 512 = 1,048,576 parameters
  • b₂: 512 parameters
  • Total: 2,099,712 parameters per FFN

Note: This is typically the largest component in transformer layers!

Detailed Examples

Step-by-Step Example: Processing a Word

Let's trace through how FFN processes the word "cat" in a sentence:

Step 1: Input from Attention
  • After attention, "cat" has a representation: [0.2, -0.5, 0.8, ..., 0.1] (512 dimensions)
  • This representation includes information from other words (via attention)
  • Now FFN will process this information
Step 2: First Linear Transformation (Expansion)
  • Input: [0.2, -0.5, 0.8, ..., 0.1] (512 dims)
  • Multiply by W₁ (512 × 2048): Matrix multiplication
  • Add bias b₁: [0.1, 0.3, -0.2, ..., 0.5] (2048 dims)
  • Result: Expanded representation with more dimensions
Step 3: ReLU Activation
  • Apply ReLU element-wise: ReLU(x) = max(0, x)
  • Positive values stay, negative values become 0
  • Example: [0.1, 0.3, -0.2, 0.5] → [0.1, 0.3, 0.0, 0.5]
  • This introduces non-linearity
Step 4: Second Linear Transformation (Projection)
  • Input: [0.1, 0.3, 0.0, ..., 0.5] (2048 dims)
  • Multiply by W₂ (2048 × 512): Matrix multiplication
  • Add bias b₂: [0.15, -0.3, 0.6, ..., 0.2] (512 dims)
  • Result: Processed representation back to original dimension
Step 5: Final Output
  • The output is a transformed version of the input
  • It has learned to process the information gathered by attention
  • Ready to be added to the residual connection

What Does FFN Learn?

FFN learns to transform representations:

Example Transformations:
  • Noun → Verb: "cat" (noun) → "catting" (verb form concept)
  • Singular → Plural: "cat" → "cats" (plural concept)
  • Base → Derived: "happy" → "happiness" (derived form)
  • Semantic Relations: "king" → "royalty" (related concepts)

Key: FFN learns these transformations through training!

Implementation

Feed-Forward Network Implementation

import numpy as np

class FeedForwardNetwork:
    """Feed-Forward Network for Transformer"""
    
    def __init__(self, d_model, d_ff):
        """
        Initialize FFN
        
        Parameters:
        d_model: Model dimension (e.g., 512)
        d_ff: Feed-forward dimension (e.g., 2048, typically 4× d_model)
        """
        self.d_model = d_model
        self.d_ff = d_ff
        
        # Initialize weights with Xavier/Glorot initialization
        # W1: (d_model, d_ff)
        limit = np.sqrt(6.0 / (d_model + d_ff))
        self.W1 = np.random.uniform(-limit, limit, (d_model, d_ff))
        self.b1 = np.zeros(d_ff)
        
        # W2: (d_ff, d_model)
        limit = np.sqrt(6.0 / (d_ff + d_model))
        self.W2 = np.random.uniform(-limit, limit, (d_ff, d_model))
        self.b2 = np.zeros(d_model)
    
    def relu(self, x):
        """ReLU activation function"""
        return np.maximum(0, x)
    
    def forward(self, x):
        """
        Forward pass through FFN
        
        Parameters:
        x: Input (batch, seq_len, d_model) or (seq_len, d_model)
        
        Returns:
        Output: (batch, seq_len, d_model) or (seq_len, d_model)
        """
        # Handle different input shapes
        if len(x.shape) == 2:
            # (seq_len, d_model)
            batch_size = 1
            seq_len, d_model = x.shape
            x = x.reshape(1, seq_len, d_model)
        else:
            batch_size, seq_len, d_model = x.shape
        
        # Reshape for matrix multiplication: (batch * seq_len, d_model)
        x_reshaped = x.reshape(-1, d_model)
        
        # First linear layer: (batch * seq_len, d_model) × (d_model, d_ff)
        # = (batch * seq_len, d_ff)
        z1 = np.dot(x_reshaped, self.W1) + self.b1
        
        # ReLU activation
        a1 = self.relu(z1)
        
        # Second linear layer: (batch * seq_len, d_ff) × (d_ff, d_model)
        # = (batch * seq_len, d_model)
        z2 = np.dot(a1, self.W2) + self.b2
        
        # Reshape back: (batch, seq_len, d_model)
        output = z2.reshape(batch_size, seq_len, d_model)
        
        # Remove batch dimension if input didn't have it
        if len(x.shape) == 2:
            output = output.squeeze(0)
        
        return output

# Example usage
d_model, d_ff = 512, 2048
ffn = FeedForwardNetwork(d_model, d_ff)

# Input: (batch=2, seq_len=10, d_model=512)
x = np.random.randn(2, 10, 512)

# Forward pass
output = ffn.forward(x)
print(f"Input shape: {x.shape}")  # (2, 10, 512)
print(f"Output shape: {output.shape}")  # (2, 10, 512)
print(f"Total parameters: {np.prod(ffn.W1.shape) + np.prod(ffn.b1.shape) + np.prod(ffn.W2.shape) + np.prod(ffn.b2.shape)}")
# Output: 2,099,712 parameters

FFN in Transformer Layer Context

import numpy as np

def transformer_ffn_layer(x, W1, b1, W2, b2):
    """
    FFN as part of complete transformer layer
    
    Parameters:
    x: Input after attention (batch, seq_len, d_model)
    W1, b1: First linear layer weights and bias
    W2, b2: Second linear layer weights and bias
    """
    # First linear transformation
    expanded = np.dot(x, W1) + b1  # (batch, seq_len, d_ff)
    
    # ReLU activation
    activated = np.maximum(0, expanded)  # (batch, seq_len, d_ff)
    
    # Second linear transformation
    output = np.dot(activated, W2) + b2  # (batch, seq_len, d_model)
    
    return output

# Complete transformer layer with residual connection
def transformer_layer(x, attention_output, W1, b1, W2, b2):
    """
    Complete transformer layer: Attention + FFN with residuals
    """
    # After attention (with residual)
    x_after_attention = x + attention_output
    
    # FFN
    ffn_output = transformer_ffn_layer(x_after_attention, W1, b1, W2, b2)
    
    # Residual connection
    output = x_after_attention + ffn_output
    
    return output

Real-World Applications

Where FFN is Critical

FFN is a core component in all transformer-based models:

1. Language Models (GPT, BERT)

  • Processes contextualized word representations
  • Learns semantic transformations
  • Critical for understanding and generation

2. Machine Translation

  • Transforms source language representations
  • Prepares information for target language generation
  • Learns cross-lingual patterns

3. Text Classification

  • Processes document representations
  • Learns task-specific transformations
  • Prepares features for classification

4. Question Answering

  • Processes question and context representations
  • Learns to extract relevant information
  • Prepares for answer generation

FFN vs Other Components

Parameter distribution in transformer layers:

  • FFN: ~70% of parameters (largest component!)
  • Attention: ~20% of parameters
  • Layer Norm: ~5% of parameters
  • Embeddings: ~5% of parameters

Why so many parameters? FFN needs capacity to learn complex transformations!

Test Your Understanding

Question 1: What is the typical expansion factor for FFN hidden dimension?

A) 2×
B) 4×
C) 8×
D) Same as input

Question 2: What is the purpose of the second linear layer in FFN?

A) To expand dimensions
B) To project back to original dimension
C) To add non-linearity
D) To gather information

Question 3: How does FFN differ from attention in transformers?

A) FFN processes each position independently, attention mixes between positions
B) FFN mixes between positions, attention processes independently
C) They are the same
D) FFN is not used in transformers

Question 4: What is the formula for a feed-forward network in transformers?

A) \(FFN(x) = ReLU(xW_1 + b_1)W_2 + b_2\) where W_1 expands dimensions, W_2 projects back, and ReLU adds non-linearity
B) \(FFN(x) = x\)
C) \(FFN(x) = xW\)
D) \(FFN(x) = W\)

Question 5: Why does FFN use ReLU activation?

A) ReLU introduces non-linearity, is computationally efficient, helps with gradient flow, and allows the network to learn complex patterns by activating different neurons for different inputs
B) It's the only option
C) It's faster than other activations
D) No reason

Question 6: What is the typical size of FFN hidden layer in transformers?

A) Usually 4× the embedding dimension (d_ff = 4 × d_model), so for d_model=768, hidden size is 3072
B) Same as embedding
C) Half of embedding
D) Always 1024

Question 7: Why does FFN come after attention in transformer layers?

A) Attention gathers relevant information from other positions, then FFN processes and transforms that information, allowing the model to learn complex feature combinations
B) Order doesn't matter
C) FFN should come first
D) They're parallel

Question 8: How would you implement FFN from scratch?

A) Define weight matrices W1 (d_model × d_ff) and W2 (d_ff × d_model), biases b1 and b2. Compute: hidden = ReLU(xW1 + b1), output = hiddenW2 + b2. Apply layer norm and residual connection
B) Just return input
C) Use only W1
D) Random operations

Question 9: What happens if you remove FFN from transformer layers?

A) The model loses capacity to learn complex transformations, becomes essentially just attention operations, and performance degrades significantly
B) Nothing changes
C) It improves
D) It becomes faster

Question 10: How does FFN contribute to transformer model capacity?

A) FFN contains most of the parameters in transformers (often 2/3 of total), providing the model with capacity to learn complex non-linear transformations and feature combinations
B) It has no parameters
C) It has few parameters
D) Same as attention

Question 11: Why is the FFN pattern "expand then contract" (wide hidden layer)?

A) The wide hidden layer creates a bottleneck-free transformation space, allowing the model to learn complex patterns without information loss, then projects back to original dimension for next layer
B) To reduce computation
C) To save memory
D) No reason

Question 12: What is the purpose of the expansion in FFN?

A) To increase model capacity and allow learning complex transformations by providing more parameters and a wider representation space
B) To reduce parameters
C) To make it faster
D) To reduce memory