RNN 反向传播数学推导(BPTT 时间反向传播)

RNN 反向传播数学推导(BPTT 时间反向传播)

一、标准 RNN 前向传播公式

设:

  • 输入序列:xt∈Rdinx_t \in \mathbb{R}^{d_{in}}xt∈Rdin
  • 隐藏状态:ht∈Rdhidh_t \in \mathbb{R}^{d_{hid}}ht∈Rdhid
  • 输出:ot∈Rdouto_t \in \mathbb{R}^{d_{out}}ot∈Rdout
  • 参数:
    WxhW_{xh}Wxh:输入→隐藏,WhhW_{hh}Whh:隐藏→隐藏,WhyW_{hy}Why:隐藏→输出
    bhb_hbh:隐藏偏置,byb_yby:输出偏置

前向传播
ht=σ(Wxhxt+Whhht−1+bh)ot=Whyht+byy^t=softmax(ot) \Huge \begin{align} h_t &= \sigma\big(W_{xh}x_t + W_{hh}h_{t-1} + b_h\big) \\ o_t &= W_{hy}h_t + b_y \\ \hat y_t &= \text{softmax}(o_t) \end{align} htoty^t=σ(Wxhxt+Whhht−1+bh)=Whyht+by=softmax(ot)
σ\sigmaσ 为激活函数(常用 tanh⁡\tanhtanh),损失用交叉熵:
L=∑t=1TLt,Lt=−∑iyt,ilog⁡y^t,i \mathcal{L} = \sum_{t=1}^T \mathcal{L}t,\quad \mathcal{L}t = -\sum_i y{t,i}\log\hat y{t,i} L=t=1∑TLt,Lt=−i∑yt,ilogy^t,i


二、BPTT 核心思想

RNN 参数所有时间步共享 ,反向传播沿时间维度回溯 TTT 步 ,链式求导:
∂L∂W=∑t=1T∂Lt∂W \Huge \frac{\partial \mathcal{L}}{\partial W} = \sum_{t=1}^T \frac{\partial \mathcal{L}_t}{\partial W} ∂W∂L=t=1∑T∂W∂Lt

定义梯度记号:
δt=∂L∂ht∈Rdhid \Huge \delta_t = \frac{\partial \mathcal{L}}{\partial h_t} \in \mathbb{R}^{d_{hid}} δt=∂ht∂L∈Rdhid


三、逐位置梯度推导

1. 输出端梯度 ∂Lt∂ot\dfrac{\partial \mathcal{L}_t}{\partial o_t}∂ot∂Lt

交叉熵+softmax 简化结论:
∂Lt∂ot=y^t−yt \Huge \frac{\partial \mathcal{L}_t}{\partial o_t} = \hat y_t - y_t ∂ot∂Lt=y^t−yt

2. 隐藏层误差 δt\delta_tδt

链式法则:
δt=∂Lt∂ht=∂ot∂ht⊤∂Lt∂ot+∂ht+1∂ht⊤δt+1 \Huge \delta_t = \frac{\partial \mathcal{L}t}{\partial h_t} = \frac{\partial o_t}{\partial h_t}^\top \frac{\partial \mathcal{L}t}{\partial o_t} + \frac{\partial h{t+1}}{\partial h_t}^\top \delta{t+1} δt=∂ht∂Lt=∂ht∂ot⊤∂ot∂Lt+∂ht∂ht+1⊤δt+1

代入 ot=Whyht+byo_t=W_{hy}h_t+b_yot=Whyht+by 和 ht+1=σ(⋅)h_{t+1}=\sigma(\cdot)ht+1=σ(⋅):
∂ot∂ht=Why⊤ \Huge \frac{\partial o_t}{\partial h_t} = W_{hy}^\top ∂ht∂ot=Why⊤

∂ht+1∂ht=Whh⊤⊙σ′(zt+1) \Huge \frac{\partial h_{t+1}}{\partial h_t} = W_{hh}^\top \odot \sigma'(z_{t+1}) ∂ht∂ht+1=Whh⊤⊙σ′(zt+1)

zt+1=Wxhxt+1+Whhht+bhz_{t+1}=W_{xh}x_{t+1}+W_{hh}h_t+b_hzt+1=Wxhxt+1+Whhht+bh,⊙\odot⊙ 逐元素乘

得到 δt\delta_tδt 递推公式
δt=Why⊤(y^t−yt)+Whh⊤⋅σ′(zt+1)⊙δt+1 \Huge \delta_t = W_{hy}^\top (\hat y_t - y_t) + W_{hh}^\top \cdot \sigma'(z_{t+1}) \odot \delta_{t+1} δt=Why⊤(y^t−yt)+Whh⊤⋅σ′(zt+1)⊙δt+1

终止条件 :最后时刻 t=Tt=Tt=T,无下一时刻
δT=Why⊤(y^T−yT) \Huge \delta_T = W_{hy}^\top (\hat y_T - y_T) δT=Why⊤(y^T−yT)


四、所有参数梯度公式

1. Why, byW_{hy},\ b_yWhy, by

∂L∂Why=∑t=1T(y^t−yt) ht⊤∂L∂by=∑t=1T(y^t−yt) \Huge \begin{align} \frac{\partial \mathcal{L}}{\partial W_{hy}} &= \sum_{t=1}^T (\hat y_t - y_t)\, h_t^\top \\ \frac{\partial \mathcal{L}}{\partial b_y} &= \sum_{t=1}^T (\hat y_t - y_t) \end{align} ∂Why∂L∂by∂L=t=1∑T(y^t−yt)ht⊤=t=1∑T(y^t−yt)

2. Wxh, Whh, bhW_{xh},\ W_{hh},\ b_hWxh, Whh, bh

先记:∂ht∂zt=σ′(zt)\dfrac{\partial h_t}{\partial z_t} = \sigma'(z_t)∂zt∂ht=σ′(zt)
∂L∂Wxh=∑t=1Tδt⊙σ′(zt)  xt⊤∂L∂Whh=∑t=1Tδt⊙σ′(zt)  ht−1⊤∂L∂bh=∑t=1Tδt⊙σ′(zt) \Huge \begin{align} \frac{\partial \mathcal{L}}{\partial W_{xh}} &= \sum_{t=1}^T \delta_t \odot \sigma'(z_t)\; x_t^\top \\ \frac{\partial \mathcal{L}}{\partial W_{hh}} &= \sum_{t=1}^T \delta_t \odot \sigma'(z_t)\; h_{t-1}^\top \\ \frac{\partial \mathcal{L}}{\partial b_h} &= \sum_{t=1}^T \delta_t \odot \sigma'(z_t) \end{align} ∂Wxh∂L∂Whh∂L∂bh∂L=t=1∑Tδt⊙σ′(zt)xt⊤=t=1∑Tδt⊙σ′(zt)ht−1⊤=t=1∑Tδt⊙σ′(zt)


五、Tanh 激活导数(常用)

若 σ(z)=tanh⁡(z)\sigma(z)=\tanh(z)σ(z)=tanh(z):
σ′(z)=1−tanh⁡2(z)=1−ht2 \Huge \sigma'(z) = 1 - \tanh^2(z) = 1 - h_t^2 σ′(z)=1−tanh2(z)=1−ht2

直接代入上面公式即可。


相关推荐
码农小白AI1 小时前
漏电流报告审核为何进入“新速度时代”?IACheck用AI报告审核重构效率与精度
人工智能·重构
wanhengidc1 小时前
算力服务器的应用场景
运维·服务器·人工智能·安全·web安全·智能手机
企微增长观察1 小时前
2026企业微信AI SCRM实测:微盛·企微管家全行业私域运营
大数据·人工智能·企业微信
一只数据集1 小时前
Unitree G1苹果拾取放置深度数据集:963条高质量RGB-D操作轨迹助力3D感知与机器人学习
人工智能·学习·3d·机器人·制造
Black蜡笔小新1 小时前
自动化AI算法训练服务器/企业AI算力工作站DLTM重塑企业AI开发模式赋能企业智能转型
人工智能·算法·自动化
Mr数据杨1 小时前
【CanMV K210】AI 视觉 68 点人脸关键点检测与轮廓定位
人工智能·硬件开发·canmv k210
才兄说1 小时前
机器人二次开发机器狗巡检?多源传感器融合建图
人工智能·机器人
xinshu5271 小时前
2026企业联系方式查询平台对比:哪个能查到详细电话?
人工智能·技术分享
renhongxia11 小时前
开源大模型VS闭源大模型:2026年格局再梳理
深度学习·算法·语言模型·分类·开源