Chapter 5: Feed-Forward Networks
Processing After Attention
Learning Objectives
- Understand feed-forward networks fundamentals
- Master the mathematical foundations
- Learn practical implementation
- Apply knowledge through examples
- Recognize real-world applications
Feed-Forward Networks
What is a Feed-Forward Network in Transformers?
Feed-Forward Networks (FFNs) are two-layer neural networks applied independently to each position after attention. They process the information gathered by attention, adding non-linearity and capacity to learn complex transformations.
Think of FFN like a post-processing step:
- Attention: Gathers relevant information from other positions (like collecting ingredients)
- FFN: Processes and transforms that information (like cooking the ingredients into a dish)
- Key: Each position processes its information independently, but uses the same transformation
Why FFN After Attention?
Attention and FFN serve different purposes:
Attention's Role
- Mixes information BETWEEN positions
- Decides "what information to gather"
- Like asking: "Which other words are relevant?"
- Result: Each position has a weighted combination of all positions
FFN's Role
- Processes information WITHIN each position
- Decides "how to transform the gathered information"
- Like asking: "What should I do with this information?"
- Result: Each position gets a non-linear transformation
Together: Attention gathers context, FFN processes it!
📚 Real-World Analogy: The Research Process
Imagine writing a research paper:
- Attention (Gathering): You read multiple sources, identify relevant information from each
- FFN (Processing): You synthesize, analyze, and transform that information into your own understanding
- Result: You have both the gathered context (attention) and your processed understanding (FFN)
Key Concepts
🔑 FFN Architecture Components
FFN consists of three main components:
1. First Linear Layer (Expansion)
- Input: d_model dimensions (e.g., 512)
- Output: d_ff dimensions (e.g., 2048) - typically 4× expansion
- Purpose: Expands the representation space
- Why expand? More dimensions = more capacity to learn complex patterns
2. Activation Function (ReLU)
- Function: ReLU(x) = max(0, x)
- Purpose: Introduces non-linearity
- Why ReLU? Simple, fast, helps with gradient flow
- Effect: Allows the network to learn non-linear transformations
3. Second Linear Layer (Projection)
- Input: d_ff dimensions (e.g., 2048)
- Output: d_model dimensions (e.g., 512)
- Purpose: Projects back to original dimension
- Why project back? Maintains consistent dimensions for residual connection
The 4× Expansion Rule
Why do we expand by 4×?
Common Configurations
- BERT-base: d_model=768, d_ff=3072 (4×)
- GPT-3: d_model=12288, d_ff=49152 (4×)
- Original Transformer: d_model=512, d_ff=2048 (4×)
Why 4×?
- Empirically found to work well
- Balance between capacity and efficiency
- Too small (2×): Limited capacity
- Too large (8×): More parameters, diminishing returns
Position-Wise Processing
FFN processes each position independently:
- Same weights applied to all positions
- Like using the same "processing function" for every word
- Efficient: Can process all positions in parallel
- Key difference from attention: No interaction between positions
Mathematical Formulations
Feed-Forward Network Formula
Breaking Down the Formula:
- xW₁ + b₁: First linear transformation (expansion)
- x: Input vector (d_model dimensions)
- W₁: Weight matrix (d_model × d_ff)
- b₁: Bias vector (d_ff dimensions)
- Result: Expanded representation (d_ff dimensions)
- ReLU(...): Activation function
- Applies ReLU element-wise
- ReLU(x) = max(0, x)
- Introduces non-linearity
- ...W₂ + b₂: Second linear transformation (projection)
- W₂: Weight matrix (d_ff × d_model)
- b₂: Bias vector (d_model dimensions)
- Result: Projected back to d_model dimensions
Dimension Flow
Example with Numbers:
- Input: (2, 10, 512) - batch=2, seq_len=10, d_model=512
- After W₁: (2, 10, 2048) - expanded to d_ff=2048
- After ReLU: (2, 10, 2048) - same shape, non-linear
- After W₂: (2, 10, 512) - projected back to d_model=512
Parameter Count
Example Calculation:
For d_model=512, d_ff=2048:
- W₁: 512 × 2048 = 1,048,576 parameters
- b₁: 2048 parameters
- W₂: 2048 × 512 = 1,048,576 parameters
- b₂: 512 parameters
- Total: 2,099,712 parameters per FFN
Note: This is typically the largest component in transformer layers!
Detailed Examples
Step-by-Step Example: Processing a Word
Let's trace through how FFN processes the word "cat" in a sentence:
Step 1: Input from Attention
- After attention, "cat" has a representation: [0.2, -0.5, 0.8, ..., 0.1] (512 dimensions)
- This representation includes information from other words (via attention)
- Now FFN will process this information
Step 2: First Linear Transformation (Expansion)
- Input: [0.2, -0.5, 0.8, ..., 0.1] (512 dims)
- Multiply by W₁ (512 × 2048): Matrix multiplication
- Add bias b₁: [0.1, 0.3, -0.2, ..., 0.5] (2048 dims)
- Result: Expanded representation with more dimensions
Step 3: ReLU Activation
- Apply ReLU element-wise: ReLU(x) = max(0, x)
- Positive values stay, negative values become 0
- Example: [0.1, 0.3, -0.2, 0.5] → [0.1, 0.3, 0.0, 0.5]
- This introduces non-linearity
Step 4: Second Linear Transformation (Projection)
- Input: [0.1, 0.3, 0.0, ..., 0.5] (2048 dims)
- Multiply by W₂ (2048 × 512): Matrix multiplication
- Add bias b₂: [0.15, -0.3, 0.6, ..., 0.2] (512 dims)
- Result: Processed representation back to original dimension
Step 5: Final Output
- The output is a transformed version of the input
- It has learned to process the information gathered by attention
- Ready to be added to the residual connection
What Does FFN Learn?
FFN learns to transform representations:
Example Transformations:
- Noun → Verb: "cat" (noun) → "catting" (verb form concept)
- Singular → Plural: "cat" → "cats" (plural concept)
- Base → Derived: "happy" → "happiness" (derived form)
- Semantic Relations: "king" → "royalty" (related concepts)
Key: FFN learns these transformations through training!
Implementation
Feed-Forward Network Implementation
import numpy as np
class FeedForwardNetwork:
"""Feed-Forward Network for Transformer"""
def __init__(self, d_model, d_ff):
"""
Initialize FFN
Parameters:
d_model: Model dimension (e.g., 512)
d_ff: Feed-forward dimension (e.g., 2048, typically 4× d_model)
"""
self.d_model = d_model
self.d_ff = d_ff
# Initialize weights with Xavier/Glorot initialization
# W1: (d_model, d_ff)
limit = np.sqrt(6.0 / (d_model + d_ff))
self.W1 = np.random.uniform(-limit, limit, (d_model, d_ff))
self.b1 = np.zeros(d_ff)
# W2: (d_ff, d_model)
limit = np.sqrt(6.0 / (d_ff + d_model))
self.W2 = np.random.uniform(-limit, limit, (d_ff, d_model))
self.b2 = np.zeros(d_model)
def relu(self, x):
"""ReLU activation function"""
return np.maximum(0, x)
def forward(self, x):
"""
Forward pass through FFN
Parameters:
x: Input (batch, seq_len, d_model) or (seq_len, d_model)
Returns:
Output: (batch, seq_len, d_model) or (seq_len, d_model)
"""
# Handle different input shapes
if len(x.shape) == 2:
# (seq_len, d_model)
batch_size = 1
seq_len, d_model = x.shape
x = x.reshape(1, seq_len, d_model)
else:
batch_size, seq_len, d_model = x.shape
# Reshape for matrix multiplication: (batch * seq_len, d_model)
x_reshaped = x.reshape(-1, d_model)
# First linear layer: (batch * seq_len, d_model) × (d_model, d_ff)
# = (batch * seq_len, d_ff)
z1 = np.dot(x_reshaped, self.W1) + self.b1
# ReLU activation
a1 = self.relu(z1)
# Second linear layer: (batch * seq_len, d_ff) × (d_ff, d_model)
# = (batch * seq_len, d_model)
z2 = np.dot(a1, self.W2) + self.b2
# Reshape back: (batch, seq_len, d_model)
output = z2.reshape(batch_size, seq_len, d_model)
# Remove batch dimension if input didn't have it
if len(x.shape) == 2:
output = output.squeeze(0)
return output
# Example usage
d_model, d_ff = 512, 2048
ffn = FeedForwardNetwork(d_model, d_ff)
# Input: (batch=2, seq_len=10, d_model=512)
x = np.random.randn(2, 10, 512)
# Forward pass
output = ffn.forward(x)
print(f"Input shape: {x.shape}") # (2, 10, 512)
print(f"Output shape: {output.shape}") # (2, 10, 512)
print(f"Total parameters: {np.prod(ffn.W1.shape) + np.prod(ffn.b1.shape) + np.prod(ffn.W2.shape) + np.prod(ffn.b2.shape)}")
# Output: 2,099,712 parameters
FFN in Transformer Layer Context
import numpy as np
def transformer_ffn_layer(x, W1, b1, W2, b2):
"""
FFN as part of complete transformer layer
Parameters:
x: Input after attention (batch, seq_len, d_model)
W1, b1: First linear layer weights and bias
W2, b2: Second linear layer weights and bias
"""
# First linear transformation
expanded = np.dot(x, W1) + b1 # (batch, seq_len, d_ff)
# ReLU activation
activated = np.maximum(0, expanded) # (batch, seq_len, d_ff)
# Second linear transformation
output = np.dot(activated, W2) + b2 # (batch, seq_len, d_model)
return output
# Complete transformer layer with residual connection
def transformer_layer(x, attention_output, W1, b1, W2, b2):
"""
Complete transformer layer: Attention + FFN with residuals
"""
# After attention (with residual)
x_after_attention = x + attention_output
# FFN
ffn_output = transformer_ffn_layer(x_after_attention, W1, b1, W2, b2)
# Residual connection
output = x_after_attention + ffn_output
return output
Real-World Applications
Where FFN is Critical
FFN is a core component in all transformer-based models:
1. Language Models (GPT, BERT)
- Processes contextualized word representations
- Learns semantic transformations
- Critical for understanding and generation
2. Machine Translation
- Transforms source language representations
- Prepares information for target language generation
- Learns cross-lingual patterns
3. Text Classification
- Processes document representations
- Learns task-specific transformations
- Prepares features for classification
4. Question Answering
- Processes question and context representations
- Learns to extract relevant information
- Prepares for answer generation
FFN vs Other Components
Parameter distribution in transformer layers:
- FFN: ~70% of parameters (largest component!)
- Attention: ~20% of parameters
- Layer Norm: ~5% of parameters
- Embeddings: ~5% of parameters
Why so many parameters? FFN needs capacity to learn complex transformations!