|
| 1 | +--- |
| 2 | +title: "LSTMs: Long Short-Term Memory" |
| 3 | +sidebar_label: LSTM |
| 4 | +description: "A deep dive into the LSTM architecture, cell states, and the gating mechanisms that prevent vanishing gradients." |
| 5 | +tags: [deep-learning, rnn, lstm, sequence-modeling, nlp] |
| 6 | +--- |
| 7 | + |
| 8 | +Standard [RNNs](./rnn-basics) have a major weakness: they have a very short memory. Because of the **Vanishing Gradient** problem, they struggle to connect information that is far apart in a sequence. |
| 9 | + |
| 10 | +**LSTMs**, introduced by Hochreiter & Schmidhuber, were specifically designed to overcome this. They introduce a "Cell State" (a long-term memory track) and a series of "Gates" that control what information is kept and what is discarded. |
| 11 | + |
| 12 | +## 1. The Core Innovation: The Cell State |
| 13 | + |
| 14 | +The "Secret Sauce" of the LSTM is the **Cell State ($C_t$)**. You can imagine it as a conveyor belt that runs straight down the entire chain of sequences, with only some minor linear interactions. It is very easy for information to just flow along it unchanged. |
| 15 | + |
| 16 | +## 2. The Three Gates of LSTM |
| 17 | + |
| 18 | +An LSTM uses three specialized gates to protect and control the cell state. Each gate is composed of a **Sigmoid** neural net layer and a point-wise multiplication operation. |
| 19 | + |
| 20 | +### A. The Forget Gate ($f_t$) |
| 21 | +This gate decides what information we are going to throw away from the cell state. |
| 22 | +* **Input:** $h_{t-1}$ (previous hidden state) and $x_t$ (current input). |
| 23 | +* **Output:** A number between 0 (completely forget) and 1 (completely keep). |
| 24 | + |
| 25 | +$$ |
| 26 | +f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) |
| 27 | +$$ |
| 28 | + |
| 29 | +### B. The Input Gate ($i_t$) |
| 30 | +This gate decides which new information we’re going to store in the cell state. It works in tandem with a **tanh** layer that creates a vector of new candidate values ($\tilde{C}_t$). |
| 31 | + |
| 32 | +$$ |
| 33 | +i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) |
| 34 | +$$ |
| 35 | +$$ |
| 36 | +\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) |
| 37 | +$$ |
| 38 | + |
| 39 | +### C. The Output Gate ($o_t$) |
| 40 | +This gate decides what our next hidden state ($h_t$) should be. The hidden state contains information on previous inputs and is also used for predictions. |
| 41 | + |
| 42 | +$$ |
| 43 | +o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) |
| 44 | +$$ |
| 45 | +$$ |
| 46 | +h_t = o_t \odot \tanh(C_t) |
| 47 | +$$ |
| 48 | + |
| 49 | +## 3. Advanced Architectural Logic (Mermaid) |
| 50 | + |
| 51 | +The flow within a single LSTM cell is highly structured. The "Cell State" acts as the horizontal spine, while gates regulate the vertical flow of information. |
| 52 | + |
| 53 | +```mermaid |
| 54 | +graph LR |
| 55 | + subgraph LSTM_Cell [LSTM Cell at Time $$\ t$$] |
| 56 | + direction LR |
| 57 | + X(($$x_t$$)) --> ForgetGate{Forget Gate} |
| 58 | + X --> InputGate{Input Gate} |
| 59 | + X --> OutputGate{Output Gate} |
| 60 | + |
| 61 | + H_prev(($$h_t-1$$)) --> ForgetGate |
| 62 | + H_prev --> InputGate |
| 63 | + H_prev --> OutputGate |
| 64 | + |
| 65 | + C_prev(($$C_t-1$$)) --> Forget_Mult(($$X$$)) |
| 66 | + ForgetGate -- "$$f_t$$" --> Forget_Mult |
| 67 | + |
| 68 | + InputGate -- "$$i_t$$" --> Input_Mult(($$X$$)) |
| 69 | + X --> Candidate[$$\tan h$$] |
| 70 | + Candidate --> Input_Mult |
| 71 | + |
| 72 | + Forget_Mult --> State_Add((+)) |
| 73 | + Input_Mult --> State_Add |
| 74 | + |
| 75 | + State_Add --> C_out(($$C_t$$)) |
| 76 | + C_out --> Tanh_Final[$$\tan h$$] |
| 77 | + |
| 78 | + OutputGate -- "$$o_t$$" --> Output_Mult(($$X$$)) |
| 79 | + Tanh_Final --> Output_Mult |
| 80 | + Output_Mult --> H_out(($$h_t$$)) |
| 81 | + end |
| 82 | +
|
| 83 | +``` |
| 84 | + |
| 85 | +## 4. LSTM vs. Standard RNN |
| 86 | + |
| 87 | +| Feature | Standard RNN | LSTM | |
| 88 | +| --- | --- | --- | |
| 89 | +| **Architecture** | Simple (Single Tanh layer) | Complex (4 interacting layers) | |
| 90 | +| **Memory** | Short-term only | Long and Short-term | |
| 91 | +| **Gradient Flow** | Suffers from Vanishing Gradient | Resists Vanishing Gradient via the Cell State | |
| 92 | +| **Complexity** | Low | High (More parameters to train) | |
| 93 | + |
| 94 | +## 5. Implementation with PyTorch |
| 95 | + |
| 96 | +In PyTorch, the `nn.LSTM` module automatically handles the complex gating logic and cell state management. |
| 97 | + |
| 98 | +```python |
| 99 | +import torch |
| 100 | +import torch.nn as nn |
| 101 | + |
| 102 | +# input_size=10, hidden_size=20, num_layers=1 |
| 103 | +lstm = nn.LSTM(10, 20, batch_first=True) |
| 104 | + |
| 105 | +# Input shape: (batch_size, seq_len, input_size) |
| 106 | +input_seq = torch.randn(1, 5, 10) |
| 107 | + |
| 108 | +# Initial Hidden State (h0) and Cell State (c0) |
| 109 | +h0 = torch.zeros(1, 1, 20) |
| 110 | +c0 = torch.zeros(1, 1, 20) |
| 111 | + |
| 112 | +# Forward pass returns output and a tuple (hn, cn) |
| 113 | +output, (hn, cn) = lstm(input_seq, (h0, c0)) |
| 114 | + |
| 115 | +print(f"Output shape: {output.shape}") # [1, 5, 20] |
| 116 | +print(f"Final Cell State shape: {cn.shape}") # [1, 1, 20] |
| 117 | + |
| 118 | +``` |
| 119 | + |
| 120 | +## References |
| 121 | + |
| 122 | +* **Colah's Blog:** [Understanding LSTM Networks](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) (Essential Reading) |
| 123 | +* **Stanford CS224N:** [RNNs and LSTMs](http://web.stanford.edu/class/cs224n/) |
| 124 | + |
| 125 | +--- |
| 126 | + |
| 127 | +**LSTMs are powerful but computationally expensive because of their three gates. Is there a way to simplify this without losing the memory benefits?** |
0 commit comments