RNN 梯度计算详细推导 (BPTT)

RNN 梯度计算详细推导 (BPTT)

为了详细推导循环神经网络(RNN)中的梯度计算方法------沿时间反向传播(Backpropagation Through Time, BPTT),我们将使用一个最基础的RNN模型结构。

第一步:定义RNN模型和符号

在一个时间步 t t t,RNN的计算过程如下:

  1. 隐藏状态 (Hidden State)

    h t = f ( U x t + W h t − 1 + b ) h_t = f(U x_t + W h_{t-1} + b) ht=f(Uxt+Wht−1+b)

    • x t x_t xt:在时间步 t 的输入向量。
    • h t h_t ht:在时间步 t 的隐藏状态向量。 h 0 h_0 h0 通常初始化为零向量。
    • h t − 1 h_{t-1} ht−1:前一个时间步的隐藏状态。
    • U , W , b U, W, b U,W,b:循环层的参数。 U U U 是输入到隐藏层的权重矩阵, W W W 是隐藏层到隐藏层的权重矩阵(循环权重), b b b 是偏置向量。
    • f f f:激活函数,通常是 tanhReLU。这里我们假设为 tanh
  2. 输出 (Output)

    o t = V h t + c o_t = V h_t + c ot=Vht+c

    • o t o_t ot:在时间步 t 的输出(或称为 logits)。
    • V , c V, c V,c:输出层的参数。 V V V 是隐藏层到输出层的权重矩阵, c c c 是偏置向量。
  3. 预测概率 (Predicted Probability)

    y ^ t = g ( o t ) \hat{y}_t = g(o_t) y^t=g(ot)

    • g g g:输出激活函数,对于分类任务通常是 Softmax
  4. 损失函数 (Loss Function)

    L t = Loss ( y ^ t , y t ) L_t = \text{Loss}(\hat{y}_t, y_t) Lt=Loss(y^t,yt)

    • L t L_t Lt:在时间步 t 的损失,例如交叉熵损失。
    • y t y_t yt:在时间步 t 的真实标签。

总体目标

我们的目标是计算总损失 L = ∑ t = 1 T L t L = \sum_{t=1}^{T} L_t L=∑t=1TLt 对所有模型参数 θ = { U , W , V , b , c } \theta = \{U, W, V, b, c\} θ={U,W,V,b,c} 的梯度。即求解: ∂ L ∂ V , ∂ L ∂ c , ∂ L ∂ W , ∂ L ∂ U , ∂ L ∂ b \frac{\partial L}{\partial V}, \frac{\partial L}{\partial c}, \frac{\partial L}{\partial W}, \frac{\partial L}{\partial U}, \frac{\partial L}{\partial b} ∂V∂L,∂c∂L,∂W∂L,∂U∂L,∂b∂L。


第二步:前向传播

模型按照 t = 1 , 2 , ... , T t=1, 2, \dots, T t=1,2,...,T 的顺序,依次计算出每个时间步的 h t , o t , y ^ t , L t h_t, o_t, \hat{y}_t, L_t ht,ot,y^t,Lt,并最终得到总损失 L L L。这个过程比较直接,就是将输入序列喂给模型,得到输出和损失。


第三步:反向传播 (BPTT)

梯度是反向计算的,从最后一个时间步 T T T 开始,一直传播到第一个时间步 1 1 1。

A. 输出层参数的梯度 ( ∂ L ∂ V , ∂ L ∂ c \frac{\partial L}{\partial V}, \frac{\partial L}{\partial c} ∂V∂L,∂c∂L)

这部分比较简单,因为 V V V 和 c c c 的计算不涉及时间上的循环依赖。总损失对它们的梯度是每个时间步梯度贡献的总和。
∂ L ∂ V = ∑ t = 1 T ∂ L t ∂ V \frac{\partial L}{\partial V} = \sum_{t=1}^{T} \frac{\partial L_t}{\partial V} ∂V∂L=t=1∑T∂V∂Lt

我们来看单个时间步 t t t 的梯度 ∂ L t ∂ V \frac{\partial L_t}{\partial V} ∂V∂Lt。根据链式法则:
∂ L t ∂ V = ∂ L t ∂ o t ∂ o t ∂ V \frac{\partial L_t}{\partial V} = \frac{\partial L_t}{\partial o_t} \frac{\partial o_t}{\partial V} ∂V∂Lt=∂ot∂Lt∂V∂ot

其中:

  • ∂ L t ∂ o t \frac{\partial L_t}{\partial o_t} ∂ot∂Lt 是损失对输出 logits 的梯度。我们将其记为 δ o , t \delta_{o,t} δo,t。例如,对于Softmax和交叉熵损失, δ o , t = y ^ t − y t \delta_{o,t} = \hat{y}_t - y_t δo,t=y^t−yt。
  • ∂ o t ∂ V \frac{\partial o_t}{\partial V} ∂V∂ot:因为 o t = V h t + c o_t = V h_t + c ot=Vht+c,所以 ∂ o t ∂ V = h t T \frac{\partial o_t}{\partial V} = h_t^T ∂V∂ot=htT(转置是为了维度匹配)。

所以,
∂ L t ∂ V = δ o , t ⋅ h t T \frac{\partial L_t}{\partial V} = \delta_{o,t} \cdot h_t^T ∂V∂Lt=δo,t⋅htT

最终,
∂ L ∂ V = ∑ t = 1 T δ o , t ⋅ h t T \boxed{\frac{\partial L}{\partial V} = \sum_{t=1}^{T} \delta_{o,t} \cdot h_t^T} ∂V∂L=t=1∑Tδo,t⋅htT

同理,对于偏置 c c c:
∂ L ∂ c = ∑ t = 1 T δ o , t \boxed{\frac{\partial L}{\partial c} = \sum_{t=1}^{T} \delta_{o,t}} ∂c∂L=t=1∑Tδo,t

B. 循环层参数的梯度 ( ∂ L ∂ W , ∂ L ∂ U , ∂ L ∂ b \frac{\partial L}{\partial W}, \frac{\partial L}{\partial U}, \frac{\partial L}{\partial b} ∂W∂L,∂U∂L,∂b∂L)

这是BPTT的核心和难点。我们以 ∂ L ∂ W \frac{\partial L}{\partial W} ∂W∂L 为例进行推导。 为了解决 W W W 在时间上的复杂依赖,我们引入一个关键的中间量:总损失 L L L 对隐藏状态 h t h_t ht 的梯度 ,记为 δ h , t = ∂ L ∂ h t \delta_{h,t} = \frac{\partial L}{\partial h_t} δh,t=∂ht∂L。

根据链式法则, L L L 通过两条路径影响 h t h_t ht:

  1. 通过当前时间步的输出 o t o_t ot。
  2. 通过下一个时间步的隐藏状态 h t + 1 h_{t+1} ht+1(因为 h t + 1 h_{t+1} ht+1 的计算用到了 h t h_t ht)。

因此, δ h , t \delta_{h,t} δh,t 的计算是一个从后向前的递归过程:
δ h , t = ∂ L ∂ h t = ∂ L ∂ o t ∂ o t ∂ h t + ∂ L ∂ h t + 1 ∂ h t + 1 ∂ h t \delta_{h,t} = \frac{\partial L}{\partial h_t} = \frac{\partial L}{\partial o_t}\frac{\partial o_t}{\partial h_t} + \frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_t} δh,t=∂ht∂L=∂ot∂L∂ht∂ot+∂ht+1∂L∂ht∂ht+1

将各部分代入,我们得到 δ h , t \delta_{h,t} δh,t 的递归公式:
δ h , t = δ o , t V T + δ h , t + 1 W T diag ( f ′ ( h t + 1 ) ) \delta_{h,t} = \delta_{o,t} V^T + \delta_{h,t+1} W^T \text{diag}(f'(h_{t+1})) δh,t=δo,tVT+δh,t+1WTdiag(f′(ht+1))

  • 递归的起点(Base Case) :在最后一个时间步 T T T,没有未来的隐藏状态,所以递归项为0。
    δ h , T = ∂ L ∂ h T = ∂ L T ∂ o T ∂ o T ∂ h T = δ o , T V T \delta_{h,T} = \frac{\partial L}{\partial h_T} = \frac{\partial L_T}{\partial o_T}\frac{\partial o_T}{\partial h_T} = \delta_{o,T} V^T δh,T=∂hT∂L=∂oT∂LT∂hT∂oT=δo,TVT

我们可以从 δ h , T \delta_{h,T} δh,T 开始,反向计算出 δ h , T − 1 , ... , δ h , 1 \delta_{h,T-1}, \dots, \delta_{h,1} δh,T−1,...,δh,1。

现在,我们用 δ h , t \delta_{h,t} δh,t 来计算最终的梯度 ∂ L ∂ W \frac{\partial L}{\partial W} ∂W∂L
∂ L ∂ W = ∑ t = 1 T ∂ L ∂ h t ∂ h t ∂ W = ∑ t = 1 T δ h , t ∂ h t ∂ W \frac{\partial L}{\partial W} = \sum_{t=1}^{T} \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial W} = \sum_{t=1}^{T} \delta_{h,t} \frac{\partial h_t}{\partial W} ∂W∂L=t=1∑T∂ht∂L∂W∂ht=t=1∑Tδh,t∂W∂ht

根据 h t = f ( U x t + W h t − 1 + b ) h_t = f(U x_t + W h_{t-1} + b) ht=f(Uxt+Wht−1+b),我们有 ∂ h t ∂ W = diag ( f ′ ( h t ) ) ⋅ h t − 1 T \frac{\partial h_t}{\partial W} = \text{diag}(f'(h_t)) \cdot h_{t-1}^T ∂W∂ht=diag(f′(ht))⋅ht−1T。 所以,
∂ L ∂ W = ∑ t = 1 T diag ( f ′ ( h t ) ) ⋅ δ h , t ⋅ h t − 1 T \boxed{\frac{\partial L}{\partial W} = \sum_{t=1}^{T} \text{diag}(f'(h_t)) \cdot \delta_{h,t} \cdot h_{t-1}^T} ∂W∂L=t=1∑Tdiag(f′(ht))⋅δh,t⋅ht−1T

同理可得 ∂ L ∂ U \frac{\partial L}{\partial U} ∂U∂L 和 ∂ L ∂ b \frac{\partial L}{\partial b} ∂b∂L:
∂ L ∂ U = ∑ t = 1 T diag ( f ′ ( h t ) ) ⋅ δ h , t ⋅ x t T \boxed{\frac{\partial L}{\partial U} = \sum_{t=1}^{T} \text{diag}(f'(h_t)) \cdot \delta_{h,t} \cdot x_t^T} ∂U∂L=t=1∑Tdiag(f′(ht))⋅δh,t⋅xtT
∂ L ∂ b = ∑ t = 1 T diag ( f ′ ( h t ) ) ⋅ δ h , t \boxed{\frac{\partial L}{\partial b} = \sum_{t=1}^{T} \text{diag}(f'(h_t)) \cdot \delta_{h,t}} ∂b∂L=t=1∑Tdiag(f′(ht))⋅δh,t


第四步:梯度消失与爆炸的根源

回顾 δ h , t \delta_{h,t} δh,t 的递归公式:
δ h , t = ⋯ + δ h , t + 1 ⋅ ( W T diag ( f ′ ( h t + 1 ) ) ) \delta_{h,t} = \dots + \delta_{h,t+1} \cdot (W^T \text{diag}(f'(h_{t+1}))) δh,t=⋯+δh,t+1⋅(WTdiag(f′(ht+1)))

我们可以看到,梯度在时间上传播时,会反复乘以循环权重矩阵 W W W

  • 梯度爆炸 (Gradient Exploding) :如果 W W W 的某些特征值(或范数)大于1,经过多次连乘后,梯度会呈指数级增长,导致数值溢出,训练发散。
  • 梯度消失 (Gradient Vanishing) :如果 W W W 的某些特征值(或范数)小于1,经过多次连乘后,梯度会呈指数级衰减,趋近于0。这使得模型难以学习到长距离的依赖关系。

总结

BPTT算法的完整流程如下:

  1. 前向传播 :对于 t = 1 , ... , T t=1, \dots, T t=1,...,T,计算 h t , o t , L t h_t, o_t, L_t ht,ot,Lt,得到总损失 L L L。同时保存所有的 x t , h t x_t, h_t xt,ht。
  2. 反向传播 : a. 计算最后一个时间步的隐藏层梯度 δ h , T \delta_{h,T} δh,T。 b. 对于 t = T − 1 , ... , 1 t = T-1, \dots, 1 t=T−1,...,1,使用递归公式反向计算 δ h , t \delta_{h,t} δh,t。
  3. 计算最终梯度:根据所有时间步的中间值,使用求和公式计算出所有参数的梯度。
  4. 参数更新:使用计算出的梯度,通过梯度下降等优化算法更新模型参数。
相关推荐
独泪了无痕几秒前
Lodash-JavaScript的实用工具库
前端·javascript
有趣的老凌2 分钟前
用 Vibe Coding 搭了一个完整小程序「一定能成」
前端·javascript·后端
kyriewen11 小时前
Anthropic 估值逼近万亿美元,Claude Sonnet 5 + Claude Science 一天两连发
前端·ai编程·claude
小徐_233312 小时前
Wot UI 2.2.0 发布:Button 新增 subtle,VideoPreview 预览体验继续增强
前端·微信小程序·uni-app
天蓝色的鱼鱼14 小时前
关于 CSS 你可能不知道的属性,但关键时刻很有用
前端·css
泯泷15 小时前
第 2 篇:设计第一套字节码:Opcode、Instruction 与 Constant Pool
前端·javascript·安全
妙码生花15 小时前
从 PHP 到 AI + Golang,程序员自救转型手记(十五):优化细节、网络请求封装
前端·后端·ai编程
泯泷15 小时前
第 1 篇:从 1 + 2 开始:亲手写出第一台 JSVM
前端·javascript·安全
团团崽_七分甜15 小时前
Spring Boot 核心知识点总结
前端