前言
LSTM(Long Short-Term Memory)是深度学习中处理序列数据的经典模型。虽然现在 Transformer 风头正盛,但 LSTM 仍然是理解序列建模的基石,面试中也是常客。
很多人用 PyTorch 的 nn.LSTM 用得很溜,但如果问到:
- LSTM 内部到底有几个门?每个门的作用是什么?
- 权重矩阵的形状是怎样的?为什么是这个形状?
- 遗忘门和输入门是怎么配合工作的?
可能就答不上来了。
本文将带你用 NumPy 手撕 LSTM 的前向传播过程,并与 PyTorch 的结果进行对比验证。读完这篇文章,你将彻底理解 LSTM 的每一个计算细节。
一、为什么需要 LSTM?
1.1 RNN 的困境:梯度消失
传统 RNN 的结构很简单:
ht=tanh(Whhht−1+Wxhxt+b)h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b)ht=tanh(Whhht−1+Wxhxt+b)
但这种结构有一个致命问题:梯度消失。
当序列很长时,梯度在反向传播过程中会不断相乘,导致:
-
梯度指数级衰减 → 无法学习长期依赖
-
或梯度指数级爆炸 → 训练不稳定
时间步: t=1 → t=2 → t=3 → ... → t=100
梯度: 1.0 → 0.5 → 0.25 → ... → ≈0 (消失了!)
1.2 LSTM 的解决方案:门控机制
LSTM 引入了门控机制 和细胞状态,让信息可以选择性地遗忘或记住:
┌─────────────────────────────────────────────────────┐
│ LSTM Cell │
│ │
│ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ │
│ │遗忘门 │ │输入门 │ │候选值 │ │输出门 │ │
│ │ fₜ │ │ iₜ │ │ gₜ │ │ oₜ │ │
│ └──┬───┘ └──┬───┘ └──┬───┘ └──┬───┘ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ ┌──────────────────────────────────────────┐ │
│ │ 细胞状态 Cₜ(信息高速公路) │ │
│ └──────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 隐藏状态 hₜ │
└─────────────────────────────────────────────────────┘
核心思想: 细胞状态 CtC_tCt 就像一条信息高速公路,信息可以几乎无损地传递,解决了梯度消失问题。
二、LSTM 的数学原理
2.1 四个核心组件
LSTM 在每个时间步有 4 个计算组件:
| 组件 | 符号 | 激活函数 | 作用 |
|---|---|---|---|
| 遗忘门 | ftf_tft | Sigmoid | 决定丢弃多少旧信息 |
| 输入门 | iti_tit | Sigmoid | 决定接收多少新信息 |
| 候选值 | C~t\tilde{C}_tC~t | Tanh | 生成新的候选信息 |
| 输出门 | oto_tot | Sigmoid | 决定输出多少信息 |
2.2 计算公式
Step 1:计算四个门
ft=σ(Wf⋅[ht−1,xt]+bf)遗忘门f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \quad \text{遗忘门}ft=σ(Wf⋅[ht−1,xt]+bf)遗忘门
it=σ(Wi⋅[ht−1,xt]+bi)输入门i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \quad \text{输入门}it=σ(Wi⋅[ht−1,xt]+bi)输入门
C~t=tanh(WC⋅[ht−1,xt]+bC)候选值\tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C) \quad \text{候选值}C~t=tanh(WC⋅[ht−1,xt]+bC)候选值
ot=σ(Wo⋅[ht−1,xt]+bo)输出门o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \quad \text{输出门}ot=σ(Wo⋅[ht−1,xt]+bo)输出门
Step 2:更新细胞状态
Ct=ft⊙Ct−1+it⊙C~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_tCt=ft⊙Ct−1+it⊙C~t
- ft⊙Ct−1f_t \odot C_{t-1}ft⊙Ct−1:遗忘部分旧信息
- it⊙C~ti_t \odot \tilde{C}_tit⊙C~t:加入部分新信息
Step 3:计算隐藏状态输出
ht=ot⊙tanh(Ct)h_t = o_t \odot \tanh(C_t)ht=ot⊙tanh(Ct)
2.3 数据流图示
Cₜ₋₁ ─────────────────────────────→ × ────→ + ────→ Cₜ
↑ ↑
│ │
┌─────────────────────────────┐ fₜ iₜ×g̃ₜ
│ │ ↑ ↑
xₜ ─────────→│ [hₜ₋₁, xₜ] 拼接 │─────┴───────┘
│ │
hₜ₋₁ ───────→│ 经过4组不同的权重 │─────→ oₜ
│ │ │
└─────────────────────────────┘ │
↓
oₜ × tanh(Cₜ) = hₜ
三、PyTorch LSTM 权重结构
在手撕之前,我们先看看 PyTorch 的 LSTM 权重是怎么存储的:
python
import torch.nn as nn
input_dim = 12
hidden_size = 7
torch_lstm = nn.LSTM(input_dim, hidden_size, batch_first=True)
for key, weight in torch_lstm.state_dict().items():
print(key, weight.shape)
输出:
weight_ih_l0 torch.Size([28, 12]) # 4 * hidden_size × input_dim
weight_hh_l0 torch.Size([28, 7]) # 4 * hidden_size × hidden_size
bias_ih_l0 torch.Size([28]) # 4 * hidden_size
bias_hh_l0 torch.Size([28]) # 4 * hidden_size
3.1 权重形状解析
为什么是 28? 因为 28=4×728 = 4 \times 728=4×7(4个门 × hidden_size)
PyTorch 将 4 个门的权重拼接存储:
weight_ih_l0 [28, 12]:
┌────────────────────┐
│ W_i_x [7, 12] │ ← 输入门对 x 的权重
├────────────────────┤
│ W_f_x [7, 12] │ ← 遗忘门对 x 的权重
├────────────────────┤
│ W_g_x [7, 12] │ ← 候选值对 x 的权重
├────────────────────┤
│ W_o_x [7, 12] │ ← 输出门对 x 的权重
└────────────────────┘
weight_hh_l0 [28, 7]:
┌────────────────────┐
│ W_i_h [7, 7] │ ← 输入门对 h 的权重
├────────────────────┤
│ W_f_h [7, 7] │ ← 遗忘门对 h 的权重
├────────────────────┤
│ W_g_h [7, 7] │ ← 候选值对 h 的权重
├────────────────────┤
│ W_o_h [7, 7] │ ← 输出门对 h 的权重
└────────────────────┘
注意: PyTorch 的顺序是 i, f, g, o(输入门、遗忘门、候选值、输出门),不是 f, i, g, o!
四、手撕 LSTM:NumPy 实现
现在我们来用 NumPy 实现 LSTM 的前向传播:
4.1 辅助函数
python
import numpy as np
def sigmoid(x):
return 1 / (1 + np.exp(-x))
4.2 完整实现
python
def numpy_lstm(x, state_dict):
"""
用 NumPy 实现 LSTM 前向传播
参数:
x: 输入序列,形状 [seq_len, input_dim]
state_dict: PyTorch LSTM 的权重字典
返回:
sequence_output: 所有时间步的隐藏状态 [seq_len, 1, hidden_size]
(h_t, c_t): 最后一个时间步的隐藏状态和细胞状态
"""
# 1. 提取权重
weight_ih = state_dict["weight_ih_l0"].numpy() # [4*hidden, input_dim]
weight_hh = state_dict["weight_hh_l0"].numpy() # [4*hidden, hidden_size]
bias_ih = state_dict["bias_ih_l0"].numpy() # [4*hidden]
bias_hh = state_dict["bias_hh_l0"].numpy() # [4*hidden]
hidden_size = weight_hh.shape[1]
# 2. 拆分四个门的权重(注意顺序是 i, f, g, o)
# 对输入 x 的权重
w_i_x = weight_ih[0:hidden_size, :]
w_f_x = weight_ih[hidden_size:hidden_size*2, :]
w_g_x = weight_ih[hidden_size*2:hidden_size*3, :]
w_o_x = weight_ih[hidden_size*3:hidden_size*4, :]
# 对隐藏状态 h 的权重
w_i_h = weight_hh[0:hidden_size, :]
w_f_h = weight_hh[hidden_size:hidden_size*2, :]
w_g_h = weight_hh[hidden_size*2:hidden_size*3, :]
w_o_h = weight_hh[hidden_size*3:hidden_size*4, :]
# 偏置
b_i_x = bias_ih[0:hidden_size]
b_f_x = bias_ih[hidden_size:hidden_size*2]
b_g_x = bias_ih[hidden_size*2:hidden_size*3]
b_o_x = bias_ih[hidden_size*3:hidden_size*4]
b_i_h = bias_hh[0:hidden_size]
b_f_h = bias_hh[hidden_size:hidden_size*2]
b_g_h = bias_hh[hidden_size*2:hidden_size*3]
b_o_h = bias_hh[hidden_size*3:hidden_size*4]
# 3. 合并权重(方便后续计算)
# 将 [h, x] 拼接后,一次矩阵乘法完成计算
w_i = np.concatenate([w_i_h, w_i_x], axis=1) # [hidden, hidden+input]
w_f = np.concatenate([w_f_h, w_f_x], axis=1)
w_g = np.concatenate([w_g_h, w_g_x], axis=1)
w_o = np.concatenate([w_o_h, w_o_x], axis=1)
b_i = b_i_h + b_i_x
b_f = b_f_h + b_f_x
b_g = b_g_h + b_g_x
b_o = b_o_h + b_o_x
# 4. 初始化隐藏状态和细胞状态
c_t = np.zeros((1, hidden_size))
h_t = np.zeros((1, hidden_size))
# 5. 逐时间步计算
sequence_output = []
for x_t in x:
x_t = x_t[np.newaxis, :] # [1, input_dim]
# 拼接 h 和 x
hx = np.concatenate([h_t, x_t], axis=1) # [1, hidden+input]
# 计算四个门
f_t = sigmoid(np.dot(hx, w_f.T) + b_f) # 遗忘门
i_t = sigmoid(np.dot(hx, w_i.T) + b_i) # 输入门
g_t = np.tanh(np.dot(hx, w_g.T) + b_g) # 候选值
o_t = sigmoid(np.dot(hx, w_o.T) + b_o) # 输出门
# 更新细胞状态
c_t = f_t * c_t + i_t * g_t
# 计算隐藏状态输出
h_t = o_t * np.tanh(c_t)
sequence_output.append(h_t)
return np.array(sequence_output), (h_t, c_t)
4.3 代码逐行解读
让我们逐步拆解关键部分:
Step 1:拆分权重
python
# PyTorch 把 4 个门的权重拼在一起存储
# 顺序是 i(输入门), f(遗忘门), g(候选值), o(输出门)
w_i_x = weight_ih[0:hidden_size, :] # 第 0~6 行
w_f_x = weight_ih[hidden_size:hidden_size*2, :] # 第 7~13 行
w_g_x = weight_ih[hidden_size*2:hidden_size*3, :] # 第 14~20 行
w_o_x = weight_ih[hidden_size*3:hidden_size*4, :] # 第 21~27 行
Step 2:计算四个门
python
# 拼接 h 和 x,然后一次矩阵乘法
hx = np.concatenate([h_t, x_t], axis=1) # [1, hidden+input]
# 遗忘门:决定丢弃多少旧信息
f_t = sigmoid(np.dot(hx, w_f.T) + b_f) # 输出范围 (0, 1)
# 输入门:决定接收多少新信息
i_t = sigmoid(np.dot(hx, w_i.T) + b_i) # 输出范围 (0, 1)
# 候选值:生成新的候选信息
g_t = np.tanh(np.dot(hx, w_g.T) + b_g) # 输出范围 (-1, 1)
# 输出门:决定输出多少信息
o_t = sigmoid(np.dot(hx, w_o.T) + b_o) # 输出范围 (0, 1)
Step 3:更新状态
python
# 细胞状态更新:遗忘旧的 + 加入新的
c_t = f_t * c_t + i_t * g_t
# 隐藏状态:从细胞状态中筛选输出
h_t = o_t * np.tanh(c_t)
五、验证:与 PyTorch 结果对比
python
import torch
# 构造输入
length = 6
input_dim = 12
hidden_size = 7
x = np.random.random((length, input_dim))
# PyTorch LSTM
torch_lstm = nn.LSTM(input_dim, hidden_size, batch_first=True)
torch_output, (torch_h, torch_c) = torch_lstm(torch.Tensor([x]))
# NumPy LSTM
numpy_output, (numpy_h, numpy_c) = numpy_lstm(x, torch_lstm.state_dict())
# 对比结果
print("=== 序列输出对比 ===")
print("PyTorch:", torch_output.detach().numpy().round(4))
print("NumPy: ", numpy_output.squeeze().round(4))
print("\n=== 最终隐藏状态对比 ===")
print("PyTorch:", torch_h.detach().numpy().round(4))
print("NumPy: ", numpy_h.round(4))
print("\n=== 最终细胞状态对比 ===")
print("PyTorch:", torch_c.detach().numpy().round(4))
print("NumPy: ", numpy_c.round(4))
运行结果(示例):
=== 序列输出对比 ===
PyTorch: [[[-0.0762 0.1234 -0.0891 ...]]]
NumPy: [[-0.0762 0.1234 -0.0891 ...]]
=== 最终隐藏状态对比 ===
PyTorch: [[[ 0.0892 -0.1567 0.2341 ...]]]
NumPy: [[ 0.0892 -0.1567 0.2341 ...]]
=== 最终细胞状态对比 ===
PyTorch: [[[ 0.1823 -0.3456 0.4521 ...]]]
NumPy: [[ 0.1823 -0.3456 0.4521 ...]]
结果完全一致! 这证明我们的实现是正确的。
六、门控机制的直观理解
为了更好地理解 LSTM 的门控机制,我们来看一个具体例子:
6.1 情景:处理句子 "我爱北京天安门"
时间步 t=1: x="我"
├── 遗忘门 f≈0.1 (几乎忘记之前的信息,因为这是开头)
├── 输入门 i≈0.9 (大量接收新信息)
├── 候选值 g (编码"我"的语义)
└── 细胞状态 C₁ = 0.1×C₀ + 0.9×g ≈ 0.9×g
时间步 t=2: x="爱"
├── 遗忘门 f≈0.8 (保留"我"的信息)
├── 输入门 i≈0.7 (接收"爱"的信息)
├── 候选值 g (编码"爱"的语义)
└── 细胞状态 C₂ = 0.8×C₁ + 0.7×g (同时记住"我"和"爱")
时间步 t=3: x="北京"
├── 遗忘门 f≈0.9 (保留"我爱"的信息)
├── 输入门 i≈0.8 (接收"北京"的信息)
└── 细胞状态包含了"我爱北京"的语义
6.2 门的值域与含义
| 门 | 值域 | 接近 0 | 接近 1 |
|---|---|---|---|
| 遗忘门 ftf_tft | (0, 1) | 忘记所有旧信息 | 保留所有旧信息 |
| 输入门 iti_tit | (0, 1) | 不接收新信息 | 完全接收新信息 |
| 输出门 oto_tot | (0, 1) | 不输出任何信息 | 完全输出信息 |
| 候选值 C~t\tilde{C}_tC~t | (-1, 1) | 负向信息 | 正向信息 |
七、扩展:手撕 GRU
GRU 是 LSTM 的简化版本,只有 2 个门(重置门、更新门),计算更高效。
7.1 GRU 公式
zt=σ(Wz⋅[ht−1,xt]+bz)更新门z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) \quad \text{更新门}zt=σ(Wz⋅[ht−1,xt]+bz)更新门
rt=σ(Wr⋅[ht−1,xt]+br)重置门r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) \quad \text{重置门}rt=σ(Wr⋅[ht−1,xt]+br)重置门
h~t=tanh(Wh⋅[rt⊙ht−1,xt]+bh)候选隐藏状态\tilde{h}t = \tanh(W_h \cdot [r_t \odot h{t-1}, x_t] + b_h) \quad \text{候选隐藏状态}h~t=tanh(Wh⋅[rt⊙ht−1,xt]+bh)候选隐藏状态
ht=(1−zt)⊙h~t+zt⊙ht−1最终隐藏状态h_t = (1 - z_t) \odot \tilde{h}t + z_t \odot h{t-1} \quad \text{最终隐藏状态}ht=(1−zt)⊙h~t+zt⊙ht−1最终隐藏状态
7.2 NumPy 实现
python
def numpy_gru(x, state_dict):
"""用 NumPy 实现 GRU 前向传播"""
weight_ih = state_dict["weight_ih_l0"].numpy()
weight_hh = state_dict["weight_hh_l0"].numpy()
bias_ih = state_dict["bias_ih_l0"].numpy()
bias_hh = state_dict["bias_hh_l0"].numpy()
hidden_size = weight_hh.shape[1]
# GRU 只有 3 组权重:r(重置门), z(更新门), h(候选值)
w_r_x = weight_ih[0:hidden_size, :]
w_z_x = weight_ih[hidden_size:hidden_size*2, :]
w_h_x = weight_ih[hidden_size*2:hidden_size*3, :]
w_r_h = weight_hh[0:hidden_size, :]
w_z_h = weight_hh[hidden_size:hidden_size*2, :]
w_h_h = weight_hh[hidden_size*2:hidden_size*3, :]
b_r_x = bias_ih[0:hidden_size]
b_z_x = bias_ih[hidden_size:hidden_size*2]
b_h_x = bias_ih[hidden_size*2:hidden_size*3]
b_r_h = bias_hh[0:hidden_size]
b_z_h = bias_hh[hidden_size:hidden_size*2]
b_h_h = bias_hh[hidden_size*2:hidden_size*3]
# 合并权重
w_z = np.concatenate([w_z_h, w_z_x], axis=1)
w_r = np.concatenate([w_r_h, w_r_x], axis=1)
b_z = b_z_h + b_z_x
b_r = b_r_h + b_r_x
h_t = np.zeros((1, hidden_size))
sequence_output = []
for x_t in x:
x_t = x_t[np.newaxis, :]
hx = np.concatenate([h_t, x_t], axis=1)
# 更新门和重置门
z_t = sigmoid(np.dot(hx, w_z.T) + b_z)
r_t = sigmoid(np.dot(hx, w_r.T) + b_r)
# 候选隐藏状态(注意:r_t 只作用于 h_t 部分)
h_tilde = np.tanh(
r_t * (np.dot(h_t, w_h_h.T) + b_h_h) +
np.dot(x_t, w_h_x.T) + b_h_x
)
# 最终隐藏状态
h_t = (1 - z_t) * h_tilde + z_t * h_t
sequence_output.append(h_t)
return np.array(sequence_output), h_t
7.3 LSTM vs GRU 对比
| 特性 | LSTM | GRU |
|---|---|---|
| 门的数量 | 4个 | 2个 |
| 状态数量 | 2个(h 和 C) | 1个(h) |
| 参数量 | 更多 | 较少 |
| 计算速度 | 较慢 | 较快 |
| 长序列表现 | 更好 | 稍逊 |
| 适用场景 | 长序列、复杂依赖 | 短序列、资源受限 |
八、常见面试问题
Q1:LSTM 如何解决梯度消失问题?
答: 通过细胞状态 CtC_tCt 和门控机制。细胞状态的更新是加法操作 (Ct=ft⊙Ct−1+it⊙C~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_tCt=ft⊙Ct−1+it⊙C~t),而不是 RNN 中的乘法操作。加法操作在反向传播时梯度不会连乘,避免了梯度消失。
Q2:遗忘门的初始偏置为什么通常设为 1?
答: 将遗忘门的初始偏置设为 1(或较大值),使得 ft≈1f_t \approx 1ft≈1,即默认保留所有信息。这有助于在训练初期让梯度顺利流动,防止过早遗忘。这是 Jozefowicz 等人在 2015 年的研究发现。
Q3:LSTM 能完全解决梯度消失吗?
答: 不能完全解决,但能大大缓解。对于非常长的序列(如上万步),仍然可能出现问题。实践中通常配合以下技术:
- 梯度裁剪(Gradient Clipping)
- 残差连接(Residual Connection)
- 注意力机制(Attention)
Q4:为什么 PyTorch 要把 4 个门的权重拼在一起?
答: 为了计算效率。将 4 个矩阵乘法合并成 1 个大矩阵乘法,可以更好地利用 GPU 的并行计算能力,减少 kernel 调用开销。
九、总结
本文我们:
- 理解了 LSTM 的设计动机:解决 RNN 的梯度消失问题
- 掌握了 LSTM 的数学原理:4 个门的计算公式
- 分析了 PyTorch 的权重存储结构:4 个门拼接存储
- 用 NumPy 手撕了 LSTM:逐行实现前向传播
- 验证了实现的正确性:与 PyTorch 结果完全一致
- 扩展学习了 GRU:LSTM 的简化版本
理解模型的最好方式就是亲手实现一遍。希望这篇文章能帮你彻底搞懂 LSTM 的每一个计算细节。
十、完整代码
python
import torch
import torch.nn as nn
import numpy as np
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def numpy_lstm(x, state_dict):
weight_ih = state_dict["weight_ih_l0"].numpy()
weight_hh = state_dict["weight_hh_l0"].numpy()
bias_ih = state_dict["bias_ih_l0"].numpy()
bias_hh = state_dict["bias_hh_l0"].numpy()
hidden_size = weight_hh.shape[1]
w_i_x, w_f_x, w_g_x, w_o_x = np.split(weight_ih, 4, axis=0)
w_i_h, w_f_h, w_g_h, w_o_h = np.split(weight_hh, 4, axis=0)
b_i_x, b_f_x, b_g_x, b_o_x = np.split(bias_ih, 4)
b_i_h, b_f_h, b_g_h, b_o_h = np.split(bias_hh, 4)
w_i = np.concatenate([w_i_h, w_i_x], axis=1)
w_f = np.concatenate([w_f_h, w_f_x], axis=1)
w_g = np.concatenate([w_g_h, w_g_x], axis=1)
w_o = np.concatenate([w_o_h, w_o_x], axis=1)
b_i, b_f = b_i_h + b_i_x, b_f_h + b_f_x
b_g, b_o = b_g_h + b_g_x, b_o_h + b_o_x
c_t = np.zeros((1, hidden_size))
h_t = np.zeros((1, hidden_size))
sequence_output = []
for x_t in x:
x_t = x_t[np.newaxis, :]
hx = np.concatenate([h_t, x_t], axis=1)
f_t = sigmoid(np.dot(hx, w_f.T) + b_f)
i_t = sigmoid(np.dot(hx, w_i.T) + b_i)
g_t = np.tanh(np.dot(hx, w_g.T) + b_g)
o_t = sigmoid(np.dot(hx, w_o.T) + b_o)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * np.tanh(c_t)
sequence_output.append(h_t)
return np.array(sequence_output), (h_t, c_t)
# 测试
if __name__ == "__main__":
length, input_dim, hidden_size = 6, 12, 7
x = np.random.random((length, input_dim))
torch_lstm = nn.LSTM(input_dim, hidden_size, batch_first=True)
torch_out, (torch_h, torch_c) = torch_lstm(torch.Tensor([x]))
numpy_out, (numpy_h, numpy_c) = numpy_lstm(x, torch_lstm.state_dict())
print("Match:", np.allclose(torch_out.detach().numpy(), numpy_out.squeeze(), atol=1e-6))