Chapter 7: Long Short-Term Memory (LSTM)

Solving the vanishing gradient problem with gated memory cells

Learning Objectives

  • Understand why LSTMs solve the vanishing gradient problem
  • Master the three gates: forget, input, and output
  • Learn the cell state and hidden state mechanism
  • Understand the complete LSTM forward pass
  • Implement an LSTM from scratch
  • Compare LSTMs with standard RNNs

What is LSTM?

🧠 Gated Memory Networks

Long Short-Term Memory (LSTM) networks are a special type of RNN designed to solve the vanishing gradient problem. They use gating mechanisms to selectively remember or forget information, allowing them to learn long-term dependencies.

Think of LSTM as an advanced memory system:

  • Standard RNN: Like a person with short-term memory - they remember recent things but forget older information
  • LSTM: Like a person with both short-term and long-term memory - they can remember important things from much earlier
  • Key difference: LSTMs have explicit mechanisms (gates) to decide what to remember and what to forget

Why LSTMs Were Invented

The fundamental problem with standard RNNs:

Real-World Example: Understanding Long Sentences

Sentence: "The cat, which was very fluffy and had been sleeping on the mat all afternoon, finally woke up and stretched."

Standard RNN problem:

  • By the time we reach "stretched", the RNN has processed many words
  • The information about "cat" from the beginning has been multiplied many times
  • Like a game of telephone - the message gets distorted and lost
  • Result: The RNN might forget that "stretched" refers to the "cat"

LSTM solution:

  • LSTM has a special "cell state" that acts like a conveyor belt
  • Important information (like "cat") can be stored in the cell state
  • This information flows through time without being multiplied
  • Result: LSTM remembers "cat" even after many words, so "stretched" correctly refers to it

Key Innovations of LSTMs

1. Cell State: The Information Highway

The cell state is like a conveyor belt that runs through the entire sequence:

  • Information can be added to it or removed from it
  • But it flows through time without being multiplied
  • This is the key to solving vanishing gradients!
  • Analogy: Like a river - water flows continuously, and you can add or remove things, but the flow itself doesn't shrink
2. Gates: Selective Memory Control

LSTMs have three gates that control information flow:

  • Forget Gate: "What should I forget from the past?"
  • Input Gate: "What new information should I remember?"
  • Output Gate: "What information should I use for the current output?"

Key insight: These gates are learned - the network learns what to remember and what to forget!

3. Additive Updates: No Gradient Shrinking

Standard RNN: h_t = f(h_{t-1}, x_t) - information is transformed (multiplied)

LSTM: C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t - information is added

  • Addition allows gradients to flow through without shrinking
  • This is the mathematical key to solving vanishing gradients
  • Analogy: Like adding items to a list vs transforming the entire list - addition preserves information

📚 Detailed Analogy: LSTM as a Smart Filing System

Think of LSTM like an intelligent filing system:

The Filing Cabinet (Cell State)
  • This is your long-term storage - information stays here across time
  • Like a physical filing cabinet, you can add files or remove files
  • But the cabinet itself (the structure) remains constant
  • Key: Information in the cabinet doesn't degrade over time
The Three Gates (Smart Decisions)
  • Forget Gate: Like a secretary who reviews old files and decides which ones to archive/delete
    • Looks at current context and old information
    • Decides: "This old information is no longer relevant"
    • Removes it from the cabinet
  • Input Gate: Like a secretary who reviews new documents and decides which ones to file
    • Looks at new information and current context
    • Decides: "This new information is important"
    • Adds it to the cabinet
  • Output Gate: Like a secretary who decides what information to show you
    • Looks at what's in the cabinet
    • Decides: "This information is relevant for the current task"
    • Makes it available (hidden state)

🔄 LSTM Gate Flow Visualization

Input x_t

Current word

h_{t-1}

Previous hidden

C_{t-1}

Previous cell state

Forget Gate

f_t = σ(...)

What to forget?

Input Gate

i_t = σ(...)

What to remember?

Output Gate

o_t = σ(...)

What to output?

Cell State Update

C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t

(Forget old) + (Add new)

h_t

New hidden state

C_t

New cell state

💡 Key: The three gates control information flow. Forget gate removes old info, Input gate adds new info, Output gate decides what to use. The cell state (C_t) flows through time without degradation!

The Complete Process

When processing a new word:

  1. Forget Gate: Reviews old files, decides what to remove
  2. Input Gate: Reviews new information, decides what to add
  3. Update Cabinet: Remove old files, add new files
  4. Output Gate: Decides what information from the cabinet to use
  5. Result: You have an updated understanding that combines old and new information

The Problem LSTMs Solve

⚠️ Vanishing Gradients in RNNs

Standard RNNs suffer from vanishing gradients because:

  • Gradients are multiplied by W_hh at each time step
  • If |W_hh| < 1, gradients shrink exponentially
  • Early time steps receive almost no gradient
  • Can't learn long-term dependencies

RNN Gradient Problem

In RNNs, gradient at time t-k:

∂L/∂h_{t-k} = ∂L/∂h_t × (W_hh)^k × tanh'(z_{t-k}) × ... × tanh'(z_t)

Problem: If W_hh < 1, (W_hh)^k → 0 as k increases

LSTM Solution:
  • Uses additive updates instead of multiplicative
  • Cell state: C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
  • Gradients can flow through addition (no shrinking!)

The Three Gates

🚪 Gate Mechanism

LSTMs use three gates to control information flow:

1. Forget Gate

Forget Gate Formula

f_t = σ(W_f × [h_{t-1}, x_t] + b_f)
Purpose:
  • Decides what information to forget from cell state
  • Output: 0 (forget) to 1 (keep)
  • Applied to previous cell state: f_t ⊙ C_{t-1}

2. Input Gate

Input Gate Formula

i_t = σ(W_i × [h_{t-1}, x_t] + b_i)
C̃_t = tanh(W_C × [h_{t-1}, x_t] + b_C)
Purpose:
  • i_t: Decides which values to update
  • C̃_t: New candidate values
  • Together: i_t ⊙ C̃_t (what new information to add)

3. Output Gate

Output Gate Formula

o_t = σ(W_o × [h_{t-1}, x_t] + b_o)
Purpose:
  • Decides what parts of cell state to output
  • Applied to tanh(C_t): o_t ⊙ tanh(C_t)
  • This becomes the hidden state h_t

Cell State and Hidden State

Two-State System

LSTMs maintain two states:

  • Cell State (C_t): Long-term memory (gradients flow easily)
  • Hidden State (h_t): Short-term memory (used for predictions)

Cell State Update

C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
Breakdown:
  • f_t ⊙ C_{t-1}: What to keep from previous state
  • i_t ⊙ C̃_t: What new information to add
  • Addition: Key! Gradients flow through addition

Hidden State Update

h_t = o_t ⊙ tanh(C_t)
Purpose:
  • Hidden state is a filtered version of cell state
  • Output gate controls what information to expose
  • Used for predictions and next time step

Complete LSTM Formulas

Full LSTM Forward Pass

At each time step t:

Step 1: Compute Gates
f_t = σ(W_f × [h_{t-1}, x_t] + b_f)
i_t = σ(W_i × [h_{t-1}, x_t] + b_i)
o_t = σ(W_o × [h_{t-1}, x_t] + b_o)

Step 2: Candidate Values
C̃_t = tanh(W_C × [h_{t-1}, x_t] + b_C)

Step 3: Update Cell State
C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t

Step 4: Update Hidden State
h_t = o_t ⊙ tanh(C_t)
Notation:
  • ⊙: Element-wise multiplication (Hadamard product)
  • [h_{t-1}, x_t]: Concatenation of hidden state and input
  • σ: Sigmoid function (outputs 0-1)
  • tanh: Hyperbolic tangent (outputs -1 to 1)

LSTM Implementation

Complete LSTM Implementation

import numpy as np

class LSTM:
    """Long Short-Term Memory Network"""
    
    def __init__(self, input_size, hidden_size):
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Weight matrices for forget gate
        self.W_f = np.random.randn(hidden_size, hidden_size + input_size) * 0.1
        self.b_f = np.zeros((hidden_size, 1))
        
        # Weight matrices for input gate
        self.W_i = np.random.randn(hidden_size, hidden_size + input_size) * 0.1
        self.b_i = np.zeros((hidden_size, 1))
        
        # Weight matrices for candidate values
        self.W_C = np.random.randn(hidden_size, hidden_size + input_size) * 0.1
        self.b_C = np.zeros((hidden_size, 1))
        
        # Weight matrices for output gate
        self.W_o = np.random.randn(hidden_size, hidden_size + input_size) * 0.1
        self.b_o = np.zeros((hidden_size, 1))
    
    def sigmoid(self, x):
        return 1 / (1 + np.exp(-np.clip(x, -250, 250)))
    
    def tanh(self, x):
        return np.tanh(x)
    
    def forward_step(self, x, h_prev, C_prev):
        """One LSTM forward step"""
        # Concatenate hidden state and input
        concat = np.vstack([h_prev, x])
        
        # Forget gate
        f_t = self.sigmoid(np.dot(self.W_f, concat) + self.b_f)
        
        # Input gate
        i_t = self.sigmoid(np.dot(self.W_i, concat) + self.b_i)
        
        # Candidate values
        C_tilde = self.tanh(np.dot(self.W_C, concat) + self.b_C)
        
        # Update cell state
        C_t = f_t * C_prev + i_t * C_tilde
        
        # Output gate
        o_t = self.sigmoid(np.dot(self.W_o, concat) + self.b_o)
        
        # Update hidden state
        h_t = o_t * self.tanh(C_t)
        
        return h_t, C_t
    
    def forward(self, sequence):
        """Forward pass through sequence"""
        h = np.zeros((self.hidden_size, 1))
        C = np.zeros((self.hidden_size, 1))
        outputs = []
        
        for x in sequence:
            h, C = self.forward_step(x, h, C)
            outputs.append(h)
        
        return outputs

# Example usage
lstm = LSTM(input_size=10, hidden_size=20)
sequence = [np.random.randn(10, 1) for _ in range(5)]
outputs = lstm.forward(sequence)
print(f"Processed {len(outputs)} time steps")

Test Your Understanding

Question 1: How do LSTMs solve the vanishing gradient problem?

A) By using fewer layers
B) By using additive updates to cell state instead of multiplicative
C) By removing hidden states
D) By using larger learning rates

Question 2: What does the forget gate do?

A) Decides what information to forget from the cell state
B) Removes all previous information
C) Controls the output
D) Adds new information

Question 3: What is the difference between cell state and hidden state?

A) Cell state is long-term memory, hidden state is short-term memory used for predictions
B) They are the same thing
C) Cell state is input, hidden state is output
D) There is no difference

Question 4: How does the forget gate work in an LSTM?

A) The forget gate decides what information to discard from the cell state by outputting values between 0 and 1, where 0 means completely forget and 1 means keep everything
B) It forgets everything
C) It remembers everything
D) It only works on inputs

Question 5: What is the purpose of the cell state in LSTM?

A) The cell state acts as a highway for information to flow through the sequence with minimal modification, allowing gradients to flow better and enabling long-term memory
B) It stores only current input
C) It's the same as hidden state
D) It's not important

Question 6: How does LSTM solve the vanishing gradient problem better than basic RNNs?

A) The cell state provides a direct path for gradients to flow through time with minimal multiplication, and the gates allow selective information flow, preventing gradients from vanishing as they propagate backward
B) It doesn't solve it
C) By using more layers
D) By using fewer parameters

Question 7: What is the difference between LSTM and GRU?

A) LSTM has separate forget and input gates with a cell state, while GRU combines forget and input gates into a single update gate and merges cell state with hidden state, making GRU simpler and faster but sometimes less powerful
B) They're identical
C) GRU is always better
D) LSTM has no gates

Question 8: How do you compute the LSTM output at each time step?

A) The output gate filters the cell state using a sigmoid, then the filtered cell state is passed through tanh to produce the hidden state, which becomes the output
B) Just use cell state directly
C) Use only input
D) Random value

Question 9: When would you choose LSTM over a basic RNN?

A) When you need to learn long-term dependencies in sequences, when basic RNNs struggle with vanishing gradients, or when task requires remembering information over many time steps
B) Always use basic RNN
C) Only for short sequences
D) They're interchangeable

Question 10: What is the computational cost of LSTM compared to basic RNN?

A) LSTM is more expensive due to multiple gates and the cell state, requiring more parameters and computations per time step, but often worth it for better performance on long sequences
B) LSTM is faster
C) They're the same
D) LSTM uses fewer parameters

Question 11: How does the input gate decide what new information to store?

A) The input gate uses a sigmoid to decide which values to update, then a tanh creates candidate values, and these are combined to update the cell state selectively
B) It stores everything
C) It stores nothing
D) Random selection

Question 12: How would you implement an LSTM cell from scratch?

A) Implement forget gate (sigmoid of weighted sum), input gate (sigmoid), candidate values (tanh), update cell state (forget old, add new), output gate (sigmoid), compute hidden state (output gate times tanh of cell state). Each gate has its own weight matrices
B) Just copy RNN code
C) Use only one gate
D) No implementation needed