LSTM(Long Short-Term Memory)个人理解

作为学习长短期记忆网络的笔记,方便复习与理解。

1.LSTM细胞(LSTM Cell)

每个LSTM层都由若干个LSTM细胞(LSTM Cell)组成,其结构如下所示:

其公式如图所示:

看起来非常复杂难懂,先从图下手,把图中的三个门和两个细胞状态打包起来,如图所示:

这样就简单了许多。

①三个门

简单来说,三个门的输出范围都在[0,1],控制的是比例,三门分别对应:

  • 遗忘门 (Forget Gate) : 控制过去记忆的保留比例,其输出为
  • 输入门 (Input Gate) : 控制新信息的学习比例,其输出为
  • 输出门 (Output Gate) :控制内部记忆转化为当前输出的比例,其输出为

门的结构如图所示:

其输入通过各自对应的投影矩阵(,角标对应其英文名开头)投影变成预激活值,再通过激活函数Sigmoid变为最终的输出(门控信号)。

PS:Sigmoid激活函数的作用

核心作用就是将输入映射到 (0, 1) 区间,输出一个控制信息的比例系数**。**

②两个细胞状态

候选细胞状态(Candidate Cell State) 是LSTM细胞根据当前的新输入和上一时刻的短期记忆计算出的一个潜在的新记忆内容。

其计算方式为:将输入通过自己对应的投影矩阵(,角标对应其英文名开头)投影变成预激活值,再通过激活函数Tanh变为最终的输出(细胞状态)。

其结构如下图所示:

是不是和前面的三个门很相似?事实上它们的计算方式确实一模一样,区别只在于投影矩阵和激活函数。

然后是压缩细胞状态(The squashed cell state,或者 ,需要注意的是,压缩细胞状态并不是一个专业的称呼,这里只是为了把原图中的所有激活函数都打包起来,让原图看起来不是那么复杂才进行的命名,它的计算方式非常简单,就是直接将作为预激活值放入激活函数tanh得到的输出(细胞状态),所以一般直接称呼它为

PS:Tanh激活函数的作用

  1. 引入非线性,学习更复杂的知识。
  2. 压缩范围,防止梯度爆炸。

针对第一点,个人理解是,如果把现实比作一个复杂函数,那么只有线性,就像是只允许画直线,只能拟合其中简单、平坦的规律,而非线性激活函数则相当于提供了可以自由描绘曲线的能力,使得神经网络能够弯曲、转折,从而拟合这个函数中可能遇到的各种复杂、曲折的映射关系。

针对第二点,则是由于细胞状态c_t在理论上由于遗忘门和输入门的累积效应可以无限增长,如果不加处理直接给h_t输出,可能导致h_t反向传播时梯度爆炸。

在候选细胞状态的计算中,tanh还能将候选细胞状态限制在[-1, 1],防止新添加的信息值域过大,直接破坏当前细胞状态。

③门和细胞状态的对应

如上文所述,门的输出是控制各个信息和记忆的比例,而LSTM中的记忆都由细胞状态(Cell State)储存,所以实际上每个门要发挥作用都要和对应的细胞状态进行对应元素相乘的计算。

  1. 遗忘门:遗忘门用于控制过去记忆的保留比例,对应的细胞状态是上一个时间步的细胞状态,其相乘的式子为
  2. 输入门:输入门用于控制新信息的学习比例,对应的是候选细胞状态,其相乘的式子为
  3. 输出门:控制内部记忆转化为当前输出的比例,对应的是压缩细胞状态,其相乘的式子为

将这里的内容与前文的LSTM细胞图进行对比,从图中找到门与其对应的细胞状态辅助理解。

④公式

有了以上的了解,公式就很容易看明白了,整体公式如下所示:

其中,就是前文图中的时间步

分别对应的是遗忘门、输入门、输出门的输出。

代表各个门的线性投影(权重矩阵与输入向量的乘积)

是三个门对应的偏置项,也是可学习的参数,它直接与线性投影的结果相加。

分别表示的Sigmoid激活函数和Tanh激活函数。

表示逐项相乘(逐项相乘也叫Hadamard积,与矩阵的乘法计算方式不同)

表示当前时间步的细胞状态,它等于经过了一定遗忘的上一时间步的细胞状态与一定比例候选细胞状态的和,代表长期记忆。

表示当前时间步的隐藏状态,它等于一定比例的压缩细胞状态,代表短期记忆。

PS:偏置项

  • 是什么?
    偏置是一个可学习的偏移量,用于调整神经元(或门)的激活阈值。
  • 作用是什么?
    核心作用是为每个门控提供一个可学习的、独立于输入的基准激活倾向。
  • 怎么用?
    它直接与线性投影(权重矩阵与输入向量的乘积)的结果相加。
  • 举个例子?
    比如遗忘门偏置常初始化为正数(如1.0) ,让最开始细胞状态在没有学习到多少内容时,让遗忘门控制过去记忆的保留比例尽可能大一些,避免还没学到东西就开始大量遗忘内容。

2.LSTM模块

PS:LSTM模块,LSTM层,LSTM细胞

为了方便区分下文中提到的各种LSTM,我会将其分为LSTM模块,LSTM层,LSTM细胞来进行描述,三者分别代表:

  • LSTM模块:如a = nn.LSTM(),这里的lstm就是一个LSTM模块
  • LSTM层:如b = nn.LSTM(num_layers=2),代表这个b中包含两个LSTM层
  • LSTM细胞:如前文定义,是LSTM层里的小组件

常见的LSTM的文章介绍的时候基本上都会给一个时间展开图,在展开图里,上一个时间步的细胞状态和隐藏状态从左边输入进LSTM细胞,下面输入当前时间步的输入特征,经过LSTM细胞计算后,再从右边输出当前时间步的细胞状态和隐藏状态,如下图所示:

需要注意的是,这里的时间展开图表示的是一个LSTM细胞在不同时间步所进行的运算输入输出,并不代表多个LSTM细胞。

事实上,一个LSTM层只有一个LSTM细胞,其输出为一个元组,代表当前时间步的隐藏状态与细胞状态,两者的维度由hidden_size的值决定,也就是隐藏层的大小,下面是Pytorch中LSTM模块的一个简单实例化:

python 复制代码
import torch
import torch.nn as nn

# 最简单的LSTM实例化
lstm = nn.LSTM(input_size=100, hidden_size=50, num_layers=1)

# 参数说明:
# input_size: 输入特征的维度
# hidden_size: 隐藏状态的维度
# num_layers: LSTM的层数

这里参数hidden_size=50,代表这个LSTM模块中的LSTM细胞的隐藏状态与细胞状态的维度为50,num_layers=1,代表这个LSTM模块中只有一个LSTM层。

①单层LSTM

假设一个LSTM模块的参数num_layers为1,即这个模块只包含一个LSTM层,这样的模块就被称为单层LSTM。

举一个简单的单层LSTM的例子:

python 复制代码
# 输入数据的形状为(batch_size=3, sequence_length=10, input_size=100)
input_data = torch.randn(3, 10, 100)
# 一个简单的单层LSTM
lstm = nn.LSTM(input_size=100, hidden_size=64, num_layers=1, batch_first=True)

其输入的数据形状为(3,10,100),代表:

  • batch_size=3,代表输入是三个独立的序列
  • sequence_length=10,代表每个序列有10个时间步
  • input_size=100,代表每个时间步的向量维度是100

那么数据在经过这个LSTM层时,这个LSTM层里的LSTM细胞会被重复调用sequence_length=10次,并且在第个时间步输入当前时间步所对应的input_size=100个特征,并且输出两个形状为(batch_size=3,hidden_size=64)的张量,分别代表每个batch当前时间步的隐藏状态和细胞状态。

LSTM层会反复调用LSTM细胞sequence_length=10次,保留每个时间步的隐藏状态和最后时间步的隐藏状态和细胞状态,最后得到一个张量和元组,其中张量的形状为(batch_size=3, sequence_length=10 ,hidden_size=64),对应每个batch、每个时间步的隐藏状态,元组包括两个形状为(num_layers=1,batch_size=3,hidden_size=64)的张量,它们分别对应最后一个时间步(这里指第sequence_length=10个时间步)每个batch的隐藏状态和细胞状态。

在训练和预测时,前者的隐藏状态张量都会作为这个LSTM层的主要输出被使用,而后者的隐藏状态和细胞状态在预测时通常被忽略,在训练时,则根据具体任务需求决定是否继续使用这些状态。

②堆叠LSTM(多层LSTM)

假设一个LSTM模块的参数num_layers不为1,那么代表这个模块包含多个LSTM层,这样的模块就被称为堆叠LSTM(多层LSTM)。

将前文中单层LSTM例子的参数num_layers修改为2,得到一个新的例子:

python 复制代码
# 输入数据的形状为(batch_size=3, sequence_length=10, input_size=100)
input_data = torch.randn(3, 10, 100)
# 一个简单的堆叠LSTM
lstm = nn.LSTM(input_size=100, hidden_size=64, num_layers=2, batch_first=True)

此时,这里的LSTM模块中包括两个叠加的LSTM层(为了区分这两个堆叠的LSTM层,称靠近输入的为底层LSTM,另一个为上层LSTM)。

在这个情况下,底层LSTM的输入输出和前文提到的参数num_layers为1的LSTM层一模一样,最终得到一个最后得到一个形状为(batch_size=3, sequence_length=10 ,hidden_size=64)张量和包括两个形状为(1,batch_size=3,hidden_size=64)的元组。

而上层LSTM以底层LSTM输出的第一个张量作为输入,第二个元组中保存的最后一个时间步(这里指第sequence_length=10个时间步)每个batch的隐藏状态和细胞状态作为该LSTM层的隐藏状态和细胞状态的初始状态(即用这个输出初始化该层的隐藏状态和细胞状态)。

并且上层LSTM中的LSTM细胞与底层LSTM的LSTM细胞不同,底层细胞的输入为input_size=100, 而上层细胞的输入为hidden_size=64,如果在上层LSTM之上继续堆叠定LSTM层,即参数设时num_layers>2,那么后面的LSTM层的输入数据维度、输出数据维度依旧与上层LSTM相同,并且数据来源一样是自己的上一层LSTM。

与单层LSTM对比,堆叠LSTM通过增加深度提升了模型表达能力但代价是训练更复杂且易过拟合,而单层LSTM训练简单稳定但建模能力有限。

③双向LSTM

前文只讲述了单向的LSTM层,而双向LSTM层会同时使用两个LSTM层,分别从前向后和从后向前两个方向处理输入的序列,也就是说,在同样的num_layers的参数下,双向LSTM会有两倍的LSTM层。

在pytorch中,可以通过设置参数bidirectional=True来使用双向LSTM层。

在双向LSTM中,正向的LSTM层与反向的LSTM层的理论机制完全相同,不同的是,反向LSTM层会将输入的正向序列反转为反向序列后再进行处理,此时的双向LSTM既可以学习到正向的知识,也可以学习到反向的知识,在处理自然语言的数据时,能够十分有效地捕捉到上下文的信息。

这里不再对其理论机制进行过多赘述,只给出两者的输出维度对比。

python 复制代码
import torch
import torch.nn as nn

input_data = torch.randn(3, 10, 100)

# 单向LSTM
unidirectional = nn.LSTM(100, 64, bidirectional=False, batch_first=True)
output_uni, (hn_uni, cn_uni) = unidirectional(input_data)
print("=== 单向LSTM ===")
print(f"输出形状: {output_uni.shape}")      # [3, 10, 64]
print(f"隐藏状态形状: {hn_uni.shape}")       # [1, 3, 64]
print(f"细胞状态形状: {cn_uni.shape}")        # [1, 3, 64]

# 双向LSTM  
bidirectional = nn.LSTM(100, 64, bidirectional=True, batch_first=True)
output_bi, (hn_bi, cn_bi) = bidirectional(input_data)
print("\n=== 双向LSTM ===")
print(f"输出形状: {output_bi.shape}")       # [3, 10, 128]
print(f"隐藏状态形状: {hn_bi.shape}")        # [2, 3, 64]
print(f"细胞状态形状: {cn_bi.shape}")         # [2, 3, 64]

当num_layers取2时:

python 复制代码
import torch
import torch.nn as nn

input_data = torch.randn(3, 10, 100)

# 单向LSTM,2层
unidirectional = nn.LSTM(100, 64, num_layers=2, bidirectional=False, batch_first=True)
output_uni, (hn_uni, cn_uni) = unidirectional(input_data)

print("=== 单向LSTM (2层) ===")
print(f"输出形状: {output_uni.shape}")      # [3, 10, 64]
print(f"隐藏状态形状: {hn_uni.shape}")       # [2, 3, 64]
print(f"细胞状态形状: {cn_uni.shape}")        # [2, 3, 64]

# 双向LSTM,2层  
bidirectional = nn.LSTM(100, 64, num_layers=2, bidirectional=True, batch_first=True)
output_bi, (hn_bi, cn_bi) = bidirectional(input_data)

print("\n=== 双向LSTM (2层) ===")
print(f"输出形状: {output_bi.shape}")       # [3, 10, 128]
print(f"隐藏状态形状: {hn_bi.shape}")        # [4, 3, 64]
print(f"细胞状态形状: {cn_bi.shape}")         # [4, 3, 64]

# 对于双向2层LSTM的隐藏状态:
print(f"hn_bi形状: {hn_bi.shape}")  # [4, 3, 64]

# 按层和方向访问:
layer0_forward = hn_bi[0]   # 第0层,前向
layer0_backward = hn_bi[1]  # 第0层,反向  
layer1_forward = hn_bi[2]   # 第1层,前向
layer1_backward = hn_bi[3]  # 第1层,反向

3.总结

1. 基本单元

  • LSTM细胞:包含3个门控 + 2个状态的核心计算单元
  • 三个门:遗忘门、输入门、输出门(Sigmoid激活,控制信息流)
  • 两个状态:细胞状态cₜ(长期记忆)、隐藏状态hₜ(短期记忆)

2. 网络结构

  • 单层LSTM:1个LSTM细胞处理所有时间步
  • 堆叠LSTM:多层LSTM,上层以下层的输出为输入
  • 双向LSTM:正向+反向两个LSTM,输出拼接
相关推荐
翔云 OCR API7 小时前
基于深度学习与OCR研发的报关单识别接口技术解析
人工智能·深度学习·ocr
wwlsm_zql7 小时前
京津冀工业智能体赋能:重构产业链升级新篇章
人工智能·重构
lzjava20247 小时前
Spring AI实现一个智能客服
java·人工智能·spring
hweiyu007 小时前
数据挖掘 miRNA调节网络的构建(视频教程)
人工智能·数据挖掘
飞哥数智坊7 小时前
AI Coding 新手常见的3大误区
人工智能·ai编程
3Bronze1Pyramid7 小时前
深度学习参数优化
人工智能·深度学习
笨笨没好名字7 小时前
自然语言处理(NLP)之文本预处理:词元化——以《时间机器》文本数据集为例
人工智能·自然语言处理
skywalk81637 小时前
简单、高效且低成本的预训练、微调与服务,惠及大众基于 Ray 架构设计的覆盖大语言模型(LLM)完整生命周期的解决方案byzer-llm
人工智能·语言模型·自然语言处理
urkay-7 小时前
Android Cursor AI代码编辑器
android·人工智能·编辑器·iphone·androidx