长短期记忆网络(LSTM)

长短期记忆网络(LSTM)基本原理详解

一、LSTM核心思想

目标 :解决传统RNN的梯度消失/爆炸问题,显式建模长期依赖关系
核心创新 :引入细胞状态(Cell State)门控机制 ,通过三个门结构精确控制信息流动


二、网络结构分解

1. 核心组件(四个关键部分)

组件 符号 功能描述
遗忘门 f t f_t ft 决定从细胞状态中丢弃哪些信息
输入门 i t i_t it 确定新信息存入细胞状态的比例
候选值 C ~ t \tilde{C}_t C~t 生成待存入细胞状态的新候选值
输出门 o t o_t ot 控制细胞状态到隐藏状态的输出比例

2. 数学公式推导

遗忘门(Forget Gate)

f t = σ ( W f ⋅ h t − 1 , x t + b f ) f_t = \sigma(W_f \cdot h_{t-1}, x_t + b_f) ft=σ(Wf⋅ht−1,xt+bf)

  • σ \sigma σ: Sigmoid函数(输出0-1间的遗忘比例)
输入门(Input Gate)

i t = σ ( W i ⋅ h t − 1 , x t + b i ) i_t = \sigma(W_i \cdot h_{t-1}, x_t + b_i) it=σ(Wi⋅ht−1,xt+bi)

候选细胞状态

C ~ t = tanh ⁡ ( W C ⋅ h t − 1 , x t + b C ) \tilde{C}_t = \tanh(W_C \cdot h_{t-1}, x_t + b_C) C~t=tanh(WC⋅ht−1,xt+bC)

细胞状态更新

C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ft⊙Ct−1+it⊙C~t

  • ⊙ \odot ⊙: Hadamard积(逐元素相乘)
输出门(Output Gate)

o t = σ ( W o ⋅ h t − 1 , x t + b o ) o_t = \sigma(W_o \cdot h_{t-1}, x_t + b_o) ot=σ(Wo⋅ht−1,xt+bo)

隐藏状态计算

h t = o t ⊙ tanh ⁡ ( C t ) h_t = o_t \odot \tanh(C_t) ht=ot⊙tanh(Ct)


三、PyTorch实现

1. LSTM单元实现

python 复制代码
import torch
import torch.nn as nn

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 合并计算四个门的参数矩阵
        self.W = nn.Linear(input_size + hidden_size, 4*hidden_size)
        
    def forward(self, x, state):
        # state = (h, c)
        h_prev, c_prev = state
        
        # 合并输入与隐藏状态
        combined = torch.cat((x, h_prev), dim=1)
        gates = self.W(combined)
        
        # 分割四个门计算结果
        f, i, o, g = torch.split(gates, self.hidden_size, dim=1)
        
        # 激活函数应用
        f = torch.sigmoid(f)  # 遗忘门
        i = torch.sigmoid(i)  # 输入门
        o = torch.sigmoid(o)  # 输出门
        g = torch.tanh(g)     # 候选值
        
        # 更新细胞状态
        c = f * c_prev + i * g
        # 更新隐藏状态
        h = o * torch.tanh(c)
        
        return (h, c)
相关推荐
火山引擎开发者社区4 小时前
没有长期记忆,Agent 谈何持续进化?一图看懂火山 Mem0:解锁 Agent 持续学习与进化之路
人工智能
冬奇Lab8 小时前
Workflow 系列(06):安全——跨步骤注入传播与四层防御
人工智能·工作流引擎
冬奇Lab8 小时前
每日一个开源项目(第149篇):RAG-Anything - 把图片、表格、公式当成一等公民的多模态 RAG 框架
人工智能·开源
米小虾8 小时前
AI Agent 安全实战指南:当智能体开始"不听话",开发者该如何应对?
人工智能·安全·agent
IT_陈寒10 小时前
Vite的热更新突然不香了,排查三小时差点砸键盘
前端·人工智能·后端
阿里云大数据AI技术12 小时前
构建高转化海外电商搜索:阿里云OpenSearch行业算法版的全链路智能优化策略实战
人工智能·搜索引擎
Awu122712 小时前
⚡从零开发 Agent CLI(五)实现一个可治理、可扩展的工具系统
前端·人工智能·claude
字节跳动视频云技术团队12 小时前
让 Agent 成为音视频工作台:AI MediaKit CLI + Skill 发布
人工智能·音视频开发
魏祖潇12 小时前
framework 整合实战——DDD/TDD/SDD 三件套在 framework 仓的真实落地
人工智能·后端