RNN 反向传播数学推导(BPTT 时间反向传播)
一、标准 RNN 前向传播公式
设:
- 输入序列:xt∈Rdinx_t \in \mathbb{R}^{d_{in}}xt∈Rdin
- 隐藏状态:ht∈Rdhidh_t \in \mathbb{R}^{d_{hid}}ht∈Rdhid
- 输出:ot∈Rdouto_t \in \mathbb{R}^{d_{out}}ot∈Rdout
- 参数:
WxhW_{xh}Wxh:输入→隐藏,WhhW_{hh}Whh:隐藏→隐藏,WhyW_{hy}Why:隐藏→输出
bhb_hbh:隐藏偏置,byb_yby:输出偏置
前向传播
ht=σ(Wxhxt+Whhht−1+bh)ot=Whyht+byy^t=softmax(ot) \Huge \begin{align} h_t &= \sigma\big(W_{xh}x_t + W_{hh}h_{t-1} + b_h\big) \\ o_t &= W_{hy}h_t + b_y \\ \hat y_t &= \text{softmax}(o_t) \end{align} htoty^t=σ(Wxhxt+Whhht−1+bh)=Whyht+by=softmax(ot)
σ\sigmaσ 为激活函数(常用 tanh\tanhtanh),损失用交叉熵:
L=∑t=1TLt,Lt=−∑iyt,ilogy^t,i \mathcal{L} = \sum_{t=1}^T \mathcal{L}t,\quad \mathcal{L}t = -\sum_i y{t,i}\log\hat y{t,i} L=t=1∑TLt,Lt=−i∑yt,ilogy^t,i
二、BPTT 核心思想
RNN 参数所有时间步共享 ,反向传播沿时间维度回溯 TTT 步 ,链式求导:
∂L∂W=∑t=1T∂Lt∂W \Huge \frac{\partial \mathcal{L}}{\partial W} = \sum_{t=1}^T \frac{\partial \mathcal{L}_t}{\partial W} ∂W∂L=t=1∑T∂W∂Lt
定义梯度记号:
δt=∂L∂ht∈Rdhid \Huge \delta_t = \frac{\partial \mathcal{L}}{\partial h_t} \in \mathbb{R}^{d_{hid}} δt=∂ht∂L∈Rdhid
三、逐位置梯度推导
1. 输出端梯度 ∂Lt∂ot\dfrac{\partial \mathcal{L}_t}{\partial o_t}∂ot∂Lt
交叉熵+softmax 简化结论:
∂Lt∂ot=y^t−yt \Huge \frac{\partial \mathcal{L}_t}{\partial o_t} = \hat y_t - y_t ∂ot∂Lt=y^t−yt
2. 隐藏层误差 δt\delta_tδt
链式法则:
δt=∂Lt∂ht=∂ot∂ht⊤∂Lt∂ot+∂ht+1∂ht⊤δt+1 \Huge \delta_t = \frac{\partial \mathcal{L}t}{\partial h_t} = \frac{\partial o_t}{\partial h_t}^\top \frac{\partial \mathcal{L}t}{\partial o_t} + \frac{\partial h{t+1}}{\partial h_t}^\top \delta{t+1} δt=∂ht∂Lt=∂ht∂ot⊤∂ot∂Lt+∂ht∂ht+1⊤δt+1
代入 ot=Whyht+byo_t=W_{hy}h_t+b_yot=Whyht+by 和 ht+1=σ(⋅)h_{t+1}=\sigma(\cdot)ht+1=σ(⋅):
∂ot∂ht=Why⊤ \Huge \frac{\partial o_t}{\partial h_t} = W_{hy}^\top ∂ht∂ot=Why⊤
∂ht+1∂ht=Whh⊤⊙σ′(zt+1) \Huge \frac{\partial h_{t+1}}{\partial h_t} = W_{hh}^\top \odot \sigma'(z_{t+1}) ∂ht∂ht+1=Whh⊤⊙σ′(zt+1)
zt+1=Wxhxt+1+Whhht+bhz_{t+1}=W_{xh}x_{t+1}+W_{hh}h_t+b_hzt+1=Wxhxt+1+Whhht+bh,⊙\odot⊙ 逐元素乘
得到 δt\delta_tδt 递推公式 :
δt=Why⊤(y^t−yt)+Whh⊤⋅σ′(zt+1)⊙δt+1 \Huge \delta_t = W_{hy}^\top (\hat y_t - y_t) + W_{hh}^\top \cdot \sigma'(z_{t+1}) \odot \delta_{t+1} δt=Why⊤(y^t−yt)+Whh⊤⋅σ′(zt+1)⊙δt+1
终止条件 :最后时刻 t=Tt=Tt=T,无下一时刻
δT=Why⊤(y^T−yT) \Huge \delta_T = W_{hy}^\top (\hat y_T - y_T) δT=Why⊤(y^T−yT)
四、所有参数梯度公式
1. Why, byW_{hy},\ b_yWhy, by
∂L∂Why=∑t=1T(y^t−yt) ht⊤∂L∂by=∑t=1T(y^t−yt) \Huge \begin{align} \frac{\partial \mathcal{L}}{\partial W_{hy}} &= \sum_{t=1}^T (\hat y_t - y_t)\, h_t^\top \\ \frac{\partial \mathcal{L}}{\partial b_y} &= \sum_{t=1}^T (\hat y_t - y_t) \end{align} ∂Why∂L∂by∂L=t=1∑T(y^t−yt)ht⊤=t=1∑T(y^t−yt)
2. Wxh, Whh, bhW_{xh},\ W_{hh},\ b_hWxh, Whh, bh
先记:∂ht∂zt=σ′(zt)\dfrac{\partial h_t}{\partial z_t} = \sigma'(z_t)∂zt∂ht=σ′(zt)
∂L∂Wxh=∑t=1Tδt⊙σ′(zt) xt⊤∂L∂Whh=∑t=1Tδt⊙σ′(zt) ht−1⊤∂L∂bh=∑t=1Tδt⊙σ′(zt) \Huge \begin{align} \frac{\partial \mathcal{L}}{\partial W_{xh}} &= \sum_{t=1}^T \delta_t \odot \sigma'(z_t)\; x_t^\top \\ \frac{\partial \mathcal{L}}{\partial W_{hh}} &= \sum_{t=1}^T \delta_t \odot \sigma'(z_t)\; h_{t-1}^\top \\ \frac{\partial \mathcal{L}}{\partial b_h} &= \sum_{t=1}^T \delta_t \odot \sigma'(z_t) \end{align} ∂Wxh∂L∂Whh∂L∂bh∂L=t=1∑Tδt⊙σ′(zt)xt⊤=t=1∑Tδt⊙σ′(zt)ht−1⊤=t=1∑Tδt⊙σ′(zt)
五、Tanh 激活导数(常用)
若 σ(z)=tanh(z)\sigma(z)=\tanh(z)σ(z)=tanh(z):
σ′(z)=1−tanh2(z)=1−ht2 \Huge \sigma'(z) = 1 - \tanh^2(z) = 1 - h_t^2 σ′(z)=1−tanh2(z)=1−ht2
直接代入上面公式即可。