《动手学深度学习》-57长短期记忆网络LSTM

一、长短期记忆网络LSTM

1. LSTM 是什么?为什么需要它?

LSTM(Long Short-Term Memory)是一类改进的循环神经网络(RNN)。它的目标是解决普通 RNN 在长序列中常见的两个问题:

  1. 梯度消失/梯度爆炸:信息跨越很多时间步后很难有效传播,模型学不到长期依赖。

  2. 记忆不稳定:普通 RNN 的隐藏状态每一步都会被新的输入强烈"覆盖",长期信息容易被冲掉。

LSTM 的关键思想是引入一条更适合长期保存信息的"记忆通道",并用门控机制对信息进行选择性保留、写入与输出,从而更稳定地学习长程依赖。


2. LSTM 的两个状态:h 和 c

LSTM 在每个时间步维护两种状态:

  • 隐藏状态 ht:对外输出的状态,常用于预测(比如接线性层输出 vocab 概率)。

  • 细胞状态 ct:更像"长期记忆存储",专门用来跨时间保存信息。

理解上可以这样分工:

  • ct:像笔记本,内容可以长期保存;

  • ht:像你此刻说出口的话,是笔记本内容的一部分投影/展示。

3. LSTM 的门:忘记门、输入门、输出门

LSTM 通过三个门控制记忆的"删、写、读"。三个门的值都在 0~1 (由 sigmoid 输出),并且是逐元素控制的:隐藏维度的每一维都有自己的开关。

3.1 忘记门(Forget Gate)ft:决定"旧记忆留多少"

忘记门控制上一时刻的细胞状态 ct−1 有多少要保留到当前:

  • ft≈1:保留旧记忆(长期依赖更容易学到)

  • ft≈0:丢弃旧记忆(清除不相关信息,防止污染)

影响:

  • 让模型能"选择性遗忘",避免无关信息一直累积;

  • 当需要长期信息时,忘记门可以让记忆更稳定地跨越很多步。

3.2 输入门(Input Gate)it:决定"新记忆写多少"

输入门控制当前输入生成的新信息是否写入记忆。LSTM 会先生成候选记忆(由 tanh 输出),再用输入门决定写多少:

  • it≈1:写入较多新信息

  • it≈0:几乎不写入(说明当前输入可能是噪声/短期波动)

影响:

  • 抑制噪声,避免每一步都把输入强行写入记忆;

  • 让"记忆更新"变得可控,不会被短期变化频繁扰动。

3.3 输出门(Output Gate)ot:决定"记忆对外说多少"

输出门控制细胞状态 ct中哪些信息会被"展示"为隐藏状态 ht。即使某些信息保存在记忆里,也不一定要立刻输出:

  • ot≈1:输出更多记忆内容 → ht 更"公开"

  • ot≈0:暂时不输出 → "心里记着但不说"

影响:

  • 记忆与输出解耦:模型可以保存信息但选择在合适的时刻再用;

  • 对生成任务(语言模型)尤其重要:模型可控地决定何时利用长期信息。

二、代码

复制代码
import torch
from torch import nn
from torch.nn import functional as F
import test_55RNNesay_realize
import d2l
import test_53LanguageModel
import test_55RNNdifficult_realize
batch_size,num_steps=32,35
train_iter,vocab=test_53LanguageModel.load_data_time_machine(batch_size,num_steps)
def get_lstm_params(vocab_size,num_hiddens,device):
    num_inputs=num_outputs=vocab_size
    def normal(shape):
        return torch.randn(shape,device=device)*0.01
    def three():#参数值初始化
        return (normal((num_inputs,num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens,device=device))
    W_xi, W_hi, b_i=three()#输入门参数
    W_xf, W_hf, b_f = three()#遗忘门参数
    W_xo, W_ho, b_o = three()#输出门参数
    W_xc, W_hc, b_c = three()#候选记忆元参数
    #输出层参数
    W_hq=normal((num_hiddens,num_outputs))
    b_q=torch.zeros(num_outputs,device=device)
    #附加梯度
    params=[W_xi, W_hi, b_i,W_xf, W_hf, b_f,W_xo, W_ho, b_o,W_xc, W_hc, b_c,W_hq,b_q]#参数有3种,更新门、重置门和模型参数
    for param in params:#令所有参数都可以求梯度
        param.requires_grad=True
    return params
def init_LSTM_state(batch_size,num_hiddens,device):
    return (torch.zeros((batch_size,num_hiddens),device=device),
            torch.zeros((batch_size,num_hiddens),device=device))#单元素(x,)表示元组
def lstm(inputs,state,params):
    W_xi, W_hi, b_i,W_xf, W_hf, b_f,W_xo, W_ho, b_o,W_xc, W_hc, b_c,W_hq,b_q=params
    (H, C)=state
    outputs=[]
    for X in inputs:
        I=torch.sigmoid((X @ W_xi)+(H @ W_hi)+b_i)
        F=torch.sigmoid((X @ W_xf)+(H @ W_hf)+b_f)
        O=torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda=torch.tanh((X @ W_xc)+(H @ W_hc)+b_c)
        C=F*C+I*C_tilda
        H=O*(torch.tanh(C))
        y=(H @ W_hq)+b_q
        outputs.append(y)
    return torch.cat(outputs,dim=0),(H,C)
vocab_size,num_hiddens,device=len(vocab),256,d2l.try_gpu()
num_epochs,lr=500,1
model=test_55RNNdifficult_realize.RNNModelScratch(len(vocab),num_hiddens,d2l.try_gpu(),get_lstm_params,init_LSTM_state,lstm)
test_55RNNdifficult_realize.train_ch8(model,train_iter,vocab,lr,num_epochs,device)
#简约实现
num_inputs=vocab_size
lstm=nn.LSTM(input_size=num_inputs,hidden_size=num_hiddens)
model=test_55RNNesay_realize.RNNModel(lstm,len(vocab))
model=model.to(d2l.try_gpu())
test_55RNNdifficult_realize.train_ch8(model,train_iter,vocab,lr,num_epochs,device)

三、总结

代码结构跟gru几乎类似,是gru升级,加入了新的门,但是模型框架几乎相同,需要注意state初始化需要更新两层,隐藏层和记忆层

相关推荐
LASDAaaa12312 小时前
基于DETR的花卉种类识别与分类系统详解
人工智能·数据挖掘
数琨创享TQMS质量数智化2 小时前
国有大型交通运输设备制造集团QMS质量管理平台案例
大数据·人工智能·物联网
yhdata2 小时前
绿色能源新动力:硫酸亚铁助力锂电池产业,年复合增长率攀升至14.8%
大数据·人工智能
围炉聊科技2 小时前
从机械扫描到逻辑阅读:DeepSeek-OCR 2的技术革新
人工智能·ocr
范桂飓2 小时前
Transformer 大模型架构深度解析(5)GPT 与 LLM 大语言模型技术解析
人工智能·gpt·语言模型·transformer
charlie1145141912 小时前
机器学习概论:一门教计算机如何“不确定地正确”的学问
人工智能·笔记·机器学习·工程实践
凡。。。2962 小时前
APS概念-EOQ模型
人工智能·制造
FreeBuf_2 小时前
MEDUSA安全测试工具:集成74种扫描器与180余项AI Agent安全规则
人工智能·安全
迅为电子2 小时前
迅为iTOP-Hi3403开发板:解锁多目拼接相机的10.4TOPS强“芯”动力,开启4K智能视觉新纪元
人工智能·itop-hi3403开发板·海思hi3403·多目拼接相机