Chapter 10: Transformer Variants & Optimizations

Beyond the Original

Learning Objectives

  • Understand transformer variants & optimizations fundamentals
  • Master the mathematical foundations
  • Learn practical implementation
  • Apply knowledge through examples
  • Recognize real-world applications

Transformer Variants & Optimizations

Introduction

Beyond the Original

This chapter provides comprehensive coverage of transformer variants & optimizations, including detailed explanations, mathematical formulations, code implementations, and real-world examples.

šŸ“š Why This Matters

Understanding transformer variants & optimizations is crucial for mastering modern AI systems. This chapter breaks down complex concepts into digestible explanations with step-by-step examples.

Key Concepts

Transformer Variants

Many architectures extend the original transformer:

  • BERT: Bidirectional encoder for understanding tasks
  • GPT: Autoregressive decoder for generation
  • T5: Encoder-decoder for text-to-text tasks
  • RoBERTa: Optimized BERT training
  • ALBERT: Parameter-sharing for efficiency

Optimization Techniques

Key optimizations improve efficiency:

  • Sparse Attention: Reduce computation by attending to fewer positions
  • Linear Attention: Approximate attention with linear complexity
  • Quantization: Use lower precision to reduce memory
  • Pruning: Remove less important weights
  • Knowledge Distillation: Train smaller models from larger ones

Scaling Strategies

Methods to scale transformers:

  • Model Parallelism: Distribute model across devices
  • Pipeline Parallelism: Split layers across GPUs
  • Mixed Precision: Use float16 for speed, float32 for stability
  • Gradient Checkpointing: Trade compute for memory

Mathematical Formulations

Efficient Attention Variants

\[\text{Sparse Attention: } O(n \log n) \text{ complexity}\]
\[\text{Linear Attention: } O(n) \text{ complexity instead of } O(n^2)\]
\[\text{Flash Attention: Memory-efficient attention computation}\]

Various optimizations reduce the quadratic complexity of standard attention, making transformers more efficient for long sequences.

Model Parallelism

\[\text{Pipeline Parallelism: Split layers across devices}\]
\[\text{Tensor Parallelism: Split matrix operations}\]
\[\text{Data Parallelism: Replicate model, split data}\]

Parallelization strategies enable training and inference of very large models across multiple GPUs/TPUs.

Quantization

\[W_q = \text{round}(W \times s)\]
\[W \approx W_q / s\]

Where \(W_q\) is quantized weight, \(W\) is original weight, and \(s\) is the quantization scale factor

Quantization reduces model size and memory by using lower precision (e.g., int8 instead of float32) while maintaining reasonable accuracy.

Knowledge Distillation

\[L = \alpha L_{\text{task}} + (1-\alpha) L_{\text{distill}}\]
\[L_{\text{distill}} = \text{KL}(P_{\text{student}} || P_{\text{teacher}})\]

Where \(P\) are probability distributions from teacher and student models, and \(\alpha\) balances task loss and distillation loss

Knowledge distillation trains smaller student models to mimic larger teacher models, transferring knowledge while reducing size.

Detailed Examples

Example: Sparse Attention Pattern

Instead of full attention (all-to-all), sparse attention might use:

  • Local attention: Each token attends to nearby tokens (window of size w)
  • Strided attention: Attend to every k-th token
  • Global attention: Some tokens attend to all positions

This reduces computation from O(n²) to O(nƗw) for local attention.

Example: Model Variants Comparison

BERT vs GPT vs T5:

  • BERT: "The cat sat" → [CLS] token for classification
  • GPT: "The cat sat" → predicts "on" (next token)
  • T5: "translate: The cat sat" → "Le chat s'est assis"

Each architecture is optimized for different task types.

Example: Quantization Impact

Original model: 1 billion parameters Ɨ 4 bytes (float32) = 4GB

Quantized (int8): 1 billion parameters Ɨ 1 byte = 1GB

4Ɨ reduction in memory with minimal accuracy loss when done carefully.

Implementation

Sparse Attention Implementation

import torch
import torch.nn as nn

def sparse_attention(Q, K, V, window_size=3):
    """
    Sparse attention with local window
    
    Args:
        Q, K, V: Query, Key, Value tensors (batch, seq_len, d_model)
        window_size: Number of nearby tokens to attend to
    """
    seq_len = Q.size(1)
    scores = torch.matmul(Q, K.transpose(-2, -1))
    
    # Create sparse mask (only attend to nearby tokens)
    mask = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)
        mask[i, start:end] = 1
    
    # Apply mask
    scores = scores.masked_fill(mask == 0, float('-inf'))
    scores = scores / (Q.size(-1) ** 0.5)
    attn_weights = torch.softmax(scores, dim=-1)
    
    output = torch.matmul(attn_weights, V)
    return output

Quantization Example

import numpy as np

def quantize_weights(weights, bits=8):
    """
    Quantize weights to int8
    """
    # Calculate scale factor
    max_val = np.abs(weights).max()
    scale = (2 ** (bits - 1) - 1) / max_val
    
    # Quantize
    quantized = np.round(weights * scale).astype(np.int8)
    
    return quantized, scale

def dequantize_weights(quantized, scale):
    """
    Dequantize back to float
    """
    return quantized.astype(np.float32) / scale

# Example
weights = np.random.randn(100, 100).astype(np.float32)
quantized, scale = quantize_weights(weights)
reconstructed = dequantize_weights(quantized, scale)
print(f"Original size: {weights.nbytes} bytes")
print(f"Quantized size: {quantized.nbytes} bytes")
print(f"Compression: {weights.nbytes / quantized.nbytes:.1f}x")

Real-World Applications

BERT Applications

Understanding tasks:

  • Sentiment analysis in customer reviews
  • Named entity recognition in documents
  • Question answering systems
  • Text classification and tagging

GPT Applications

Generation tasks:

  • Chatbots and conversational AI
  • Code generation and completion
  • Content creation and writing assistance
  • Text summarization

T5 Applications

Text-to-text tasks:

  • Machine translation
  • Text summarization
  • Question answering
  • Text classification (formatted as generation)

Optimization Benefits

Efficiency improvements enable:

  • Faster inference for real-time applications
  • Deployment on edge devices
  • Reduced computational costs
  • Handling longer sequences

Test Your Understanding

Question 1: What are the main applications of transformer models?

A) Natural language understanding (BERT), text generation (GPT), translation, summarization, question answering, chatbots, code generation, and many NLP tasks
B) Only classification
C) Only generation
D) Limited applications

Question 2: What is BERT and how does it work?

A) BERT is a bidirectional encoder-only transformer pre-trained on masked language modeling and next sentence prediction, then fine-tuned for downstream tasks like classification and QA
B) It's a decoder model
C) It only generates
D) It's not a transformer

Question 3: What is GPT and how does it differ from BERT?

A) GPT is a decoder-only transformer pre-trained on next token prediction, generating text autoregressively. BERT is encoder-only for understanding, GPT is for generation
B) They're the same
C) GPT is encoder-only
D) BERT generates text

Question 4: How do you fine-tune a pre-trained transformer?

A) Start with pre-trained weights, add task-specific head if needed, train on your task data with lower learning rate than pre-training, often freezing early layers and only fine-tuning later layers
B) Train from scratch
C) Don't use pre-trained weights
D) Same as pre-training

Question 5: What is transfer learning in transformers?

A) Using knowledge learned from large-scale pre-training on general tasks and applying it to specific downstream tasks, allowing models to perform well with less task-specific data
B) Training from scratch
C) Only using task data
D) No pre-training

Question 6: What tasks can encoder models like BERT handle?

A) Classification, named entity recognition, question answering, sentiment analysis, text similarity - tasks requiring understanding and processing input rather than generation
B) Only generation
C) Only translation
D) Limited tasks

Question 7: What tasks can decoder models like GPT handle?

A) Text generation, completion, story writing, code generation, chatbots, creative writing - tasks requiring generating new text from context
B) Only classification
C) Only understanding
D) Limited tasks

Question 8: What is T5 and how does it work?

A) T5 is an encoder-decoder transformer that frames all NLP tasks as text-to-text problems, using the same architecture for translation, summarization, QA, and classification by converting tasks to text generation
B) It's encoder-only
C) It's decoder-only
D) Not a transformer

Question 9: How do you choose between BERT, GPT, and T5 for a task?

A) BERT for understanding/classification tasks, GPT for generation tasks, T5 for seq2seq tasks like translation/summarization. Consider your task type, data availability, and computational resources
B) Always use BERT
C) Always use GPT
D) They're interchangeable

Question 10: What is prompt engineering and why is it important?

A) Prompt engineering is crafting input prompts to guide model behavior, especially for large language models. Good prompts can significantly improve performance without fine-tuning
B) Not important
C) Only for small models
D) Random prompts work

Question 11: How do modern transformers scale to billions of parameters?

A) Through model parallelism, pipeline parallelism, efficient attention mechanisms, mixture-of-experts architectures, and distributed training across many GPUs/TPUs
B) Single GPU training
C) No special techniques
D) Can't scale

Question 12: What are some limitations of transformer models?

A) O(n²) attention complexity limits sequence length, high computational and memory requirements, need for large datasets, potential for generating biased or incorrect information, and difficulty with very long-range dependencies
B) No limitations
C) Only computational
D) Perfect for all tasks