RNN 梯度计算详细推导 (BPTT)

RNN 梯度计算详细推导 (BPTT)

为了详细推导循环神经网络(RNN)中的梯度计算方法------沿时间反向传播(Backpropagation Through Time, BPTT),我们将使用一个最基础的RNN模型结构。

第一步:定义RNN模型和符号

在一个时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t,RNN的计算过程如下:

  1. 隐藏状态 (Hidden State)

    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h t = f ( U x t + W h t − 1 + b ) h_t = f(U x_t + W h_{t-1} + b) </math>ht=f(Uxt+Wht−1+b)

    • <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt:在时间步 t 的输入向量。
    • <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht:在时间步 t 的隐藏状态向量。 <math xmlns="http://www.w3.org/1998/Math/MathML"> h 0 h_0 </math>h0 通常初始化为零向量。
    • <math xmlns="http://www.w3.org/1998/Math/MathML"> h t − 1 h_{t-1} </math>ht−1:前一个时间步的隐藏状态。
    • <math xmlns="http://www.w3.org/1998/Math/MathML"> U , W , b U, W, b </math>U,W,b:循环层的参数。 <math xmlns="http://www.w3.org/1998/Math/MathML"> U U </math>U 是输入到隐藏层的权重矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> W W </math>W 是隐藏层到隐藏层的权重矩阵(循环权重), <math xmlns="http://www.w3.org/1998/Math/MathML"> b b </math>b 是偏置向量。
    • <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f:激活函数,通常是 tanhReLU。这里我们假设为 tanh
  2. 输出 (Output)

    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> o t = V h t + c o_t = V h_t + c </math>ot=Vht+c

    • <math xmlns="http://www.w3.org/1998/Math/MathML"> o t o_t </math>ot:在时间步 t 的输出(或称为 logits)。
    • <math xmlns="http://www.w3.org/1998/Math/MathML"> V , c V, c </math>V,c:输出层的参数。 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 是隐藏层到输出层的权重矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c 是偏置向量。
  3. 预测概率 (Predicted Probability)

    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y ^ t = g ( o t ) \hat{y}_t = g(o_t) </math>y^t=g(ot)

    • <math xmlns="http://www.w3.org/1998/Math/MathML"> g g </math>g:输出激活函数,对于分类任务通常是 Softmax
  4. 损失函数 (Loss Function)

    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L t = Loss ( y ^ t , y t ) L_t = \text{Loss}(\hat{y}_t, y_t) </math>Lt=Loss(y^t,yt)

    • <math xmlns="http://www.w3.org/1998/Math/MathML"> L t L_t </math>Lt:在时间步 t 的损失,例如交叉熵损失。
    • <math xmlns="http://www.w3.org/1998/Math/MathML"> y t y_t </math>yt:在时间步 t 的真实标签。

总体目标

我们的目标是计算总损失 <math xmlns="http://www.w3.org/1998/Math/MathML"> L = ∑ t = 1 T L t L = \sum_{t=1}^{T} L_t </math>L=∑t=1TLt 对所有模型参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ = { U , W , V , b , c } \theta = \{U, W, V, b, c\} </math>θ={U,W,V,b,c} 的梯度。即求解: <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ V , ∂ L ∂ c , ∂ L ∂ W , ∂ L ∂ U , ∂ L ∂ b \frac{\partial L}{\partial V}, \frac{\partial L}{\partial c}, \frac{\partial L}{\partial W}, \frac{\partial L}{\partial U}, \frac{\partial L}{\partial b} </math>∂V∂L,∂c∂L,∂W∂L,∂U∂L,∂b∂L。


第二步:前向传播

模型按照 <math xmlns="http://www.w3.org/1998/Math/MathML"> t = 1 , 2 , ... , T t=1, 2, \dots, T </math>t=1,2,...,T 的顺序,依次计算出每个时间步的 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t , o t , y ^ t , L t h_t, o_t, \hat{y}_t, L_t </math>ht,ot,y^t,Lt,并最终得到总损失 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L。这个过程比较直接,就是将输入序列喂给模型,得到输出和损失。


第三步:反向传播 (BPTT)

梯度是反向计算的,从最后一个时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 开始,一直传播到第一个时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1。

A. 输出层参数的梯度 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ V , ∂ L ∂ c \frac{\partial L}{\partial V}, \frac{\partial L}{\partial c} </math>∂V∂L,∂c∂L)

这部分比较简单,因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c 的计算不涉及时间上的循环依赖。总损失对它们的梯度是每个时间步梯度贡献的总和。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ V = ∑ t = 1 T ∂ L t ∂ V \frac{\partial L}{\partial V} = \sum_{t=1}^{T} \frac{\partial L_t}{\partial V} </math>∂V∂L=t=1∑T∂V∂Lt

我们来看单个时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 的梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L t ∂ V \frac{\partial L_t}{\partial V} </math>∂V∂Lt。根据链式法则:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L t ∂ V = ∂ L t ∂ o t ∂ o t ∂ V \frac{\partial L_t}{\partial V} = \frac{\partial L_t}{\partial o_t} \frac{\partial o_t}{\partial V} </math>∂V∂Lt=∂ot∂Lt∂V∂ot

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L t ∂ o t \frac{\partial L_t}{\partial o_t} </math>∂ot∂Lt 是损失对输出 logits 的梯度。我们将其记为 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ o , t \delta_{o,t} </math>δo,t。例如,对于Softmax和交叉熵损失, <math xmlns="http://www.w3.org/1998/Math/MathML"> δ o , t = y ^ t − y t \delta_{o,t} = \hat{y}_t - y_t </math>δo,t=y^t−yt。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ o t ∂ V \frac{\partial o_t}{\partial V} </math>∂V∂ot:因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> o t = V h t + c o_t = V h_t + c </math>ot=Vht+c,所以 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ o t ∂ V = h t T \frac{\partial o_t}{\partial V} = h_t^T </math>∂V∂ot=htT(转置是为了维度匹配)。

所以,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L t ∂ V = δ o , t ⋅ h t T \frac{\partial L_t}{\partial V} = \delta_{o,t} \cdot h_t^T </math>∂V∂Lt=δo,t⋅htT

最终,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ V = ∑ t = 1 T δ o , t ⋅ h t T \boxed{\frac{\partial L}{\partial V} = \sum_{t=1}^{T} \delta_{o,t} \cdot h_t^T} </math>∂V∂L=t=1∑Tδo,t⋅htT

同理,对于偏置 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ c = ∑ t = 1 T δ o , t \boxed{\frac{\partial L}{\partial c} = \sum_{t=1}^{T} \delta_{o,t}} </math>∂c∂L=t=1∑Tδo,t

B. 循环层参数的梯度 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ W , ∂ L ∂ U , ∂ L ∂ b \frac{\partial L}{\partial W}, \frac{\partial L}{\partial U}, \frac{\partial L}{\partial b} </math>∂W∂L,∂U∂L,∂b∂L)

这是BPTT的核心和难点。我们以 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ W \frac{\partial L}{\partial W} </math>∂W∂L 为例进行推导。 为了解决 <math xmlns="http://www.w3.org/1998/Math/MathML"> W W </math>W 在时间上的复杂依赖,我们引入一个关键的中间量:总损失 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 对隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht 的梯度 ,记为 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , t = ∂ L ∂ h t \delta_{h,t} = \frac{\partial L}{\partial h_t} </math>δh,t=∂ht∂L。

根据链式法则, <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 通过两条路径影响 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht:

  1. 通过当前时间步的输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> o t o_t </math>ot。
  2. 通过下一个时间步的隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t + 1 h_{t+1} </math>ht+1(因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t + 1 h_{t+1} </math>ht+1 的计算用到了 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht)。

因此, <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , t \delta_{h,t} </math>δh,t 的计算是一个从后向前的递归过程:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> δ h , t = ∂ L ∂ h t = ∂ L ∂ o t ∂ o t ∂ h t + ∂ L ∂ h t + 1 ∂ h t + 1 ∂ h t \delta_{h,t} = \frac{\partial L}{\partial h_t} = \frac{\partial L}{\partial o_t}\frac{\partial o_t}{\partial h_t} + \frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_t} </math>δh,t=∂ht∂L=∂ot∂L∂ht∂ot+∂ht+1∂L∂ht∂ht+1

将各部分代入,我们得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , t \delta_{h,t} </math>δh,t 的递归公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> δ h , t = δ o , t V T + δ h , t + 1 W T diag ( f ′ ( h t + 1 ) ) \delta_{h,t} = \delta_{o,t} V^T + \delta_{h,t+1} W^T \text{diag}(f'(h_{t+1})) </math>δh,t=δo,tVT+δh,t+1WTdiag(f′(ht+1))

  • 递归的起点(Base Case) :在最后一个时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T,没有未来的隐藏状态,所以递归项为0。
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> δ h , T = ∂ L ∂ h T = ∂ L T ∂ o T ∂ o T ∂ h T = δ o , T V T \delta_{h,T} = \frac{\partial L}{\partial h_T} = \frac{\partial L_T}{\partial o_T}\frac{\partial o_T}{\partial h_T} = \delta_{o,T} V^T </math>δh,T=∂hT∂L=∂oT∂LT∂hT∂oT=δo,TVT

我们可以从 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , T \delta_{h,T} </math>δh,T 开始,反向计算出 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , T − 1 , ... , δ h , 1 \delta_{h,T-1}, \dots, \delta_{h,1} </math>δh,T−1,...,δh,1。

现在,我们用 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , t \delta_{h,t} </math>δh,t 来计算最终的梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ W \frac{\partial L}{\partial W} </math>∂W∂L
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ W = ∑ t = 1 T ∂ L ∂ h t ∂ h t ∂ W = ∑ t = 1 T δ h , t ∂ h t ∂ W \frac{\partial L}{\partial W} = \sum_{t=1}^{T} \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial W} = \sum_{t=1}^{T} \delta_{h,t} \frac{\partial h_t}{\partial W} </math>∂W∂L=t=1∑T∂ht∂L∂W∂ht=t=1∑Tδh,t∂W∂ht

根据 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t = f ( U x t + W h t − 1 + b ) h_t = f(U x_t + W h_{t-1} + b) </math>ht=f(Uxt+Wht−1+b),我们有 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ h t ∂ W = diag ( f ′ ( h t ) ) ⋅ h t − 1 T \frac{\partial h_t}{\partial W} = \text{diag}(f'(h_t)) \cdot h_{t-1}^T </math>∂W∂ht=diag(f′(ht))⋅ht−1T。 所以,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ W = ∑ t = 1 T diag ( f ′ ( h t ) ) ⋅ δ h , t ⋅ h t − 1 T \boxed{\frac{\partial L}{\partial W} = \sum_{t=1}^{T} \text{diag}(f'(h_t)) \cdot \delta_{h,t} \cdot h_{t-1}^T} </math>∂W∂L=t=1∑Tdiag(f′(ht))⋅δh,t⋅ht−1T

同理可得 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ U \frac{\partial L}{\partial U} </math>∂U∂L 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ b \frac{\partial L}{\partial b} </math>∂b∂L:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ U = ∑ t = 1 T diag ( f ′ ( h t ) ) ⋅ δ h , t ⋅ x t T \boxed{\frac{\partial L}{\partial U} = \sum_{t=1}^{T} \text{diag}(f'(h_t)) \cdot \delta_{h,t} \cdot x_t^T} </math>∂U∂L=t=1∑Tdiag(f′(ht))⋅δh,t⋅xtT
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ b = ∑ t = 1 T diag ( f ′ ( h t ) ) ⋅ δ h , t \boxed{\frac{\partial L}{\partial b} = \sum_{t=1}^{T} \text{diag}(f'(h_t)) \cdot \delta_{h,t}} </math>∂b∂L=t=1∑Tdiag(f′(ht))⋅δh,t


第四步:梯度消失与爆炸的根源

回顾 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , t \delta_{h,t} </math>δh,t 的递归公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> δ h , t = ⋯ + δ h , t + 1 ⋅ ( W T diag ( f ′ ( h t + 1 ) ) ) \delta_{h,t} = \dots + \delta_{h,t+1} \cdot (W^T \text{diag}(f'(h_{t+1}))) </math>δh,t=⋯+δh,t+1⋅(WTdiag(f′(ht+1)))

我们可以看到,梯度在时间上传播时,会反复乘以循环权重矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W W </math>W

  • 梯度爆炸 (Gradient Exploding) :如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> W W </math>W 的某些特征值(或范数)大于1,经过多次连乘后,梯度会呈指数级增长,导致数值溢出,训练发散。
  • 梯度消失 (Gradient Vanishing) :如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> W W </math>W 的某些特征值(或范数)小于1,经过多次连乘后,梯度会呈指数级衰减,趋近于0。这使得模型难以学习到长距离的依赖关系。

总结

BPTT算法的完整流程如下:

  1. 前向传播 :对于 <math xmlns="http://www.w3.org/1998/Math/MathML"> t = 1 , ... , T t=1, \dots, T </math>t=1,...,T,计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t , o t , L t h_t, o_t, L_t </math>ht,ot,Lt,得到总损失 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L。同时保存所有的 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t , h t x_t, h_t </math>xt,ht。
  2. 反向传播 : a. 计算最后一个时间步的隐藏层梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , T \delta_{h,T} </math>δh,T。 b. 对于 <math xmlns="http://www.w3.org/1998/Math/MathML"> t = T − 1 , ... , 1 t = T-1, \dots, 1 </math>t=T−1,...,1,使用递归公式反向计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , t \delta_{h,t} </math>δh,t。
  3. 计算最终梯度:根据所有时间步的中间值,使用求和公式计算出所有参数的梯度。
  4. 参数更新:使用计算出的梯度,通过梯度下降等优化算法更新模型参数。
相关推荐
网络点点滴4 分钟前
前端与后端的区别与联系
前端
EnCi Zheng28 分钟前
M5-markconv自定义CSS样式指南 [特殊字符]
前端·css·python
kyriewen32 分钟前
你的网页慢,用户不说直接走——前端性能监控教你“读心术”
前端·性能优化·监控
广州华水科技33 分钟前
北斗GNSS变形监测在大坝安全监测中的应用与优势分析
前端
前端老石人44 分钟前
前端开发中的 URL 完全指南
开发语言·前端·javascript·css·html
CAE虚拟与现实44 分钟前
五一假期闲来无事,来个前段、后端的说明吧
前端·后端·vtk·three.js·前后端
Sarvartha1 小时前
三目运算符
linux·服务器·前端
晓晨的博客1 小时前
ROS1录制的bag包转换为ROS2格式
前端·chrome
Wect1 小时前
LeetCode 72. 编辑距离:动态规划经典题解
前端·算法·typescript
donecoding1 小时前
别再让 pnpm 跟着 nvm 跑了!独立安装终极指南
前端·node.js·前端工程化