一、长短期记忆网络LSTM
1. LSTM 是什么?为什么需要它?
LSTM(Long Short-Term Memory)是一类改进的循环神经网络(RNN)。它的目标是解决普通 RNN 在长序列中常见的两个问题:
-
梯度消失/梯度爆炸:信息跨越很多时间步后很难有效传播,模型学不到长期依赖。
-
记忆不稳定:普通 RNN 的隐藏状态每一步都会被新的输入强烈"覆盖",长期信息容易被冲掉。
LSTM 的关键思想是引入一条更适合长期保存信息的"记忆通道",并用门控机制对信息进行选择性保留、写入与输出,从而更稳定地学习长程依赖。
2. LSTM 的两个状态:h 和 c
LSTM 在每个时间步维护两种状态:
-
隐藏状态 ht:对外输出的状态,常用于预测(比如接线性层输出 vocab 概率)。
-
细胞状态 ct:更像"长期记忆存储",专门用来跨时间保存信息。
理解上可以这样分工:
-
ct:像笔记本,内容可以长期保存;
-
ht:像你此刻说出口的话,是笔记本内容的一部分投影/展示。
3. LSTM 的门:忘记门、输入门、输出门
LSTM 通过三个门控制记忆的"删、写、读"。三个门的值都在 0~1 (由 sigmoid 输出),并且是逐元素控制的:隐藏维度的每一维都有自己的开关。
3.1 忘记门(Forget Gate)ft:决定"旧记忆留多少"
忘记门控制上一时刻的细胞状态 ct−1 有多少要保留到当前:
-
ft≈1:保留旧记忆(长期依赖更容易学到)
-
ft≈0:丢弃旧记忆(清除不相关信息,防止污染)
影响:
-
让模型能"选择性遗忘",避免无关信息一直累积;
-
当需要长期信息时,忘记门可以让记忆更稳定地跨越很多步。
3.2 输入门(Input Gate)it:决定"新记忆写多少"
输入门控制当前输入生成的新信息是否写入记忆。LSTM 会先生成候选记忆(由 tanh 输出),再用输入门决定写多少:
-
it≈1:写入较多新信息
-
it≈0:几乎不写入(说明当前输入可能是噪声/短期波动)
影响:
-
抑制噪声,避免每一步都把输入强行写入记忆;
-
让"记忆更新"变得可控,不会被短期变化频繁扰动。
3.3 输出门(Output Gate)ot:决定"记忆对外说多少"
输出门控制细胞状态 ct中哪些信息会被"展示"为隐藏状态 ht。即使某些信息保存在记忆里,也不一定要立刻输出:
-
ot≈1:输出更多记忆内容 → ht 更"公开"
-
ot≈0:暂时不输出 → "心里记着但不说"
影响:
-
记忆与输出解耦:模型可以保存信息但选择在合适的时刻再用;
-
对生成任务(语言模型)尤其重要:模型可控地决定何时利用长期信息。

二、代码
import torch
from torch import nn
from torch.nn import functional as F
import test_55RNNesay_realize
import d2l
import test_53LanguageModel
import test_55RNNdifficult_realize
batch_size,num_steps=32,35
train_iter,vocab=test_53LanguageModel.load_data_time_machine(batch_size,num_steps)
def get_lstm_params(vocab_size,num_hiddens,device):
num_inputs=num_outputs=vocab_size
def normal(shape):
return torch.randn(shape,device=device)*0.01
def three():#参数值初始化
return (normal((num_inputs,num_hiddens)),
normal((num_hiddens, num_hiddens)),
torch.zeros(num_hiddens,device=device))
W_xi, W_hi, b_i=three()#输入门参数
W_xf, W_hf, b_f = three()#遗忘门参数
W_xo, W_ho, b_o = three()#输出门参数
W_xc, W_hc, b_c = three()#候选记忆元参数
#输出层参数
W_hq=normal((num_hiddens,num_outputs))
b_q=torch.zeros(num_outputs,device=device)
#附加梯度
params=[W_xi, W_hi, b_i,W_xf, W_hf, b_f,W_xo, W_ho, b_o,W_xc, W_hc, b_c,W_hq,b_q]#参数有3种,更新门、重置门和模型参数
for param in params:#令所有参数都可以求梯度
param.requires_grad=True
return params
def init_LSTM_state(batch_size,num_hiddens,device):
return (torch.zeros((batch_size,num_hiddens),device=device),
torch.zeros((batch_size,num_hiddens),device=device))#单元素(x,)表示元组
def lstm(inputs,state,params):
W_xi, W_hi, b_i,W_xf, W_hf, b_f,W_xo, W_ho, b_o,W_xc, W_hc, b_c,W_hq,b_q=params
(H, C)=state
outputs=[]
for X in inputs:
I=torch.sigmoid((X @ W_xi)+(H @ W_hi)+b_i)
F=torch.sigmoid((X @ W_xf)+(H @ W_hf)+b_f)
O=torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
C_tilda=torch.tanh((X @ W_xc)+(H @ W_hc)+b_c)
C=F*C+I*C_tilda
H=O*(torch.tanh(C))
y=(H @ W_hq)+b_q
outputs.append(y)
return torch.cat(outputs,dim=0),(H,C)
vocab_size,num_hiddens,device=len(vocab),256,d2l.try_gpu()
num_epochs,lr=500,1
model=test_55RNNdifficult_realize.RNNModelScratch(len(vocab),num_hiddens,d2l.try_gpu(),get_lstm_params,init_LSTM_state,lstm)
test_55RNNdifficult_realize.train_ch8(model,train_iter,vocab,lr,num_epochs,device)
#简约实现
num_inputs=vocab_size
lstm=nn.LSTM(input_size=num_inputs,hidden_size=num_hiddens)
model=test_55RNNesay_realize.RNNModel(lstm,len(vocab))
model=model.to(d2l.try_gpu())
test_55RNNdifficult_realize.train_ch8(model,train_iter,vocab,lr,num_epochs,device)
三、总结
代码结构跟gru几乎类似,是gru升级,加入了新的门,但是模型框架几乎相同,需要注意state初始化需要更新两层,隐藏层和记忆层