第二章、LSTM(Long Short-term Memory:长短时记忆网络)

0 前言

RNN(循环神经网络)本身存在各种各样的缺陷,比如梯度弥散、梯度爆炸和短时记忆的问题。为弥补RNN的这些问题,瑞士人工智能科学家于1997提出了Long Short-term Memory(长短时记忆网络),即现在常用的LSTM。

1 RNN的局限性

以下阐述流程

  • 问题出现的原因
  • 直观的解决问题的方法

循环神经网络会出现这三个问题的绝大多数原因取决于其参数梯度中的δhtδhi\frac{\delta h_t}{\delta h_i}δhiδht这一项。其展开如下所示,此处不做推导:
δhtδhi=Πj=it−1diag(σ′(Wxhxj+1+Whhhj+b))Whh\frac{\delta h_t}{\delta h_i}=\Pi^{t-1}{j=i}diag(\sigma'(W{xh}x_{j+1}+W_{hh}h_j+b))W_{hh}δhiδht=Πj=it−1diag(σ′(Wxhxj+1+Whhhj+b))Whh

观察上式我们发现实际上这个式子中存在WhhW_{hh}Whh的连乘运算,那么如果矩阵WhhW_{hh}Whh的最大特征值小于1,连乘会导致δhtδhi\frac{\delta h_t}{\delta h_i}δhiδht趋近于0,这就导致了梯度弥散。相对应的如果该值大于1,则会导致δhtδhi\frac{\delta h_t}{\delta h_i}δhiδht值爆炸式增长,即梯度爆炸。

1.1 梯度爆炸

很自然的,因为某个值过大而产生的问题,我们可以通过限制该值来解决。我们可以做梯度裁减,使WWW中的所有元素都在一定范围内就可以了。

  • 假设张量为WWW,令所有元素wij∈[min,max]w_{ij}\in[min,max]wij∈[min,max]
  • 假设张量为WWW,限制张量的二范数∣∣W∣∣2∈[0,max]||W||_2\in[0,max]∣∣W∣∣2∈[0,max],若∣∣W∣∣2>max||W||_2>max∣∣W∣∣2>max,则令W′=W∣∣W∣∣2⋅maxW'=\frac{W}{||W||_2}\cdot maxW′=∣∣W∣∣2W⋅max
  • 假设张量为WWW,考虑全局范数裁减,令global_norm=∑i∣∣W(i)∣∣22global\_norm=\sqrt{\sum_i{||W^{(i)}||_2 }^2}global_norm=∑i∣∣W(i)∣∣22 ,则有W(i)=W(i)⋅max_normmax(global_norm,max_norm)W^{(i)}=\frac{W^{(i)}\cdot max\_norm}{max(global\_norm,max\_norm)}W(i)=max(global_norm,max_norm)W(i)⋅max_norm

上面的三种方法实际上只是从不同角度出发的裁减,目的都是一样的防止WWW过大导致梯度爆炸。

1.2 梯度弥散

对于梯度弥散现象,可以通过增加学习率、减少网络深度、添加SKip Connection(跳接,不了解可以看看unet)等一系列措施抑制。

1.3 短时记忆

上述两个问题必然会导致RNN的短时记忆,那么接下来就是来看LSTM是怎么解决这些问题的,我们先介绍门控制,再对门控制进行组合成为LSTM。

2 门控机制

实际上门这个概念很好理解,不管是电路、生物还是电脑的最底层理论里无外乎都是这些东西,那什么是门,通俗的理解就是有的东西能过去有的东西过不去,它对信号也好,化学物质也好做了筛选,实际上LSTM中的门控也是这样的。

LSTM的门控机制如上图所示,这个图实际上就表明了输出o=输入x∗门控值g输出o=输入x*门控值g输出o=输入x∗门控值g,门控制g∈(0,1)g\in (0,1)g∈(0,1),显然g=0g=0g=0表示门关闭输入完全没有进来,g=1g=1g=1时刚好相反。

这个理念很好理解,但这里存在一个问题,我们的大脑对自动根据环境信息判断当前的信息要不要接收或者接受多少,这个东西就是所谓的门控值ggg,那么在LSTM中这个门控值ggg怎么计算呢?

实际上也很简单,我们也根据现在输入的环境信息获取一个取值范围在0到1之间的值就可以了。

LSTM有两个很重要的变量一个是输出hth_tht,一个是状态ctc_tct。

2.1 遗忘门

如上图所示实际上遗忘门就是对过去的状态ct−1c_{t-1}ct−1做筛选,而该门的门控值是通过ht−1、xth_{t-1}、x_tht−1、xt得到的,而门控值的取值范围是(0,1)(0,1)(0,1),因此最合理的方式是采取sigmoidsigmoidsigmoid函数,即ft=sigmoid(Wf[ht−1,xt]+bf)f_t=sigmoid(W_f[h_{t-1},x_t]+b_f)ft=sigmoid(Wf[ht−1,xt]+bf),经过该遗忘门后状态向量ct−1c_{t-1}ct−1变为ft∗ct−1f_t*c_{t-1}ft∗ct−1。

2.2 输入门

上图中的蓝色虚线部分就是输入门的部分,输入门的门控值依然是通过ht−1、xth_{t-1}、x_tht−1、xt得到的,即it=sigmoid(Wi[ht−1,xt]+bi)i_t=sigmoid(W_i[h_{t-1},x_t]+b_i)it=sigmoid(Wi[ht−1,xt]+bi),而输入门要过滤的值同样与输入相关,ct~=tanh(Wc[ht−1,xt]+bc)\tilde{c_t}=tanh(W_c[h_{t-1},x_t]+b_c)ct~=tanh(Wc[ht−1,xt]+bc),该值经过输入门后变为it∗ct~i_t*\tilde{c_t}it∗ct~。

将输入门的结果与遗忘门的结果相加得到的就是新的状态向量ct=ft∗ct−1+it∗ct~c_t=f_t*c_{t-1}+i_t*\tilde{c_t}ct=ft∗ct−1+it∗ct~

2.3 输出门

输出门的门控值依然是通过ht−1、xth_{t-1}、x_tht−1、xt得到的,即ot=sigmoid(Wo[ht−1,xt]+bo)o_t=sigmoid(W_o[h_{t-1},x_t]+b_o)ot=sigmoid(Wo[ht−1,xt]+bo),而输出门要过滤的值是tanh(ct)tanh(c_t)tanh(ct),所以输出ht=ot∗tanh(ct)h_t=o_t*tanh(c_t)ht=ot∗tanh(ct)

2.4 LSTM解决梯度爆炸及梯度弥散的方法

实际上我们通过简单的推理就能知道:ctc0≈Πj=1tfi\frac{c_t}{c_0}\approx \Pi^t_{j=1}f_ic0ct≈Πj=1tfi,其中fif_ifi是门控制,它的取值范围在(0,1)(0,1)(0,1)之间,实际上来说这也是一种裁减方式。fk<1f_k<1fk<1的约束避免了梯度爆炸。

深究RNN 我们会发现实际上导致梯度弥散的本质是激活函数求导造成的。

正向传播:

ht=σ(W⋅[ht−1,xt]+b)h_t=σ(W⋅[h_{t−1},x_t]+b)ht=σ(W⋅[ht−1,xt]+b)

σσσ 是激活函数(如 tanhtanhtanh 或 sigmoidsigmoidsigmoid)

反向传播(关键路径):

损失函数 LLL 对 ht−kh_{t−k}ht−k​ 的梯度依赖于链式法则:

∂L∂ht−k=∂L∂ht⋅(Πj=t−k+1t∂hj∂hj−1)\frac{\partial L}{\partial h_{t-k}}=\frac{\partial L}{\partial h_t}\cdot (\Pi^t_{j=t-k+1}\frac{\partial h_j}{\partial h_{j-1}})∂ht−k∂L=∂ht∂L⋅(Πj=t−k+1t∂hj−1∂hj)

而Πj=t−k+1t∂hj∂hj−1\Pi^t_{j=t-k+1}\frac{\partial h_j}{\partial h_{j-1}}Πj=t−k+1t∂hj−1∂hj部分的连乘是导致梯度弥散的关键。

其中∂hj∂hj−1=diag(σ′(zj))⋅Whh\frac{\partial h_j}{\partial h_{j-1}}=diag(\sigma'(z_j))\cdot W_{hh}∂hj−1∂hj=diag(σ′(zj))⋅Whh

其中σ′(zj)\sigma'(z_j)σ′(zj)是激活函数的导数,其值远小于1,所以就算WhhW_hhWhh特征值接近于1,连乘还是会导致梯度弥散。

但是LSTM中实际上梯度ctc0≈Πj=1tfi\frac{c_t}{c_0}\approx \Pi^t_{j=1}f_ic0ct≈Πj=1tfi只与门控值相关,没有激活函数的导数,从而及大程度的避免了梯度弥散的出现。

相关推荐
算家计算12 分钟前
OpenAI 全新开源模型 GPT-OSS-20B本地部署教程:16GB显存跑透128K上下文
人工智能·开源·openai
墨尘游子15 分钟前
2- Python 网络爬虫 — 如何精准提取网页数据?XPath、Beautiful Soup、pyquery 与 parsel 实战指南
人工智能·网络爬虫·知识图谱·机器翻译
lll482331 小时前
数字图像处理4
人工智能·计算机视觉
大海的John1 小时前
FinQ4Cn: 基于 MCP 协议的中国 A 股量化分析
人工智能
霍格沃兹测试开发学社测试人社区1 小时前
Dify入门指南(2):5 分钟部署 Dify:云服务 vs 本地 Docker
人工智能·低代码
西猫雷婶1 小时前
scikit-learn/sklearn学习|岭回归解读
开发语言·人工智能·机器学习·支持向量机·回归·scikit-learn·sklearn
AI导航猿1 小时前
能和MJ掰手腕?免费开源的Krea Dev,我用日常场景实测了一整天|AI导航猿
人工智能
CoderLiu2 小时前
AI提示词工程优化指南:8个技巧,释放大语言模型的全部潜力
前端·人工智能·ai编程
机器之心2 小时前
是「福尔摩斯」,也是「列文虎克」,智谱把OpenAI藏着掖着的视觉推理能力开源了
人工智能·openai
gaosushexiangji2 小时前
一种基于空间聚类的低特征场景下多目标跟踪技术
图像处理·人工智能·计算机视觉