RNN/LSTM/GRU 学习笔记

文章目录

RNN/LSTM/GRU

一、RNN

1、为何引入RNN?

循环神经网络(Recurrent Neural Network,RNN) 是用来建模序列化数据的一种主流深度学习模型。我们知道,传统的前馈神经网络一般的输入都是一个定长的向量,无法处理变长的序列信息,即使通过一些方法把序列处理成定长的向量,模型也很难捕捉序列中的长距离依赖关系。RNN则通过将神经元串行起来处理序列化的数据。由于每个神经元能用它的内部变量保存之前输入的序列信息,因此整个序列被浓缩成抽象的表示,并可以据此进行分类或生成新的序列[1](#1)

2、RNN的基本结构

RNN的朴素形式可分别由如下两幅图表示[2](#2)

其中 x 1 , x 2 , ⋯   , x T x_1,x_2,\cdots,x_T x1,x2,⋯,xT 是输入,每一个位置是一个实数向量; U U U、 V V V、 W W W 是权重矩阵,通常在模型初始化时随机生成,通过梯度下降进行优化; h t h_t ht 是位于隐藏层上的活性值,很多文献上也称为状态(State)或隐状态(Hidden State); p t p_t pt 表示第 t t t 个位置上的输出。

h t h_t ht、 p t p_t pt 可由下列公式得出( b b b 是偏置项):
h t = tanh ⁡ ( U ⋅ h t − 1 + W ⋅ x t + b ) h_t=\tanh\left(U\cdot h_{t-1}+W\cdot x_t+b\right) ht=tanh(U⋅ht−1+W⋅xt+b)

p t = s o f t m a x ( V ⋅ h t + c ) p_t=\mathrm{softmax}(V\cdot h_t+c) pt=softmax(V⋅ht+c)

3、各种形式的RNN及其应用

(图片来自于cs231n)

模式 描述 应用领域
One to One 单个输入对应单个输出 图像分类、回归任务
One to Many 单个输入生成序列输出 图像字幕生成、音乐生成
Many to One 序列输入生成单个输出 情感分析、时间序列分类
Many to Many 序列输入对应序列输出 机器翻译、语音识别
Many to Many(同步) 同步序列输入输出 视频帧分类、实时语音处理

4、RNN的缺陷

RNN通过在所有时间步共享相同的权重,使得可以在不同时间步之间传递和积累信息,从而更好地捕捉序列数据中的长期依赖关系,但是缺点也很明显:在RNN的学习过程中,由于共享权重 W W W,导致随着时间步的增加,权重矩阵 W W W 不断连乘,最终产生梯度消失(即 ∂ L t ∂ h k \frac{\partial \mathcal{L}{t}}{\partial \boldsymbol{h}{k}} ∂hk∂Lt 消失, 1 ≤ k ≤ t 1 \le k\le t 1≤k≤t )和梯度爆炸,具体解释如下:

首先由RNN前向传播公式:
h t = f ( W ⋅ h t − 1 + U ⋅ x t + b ) h_t=f(W\cdot h_{t-1}+U\cdot x_t+b) ht=f(W⋅ht−1+U⋅xt+b)

其中 f f f 为激活函数。

在反向传播时(BPTT),损失函数 L \mathcal{L} L 对某一时间步长的梯度涉及到时间上所有的前置状态,因此梯度会被多个矩阵连乘表示为:
∂ L ∂ h t = ∂ L ∂ h T ⋅ ∏ k = t T − 1 A k \frac{\partial\mathcal{L}}{\partial h_t}=\frac{\partial\mathcal{L}}{\partial h_T}\cdot\prod_{k=t}^{T-1}A_k ∂ht∂L=∂hT∂L⋅k=t∏T−1Ak

其中 A k = diag ⁡ ( f ′ ( h k ) ) ⋅ W A_k=\operatorname{diag}(f^{\prime}(h_k))\cdot W Ak=diag(f′(hk))⋅W 。

显然若 W > 1 W>1 W>1,随着时间的增加,多个 W W W 连乘后结果会不断增大,最终导致梯度爆炸;

同理 W < 1 W<1 W<1,多个 W W W 连乘后结果会不断减小至趋于0,最终导致梯度消失。

而在CNN中,每一层的权重矩阵 W W W 是不同的,并且在初始化时它们是独立同分布的,因此最后可以相互抵消,不容易发生梯度爆炸或消失。

5、如何应对RNN的缺陷?

对于梯度爆炸 ,一般通过权重衰减(Weight Decay)梯度截断(Gradient Clipping) 来避免[3](#3)。权重衰减,通过引入衰减系数来约束并避免权重矩阵元素过大,从而减少梯度连乘时的爆炸风险;梯度截断,直接将梯度大小进行限制以防止梯度爆炸,比如按值截断:在第 t t t 次迭代时,梯度为 g t g_t gt ,给定一个区间 [ a , b ] [a,b] [a,b] ,如果一个参数的梯度小于 a a a 时,就将其设为 a a a ;如果大于 b b b 时,就将其设为 b b b,公式如下:
g t = max ⁡ ( min ⁡ ( g t , b ) , a ) . \mathbf{g}_t=\max(\min(\mathbf{g}_t,b),a). gt=max(min(gt,b),a).

对于梯度消失 ,一个想法是改进激活函数,比如替换成 ReLU ,因为其右侧导数恒为 1 ,可以缓解梯度消失(不能杜绝,因为本质上是权重矩阵的问题)。缺点是不好解决梯度爆炸,从 RNN 的前向传播公式来看待这个问题,前向传播公式如下:
h t = f ( W ⋅ h t − 1 + U ⋅ x t + b ) h_t=f(W\cdot h_{t-1}+U\cdot x_t+b) ht=f(W⋅ht−1+U⋅xt+b)

使用 ReLU 激活函数后, h t h_t ht 可表达为:
h t = r e l u ( W ⋅ h t − 1 + U ⋅ x t + b ) h_t=\mathrm{relu}\left(W\cdot h_{t-1}+U\cdot x_t+b\right) ht=relu(W⋅ht−1+U⋅xt+b)

显然不管 h t − 1 h_{t-1} ht−1 怎么变化,前面始终要乘上一个权重矩阵 W W W ,因此替换激活函数后,并不能实质上解决由于权重矩阵 W W W 连乘而导致的梯度爆炸问题。

③ 使用合适的权重初始化 方法,如 Xavier 初始化或 He 初始化,使 W W W 的特征值接近 1 。

如果从结构上来考虑,通过改变网络结构来减缓梯度消失或爆炸,长短期记忆网络(LSTM,Long Short-Term Memory) 就是基于这个想法诞生的。

6、BPTT和BP的区别

BP算法:只处理纵向层级间的梯度反向传播,适用于前馈神经网络。

BPTT算法:在训练RNN时,需要同时处理纵向层级间的反向传播(深度方向)和时间维度上的反向传播(时间方向)。

二、LSTM

1、LSTM 简介

LSTM 是循环神经网络的一个变体,可以有效地解决简单循环神经网络的梯度爆炸或消失问题。LSTM 网络结构如下:

在这里插入图片描述

LSTM 网络引入门控机制(Gating Mechanism) 来控制信息传递的路径,公式如下:
i t = σ ( U i ⋅ h t − 1 + W i ⋅ x t + b i ) f t = σ ( U f ⋅ h t − 1 + W f ⋅ x t + b f ) o t = σ ( U o ⋅ h t − 1 + W o ⋅ x t + b o ) c ~ t = tanh ⁡ ( U c ⋅ h t − 1 + W c ⋅ x t + b c ) c t = i t ⊙ c ~ t + f t ⊙ c t − 1 h t = o t ⊙ tanh ⁡ ( c t ) \begin{array}{c}\boldsymbol{i}{t}=\sigma\left(\boldsymbol{U}{i} \cdot \boldsymbol{h}{t-1}+\boldsymbol{W}{i} \cdot \boldsymbol{x}{t}+\boldsymbol{b}{i}\right) \\\boldsymbol{f}{t}=\sigma\left(\boldsymbol{U}{f} \cdot \boldsymbol{h}{t-1}+\boldsymbol{W}{f} \cdot \boldsymbol{x}{t}+\boldsymbol{b}{f}\right) \\\boldsymbol{o}{t}=\sigma\left(\boldsymbol{U}{o} \cdot \boldsymbol{h}{t-1}+\boldsymbol{W}{o} \cdot \boldsymbol{x}{t}+\boldsymbol{b}{o}\right) \\\tilde{\boldsymbol{c}}{t}=\tanh \left(\boldsymbol{U}{c} \cdot \boldsymbol{h}{t-1}+\boldsymbol{W}{c} \cdot \boldsymbol{x}{t}+\boldsymbol{b}{c}\right) \\\boldsymbol{c}{t}=\boldsymbol{i}{t} \odot \tilde{\boldsymbol{c}}{t}+\boldsymbol{f}{t} \odot \boldsymbol{c}{t-1} \\\boldsymbol{h}{t}=\boldsymbol{o}{\boldsymbol{t}} \odot \tanh \left(\boldsymbol{c}{t}\right)\end{array} it=σ(Ui⋅ht−1+Wi⋅xt+bi)ft=σ(Uf⋅ht−1+Wf⋅xt+bf)ot=σ(Uo⋅ht−1+Wo⋅xt+bo)c~t=tanh(Uc⋅ht−1+Wc⋅xt+bc)ct=it⊙c~t+ft⊙ct−1ht=ot⊙tanh(ct)

进一步可以简写成:

c \~ t o t i t f t \] = \[ tanh ⁡ σ σ σ \] ( W \[ x t h t − 1 \] + b ) , c t = f t ⊙ c t − 1 + i t ⊙ c \~ t , h t = o t ⊙ tanh ⁡ ( c t ) , \\begin{aligned}\\begin{bmatrix}\\tilde{\\boldsymbol{c}}_t\\\\\\\\\\boldsymbol{o}_t\\\\\\\\\\boldsymbol{i}_t\\\\\\\\\\boldsymbol{f}_t\\end{bmatrix}\&=\\begin{bmatrix}\\tanh\\\\\\\\\\sigma\\\\\\\\\\sigma\\\\\\\\\\sigma\\end{bmatrix}\\begin{pmatrix}\\boldsymbol{W}\\begin{bmatrix}\\boldsymbol{x}_t\\\\\\\\\\boldsymbol{h}_{t-1}\\end{bmatrix}+\\boldsymbol{b}\\end{pmatrix},\\\\\\\\\\boldsymbol{c}_t\&=\\boldsymbol{f}_t\\odot\\boldsymbol{c}_{t-1}+\\boldsymbol{i}_t\\odot\\boldsymbol{\\tilde{c}}_t,\\\\\\boldsymbol{h}_t\&=\\boldsymbol{o}_t\\odot\\tanh\\left(\\boldsymbol{c}_t\\right),\\end{aligned} c\~totitft ctht= tanhσσσ W xtht−1 +b ,=ft⊙ct−1+it⊙c\~t,=ot⊙tanh(ct), 公式中有三个"门",分别为输入门 i t \\boldsymbol{i}_t it 、遗忘门 f t \\boldsymbol{f}_t ft 和输出门 o t \\boldsymbol{o}_t ot 。这三个门的作用为 * 遗忘门 f t f_t ft 控制上一个时刻的内部状态 c t − 1 \\boldsymbol c_t-1 ct−1 需要遗忘多少信息。 * 输入门 i t \\boldsymbol{i}_t it 控制当前时刻的候选状态 c \~ t \\tilde{\\boldsymbol{c}}_t c\~t 有多少信息需要保存。 * 输出门 o t \\boldsymbol{o}_t ot 控制当前时刻的内部状态 c t \\boldsymbol{c}_t ct 有多少信息需要输出给外部状态 h t . \\boldsymbol{h}_t. ht. 具体的可点击查看如下视频,很清晰易懂: 【【官方双语】LSTM(长短期记忆神经网络)最简单清晰的解释来了!】 https://www.bilibili.com/video/BV1zD421N7nA/?share_source=copy_web\&vd_source=199a3f4e3a9db6061e1523e94505165a #### 2、LSTM如何缓解梯度消失与梯度爆炸? LSTM的细胞状态更新机制(下图黄色部分)可以有效地存储长期的信息: ![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/9249c39720f64bcea46e2a61596415f8.png) 其更新公式如下: C t = f t ⊙ C t − 1 + i t ⊙ C \~ t C_t=f_t\\odot C_{t-1}+i_t\\odot\\tilde{C}_t Ct=ft⊙Ct−1+it⊙C\~t 由于这一过程本质是**线性操作**(加权求和),相当于是所有候选路径的线性组合,故不会因为一个路径上梯度的消失,而导致整体梯度不断衰减。LSTM的细胞状态经过门控机制(通过或阻断,即 1 或 0)控制这个线性组合,达到缓解梯度消失的效果;而门控机制又可以通过调节输入输出,通过灵活地舍弃一些部分,来缓解梯度爆炸问题。 简言之,由于此线性组合会通过门控机制自主的调节,而非 RNN 那样直接连乘,因此可以达到减缓梯度消失和梯度爆炸的效果,并实现对信息的过滤,从而达到对长期记忆的保存与控制。 ### 三、GRU **门控循环单元(GRU)** 是对 LSTM 进行简化得到的模型。对于 LSTM 与 GRU 而言,它们效果相当,但由于 GRU 参数更少,所以 GRU 的收敛速度更快,计算效率更高。 与LSTM相比,GRU 仅有两个门------更新门(update gate)和重置门(reset gate),不使用记忆元。重置门有助于捕获序列中的短期依赖关系,更新门有助于捕获序列中的长期依赖关系,详细结构如下图: ![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/e83ecdc4271b4c8bb414b50e25d9ea9c.png) ### 四、参考文献 *** ** * ** *** 1. 诸葛越, 葫芦娃, 百面机器学习, 北京:人民邮电出版社, 2018 [↩︎](#↩︎) 2. 李航. 机器学习方法\[M\]. 第一版. 清华大学出版社, 2022. [↩︎](#↩︎) 3. 邱锡鹏, 神经网络与深度学习, 北京:机械工业出版社, 2020 [↩︎](#↩︎)

相关推荐
聆风吟º6 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
AI_56786 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子6 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
智驱力人工智能6 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
人工不智能5777 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
h64648564h7 小时前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
心疼你的一切7 小时前
解密CANN仓库:AIGC的算力底座、关键应用与API实战解析
数据仓库·深度学习·aigc·cann
小鸡吃米…8 小时前
机器学习的商业化变现
人工智能·机器学习
学电子她就能回来吗9 小时前
深度学习速成:损失函数与反向传播
人工智能·深度学习·学习·计算机视觉·github
Coder_Boy_9 小时前
TensorFlow小白科普
人工智能·深度学习·tensorflow·neo4j