深度学习基础—循环神经网络的梯度消失与解决

引言

深度学习基础---循环神经网络(RNN)https://blog.csdn.net/sniper_fandc/article/details/143417972?fromshare=blogdetail&sharetype=blogdetail&sharerId=143417972&sharerefer=PC&sharesource=sniper_fandc&sharefrom=from_link深度学习基础---语言模型和序列生成https://blog.csdn.net/sniper_fandc/article/details/143418185?fromshare=blogdetail&sharetype=blogdetail&sharerId=143418185&sharerefer=PC&sharesource=sniper_fandc&sharefrom=from_link深度学习基础---序列采样https://blog.csdn.net/sniper_fandc/article/details/143457214?fromshare=blogdetail&sharetype=blogdetail&sharerId=143457214&sharerefer=PC&sharesource=sniper_fandc&sharefrom=from_link

在前面的博客中,我们已经了解很多RNN网络的结构和相关基础,基本的RNN网络的前向传播和反向传播比较了解了,但是RNN还有一个比较重要的问题需要解决,就是梯度消失,让我们一起来了解一下。

1.梯度消失

现在有两个需要训练的序列:"The cat, which already ate ......, was full."和"The cats, which already ate ......, were full.",省略号表示中间还有很多词。当主语是cat(单数)时,应该用was;当主语时cats(复数),应该用were。这就意味着序列前后有长期的依赖,最前面的单词对句子后面的单词有影响。但是基础的RNN模型不擅长捕捉长期依赖效应,正是梯度消失造成了这样的影响。

前面已经讨论过,当神经网络很深的时候,会产生梯度消失现象,同理基础的RNN也会有这样的现象。因为梯度消失问题,反向传播的时候后面层的输出误差很难传播到前面层,这就意味着很难让神经网络能够意识到它要记住看到的是单数名词还是复数名词,然后在序列后面生成依赖单复数形式的was或者were。

实际上,基本的RNN模型会有很多局部影响,比如y3这个输出(上图编号9所示)主要受y3附近的值(上图编号10所示)的影响,上图编号11所示的一个数值主要与附近的输入(上图编号12所示)有关,上图编号6所示的输出,基本上很难受到序列靠前的输入(上图编号10所示)的影响。

注意:RNN也会出现梯度爆炸问题,但是这里不讨论,因为梯度爆炸容易察觉(网络的参数会变的很大直到崩溃),如果出现梯度爆炸问题,我们需要通过梯度修剪来解决。梯度修剪就是观察梯度向量,如果它大于某个阈值,缩放梯度向量,保证它不会太大。

这就是RNN的梯度消失,它与网络的长期依赖效应相关,想要解决,我们可以用GRU单元或者LSTM网络,这里我们先介绍一下GRU单元。

2.GRU单元

上图是基本的RNN网络的单元,此时引入GRU单元,GRU单元用变量c表示,代表记忆细胞。在目前这个情况,c的值和激活值a一样,即:

之所以GRU单元叫记忆细胞,就是它能通过记忆某个时间步t时的激活值,然后通过一定的时间步后,将这个值传递给此刻的输入。因此我们使用公式1表示:

式1中c上面带波浪的符号表示候选值,供某个时间步的c来选择,其表达式如下式2:

上式中,符号Γ表示门,下标r是relevance相关性的意思,主要告诉t-1时间步的c值和t时间步的候选值有多少相关性。表达式3如下:

式1中下标u是update的意思,表示更新门,更新门的公式4如下:

σ是sigmoid函数,因此更新门的值为[0,1]之间,实际上,sigmoid函数的输出往往非常接近0或1,假设更新门的值为1,式1就等价于:

此时说明选择更新操作,表示记忆细胞选择将自己记忆的信息更新到当前时间步,因此有利于捕捉序列中短期依赖关系。而如果更新门的值为0,式1等价于:

此时说明选择不更新,更有利于捕捉序列中长期依赖关系。

实际上,c是一个向量,c、候选值、更新门的值是同一个维度,比如c是100维向量,那么他们都是100维向量。而更新门的每一个元素都表示一种更新状态,可能这个元素接近0就表示不更新,那个元素接近1就表示更新。每种更新状态表示要记忆的状态,可能需要记忆主语的单复数,也需要记忆时态等等状态。

GRU单元最核心的就是公式1,比如现在在时间步1,记忆细胞记忆了cat是单数的,在时间步7,需要处理的问题是was还是were。那么这中间的时间步中更新门对应的位置都应该是0,表示不更新,然后在时间步7更新门对应的位置是1,表示更新,此时就会把cat单数的情况告诉这个单元,网络就有了捕捉长期依赖关系的能力,也解决了梯度消失的问题。

3.长短期记忆(LSTM)

3.1.前向传播

上图所示就是LSTM的网络结构,经典的LSTM与GRU单元最大的区别就是:LSTM网络结构中激活值a<t>不等于c<t>,并将更新门的选择权交给了记忆细胞。让我们一起了解一下其中的细节:

上式(公式1),由于a<t>不等于c<t>,因此候选值的计算不再依赖c<t-1>,而是a<t-1>,相当于用RNN的激活值作为候选值。

上式(公式2),更新门同样将c<t-1>变为a<t-1>,其他不变。

上式(公式3),增加遗忘门Γf(forget),其作用是代替4.2部分公式1的(1-Γu)部分,因为输出在[0,1]之间,因此如果输出接近1表示不遗忘,输出接近0表示遗忘之前记忆的信息c<t-1>。

上式(公式4),增加输出门Γo,调整网络在第t个时间步的激活值的输出。

上式(公式5),激活值的输出值不再等于c<t>,而是经过输出门调整后再输出。

上式(公式6),也是LSTM的核心公式,由更新门决定是否使用候选值替换,由遗忘门决定是否遗忘之前的信息,由此,记忆细胞便掌握了更新权和遗忘权。这样只要正确的设置更新门和遗忘门,很容易把某个t时间步的值传递给更后面的值,即网络有了长期记忆的能力。

注意:可能这样的LSTM结构和其他人所讲的有些不一致,还有一个版本叫:"偷窥孔连接",即门值不仅取决于a<t-1>和x<t>,还取决于c<t-1>。结合这三个值计算三个门值。

3.2.反向传播

对门求偏导:

对参数求偏导,对b求偏导需要将上面4个公式(公式1、2、3、4)求和:

对激活值、记忆值、输入求偏导:

LSTM的优点是更加灵活和强大,GRU单元的优点是更加简洁,根据不同的场合选择合适的结构才能更好的解决问题。

相关推荐
机智的小神仙儿12 分钟前
Query Processing——搜索与推荐系统的核心基础
人工智能·推荐算法
AI_小站19 分钟前
RAG 示例:使用 langchain、Redis、llama.cpp 构建一个 kubernetes 知识库问答
人工智能·程序人生·langchain·kubernetes·llama·知识库·rag
Doker 多克21 分钟前
Spring AI 框架使用的核心概念
人工智能·spring·chatgpt
Guofu_Liao21 分钟前
Llama模型文件介绍
人工智能·llama
思通数科多模态大模型1 小时前
10大核心应用场景,解锁AI检测系统的智能安全之道
人工智能·深度学习·安全·目标检测·计算机视觉·自然语言处理·数据挖掘
数据岛1 小时前
数据集论文:面向深度学习的土地利用场景分类与变化检测
人工智能·深度学习
学不会lostfound1 小时前
三、计算机视觉_05MTCNN人脸检测
pytorch·深度学习·计算机视觉·mtcnn·p-net·r-net·o-net
红色的山茶花1 小时前
YOLOv8-ultralytics-8.2.103部分代码阅读笔记-block.py
笔记·深度学习·yolo
龙的爹23331 小时前
论文翻译 | RECITATION-AUGMENTED LANGUAGE MODELS
人工智能·语言模型·自然语言处理·prompt·gpu算力
白光白光1 小时前
凸函数与深度学习调参
人工智能·深度学习