1. LSTM 简介
循环神经网络(RNN)在处理序列数据(如文本、时间序列、语音)时具有天然优势,但传统 RNN 存在 梯度消失和梯度爆炸 的问题,难以捕捉长距离依赖关系。
长短期记忆网络(Long Short-Term Memory,LSTM) 是一种特殊的 RNN 变体,由 Hochreiter 和 Schmidhuber 在 1997 年提出。它通过 门控机制(Gates) 控制信息的保留与遗忘,解决了 RNN 在长序列训练中的问题。
2. LSTM 的核心结构
LSTM 的基本单元包含三个门和一个单元状态(Cell State):
-
遗忘门(Forget Gate):决定丢弃多少历史信息。
-
输入门(Input Gate):决定当前输入信息保留多少。
-
输出门(Output Gate):决定当前时刻的隐藏状态输出。
-
单元状态(Cell State):类似"传送带",携带长期信息。
结构示意
可以把 LSTM 看作在传统 RNN 的基础上,多了一条 信息高速通道(Cell State),通过门控机制有选择地更新。
3. LSTM 数学公式
设输入为 ,上一时刻的隐藏状态为
,上一时刻的单元状态为
,则 LSTM 的计算公式如下:
(1)遗忘门(Forget Gate):
(2)输入门(Input Gate):
(3)候选单元状态(Cell Candidate):
(4)单元状态更新(Cell State):
(5)输出门(Output Gate):
(6)隐藏状态更新(Hidden State):
其中:
-
表示 Sigmoid 激活函数;
-
表示双曲正切激活函数;
-
表示元素逐位相乘。
4. LSTM 的优点
-
解决长依赖问题:能够捕捉数百步的时间依赖关系。
-
避免梯度消失/爆炸:门控机制使得梯度能在长序列中稳定传播。
-
广泛应用:文本生成、机器翻译、语音识别、金融时间序列预测等。
5. LSTM 的缺点
-
结构复杂:相较于 RNN,参数更多,计算开销更大。
-
训练速度慢:长序列数据下,训练时间成本较高。
-
难以并行:依赖序列前后顺序,难以像 Transformer 那样并行化。
6. Python 实现示例(PyTorch)
python
import torch
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers=1):
super(LSTMModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
# x shape: (batch_size, seq_len, input_size)
out, (h_n, c_n) = self.lstm(x)
out = self.fc(out[:, -1, :]) # 取最后时间步的隐藏状态
return out
# 示例
input_size = 10 # 每个时间步输入维度
hidden_size = 50 # 隐藏层维度
output_size = 1 # 输出维度
model = LSTMModel(input_size, hidden_size, output_size)
x = torch.randn(32, 5, 10) # batch_size=32, seq_len=5, input_size=10
y = model(x)
print(y.shape) # (32, 1)
7. 总结
-
RNN 善于处理短期依赖,但难以记忆长期信息。
-
LSTM 引入门控机制,通过遗忘门、输入门、输出门和单元状态,有效解决了梯度消失问题,能更好地建模长序列。
-
但 LSTM 依然存在计算慢、难以并行的缺点,这也是后来 GRU 和 Transformer 出现的原因。