长短期记忆网络(LSTM)

长短期记忆网络(LSTM)基本原理详解

一、LSTM核心思想

目标 :解决传统RNN的梯度消失/爆炸问题,显式建模长期依赖关系
核心创新 :引入细胞状态(Cell State)门控机制 ,通过三个门结构精确控制信息流动


二、网络结构分解

1. 核心组件(四个关键部分)

组件 符号 功能描述
遗忘门 f t f_t ft 决定从细胞状态中丢弃哪些信息
输入门 i t i_t it 确定新信息存入细胞状态的比例
候选值 C ~ t \tilde{C}_t C~t 生成待存入细胞状态的新候选值
输出门 o t o_t ot 控制细胞状态到隐藏状态的输出比例

2. 数学公式推导

遗忘门(Forget Gate)

f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)

  • σ \sigma σ: Sigmoid函数(输出0-1间的遗忘比例)
输入门(Input Gate)

i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)

候选细胞状态

C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC)

细胞状态更新

C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ft⊙Ct−1+it⊙C~t

  • ⊙ \odot ⊙: Hadamard积(逐元素相乘)
输出门(Output Gate)

o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)

隐藏状态计算

h t = o t ⊙ tanh ⁡ ( C t ) h_t = o_t \odot \tanh(C_t) ht=ot⊙tanh(Ct)


三、PyTorch实现

1. LSTM单元实现

python 复制代码
import torch
import torch.nn as nn

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 合并计算四个门的参数矩阵
        self.W = nn.Linear(input_size + hidden_size, 4*hidden_size)
        
    def forward(self, x, state):
        # state = (h, c)
        h_prev, c_prev = state
        
        # 合并输入与隐藏状态
        combined = torch.cat((x, h_prev), dim=1)
        gates = self.W(combined)
        
        # 分割四个门计算结果
        f, i, o, g = torch.split(gates, self.hidden_size, dim=1)
        
        # 激活函数应用
        f = torch.sigmoid(f)  # 遗忘门
        i = torch.sigmoid(i)  # 输入门
        o = torch.sigmoid(o)  # 输出门
        g = torch.tanh(g)     # 候选值
        
        # 更新细胞状态
        c = f * c_prev + i * g
        # 更新隐藏状态
        h = o * torch.tanh(c)
        
        return (h, c)
相关推荐
AI视觉网奇12 分钟前
公式动画软件学习笔记
人工智能·公式绘图
天天代码码天天15 分钟前
C# OnnxRuntime 部署 DDColor
人工智能·ddcolor
惠惠软件16 分钟前
豆包 AI 学习投喂与排名优化指南
人工智能·学习·语音识别
数据中心的那点事儿16 分钟前
从设计到运营全链破局 恒华智算专场解锁产业升级密码
大数据·人工智能
FluxMelodySun20 分钟前
机器学习(三十三) 概率图模型与隐马尔可夫模型
人工智能·机器学习
深兰科技25 分钟前
深兰科技与淡水河谷合作推进:矿区示范加速落地
java·人工智能·python·c#·scala·symfony·深兰科技
V搜xhliang024629 分钟前
OpenClaw、AI大模型赋能数据分析与学术科研 学习
人工智能·深度学习·学习·机器学习·数据挖掘·数据分析
PHOSKEY31 分钟前
3D工业相机对焊后缺陷全检——机械手焊接系统质量控制的最后关口
人工智能
Aaron158832 分钟前
8通道测向系统演示科研套件
人工智能·算法·fpga开发·硬件工程·信息与通信·信号处理·基带工程
每天进步一点点️37 分钟前
AI芯片制造的“择优录用”:解读 APU Cluster4 的 Harvesting 机制
人工智能·soc片上系统·半导体芯片