GRU 和 LSTM 公式推导与矩阵变换过程图解

GRU 和 LSTM 公式推导与矩阵变换过程图解


本文的前置篇链接: 单向/双向,单层/多层RNN输入输出维度问题一次性解决


GRU

GRU(Gate Recurrent Unit)是循环神经网络(RNN)的一种,可以解决RNN中不能长期记忆和反向传播中的梯度等问题,与LSTM的作用类似,不过比LSTM简单,容易进行训练。

GRU的输入输出结构与普通的RNN是一样的:

GRU 内部结构有两个核心的门控状态,分别为r控制重置的门控(reset gate) , z控制更新的门控(update gate) ,这两个门控状态值通过上一时刻传递过来的 h t − 1 h_{t-1} ht−1和当前节点输入 x t x_{t} xt结合sigmoid得出,具体公式如下:

下面我们更加细致来看一下两个门控状态值计算过程中的矩阵维度变换:

得到门控信号后,我们首先使用重置门r来得到重置后的 h t − 1 ′ = h t − 1 ⊙ r {h_{t-1}}' =h_{t-1}\odot r ht−1′=ht−1⊙r :

符号 ⊙ 通常表示 Hadamard 乘积(也称为元素-wise 乘积或 Schur 乘积)。Hadamard 乘积是两个矩阵的元素-wise 乘积,即两个矩阵的对应元素相乘,图中用*代替Hadamard 乘积,用x代替矩阵乘法运算

再将 h t − 1 ′ {h_{t-1}}' ht−1′与输入 x t x_{t} xt进行拼接,再通过一个tanh激活函数对数据进行非线性变换,得到 h ′ {h}' h′, 具体计算过程如下图所示

更新记忆阶段,我们同时进行了遗忘记忆两个步骤。我们使用了先前得到的更新门控z,更新表达式如下:
h t = z ⊙ h t − 1 + ( 1 − z ) ⊙ h ′ h_{t} = z\odot h_{t-1} + (1-z) \odot {h}' ht=z⊙ht−1+(1−z)⊙h′

门控信号z的范围为0~1。门控信号越接近1,代表"记忆"下来的数据越多;而越接近0则代表"遗忘"的越多。

GRU很聪明的一点就在于,我们使用了同一个门控z就同时可以进行遗忘和选择记忆(LSTM则要使用多个门控)。

  • z ⊙ h t − 1 z\odot h_{t-1} z⊙ht−1 表示对原本隐藏状态的选择性遗忘,这里的 z z z可以想象成遗忘门,忘记 h t − 1 h_{t-1} ht−1中一些不重要的信息。
  • ( 1 − z ) ⊙ h ′ (1-z)\odot {h}' (1−z)⊙h′ 表示对包含当前节点信息的 h ′ {h}' h′进行选择性记忆,这里的 ( 1 − z ) (1-z) (1−z)同理会忘记 h ′ {h}' h′中一些不重要的信息,或者看成是对 h ′ {h}' h′中某些重要信息的筛选。
  • h t = z ⊙ h t − 1 + ( 1 − z ) ⊙ h ′ h_{t} = z\odot h_{t-1} + (1-z) \odot {h}' ht=z⊙ht−1+(1−z)⊙h′ 总的来看就是忘记传递下来的 h t − 1 h_{t-1} ht−1中的某些不重要信息,并加入当前节点输入信息中某些重要部分。

更新门 与 重置门 总结:

  • 重置门的作用是决定前一时间步的隐藏状态在多大程度上被忽略。当重置门的输出接近0时,网络倾向于"忘记"前一时间步的信息,仅依赖于当前输入;而当输出接近1时,前一时间步的信息将被更多地保留。
  • 更新门的作用是决定当前时间步的隐藏状态需要保留多少前一个时间步的信息。更新门的输出值介于0和1之间,值越大表示保留的过去信息越多,值越小则意味着更多地依赖于当前输入的信息。

把GRU所有流程放在一张图展示:

总的来说:

  • GRU输入输出的结构与普通的RNN相似,其中的内部思想与LSTM相似。
  • 与LSTM相比,GRU内部少了一个"门控",参数比LSTM少,但是却也能够达到与LSTM相当的功能。考虑到硬件的计算能力和时间成本,因而很多时候我们也就会选择更加"实用"的GRU啦。

补充 ( 看完 L S T M 部分后,再看该补充说明 ) : 补充(看完LSTM部分后,再看该补充说明): 补充(看完LSTM部分后,再看该补充说明):

  • r (reset gate)实际上与他的名字有些不符合,因为我们仅仅使用它来获得了 h ′ {h}' h′。
  • GRU中的 h ′ {h}' h′实际上可以看成对应于LSTM中的hidden state,而上一个节点传下的 h t − 1 h_{t-1} ht−1则对应于LSTM中的cell state。
  • z z z对应的则是LSTM中的 z f z_{f} zf,那么 ( 1 − z ) (1-z) (1−z)就可以看成是选择门 z i z_{i} zi。

LSTM

长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

LSTM 的输入输出结构与普通的RNN区别如下图所示:

相比于RNN只有一个传递状态 h t h_{t} ht,LSTM有两个传输状态,一个 c t c_{t} ct(cell state) ,和一个 h t h_{t} ht(hidden state)。

LSTM中的 c t c_{t} ct相当于RNN中的 h t h_{t} ht

下面我们来看一下LSTM内部结构,首先使用LSTM的当前输入 x t x_{t} xt和上一个状态传递下来的 h t − 1 h_{t-1} ht−1计算得到四个状态值:



z i , z f , z o z_{i},z_{f},z{o} zi,zf,zo 通过sigmoid非线性变换,分别作为输入门控,遗忘门控和输出门控,而 z z z则是通过tanh非线性变换,作为输入数据。

LSTM内部主要有三个阶段:

  1. 选择记忆阶段: 该阶段会对输入的数据 z z z进行有选择性的记忆,而选择的门控信号由 z i z_{i} zi来进行控制。

  2. 忘记阶段: 这个阶段是对上一个节点传入的输入进行选择性的忘记,通过 z f z_{f} zf作为忘记门控,来控制上一个状态 c t − 1 c_{t-1} ct−1哪些需要留下,哪些需要忘记。

  1. 将上面两个阶段得到的结果相加,即可得到传输给下一个状态的 c t c_{t} ct。
  1. 输出阶段: 这个阶段会决定哪些将会被当成当前状态的输出,主要是通过 z o z_{o} zo进行控制,并且还对上一个阶段得到的 c t c_{t} ct进行了tanh的非线性变换。
  1. 与普通RNN类似,输出的 y t y_{t} yt往往最终也是通过 h t h_{t} ht变换得到,如: y t = s i g m o i d ( W h t ) y_{t} = sigmoid(W h_{t}) yt=sigmoid(Wht)

把LSTM所有流程放在一张图展示:

总结:

  • 以上,就是LSTM的内部结构。通过门控状态来控制传输状态,记住需要长时间记忆的,忘记不重要的信息;而不像普通的RNN那样只能够"呆萌"地仅有一种记忆叠加方式。对很多需要"长期记忆"的任务来说,尤其好用。

  • 但也因为引入了很多内容,导致参数变多,也使得训练难度加大了很多。因此很多时候我们往往会使用效果和LSTM相当但参数更少的GRU来构建大训练量的模型。

相关推荐
工藤新一¹15 小时前
蓝桥杯算法题 -蛇形矩阵(方向向量)
c++·算法·矩阵·蓝桥杯·方向向量
机器学习之心16 小时前
SHAP分析!Transformer-GRU组合模型SHAP分析,模型可解释不在发愁!
深度学习·gru·transformer·shap分析
Despacito0o17 小时前
RGB矩阵照明系统详解及WS2812配置指南
c语言·线性代数·矩阵·计算机外设·qmk
唐山柳林18 小时前
现代化水库运行管理矩阵平台如何建设?
线性代数·矩阵
passionSnail3 天前
《用MATLAB玩转游戏开发:从零开始打造你的数字乐园》基础篇(2D图形交互)-俄罗斯方块:用旋转矩阵打造经典
算法·matlab·矩阵·游戏程序·交互
KingDol_MIni3 天前
Transformer-LSTM混合模型在时序回归中的完整流程研究
回归·lstm·transformer
Akiiiira4 天前
【日撸 Java 三百行】Day 7(Java的数组与矩阵元素相加)
线性代数·矩阵
拓端研究室TRL4 天前
CNN-LSTM、GRU、XGBoost、LightGBM风电健康诊断、故障与中国银行股票预测应用实例
人工智能·神经网络·cnn·gru·lstm
HHONGQI1234 天前
LVGL- 按钮矩阵控件
矩阵·lvlgl
元亓亓亓5 天前
LeetCode热题100--54.螺旋矩阵--中等
算法·leetcode·矩阵