长短期记忆网络(LSTM)基本原理详解
一、LSTM核心思想
目标 :解决传统RNN的梯度消失/爆炸问题,显式建模长期依赖关系
核心创新 :引入细胞状态(Cell State)和门控机制 ,通过三个门结构精确控制信息流动
二、网络结构分解
1. 核心组件(四个关键部分)
组件 | 符号 | 功能描述 |
---|---|---|
遗忘门 | f t f_t ft | 决定从细胞状态中丢弃哪些信息 |
输入门 | i t i_t it | 确定新信息存入细胞状态的比例 |
候选值 | C ~ t \tilde{C}_t C~t | 生成待存入细胞状态的新候选值 |
输出门 | o t o_t ot | 控制细胞状态到隐藏状态的输出比例 |
2. 数学公式推导
遗忘门(Forget Gate)
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)
- σ \sigma σ: Sigmoid函数(输出0-1间的遗忘比例)
输入门(Input Gate)
i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)
候选细胞状态
C ~ t = tanh ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC)
细胞状态更新
C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ft⊙Ct−1+it⊙C~t
- ⊙ \odot ⊙: Hadamard积(逐元素相乘)
输出门(Output Gate)
o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)
隐藏状态计算
h t = o t ⊙ tanh ( C t ) h_t = o_t \odot \tanh(C_t) ht=ot⊙tanh(Ct)
三、PyTorch实现
1. LSTM单元实现
python
import torch
import torch.nn as nn
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# 合并计算四个门的参数矩阵
self.W = nn.Linear(input_size + hidden_size, 4*hidden_size)
def forward(self, x, state):
# state = (h, c)
h_prev, c_prev = state
# 合并输入与隐藏状态
combined = torch.cat((x, h_prev), dim=1)
gates = self.W(combined)
# 分割四个门计算结果
f, i, o, g = torch.split(gates, self.hidden_size, dim=1)
# 激活函数应用
f = torch.sigmoid(f) # 遗忘门
i = torch.sigmoid(i) # 输入门
o = torch.sigmoid(o) # 输出门
g = torch.tanh(g) # 候选值
# 更新细胞状态
c = f * c_prev + i * g
# 更新隐藏状态
h = o * torch.tanh(c)
return (h, c)