RNN 梯度计算详细推导 (BPTT)
为了详细推导循环神经网络(RNN)中的梯度计算方法------沿时间反向传播(Backpropagation Through Time, BPTT),我们将使用一个最基础的RNN模型结构。
第一步:定义RNN模型和符号
在一个时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t,RNN的计算过程如下:
-
隐藏状态 (Hidden State):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h t = f ( U x t + W h t − 1 + b ) h_t = f(U x_t + W h_{t-1} + b) </math>ht=f(Uxt+Wht−1+b)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt:在时间步
t
的输入向量。 - <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht:在时间步
t
的隐藏状态向量。 <math xmlns="http://www.w3.org/1998/Math/MathML"> h 0 h_0 </math>h0 通常初始化为零向量。 - <math xmlns="http://www.w3.org/1998/Math/MathML"> h t − 1 h_{t-1} </math>ht−1:前一个时间步的隐藏状态。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> U , W , b U, W, b </math>U,W,b:循环层的参数。 <math xmlns="http://www.w3.org/1998/Math/MathML"> U U </math>U 是输入到隐藏层的权重矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> W W </math>W 是隐藏层到隐藏层的权重矩阵(循环权重), <math xmlns="http://www.w3.org/1998/Math/MathML"> b b </math>b 是偏置向量。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f:激活函数,通常是
tanh
或ReLU
。这里我们假设为tanh
。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt:在时间步
-
输出 (Output):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> o t = V h t + c o_t = V h_t + c </math>ot=Vht+c
- <math xmlns="http://www.w3.org/1998/Math/MathML"> o t o_t </math>ot:在时间步
t
的输出(或称为 logits)。 - <math xmlns="http://www.w3.org/1998/Math/MathML"> V , c V, c </math>V,c:输出层的参数。 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 是隐藏层到输出层的权重矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c 是偏置向量。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> o t o_t </math>ot:在时间步
-
预测概率 (Predicted Probability):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y ^ t = g ( o t ) \hat{y}_t = g(o_t) </math>y^t=g(ot)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> g g </math>g:输出激活函数,对于分类任务通常是
Softmax
。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> g g </math>g:输出激活函数,对于分类任务通常是
-
损失函数 (Loss Function):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L t = Loss ( y ^ t , y t ) L_t = \text{Loss}(\hat{y}_t, y_t) </math>Lt=Loss(y^t,yt)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> L t L_t </math>Lt:在时间步
t
的损失,例如交叉熵损失。 - <math xmlns="http://www.w3.org/1998/Math/MathML"> y t y_t </math>yt:在时间步
t
的真实标签。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> L t L_t </math>Lt:在时间步
总体目标
我们的目标是计算总损失 <math xmlns="http://www.w3.org/1998/Math/MathML"> L = ∑ t = 1 T L t L = \sum_{t=1}^{T} L_t </math>L=∑t=1TLt 对所有模型参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ = { U , W , V , b , c } \theta = \{U, W, V, b, c\} </math>θ={U,W,V,b,c} 的梯度。即求解: <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ V , ∂ L ∂ c , ∂ L ∂ W , ∂ L ∂ U , ∂ L ∂ b \frac{\partial L}{\partial V}, \frac{\partial L}{\partial c}, \frac{\partial L}{\partial W}, \frac{\partial L}{\partial U}, \frac{\partial L}{\partial b} </math>∂V∂L,∂c∂L,∂W∂L,∂U∂L,∂b∂L。
第二步:前向传播
模型按照 <math xmlns="http://www.w3.org/1998/Math/MathML"> t = 1 , 2 , ... , T t=1, 2, \dots, T </math>t=1,2,...,T 的顺序,依次计算出每个时间步的 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t , o t , y ^ t , L t h_t, o_t, \hat{y}_t, L_t </math>ht,ot,y^t,Lt,并最终得到总损失 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L。这个过程比较直接,就是将输入序列喂给模型,得到输出和损失。
第三步:反向传播 (BPTT)
梯度是反向计算的,从最后一个时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 开始,一直传播到第一个时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1。
A. 输出层参数的梯度 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ V , ∂ L ∂ c \frac{\partial L}{\partial V}, \frac{\partial L}{\partial c} </math>∂V∂L,∂c∂L)
这部分比较简单,因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c 的计算不涉及时间上的循环依赖。总损失对它们的梯度是每个时间步梯度贡献的总和。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ V = ∑ t = 1 T ∂ L t ∂ V \frac{\partial L}{\partial V} = \sum_{t=1}^{T} \frac{\partial L_t}{\partial V} </math>∂V∂L=t=1∑T∂V∂Lt
我们来看单个时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 的梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L t ∂ V \frac{\partial L_t}{\partial V} </math>∂V∂Lt。根据链式法则:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L t ∂ V = ∂ L t ∂ o t ∂ o t ∂ V \frac{\partial L_t}{\partial V} = \frac{\partial L_t}{\partial o_t} \frac{\partial o_t}{\partial V} </math>∂V∂Lt=∂ot∂Lt∂V∂ot
其中:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L t ∂ o t \frac{\partial L_t}{\partial o_t} </math>∂ot∂Lt 是损失对输出 logits 的梯度。我们将其记为 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ o , t \delta_{o,t} </math>δo,t。例如,对于Softmax和交叉熵损失, <math xmlns="http://www.w3.org/1998/Math/MathML"> δ o , t = y ^ t − y t \delta_{o,t} = \hat{y}_t - y_t </math>δo,t=y^t−yt。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ o t ∂ V \frac{\partial o_t}{\partial V} </math>∂V∂ot:因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> o t = V h t + c o_t = V h_t + c </math>ot=Vht+c,所以 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ o t ∂ V = h t T \frac{\partial o_t}{\partial V} = h_t^T </math>∂V∂ot=htT(转置是为了维度匹配)。
所以,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L t ∂ V = δ o , t ⋅ h t T \frac{\partial L_t}{\partial V} = \delta_{o,t} \cdot h_t^T </math>∂V∂Lt=δo,t⋅htT
最终,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ V = ∑ t = 1 T δ o , t ⋅ h t T \boxed{\frac{\partial L}{\partial V} = \sum_{t=1}^{T} \delta_{o,t} \cdot h_t^T} </math>∂V∂L=t=1∑Tδo,t⋅htT
同理,对于偏置 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ c = ∑ t = 1 T δ o , t \boxed{\frac{\partial L}{\partial c} = \sum_{t=1}^{T} \delta_{o,t}} </math>∂c∂L=t=1∑Tδo,t
B. 循环层参数的梯度 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ W , ∂ L ∂ U , ∂ L ∂ b \frac{\partial L}{\partial W}, \frac{\partial L}{\partial U}, \frac{\partial L}{\partial b} </math>∂W∂L,∂U∂L,∂b∂L)
这是BPTT的核心和难点。我们以 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ W \frac{\partial L}{\partial W} </math>∂W∂L 为例进行推导。 为了解决 <math xmlns="http://www.w3.org/1998/Math/MathML"> W W </math>W 在时间上的复杂依赖,我们引入一个关键的中间量:总损失 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 对隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht 的梯度 ,记为 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , t = ∂ L ∂ h t \delta_{h,t} = \frac{\partial L}{\partial h_t} </math>δh,t=∂ht∂L。
根据链式法则, <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 通过两条路径影响 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht:
- 通过当前时间步的输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> o t o_t </math>ot。
- 通过下一个时间步的隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t + 1 h_{t+1} </math>ht+1(因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t + 1 h_{t+1} </math>ht+1 的计算用到了 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht)。
因此, <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , t \delta_{h,t} </math>δh,t 的计算是一个从后向前的递归过程:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> δ h , t = ∂ L ∂ h t = ∂ L ∂ o t ∂ o t ∂ h t + ∂ L ∂ h t + 1 ∂ h t + 1 ∂ h t \delta_{h,t} = \frac{\partial L}{\partial h_t} = \frac{\partial L}{\partial o_t}\frac{\partial o_t}{\partial h_t} + \frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_t} </math>δh,t=∂ht∂L=∂ot∂L∂ht∂ot+∂ht+1∂L∂ht∂ht+1
将各部分代入,我们得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , t \delta_{h,t} </math>δh,t 的递归公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> δ h , t = δ o , t V T + δ h , t + 1 W T diag ( f ′ ( h t + 1 ) ) \delta_{h,t} = \delta_{o,t} V^T + \delta_{h,t+1} W^T \text{diag}(f'(h_{t+1})) </math>δh,t=δo,tVT+δh,t+1WTdiag(f′(ht+1))
- 递归的起点(Base Case) :在最后一个时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T,没有未来的隐藏状态,所以递归项为0。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> δ h , T = ∂ L ∂ h T = ∂ L T ∂ o T ∂ o T ∂ h T = δ o , T V T \delta_{h,T} = \frac{\partial L}{\partial h_T} = \frac{\partial L_T}{\partial o_T}\frac{\partial o_T}{\partial h_T} = \delta_{o,T} V^T </math>δh,T=∂hT∂L=∂oT∂LT∂hT∂oT=δo,TVT
我们可以从 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , T \delta_{h,T} </math>δh,T 开始,反向计算出 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , T − 1 , ... , δ h , 1 \delta_{h,T-1}, \dots, \delta_{h,1} </math>δh,T−1,...,δh,1。
现在,我们用 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , t \delta_{h,t} </math>δh,t 来计算最终的梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ W \frac{\partial L}{\partial W} </math>∂W∂L。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ W = ∑ t = 1 T ∂ L ∂ h t ∂ h t ∂ W = ∑ t = 1 T δ h , t ∂ h t ∂ W \frac{\partial L}{\partial W} = \sum_{t=1}^{T} \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial W} = \sum_{t=1}^{T} \delta_{h,t} \frac{\partial h_t}{\partial W} </math>∂W∂L=t=1∑T∂ht∂L∂W∂ht=t=1∑Tδh,t∂W∂ht
根据 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t = f ( U x t + W h t − 1 + b ) h_t = f(U x_t + W h_{t-1} + b) </math>ht=f(Uxt+Wht−1+b),我们有 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ h t ∂ W = diag ( f ′ ( h t ) ) ⋅ h t − 1 T \frac{\partial h_t}{\partial W} = \text{diag}(f'(h_t)) \cdot h_{t-1}^T </math>∂W∂ht=diag(f′(ht))⋅ht−1T。 所以,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ W = ∑ t = 1 T diag ( f ′ ( h t ) ) ⋅ δ h , t ⋅ h t − 1 T \boxed{\frac{\partial L}{\partial W} = \sum_{t=1}^{T} \text{diag}(f'(h_t)) \cdot \delta_{h,t} \cdot h_{t-1}^T} </math>∂W∂L=t=1∑Tdiag(f′(ht))⋅δh,t⋅ht−1T
同理可得 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ U \frac{\partial L}{\partial U} </math>∂U∂L 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ b \frac{\partial L}{\partial b} </math>∂b∂L:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ U = ∑ t = 1 T diag ( f ′ ( h t ) ) ⋅ δ h , t ⋅ x t T \boxed{\frac{\partial L}{\partial U} = \sum_{t=1}^{T} \text{diag}(f'(h_t)) \cdot \delta_{h,t} \cdot x_t^T} </math>∂U∂L=t=1∑Tdiag(f′(ht))⋅δh,t⋅xtT
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ b = ∑ t = 1 T diag ( f ′ ( h t ) ) ⋅ δ h , t \boxed{\frac{\partial L}{\partial b} = \sum_{t=1}^{T} \text{diag}(f'(h_t)) \cdot \delta_{h,t}} </math>∂b∂L=t=1∑Tdiag(f′(ht))⋅δh,t
第四步:梯度消失与爆炸的根源
回顾 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , t \delta_{h,t} </math>δh,t 的递归公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> δ h , t = ⋯ + δ h , t + 1 ⋅ ( W T diag ( f ′ ( h t + 1 ) ) ) \delta_{h,t} = \dots + \delta_{h,t+1} \cdot (W^T \text{diag}(f'(h_{t+1}))) </math>δh,t=⋯+δh,t+1⋅(WTdiag(f′(ht+1)))
我们可以看到,梯度在时间上传播时,会反复乘以循环权重矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W W </math>W。
- 梯度爆炸 (Gradient Exploding) :如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> W W </math>W 的某些特征值(或范数)大于1,经过多次连乘后,梯度会呈指数级增长,导致数值溢出,训练发散。
- 梯度消失 (Gradient Vanishing) :如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> W W </math>W 的某些特征值(或范数)小于1,经过多次连乘后,梯度会呈指数级衰减,趋近于0。这使得模型难以学习到长距离的依赖关系。
总结
BPTT算法的完整流程如下:
- 前向传播 :对于 <math xmlns="http://www.w3.org/1998/Math/MathML"> t = 1 , ... , T t=1, \dots, T </math>t=1,...,T,计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t , o t , L t h_t, o_t, L_t </math>ht,ot,Lt,得到总损失 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L。同时保存所有的 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t , h t x_t, h_t </math>xt,ht。
- 反向传播 : a. 计算最后一个时间步的隐藏层梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , T \delta_{h,T} </math>δh,T。 b. 对于 <math xmlns="http://www.w3.org/1998/Math/MathML"> t = T − 1 , ... , 1 t = T-1, \dots, 1 </math>t=T−1,...,1,使用递归公式反向计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ h , t \delta_{h,t} </math>δh,t。
- 计算最终梯度:根据所有时间步的中间值,使用求和公式计算出所有参数的梯度。
- 参数更新:使用计算出的梯度,通过梯度下降等优化算法更新模型参数。