LSTM(长短期记忆网络)详解

1️⃣ LSTM介绍

标准的RNN存在梯度消失梯度爆炸问题,无法捕捉长期依赖关系。那么如何理解这个长期依赖关系呢?

例如,有一个语言模型基于先前的词来预测下一个词,我们有一句话 "the clouds are in the sky",基于"the clouds are in the",预测"sky",在这样的场景中,预测的词和提供的信息之间位置间隔是非常小的,如下图所示,RNN可以捕捉到先前的信息。

然而,针对复杂场景,我们有一句话"I grew up in France... I speak fluent French","French"基于"France"推断,但是它们之间的间隔很远很远,RNN 会丧失学习到连接如此远信息的能力。这就是长期依赖关系。

为了解决该问题,LSTM通过引入三种门遗忘门输入门输出门控制信息的流入和流出,有助于保留长期依赖关系,并缓解梯度消失【注意:没有梯度爆炸昂】。LSTM在1997年被提出


2️⃣ 原理

下面这张图是标准的RNN结构:

  • x t x_t xt是t时刻的输入
  • s t s_t st是t时刻的隐层输出, s t = f ( U ⋅ x t + W ⋅ s t − 1 ) s_t=f(U\cdot x_t+W\cdot s_{t-1}) st=f(U⋅xt+W⋅st−1),f表示激活函数, s t − 1 s_{t-1} st−1表示t-1时刻的隐层输出
  • h t h_t ht是t时刻的输出, h t = s o f t m a x ( V ⋅ s t ) h_t=softmax(V\cdot s_t) ht=softmax(V⋅st)

LSTM的整体结构如下图所示,第一眼看到,反正我是看不懂。前面讲到LSTM引入三种门遗忘门输入门输出门,现在我们逐一击破,一个个分析一下它们到底是什么。

这是3D视角的LSTM:

首先来看遗忘门,也就是下面这张图:

遗忘门输入包含两部分

  • s t − 1 s_{t-1} st−1:表示t-1时刻的短期记忆(即隐层输出),在LSTM中当前时间步的输出 h t − 1 h_{t-1} ht−1就是隐层输出 s t − 1 s_{t-1} st−1
  • x t x_t xt:表示t时刻的输入

遗忘门输出为 f t f_t ft,公式表示为:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t=\sigma\left(W_f\cdot[h_{t-1},x_t] + b_f\right) ft=σ(Wf⋅[ht−1,xt]+bf)

其中, W f W_f Wf和 b f b_f bf是遗忘门的参数, [ s t − 1 , x t ] [s_{t-1},x_t] [st−1,xt]表示concat操作。 σ ( ) \sigma() σ()表示sigmoid函数。

遗忘门定我们会从长期记忆中丢弃什么信息【理解为:删除什么日记】,输出一个在 0 到 1 之间的数值,1 表示"完全保留",0 表示"完全舍弃"。

然后来看输入门

输入门的输入包含两部分:

  • s t − 1 s_{t-1} st−1:表示t-1时刻的短期记忆
  • x t x_t xt:表示t时刻的输入

输入门的输出为新添加的内容 i t ∗ C ~ t i_t * \tilde{C}t it∗C~t,其具体操作为:
i t = σ ( W i ⋅ [ s t − 1 , x t ] + b i ) C ~ t = tanh ⁡ ( W C ⋅ [ s t − 1 , x t ] + b C ) \begin{aligned}i
{t}&=\sigma\left(W_i\cdot[s_{t-1},x_t] + b_i\right)\\\tilde{C}{t}&=\tanh(W_C\cdot[s{t-1},x_t] + b_C)\end{aligned} itC~t=σ(Wi⋅[st−1,xt]+bi)=tanh(WC⋅[st−1,xt]+bC)

输入门决定什么样的新信息被加入到长期记忆(即细胞状态)中【理解为:添加什么日记】。

然后,我们来更新长期记忆,将 C t − 1 C_{t-1} Ct−1更新为 C t C_t Ct。我们把旧状态 C t − 1 C_{t-1} Ct−1与遗忘门的输出 f t f_t ft相乘,忘记一些东西。接着加上输入门的输出 i t ∗ C ~ t i_t * \tilde{C}_t it∗C~t,新加一些东西,最终得到新的长期记忆 C t C_t Ct。具体操作为:

C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t=f_t*C_{t-1}+i_t*\tilde{C}_t Ct=ft∗Ct−1+it∗C~t

最后来看输出门

输出门的输入包含:

  • s t − 1 s_{t-1} st−1:表示t-1时刻的短期记忆
  • x t x_t xt:表示t时刻的输入
  • c t c_t ct:更新后的长期记忆

输出门的输出为 h t h_{t} ht和 s t s_{t} st, h t h_t ht作为当前时间步的输出, s t s_{t} st当做短期记忆输入到t+1,其具体操作为:
o t = σ ( W o [ s t − 1 , x t ] + b o ) s t = h t = o t ∗ t a n h ( C t ) \begin{aligned}&o_{t}=\sigma\left(W_{o} \left[ s_{t-1},x_{t}\right] + b_{o}\right)\\&s_{t}=h_{t}=o_{t}*\mathrm{tanh}\left(C_{t}\right)\end{aligned} ot=σ(Wo[st−1,xt]+bo)st=ht=ot∗tanh(Ct)

首先,我们运行一个 sigmoid 层来确定长期记忆的哪个部分将输出出去。接着,我们把长期记忆通过 tanh 进行处理(得到一个在-1到1之间的值)并将它和 o t o_{t} ot相乘,最终将输出copy成两份 h t h_t ht和 s t s_{t} st, h t h_t ht作为当前时间步的输出, s t s_{t} st当做短期记忆输入到t+1。

LSTM的结构分析完了,那为什么LSTM能够缓解梯度消失呢?

我前面写的这篇文章中介绍了为什么RNN会有梯度消失和爆炸:点这里查看

主要原因是反向传播时,梯度中有这一部分:
∏ j = k + 1 3 ∂ s j ∂ s j − 1 = ∏ j = k + 1 3 t a n h ′ W \prod_{j=k+1}^3\frac{\partial s_j}{\partial s_{j-1}}=\prod_{j=k+1}^3tanh^{'}W j=k+1∏3∂sj−1∂sj=j=k+1∏3tanh′W

LSTM的作用就是让 ∂ s j ∂ s j − 1 \frac{\partial s_j}{\partial s_{j-1}} ∂sj−1∂sj≈1

在LSTM里,隐藏层的输出换了个符号,从 s s s变成 C C C了,即 C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t=f_t*C_{t-1}+i_t*\tilde{C}t Ct=ft∗Ct−1+it∗C~t。注意, f t f_t ft , i t 和 C ~ t i{t\text{ 和}}\tilde{C}t it 和C~t 都是 C t − 1 C{t-1} Ct−1的复合函数(因为它们都和 h t − 1 h_{t-1} ht−1有关,而 h t − 1 h_{t-1} ht−1又和 C t − 1 C_{t-1} Ct−1有关)。因此我们来求一下 ∂ C t ∂ C t − 1 \frac{\partial C_t}{\partial C_{t-1}} ∂Ct−1∂Ct:
∂ C t ∂ C t − 1 = f t + ∂ f t ∂ C t − 1 ⋅ C t − 1 + ... \frac{\partial C_t}{\partial C_{t-1}}=f_t+\frac{\partial f_t}{\partial C_{t-1}}\cdot C_{t-1}+\ldots ∂Ct−1∂Ct=ft+∂Ct−1∂ft⋅Ct−1+...

后面的我们就不管了,展开求导太麻烦了。这里面 f t f_t ft是遗忘门的输出,1表示完全保留旧状态,0表示完全舍弃旧状态,如果我们把 f t f_t ft设置成1或者是接近于1,那 ∂ C t ∂ C t − 1 \frac{\partial C_t}{\partial C_{t-1}} ∂Ct−1∂Ct就有梯度了。因此LSTM可以一定程度上缓解梯度消失,然而如果时间步很长的话,依然会存在梯度消失问题,所以只是缓解

注意:LSTM可以缓解梯度消失,但是梯度爆炸并不能解决,因为LSTM不影响参数W


3️⃣ 代码

python 复制代码
# 创建一个LSTM模型
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTM(nn.Module):
    def __init__(self,input_size,hidden_size,num_layers,output_size):
        super().__init__()
        self.num_layers=num_layers
        self.hidden_size=hidden_size

        # 定义LSTM层
        # batch_first=True则输入形状为(batch, seq_len, input_size)
        self.lstm=nn.LSTM(input_size,hidden_size,num_layers,batch_first=True)
        # 定义全连接层,用于输出
        self.fc=nn.Linear(hidden_size,output_size)
    def forward(self, x):
        # self.lstm(x)会返回两个值
        # out:形状为 (batch,seq_len,hidden_size)
        # 隐层状态和细胞状态:形状为 (batch, num_layers, hidden_size);在这里,我们忽略隐层状态和细胞状态的输出,因此使用了占位符
        out, _ = self.lstm(x)
        out = self.fc(out)
        return out

    
if __name__=='__main__':
    input_size=10
    hidden_size=64
    num_layers=1
    output_size=1
    net=LSTM(input_size,hidden_size,num_layers,output_size)
    # x的形状为(batch_size, seq_len, input_size)
    x=torch.randn(16,8,input_size)
    out=net(x)
    print(out.shape)

输出结果为:

python 复制代码
torch.Size([16, 8, 1]),表示有16个batch,对于每个batch,有8个时间步,每个时间步的output大小为1

4️⃣ 总结

  • 思考一个问题,对于多层LSTM,如何理解呢?

    注意:图中颜色相同的其实表达的值一样, h = s h=s h=s。

    1. 第一层 LSTM 首先初始隐层状态 s 0 l a y e r 1 s^{layer1}0 s0layer1和细胞状态 c 0 l a y e r 1 c^{layer1}0 c0layer1,然后输入 x t − 1 x{t-1} xt−1 生成隐层状态和输出 s t − 1 l a y e r 1 = h t − 1 l a r y e r 1 s^{layer1}{t-1}=h_{t-1}^{laryer1} st−1layer1=ht−1laryer1和细胞状态 c t − 1 l a y e r 1 c^{layer1}_{t-1} ct−1layer1。
    2. 第二层 LSTM首先初始隐层状态 s 0 l a y e r 2 s^{layer2}0 s0layer2和细胞状态 c 0 l a y e r 2 c^{layer2}0 c0layer2,然后接收第一层的输出 h t − 1 l a r y e r 1 h{t-1}^{laryer1} ht−1laryer1作为输入,生成 s t − 1 l a y e r 2 = h t − 1 l a r y e r 2 s^{layer2}{t-1}=h_{t-1}^{laryer2} st−1layer2=ht−1laryer2和 c t − 1 l a y e r 2 c^{layer2}_{t-1} ct−1layer2
    3. 第N层 LSTM首先初始隐层状态 s 0 l a y e r N s^{layerN}0 s0layerN和细胞状态 c 0 l a y e r N c^{layerN}0 c0layerN,然后接收第N-1层的输出 h t − 1 l a r y e r N − 1 h{t-1}^{laryer N-1} ht−1laryerN−1作为输入,生成最终的 s t − 1 l a y e r N = h t − 1 l a r y e r 2 s^{layerN}{t-1}=h_{t-1}^{laryer2} st−1layerN=ht−1laryer2和 c t − 1 l a y e r N c^{layerN}_{t-1} ct−1layerN
  • 为什么需要多层LSTM?

    多层 LSTM 通过增加深度来增强模型的表示能力和复杂度,能够学习到更高阶、更抽象的特征

  • 通过控制遗忘门的输出 f t f_t ft来控制梯度,以缓解梯度消失问题,但不能缓解梯度爆炸

5️⃣ 参考

相关推荐
皓74114 分钟前
打造旅游卡服务新标杆:构建SOP框架与智能知识库应用
大数据·人工智能·旅游·敏捷流程
B站计算机毕业设计超人14 分钟前
计算机毕业设计Hive+Spark空气质量预测 空气质量可视化 空气质量分析 空气质量爬虫 Hadoop 机器学习 深度学习 Django 大模型
hive·hadoop·爬虫·深度学习·机器学习·spark·数据可视化
视窗中国21 分钟前
中信建投张青:以金融巨擘之姿,铸就公益慈善新篇章
人工智能·金融
幽络源小助理30 分钟前
桥梁缺陷YOLO免费数据集分享 – 6308张已标注8类缺陷图像
人工智能·计算机视觉·目标跟踪
念啊啊啊啊丶41 分钟前
【弱监督视频异常检测】2024-ESWA-基于扩散的弱监督视频异常检测常态预训练
人工智能·深度学习·神经网络·机器学习·计算机视觉
陌上阳光1 小时前
初学人工智不理解的名词3
人工智能·语音识别
ZHOU_WUYI1 小时前
5. langgraph中的react agent使用 (从零构建一个react agent)
人工智能·langchain
ZHOU_WUYI1 小时前
3. langgraph中的react agent使用 (在react agent添加系统提示)
人工智能·langchain
临水逸1 小时前
AI 编程编辑器和工具
人工智能·编辑器
学术搬运工1 小时前
【征稿倒计时!华南理工大学主办 | IEEE出版 | EI检索稳定】2024智能机器人与自动控制国际学术会议 (IRAC 2024)
人工智能·深度学习·算法·机器学习·机器人·自动化·自动驾驶