长短期记忆网络(LSTM)

长短期记忆网络(LSTM)基本原理详解

一、LSTM核心思想

目标 :解决传统RNN的梯度消失/爆炸问题,显式建模长期依赖关系
核心创新 :引入细胞状态(Cell State)门控机制 ,通过三个门结构精确控制信息流动


二、网络结构分解

1. 核心组件(四个关键部分)

组件 符号 功能描述
遗忘门 f t f_t ft 决定从细胞状态中丢弃哪些信息
输入门 i t i_t it 确定新信息存入细胞状态的比例
候选值 C ~ t \tilde{C}_t C~t 生成待存入细胞状态的新候选值
输出门 o t o_t ot 控制细胞状态到隐藏状态的输出比例

2. 数学公式推导

遗忘门(Forget Gate)

f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)

  • σ \sigma σ: Sigmoid函数(输出0-1间的遗忘比例)
输入门(Input Gate)

i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)

候选细胞状态

C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC)

细胞状态更新

C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ft⊙Ct−1+it⊙C~t

  • ⊙ \odot ⊙: Hadamard积(逐元素相乘)
输出门(Output Gate)

o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)

隐藏状态计算

h t = o t ⊙ tanh ⁡ ( C t ) h_t = o_t \odot \tanh(C_t) ht=ot⊙tanh(Ct)


三、PyTorch实现

1. LSTM单元实现

python 复制代码
import torch
import torch.nn as nn

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 合并计算四个门的参数矩阵
        self.W = nn.Linear(input_size + hidden_size, 4*hidden_size)
        
    def forward(self, x, state):
        # state = (h, c)
        h_prev, c_prev = state
        
        # 合并输入与隐藏状态
        combined = torch.cat((x, h_prev), dim=1)
        gates = self.W(combined)
        
        # 分割四个门计算结果
        f, i, o, g = torch.split(gates, self.hidden_size, dim=1)
        
        # 激活函数应用
        f = torch.sigmoid(f)  # 遗忘门
        i = torch.sigmoid(i)  # 输入门
        o = torch.sigmoid(o)  # 输出门
        g = torch.tanh(g)     # 候选值
        
        # 更新细胞状态
        c = f * c_prev + i * g
        # 更新隐藏状态
        h = o * torch.tanh(c)
        
        return (h, c)
相关推荐
OpenCSG39 分钟前
OpenCSG 2025年11月月报:智能体平台、AI技术合作与开源生态进展
人工智能·开源·opencsg·csghub
围炉聊科技1 小时前
当AI成为“大脑”:人类如何在机器时代找到不可替代的价值?
人工智能
لا معنى له1 小时前
残差网络论文学习笔记:Deep Residual Learning for Image Recognition全文翻译
网络·人工智能·笔记·深度学习·学习·机器学习
菜只因C1 小时前
深度学习:从技术本质到未来图景的全面解析
人工智能·深度学习
工业机器视觉设计和实现1 小时前
lenet改vgg训练cifar10突破71分
人工智能·机器学习
咚咚王者1 小时前
人工智能之数据分析 Matplotlib:第四章 图形类型
人工智能·数据分析·matplotlib
TTGGGFF2 小时前
人工智能:用Gemini 3一键生成3D粒子电子手部映射应用
人工智能·3d·交互
LitchiCheng2 小时前
Mujoco 基础:获取模型中所有 body 的 name, id 以及位姿
人工智能·python
Allen_LVyingbo2 小时前
面向医学影像检测的深度学习模型参数分析与优化策略研究
人工智能·深度学习
CareyWYR2 小时前
每周AI论文速递(251124-251128)
人工智能