一、为什么需要LSTM?
想象你要记住一个重要的电话号码。普通循环神经网络(RNN)就像一个容易分心的人:当新信息不断输入时,旧的号码很快会被遗忘。这种现象称为长期依赖问题 。长短期记忆网络(LSTM)的设计灵感来自带便签本的聪明人:它可以选择性地记录重要信息,还能随时擦除无用内容。
二、LSTM的核心组件:记忆本与三道门
2.1 记忆本的结构
每个LSTM单元都携带两个关键信息:
- 隐状态 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> H t H_t </math>Ht):对外传递的短期记忆,类似「当前要说的话」
- 记忆元 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> C t C_t </math>Ct):内部保存的长期记忆,类似「随身携带的笔记本」
2.2 控制记忆的三道门
LSTM 的核心在于对记忆信息的"读写控制",这一机制借鉴了计算机逻辑门的思想。主要有三个门:
注:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> X t X_t </math>Xt 表示当前时间步的输入;
- <math xmlns="http://www.w3.org/1998/Math/MathML"> H t − 1 H_{t-1} </math>Ht−1 表示前一时刻的隐状态;
- <math xmlns="http://www.w3.org/1998/Math/MathML"> σ \sigma </math>σ 是 sigmoid 激活函数;
- <math xmlns="http://www.w3.org/1998/Math/MathML"> W x i W_{xi} </math>Wxi、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W h i W_{hi} </math>Whi、 <math xmlns="http://www.w3.org/1998/Math/MathML"> b i b_i </math>bi 等均为各门的权重和偏置参数。
(1)输入门:决定写什么
决定当前输入中有多少信息需要写入记忆元。
公式: <math xmlns="http://www.w3.org/1998/Math/MathML"> I t = σ ( X t W x i + H t − 1 W h i + b i ) \boxed{I_t = \sigma(X_t W_{xi} + H_{t-1} W_{hi} + b_i)} </math>It=σ(XtWxi+Ht−1Whi+bi)
示例:当输入是重要名字时,输入门会完全打开(值接近1),确保信息被记录。
(2)遗忘门:决定擦除什么
控制保留多少来自过去记忆元的信息。
公式: <math xmlns="http://www.w3.org/1998/Math/MathML"> F t = σ ( X t W x f + H t − 1 W h f + b f ) \boxed{F_t = \sigma(X_t W_{xf} + H_{t-1} W_{hf} + b_f)} </math>Ft=σ(XtWxf+Ht−1Whf+bf)
示例:遇到「但是」等转折词时,遗忘门可能关闭(值接近0),清空之前的状态。
(3)输出门:决定读什么
决定记忆元中有多少信息通过处理后参与到最终输出的隐状态中。
公式: <math xmlns="http://www.w3.org/1998/Math/MathML"> O t = σ ( X t W x o + H t − 1 W h o + b o ) \boxed{O_t = \sigma(X_t W_{xo} + H_{t-1} W_{ho} + b_o)} </math>Ot=σ(XtWxo+Ht−1Who+bo)
2.3 记忆更新过程
(1)候选记忆内容
除了上述三个门控之外,LSTM 还引入了一个候选记忆元,用于生成可供更新记忆元的新信息。候选记忆元的计算与门控类似,但采用了 <math xmlns="http://www.w3.org/1998/Math/MathML"> tanh \tanh </math>tanh 激活函数,其值域在 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( − 1 , 1 ) (-1,1) </math>(−1,1) 内:
<math xmlns="http://www.w3.org/1998/Math/MathML"> C ~ t = tanh ( X t W x c + H t − 1 W h c + b c ) \boxed{\tilde{C}t = \text{tanh}(X_t W{xc} + H_{t-1} W_{hc} + b_c)} </math>C~t=tanh(XtWxc+Ht−1Whc+bc)
(2)更新记忆本
记忆元的更新由遗忘门和输入门共同控制,其更新公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML"> C t = F t ⊙ C t − 1 + I t ⊙ C ~ t \boxed{C_t = F_t \odot C_{t-1} + I_t \odot \tilde{C}_t} </math>Ct=Ft⊙Ct−1+It⊙C~t
其中:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> C t − 1 C_{t-1} </math>Ct−1 为上一时刻的记忆元;
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ⊙ \odot </math>⊙ 表示按元素乘法(Hadamard 乘积)。
这意味着当遗忘门输出接近 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1 且输入门输出接近 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 0 </math>0 时,过去的记忆会被大部分保留;反之,当输入门输出较高时,新信息会更多地写入记忆元。
(3)生成新隐状态
最终,LSTM 利用经过门控调制后的记忆元来计算当前的隐状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> H t H_t </math>Ht。隐状态不仅作为下一时刻计算的输入,还会参与最终的预测输出,其计算公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML"> H t = O t ⊙ tanh ( C t ) \boxed{H_t = O_t \odot \text{tanh}(C_t)} </math>Ht=Ot⊙tanh(Ct)
这样设计既保证了隐状态的数值稳定性(值域为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( − 1 , 1 ) (-1,1) </math>(−1,1)),又确保输出层能够根据输出门的控制,灵活地选择传递多少记忆信息。
三、动手实现LSTM
3.1 初始化参数(PyTorch示例)
python
import torch
from torch import nn
import d2l
batch_size, num_steps = 32, 25
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
def get_lstm_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return torch.rand(shape, device=device) * 0.01
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
torch.zeros(num_hiddens, device=device))
W_xi, W_hi, b_i = three() # 输入门参数
W_xf, W_hf, b_f = three() # 遗忘门参数
W_xo, W_ho, b_o = three() # 输出门参数
W_xc, W_hc, b_c = three() # 候选记忆元参数
# 输出层参数
W_hq = normal((num_hiddens, num_outputs))
b_q = torch.zeros(num_outputs, device=device)
# 附加梯度
params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
b_c, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
3.2 前向传播过程
python
def init_lstm_state(batch_size, num_hiddens, device):
"""初始化长短期记忆网络的隐状态"""
return (torch.zeros(size=(batch_size, num_hiddens), device=device),
torch.zeros(size=(batch_size, num_hiddens), device=device))
def lstm(inputs, state, params):
W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params
H, C = state
outputs = []
for X in inputs:
I = torch.sigmoid(X @ W_xi + H @ W_hi + b_i) # 输入门
F = torch.sigmoid(X @ W_xf + H @ W_hf + b_f) # 遗忘门
O = torch.sigmoid(X @ W_xo + H @ W_ho + b_o) # 输出门
C_tilda = torch.tanh(X @ W_xc + H @ W_hc + b_c) # 候选记忆元
C = F * C + I * C_tilda # 记忆元
H = O * torch.tanh(C) # 隐状态
Y = H @ W_hq + b_q # 输出
outputs.append(Y)
return torch.cat(outputs, dim=0), (H, C)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(vocab_size, num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
3.3 使用高级API快速搭建
python
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, vocab_size)
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
四、总结
本文详细介绍了长短期记忆网络(LSTM)的基本机制及其实现方法,主要内容包括:
- 门控记忆元 :通过输入门、遗忘门和输出门,LSTM 控制了信息的写入、保留和输出。 例如,输入门的计算公式为: <math xmlns="http://www.w3.org/1998/Math/MathML"> I t = σ ( X t W x i + H t − 1 W h i + b i ) \boxed{I_t = \sigma\bigl(X_t W_{xi} + H_{t-1} W_{hi} + b_i\bigr)} </math>It=σ(XtWxi+Ht−1Whi+bi)
- 候选记忆元与记忆元更新 :通过生成候选记忆元 <math xmlns="http://www.w3.org/1998/Math/MathML"> C ~ t \tilde{C}t </math>C~t 并结合遗忘门和输入门,LSTM 实现了对过去和新信息的平衡更新。更新公式为: <math xmlns="http://www.w3.org/1998/Math/MathML"> C t = F t ⊙ C t − 1 + I t ⊙ C ~ t \boxed{C_t = F_t \odot C{t-1} + I_t \odot \tilde{C}_t} </math>Ct=Ft⊙Ct−1+It⊙C~t
- 隐状态的计算 :依靠输出门对记忆元的调控,隐状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> H t H_t </math>Ht 的计算公式为: <math xmlns="http://www.w3.org/1998/Math/MathML"> H t = O t ⊙ tanh ( C t ) \boxed{H_t = O_t \odot \tanh(C_t)} </math>Ht=Ot⊙tanh(Ct)
- 模型实现:不仅展示了从零实现 LSTM 的详细过程,还可以利用 PyTorch 内置的高级 API 轻松构建 LSTM 模型,以便在实际项目中快速部署和训练。
长短期记忆网络作为解决长距离依赖问题的重要模型,其思想在自然语言处理、时间序列分析等领域都有广泛应用。希望本篇博客能够帮助你轻松理解 LSTM 的核心原理,并激发你进一步探索深度学习技术的兴趣!