Chapter 7: Long Short-Term Memory (LSTM)
Solving the vanishing gradient problem with gated memory cells
Learning Objectives
- Understand why LSTMs solve the vanishing gradient problem
- Master the three gates: forget, input, and output
- Learn the cell state and hidden state mechanism
- Understand the complete LSTM forward pass
- Implement an LSTM from scratch
- Compare LSTMs with standard RNNs
What is LSTM?
🧠 Gated Memory Networks
Long Short-Term Memory (LSTM) networks are a special type of RNN designed to solve the vanishing gradient problem. They use gating mechanisms to selectively remember or forget information, allowing them to learn long-term dependencies.
Think of LSTM as an advanced memory system:
- Standard RNN: Like a person with short-term memory - they remember recent things but forget older information
- LSTM: Like a person with both short-term and long-term memory - they can remember important things from much earlier
- Key difference: LSTMs have explicit mechanisms (gates) to decide what to remember and what to forget
Why LSTMs Were Invented
The fundamental problem with standard RNNs:
Real-World Example: Understanding Long Sentences
Sentence: "The cat, which was very fluffy and had been sleeping on the mat all afternoon, finally woke up and stretched."
Standard RNN problem:
- By the time we reach "stretched", the RNN has processed many words
- The information about "cat" from the beginning has been multiplied many times
- Like a game of telephone - the message gets distorted and lost
- Result: The RNN might forget that "stretched" refers to the "cat"
LSTM solution:
- LSTM has a special "cell state" that acts like a conveyor belt
- Important information (like "cat") can be stored in the cell state
- This information flows through time without being multiplied
- Result: LSTM remembers "cat" even after many words, so "stretched" correctly refers to it
Key Innovations of LSTMs
1. Cell State: The Information Highway
The cell state is like a conveyor belt that runs through the entire sequence:
- Information can be added to it or removed from it
- But it flows through time without being multiplied
- This is the key to solving vanishing gradients!
- Analogy: Like a river - water flows continuously, and you can add or remove things, but the flow itself doesn't shrink
2. Gates: Selective Memory Control
LSTMs have three gates that control information flow:
- Forget Gate: "What should I forget from the past?"
- Input Gate: "What new information should I remember?"
- Output Gate: "What information should I use for the current output?"
Key insight: These gates are learned - the network learns what to remember and what to forget!
3. Additive Updates: No Gradient Shrinking
Standard RNN: h_t = f(h_{t-1}, x_t) - information is transformed (multiplied)
LSTM: C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t - information is added
- Addition allows gradients to flow through without shrinking
- This is the mathematical key to solving vanishing gradients
- Analogy: Like adding items to a list vs transforming the entire list - addition preserves information
📚 Detailed Analogy: LSTM as a Smart Filing System
Think of LSTM like an intelligent filing system:
The Filing Cabinet (Cell State)
- This is your long-term storage - information stays here across time
- Like a physical filing cabinet, you can add files or remove files
- But the cabinet itself (the structure) remains constant
- Key: Information in the cabinet doesn't degrade over time
The Three Gates (Smart Decisions)
- Forget Gate: Like a secretary who reviews old files and decides which ones to archive/delete
- Looks at current context and old information
- Decides: "This old information is no longer relevant"
- Removes it from the cabinet
- Input Gate: Like a secretary who reviews new documents and decides which ones to file
- Looks at new information and current context
- Decides: "This new information is important"
- Adds it to the cabinet
- Output Gate: Like a secretary who decides what information to show you
- Looks at what's in the cabinet
- Decides: "This information is relevant for the current task"
- Makes it available (hidden state)
🔄 LSTM Gate Flow Visualization
Input x_t
Current word
h_{t-1}
Previous hidden
C_{t-1}
Previous cell state
Forget Gate
f_t = σ(...)
What to forget?
Input Gate
i_t = σ(...)
What to remember?
Output Gate
o_t = σ(...)
What to output?
Cell State Update
C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
(Forget old) + (Add new)
h_t
New hidden state
C_t
New cell state
💡 Key: The three gates control information flow. Forget gate removes old info, Input gate adds new info, Output gate decides what to use. The cell state (C_t) flows through time without degradation!
The Complete Process
When processing a new word:
- Forget Gate: Reviews old files, decides what to remove
- Input Gate: Reviews new information, decides what to add
- Update Cabinet: Remove old files, add new files
- Output Gate: Decides what information from the cabinet to use
- Result: You have an updated understanding that combines old and new information
The Problem LSTMs Solve
⚠️ Vanishing Gradients in RNNs
Standard RNNs suffer from vanishing gradients because:
- Gradients are multiplied by W_hh at each time step
- If |W_hh| < 1, gradients shrink exponentially
- Early time steps receive almost no gradient
- Can't learn long-term dependencies
RNN Gradient Problem
In RNNs, gradient at time t-k:
Problem: If W_hh < 1, (W_hh)^k → 0 as k increases
LSTM Solution:
- Uses additive updates instead of multiplicative
- Cell state: C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
- Gradients can flow through addition (no shrinking!)
The Three Gates
🚪 Gate Mechanism
LSTMs use three gates to control information flow:
1. Forget Gate
Forget Gate Formula
Purpose:
- Decides what information to forget from cell state
- Output: 0 (forget) to 1 (keep)
- Applied to previous cell state: f_t ⊙ C_{t-1}
2. Input Gate
Input Gate Formula
C̃_t = tanh(W_C × [h_{t-1}, x_t] + b_C)
Purpose:
- i_t: Decides which values to update
- C̃_t: New candidate values
- Together: i_t ⊙ C̃_t (what new information to add)
3. Output Gate
Output Gate Formula
Purpose:
- Decides what parts of cell state to output
- Applied to tanh(C_t): o_t ⊙ tanh(C_t)
- This becomes the hidden state h_t
Cell State and Hidden State
Two-State System
LSTMs maintain two states:
- Cell State (C_t): Long-term memory (gradients flow easily)
- Hidden State (h_t): Short-term memory (used for predictions)
Cell State Update
Breakdown:
- f_t ⊙ C_{t-1}: What to keep from previous state
- i_t ⊙ C̃_t: What new information to add
- Addition: Key! Gradients flow through addition
Hidden State Update
Purpose:
- Hidden state is a filtered version of cell state
- Output gate controls what information to expose
- Used for predictions and next time step
Complete LSTM Formulas
Full LSTM Forward Pass
At each time step t:
f_t = σ(W_f × [h_{t-1}, x_t] + b_f)
i_t = σ(W_i × [h_{t-1}, x_t] + b_i)
o_t = σ(W_o × [h_{t-1}, x_t] + b_o)
Step 2: Candidate Values
C̃_t = tanh(W_C × [h_{t-1}, x_t] + b_C)
Step 3: Update Cell State
C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
Step 4: Update Hidden State
h_t = o_t ⊙ tanh(C_t)
Notation:
- ⊙: Element-wise multiplication (Hadamard product)
- [h_{t-1}, x_t]: Concatenation of hidden state and input
- σ: Sigmoid function (outputs 0-1)
- tanh: Hyperbolic tangent (outputs -1 to 1)
LSTM Implementation
Complete LSTM Implementation
import numpy as np
class LSTM:
"""Long Short-Term Memory Network"""
def __init__(self, input_size, hidden_size):
self.input_size = input_size
self.hidden_size = hidden_size
# Weight matrices for forget gate
self.W_f = np.random.randn(hidden_size, hidden_size + input_size) * 0.1
self.b_f = np.zeros((hidden_size, 1))
# Weight matrices for input gate
self.W_i = np.random.randn(hidden_size, hidden_size + input_size) * 0.1
self.b_i = np.zeros((hidden_size, 1))
# Weight matrices for candidate values
self.W_C = np.random.randn(hidden_size, hidden_size + input_size) * 0.1
self.b_C = np.zeros((hidden_size, 1))
# Weight matrices for output gate
self.W_o = np.random.randn(hidden_size, hidden_size + input_size) * 0.1
self.b_o = np.zeros((hidden_size, 1))
def sigmoid(self, x):
return 1 / (1 + np.exp(-np.clip(x, -250, 250)))
def tanh(self, x):
return np.tanh(x)
def forward_step(self, x, h_prev, C_prev):
"""One LSTM forward step"""
# Concatenate hidden state and input
concat = np.vstack([h_prev, x])
# Forget gate
f_t = self.sigmoid(np.dot(self.W_f, concat) + self.b_f)
# Input gate
i_t = self.sigmoid(np.dot(self.W_i, concat) + self.b_i)
# Candidate values
C_tilde = self.tanh(np.dot(self.W_C, concat) + self.b_C)
# Update cell state
C_t = f_t * C_prev + i_t * C_tilde
# Output gate
o_t = self.sigmoid(np.dot(self.W_o, concat) + self.b_o)
# Update hidden state
h_t = o_t * self.tanh(C_t)
return h_t, C_t
def forward(self, sequence):
"""Forward pass through sequence"""
h = np.zeros((self.hidden_size, 1))
C = np.zeros((self.hidden_size, 1))
outputs = []
for x in sequence:
h, C = self.forward_step(x, h, C)
outputs.append(h)
return outputs
# Example usage
lstm = LSTM(input_size=10, hidden_size=20)
sequence = [np.random.randn(10, 1) for _ in range(5)]
outputs = lstm.forward(sequence)
print(f"Processed {len(outputs)} time steps")