长短期记忆网络(LSTM)

长短期记忆网络(LSTM)基本原理详解

一、LSTM核心思想

目标 :解决传统RNN的梯度消失/爆炸问题,显式建模长期依赖关系
核心创新 :引入细胞状态(Cell State)门控机制 ,通过三个门结构精确控制信息流动


二、网络结构分解

1. 核心组件(四个关键部分)

组件 符号 功能描述
遗忘门 f t f_t ft 决定从细胞状态中丢弃哪些信息
输入门 i t i_t it 确定新信息存入细胞状态的比例
候选值 C ~ t \tilde{C}_t C~t 生成待存入细胞状态的新候选值
输出门 o t o_t ot 控制细胞状态到隐藏状态的输出比例

2. 数学公式推导

遗忘门(Forget Gate)

f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)

  • σ \sigma σ: Sigmoid函数(输出0-1间的遗忘比例)
输入门(Input Gate)

i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)

候选细胞状态

C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC)

细胞状态更新

C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ft⊙Ct−1+it⊙C~t

  • ⊙ \odot ⊙: Hadamard积(逐元素相乘)
输出门(Output Gate)

o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)

隐藏状态计算

h t = o t ⊙ tanh ⁡ ( C t ) h_t = o_t \odot \tanh(C_t) ht=ot⊙tanh(Ct)


三、PyTorch实现

1. LSTM单元实现

python 复制代码
import torch
import torch.nn as nn

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 合并计算四个门的参数矩阵
        self.W = nn.Linear(input_size + hidden_size, 4*hidden_size)
        
    def forward(self, x, state):
        # state = (h, c)
        h_prev, c_prev = state
        
        # 合并输入与隐藏状态
        combined = torch.cat((x, h_prev), dim=1)
        gates = self.W(combined)
        
        # 分割四个门计算结果
        f, i, o, g = torch.split(gates, self.hidden_size, dim=1)
        
        # 激活函数应用
        f = torch.sigmoid(f)  # 遗忘门
        i = torch.sigmoid(i)  # 输入门
        o = torch.sigmoid(o)  # 输出门
        g = torch.tanh(g)     # 候选值
        
        # 更新细胞状态
        c = f * c_prev + i * g
        # 更新隐藏状态
        h = o * torch.tanh(c)
        
        return (h, c)
相关推荐
L、2184 分钟前
深入理解CANN:面向AI加速的异构计算架构详解
人工智能·架构
chaser&upper10 分钟前
预见未来:在 AtomGit 解码 CANN ops-nn 的投机采样加速
人工智能·深度学习·神经网络
松☆13 分钟前
CANN与大模型推理:在边缘端高效运行7B参数语言模型的实践指南
人工智能·算法·语言模型
结局无敌20 分钟前
深度探究cann仓库下的infra:AI计算的底层基础设施底座
人工智能
m0_4665252920 分钟前
绿盟科技风云卫AI安全能力平台成果重磅发布
大数据·数据库·人工智能·安全
慢半拍iii21 分钟前
从零搭建CNN:如何高效调用ops-nn算子库
人工智能·神经网络·ai·cnn·cann
晟诺数字人26 分钟前
2026年海外直播变革:数字人如何改变游戏规则
大数据·人工智能·产品运营
蛋王派27 分钟前
DeepSeek-OCR-v2 模型解析和部署应用
人工智能·ocr
禁默32 分钟前
基于CANN的ops-cv仓库-多模态场景理解与实践
人工智能·cann
禁默40 分钟前
【硬核入门】无需板卡也能造 AI 算子?深度玩转 CANN ops-math 通用数学库
人工智能·aigc·cann