Chapter 9: Complete Transformer Architecture

Putting It All Together

Learning Objectives

  • Understand complete transformer architecture fundamentals
  • Master the mathematical foundations
  • Learn practical implementation
  • Apply knowledge through examples
  • Recognize real-world applications

Complete Transformer Architecture

The Complete Picture

The Transformer combines all components we've learned: embeddings, positional encoding, multi-head attention, FFN, residuals, and layer normalization into a powerful architecture.

šŸ—ļø Complete Transformer Architecture Diagram

Input Tokens

["The", "cat", "sat"]

↓

Token Embeddings

+ Positional Encoding

↓

Encoder Stack (N layers)

Encoder Layer 1:

Multi-Head
Attention

Feed-Forward
Network

... (N-2 more layers) ...

Encoder Layer N:

Multi-Head
Attention

Feed-Forward
Network

↓

Output Representations

Rich, contextualized

šŸ’” Key Components: Each encoder layer has Multi-Head Attention + FFN, both with residual connections and layer normalization. The stack processes input through N layers, building increasingly rich representations!

Information Flow Through Transformer

Complete data flow:

  1. Input: Tokenized text → ["The", "cat", "sat"]
  2. Embeddings: Each token → dense vector (512 dimensions)
  3. Positional Encoding: Add position information
  4. Encoder Layer 1: Multi-head attention + FFN → Refined representations
  5. Encoder Layer 2-N: Further refinement through each layer
  6. Output: Rich, contextualized representations ready for tasks

Key Concepts

Complete Transformer Architecture

The full transformer combines all components:

  • Input Processing: Tokenization → Embeddings → Positional Encoding
  • Encoder Stack: Multiple encoder layers (self-attention + FFN)
  • Decoder Stack: Multiple decoder layers (masked self-attention + cross-attention + FFN)
  • Output Generation: Linear projection → Softmax → Token prediction

Information Flow Through Layers

Early layers: Capture local patterns, syntax, word-level relationships

Middle layers: Build phrase-level understanding, semantic relationships

Deep layers: Develop high-level abstractions, task-specific features

Each layer refines and abstracts the representation from the previous layer.

Training Challenges

Key challenges in training transformers:

  • Memory: Large models require significant GPU memory
  • Compute: Training takes weeks to months on many GPUs
  • Data: Need massive, high-quality datasets
  • Hyperparameters: Learning rate, warmup, batch size all critical
  • Stability: Deep networks prone to gradient issues

Mathematical Formulations

Complete Transformer Forward Pass

\[\text{Transformer}(X) = \text{Decoder}(\text{Encoder}(X_{\text{enc}}), X_{\text{dec}})\]
Encoder Stack:
  • \(X_{\text{enc}} = \text{Embed}(X_{\text{enc}}) + PE\)
  • \(H = \text{EncoderLayer}_N(\ldots\text{EncoderLayer}_1(X_{\text{enc}}))\)
Decoder Stack:
  • \(X_{\text{dec}} = \text{Embed}(X_{\text{dec}}) + PE\)
  • \(O = \text{DecoderLayer}_N(\ldots\text{DecoderLayer}_1(X_{\text{dec}}, H))\)
Output:
  • \(\text{Output} = \text{Softmax}(O \cdot W_{\text{out}})\)

Training Loss

\[L = -\sum_{i=1}^{N} \log P(y_i | x_i, \theta)\]
Where:
  • \(N\): Number of training examples
  • \(y_i\): Target token at position i
  • \(x_i\): Input context up to position i
  • \(\theta\): Model parameters
  • \(P(y_i | x_i, \theta)\): Model's predicted probability

This is the standard cross-entropy loss for language modeling.

Learning Rate Schedule

\[\text{lr}(t) = \begin{cases} \text{lr}_{\text{max}} \times \frac{t}{T_{\text{warmup}}} & \text{if } t < T_{\text{warmup}} \\ \text{lr}_{\text{max}} & \text{if } T_{\text{warmup}} \leq t < T_{\text{decay}} \\ \text{lr}_{\text{max}} \times \frac{T_{\text{total}} - t}{T_{\text{total}} - T_{\text{decay}}} & \text{if } t \geq T_{\text{decay}} \end{cases}\]
Phases:
  • Warmup: Gradually increase from 0 to max learning rate
  • Constant: Maintain max learning rate
  • Decay: Gradually decrease learning rate

Detailed Examples

Example: Complete Transformer Processing

Task: Translate "Hello world" to French

Step 1: Input Processing

  • Encoder input: "Hello world" → Token IDs → Embeddings → + Positional Encoding
  • Decoder input: [START] → Token ID → Embedding → + Positional Encoding

Step 2: Encoder Processing

  • Layer 1: Self-attention captures relationships between "Hello" and "world"
  • Layer 2-6: Further refinement of representations
  • Output: Rich contextualized representation of input

Step 3: Decoder Processing

  • Masked self-attention: [START] attends to itself
  • Cross-attention: [START] attends to encoder output
  • FFN: Processes the combined information
  • Output: Probability distribution over French vocabulary

Step 4: Generation

  • Sample "Bonjour" (highest probability)
  • Add to decoder input: [START, "Bonjour"]
  • Repeat until [END] token or max length

Example: Training Setup

Typical configuration for large transformer:

  • Model size: 12 layers, 768 dimensions, 12 heads
  • Batch size: 256 (with gradient accumulation)
  • Learning rate: 1e-4 with warmup to 1e-3
  • Optimizer: AdamW with weight decay 0.01
  • Training time: 1-2 weeks on 8 GPUs
  • Data: Millions of sentence pairs

Implementation

Complete Transformer Training Loop

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR

def train_transformer(model, train_loader, num_epochs, warmup_steps, total_steps):
    """
    Training loop for transformer model
    """
    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    
    # Learning rate schedule with warmup
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps  # Warmup
        else:
            return (total_steps - step) / (total_steps - warmup_steps)  # Decay
    
    scheduler = LambdaLR(optimizer, lr_lambda)
    criterion = nn.CrossEntropyLoss(ignore_index=-1)
    
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, (src, tgt) in enumerate(train_loader):
            # Forward pass
            output = model(src, tgt[:, :-1])  # Exclude last token
            target = tgt[:, 1:]  # Shift by one for next token prediction
            
            # Compute loss
            loss = criterion(output.reshape(-1, output.size(-1)), target.reshape(-1))
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        print(f'Epoch {epoch} average loss: {total_loss / len(train_loader):.4f}')

# Example usage
# model = TransformerModel(vocab_size=50000, d_model=512, nhead=8, num_layers=6)
# train_transformer(model, train_loader, num_epochs=10, warmup_steps=4000, total_steps=100000)

Real-World Applications

Complete Transformer Applications

Encoder-decoder transformers are used for:

  • Machine Translation: Google Translate, DeepL use transformer architectures
  • Text Summarization: Generating concise summaries from long documents
  • Question Answering: Systems that answer questions based on context
  • Dialogue Systems: Conversational AI that maintains context

Training Infrastructure

Large-scale training requires:

  • Distributed Training: Multiple GPUs/TPUs working together
  • Data Pipeline: Efficient data loading and preprocessing
  • Monitoring: Track loss, learning rate, gradient norms
  • Checkpointing: Save model state regularly
  • Mixed Precision: Use float16 for speed

Production Deployment

Deploying trained transformers:

  • Model Optimization: Quantization, pruning for efficiency
  • Inference Optimization: Batch processing, caching
  • Scalability: Handle multiple concurrent requests
  • Monitoring: Track latency, throughput, accuracy

Test Your Understanding

Question 1: What are the key challenges in training transformers?

A) Large model size requiring significant memory, long training times, need for large datasets, careful hyperparameter tuning, and managing computational costs
B) No challenges
C) Only small datasets needed
D) Very fast training

Question 2: What is the typical learning rate schedule for transformers?

A) Warmup phase (gradually increase), then constant or decay. Warmup helps stabilize training early on, then learning rate may decay to fine-tune convergence
B) Always constant
C) Always decreasing
D) Random schedule

Question 3: How do you handle memory constraints when training large transformers?

A) Use gradient checkpointing (recompute activations), mixed precision training, smaller batch sizes, model parallelism, or gradient accumulation to simulate larger batches
B) Just use larger batches
C) Remove layers
D) No solutions

Question 4: What is gradient accumulation and why use it?

A) Gradient accumulation computes gradients over multiple small batches before updating weights, allowing you to simulate larger batch sizes when memory is limited, improving training stability
B) It reduces gradients
C) It's not useful
D) Only for small models

Question 5: What optimizer is commonly used for training transformers?

A) Adam or AdamW (Adam with weight decay) are most common, as they adapt learning rates per parameter and work well for large models. Some use SGD with momentum for fine-tuning
B) Only SGD
C) Only RMSprop
D) Random optimizer

Question 6: What is mixed precision training?

A) Using both float16 and float32 precision - forward pass and gradients in float16 for speed and memory savings, but maintaining float32 master weights for numerical stability
B) Using only float16
C) Using only float32
D) Random precision

Question 7: How do you choose batch size for transformer training?

A) Balance between memory constraints and training stability. Larger batches provide more stable gradients but require more memory. Use gradient accumulation if needed. Typical sizes range from 16 to 512 depending on model size
B) Always use largest possible
C) Always use smallest
D) Doesn't matter

Question 8: What is the purpose of learning rate warmup?

A) Warmup gradually increases learning rate from small value to target, preventing large gradient updates early in training that could destabilize the model, especially important for large models
B) To decrease learning rate
C) Not needed
D) Only for small models

Question 9: How do you monitor transformer training?

A) Track loss curves (training and validation), learning rate, gradient norms, perplexity for language models, and task-specific metrics. Watch for overfitting, underfitting, or training instability
B) Only loss
C) Don't monitor
D) Only at end

Question 10: What is gradient clipping and when do you use it?

A) Gradient clipping caps gradient magnitude to prevent exploding gradients. Use it when gradients become very large, especially in deep networks or when using large learning rates
B) To increase gradients
C) Not needed
D) Only for small models

Question 11: How do you handle long sequences in transformer training?

A) Use gradient checkpointing, chunk sequences, use sparse attention patterns, or train with shorter sequences and fine-tune on longer ones. O(n²) attention complexity makes very long sequences expensive
B) Always use full length
C) Always truncate
D) No solutions

Question 12: What is the typical training setup for large language models?

A) Train on massive text corpora, use large batch sizes with gradient accumulation, AdamW optimizer with warmup and decay, mixed precision, distributed training across many GPUs, train for many epochs or until convergence
B) Small dataset, single GPU
C) No warmup needed
D) Train for one epoch