RNN 反向传播数学推导(BPTT 时间反向传播)

RNN 反向传播数学推导(BPTT 时间反向传播)

一、标准 RNN 前向传播公式

设:

  • 输入序列:xt∈Rdinx_t \in \mathbb{R}^{d_{in}}xt∈Rdin
  • 隐藏状态:ht∈Rdhidh_t \in \mathbb{R}^{d_{hid}}ht∈Rdhid
  • 输出:ot∈Rdouto_t \in \mathbb{R}^{d_{out}}ot∈Rdout
  • 参数:
    WxhW_{xh}Wxh:输入→隐藏,WhhW_{hh}Whh:隐藏→隐藏,WhyW_{hy}Why:隐藏→输出
    bhb_hbh:隐藏偏置,byb_yby:输出偏置

前向传播
ht=σ(Wxhxt+Whhht−1+bh)ot=Whyht+byy^t=softmax(ot) \Huge \begin{align} h_t &= \sigma\big(W_{xh}x_t + W_{hh}h_{t-1} + b_h\big) \\ o_t &= W_{hy}h_t + b_y \\ \hat y_t &= \text{softmax}(o_t) \end{align} htoty^t=σ(Wxhxt+Whhht−1+bh)=Whyht+by=softmax(ot)
σ\sigmaσ 为激活函数(常用 tanh⁡\tanhtanh),损失用交叉熵:
L=∑t=1TLt,Lt=−∑iyt,ilog⁡y^t,i \mathcal{L} = \sum_{t=1}^T \mathcal{L}t,\quad \mathcal{L}t = -\sum_i y{t,i}\log\hat y{t,i} L=t=1∑TLt,Lt=−i∑yt,ilogy^t,i


二、BPTT 核心思想

RNN 参数所有时间步共享 ,反向传播沿时间维度回溯 TTT 步 ,链式求导:
∂L∂W=∑t=1T∂Lt∂W \Huge \frac{\partial \mathcal{L}}{\partial W} = \sum_{t=1}^T \frac{\partial \mathcal{L}_t}{\partial W} ∂W∂L=t=1∑T∂W∂Lt

定义梯度记号:
δt=∂L∂ht∈Rdhid \Huge \delta_t = \frac{\partial \mathcal{L}}{\partial h_t} \in \mathbb{R}^{d_{hid}} δt=∂ht∂L∈Rdhid


三、逐位置梯度推导

1. 输出端梯度 ∂Lt∂ot\dfrac{\partial \mathcal{L}_t}{\partial o_t}∂ot∂Lt

交叉熵+softmax 简化结论:
∂Lt∂ot=y^t−yt \Huge \frac{\partial \mathcal{L}_t}{\partial o_t} = \hat y_t - y_t ∂ot∂Lt=y^t−yt

2. 隐藏层误差 δt\delta_tδt

链式法则:
δt=∂Lt∂ht=∂ot∂ht⊤∂Lt∂ot+∂ht+1∂ht⊤δt+1 \Huge \delta_t = \frac{\partial \mathcal{L}t}{\partial h_t} = \frac{\partial o_t}{\partial h_t}^\top \frac{\partial \mathcal{L}t}{\partial o_t} + \frac{\partial h{t+1}}{\partial h_t}^\top \delta{t+1} δt=∂ht∂Lt=∂ht∂ot⊤∂ot∂Lt+∂ht∂ht+1⊤δt+1

代入 ot=Whyht+byo_t=W_{hy}h_t+b_yot=Whyht+by 和 ht+1=σ(⋅)h_{t+1}=\sigma(\cdot)ht+1=σ(⋅):
∂ot∂ht=Why⊤ \Huge \frac{\partial o_t}{\partial h_t} = W_{hy}^\top ∂ht∂ot=Why⊤

∂ht+1∂ht=Whh⊤⊙σ′(zt+1) \Huge \frac{\partial h_{t+1}}{\partial h_t} = W_{hh}^\top \odot \sigma'(z_{t+1}) ∂ht∂ht+1=Whh⊤⊙σ′(zt+1)

zt+1=Wxhxt+1+Whhht+bhz_{t+1}=W_{xh}x_{t+1}+W_{hh}h_t+b_hzt+1=Wxhxt+1+Whhht+bh,⊙\odot⊙ 逐元素乘

得到 δt\delta_tδt 递推公式
δt=Why⊤(y^t−yt)+Whh⊤⋅σ′(zt+1)⊙δt+1 \Huge \delta_t = W_{hy}^\top (\hat y_t - y_t) + W_{hh}^\top \cdot \sigma'(z_{t+1}) \odot \delta_{t+1} δt=Why⊤(y^t−yt)+Whh⊤⋅σ′(zt+1)⊙δt+1

终止条件 :最后时刻 t=Tt=Tt=T,无下一时刻
δT=Why⊤(y^T−yT) \Huge \delta_T = W_{hy}^\top (\hat y_T - y_T) δT=Why⊤(y^T−yT)


四、所有参数梯度公式

1. Why, byW_{hy},\ b_yWhy, by

∂L∂Why=∑t=1T(y^t−yt) ht⊤∂L∂by=∑t=1T(y^t−yt) \Huge \begin{align} \frac{\partial \mathcal{L}}{\partial W_{hy}} &= \sum_{t=1}^T (\hat y_t - y_t)\, h_t^\top \\ \frac{\partial \mathcal{L}}{\partial b_y} &= \sum_{t=1}^T (\hat y_t - y_t) \end{align} ∂Why∂L∂by∂L=t=1∑T(y^t−yt)ht⊤=t=1∑T(y^t−yt)

2. Wxh, Whh, bhW_{xh},\ W_{hh},\ b_hWxh, Whh, bh

先记:∂ht∂zt=σ′(zt)\dfrac{\partial h_t}{\partial z_t} = \sigma'(z_t)∂zt∂ht=σ′(zt)
∂L∂Wxh=∑t=1Tδt⊙σ′(zt)  xt⊤∂L∂Whh=∑t=1Tδt⊙σ′(zt)  ht−1⊤∂L∂bh=∑t=1Tδt⊙σ′(zt) \Huge \begin{align} \frac{\partial \mathcal{L}}{\partial W_{xh}} &= \sum_{t=1}^T \delta_t \odot \sigma'(z_t)\; x_t^\top \\ \frac{\partial \mathcal{L}}{\partial W_{hh}} &= \sum_{t=1}^T \delta_t \odot \sigma'(z_t)\; h_{t-1}^\top \\ \frac{\partial \mathcal{L}}{\partial b_h} &= \sum_{t=1}^T \delta_t \odot \sigma'(z_t) \end{align} ∂Wxh∂L∂Whh∂L∂bh∂L=t=1∑Tδt⊙σ′(zt)xt⊤=t=1∑Tδt⊙σ′(zt)ht−1⊤=t=1∑Tδt⊙σ′(zt)


五、Tanh 激活导数(常用)

若 σ(z)=tanh⁡(z)\sigma(z)=\tanh(z)σ(z)=tanh(z):
σ′(z)=1−tanh⁡2(z)=1−ht2 \Huge \sigma'(z) = 1 - \tanh^2(z) = 1 - h_t^2 σ′(z)=1−tanh2(z)=1−ht2

直接代入上面公式即可。


相关推荐
战族狼魂1 分钟前
集 “自动飞行、智能识别、实时预警、勤务联动” 于一体的高速公路应急车道无人机检测系统方案
java·人工智能·大模型·无人机
月光船幽幽2 分钟前
Helio-Core临界控制:守护拓扑量子稳定
人工智能·科技·动态规划·拓扑学
jkyy20144 分钟前
大模型重构饮食健康服务链路:多维技术赋能膳食管理智能化升级
大数据·人工智能·信息可视化·重构·健康医疗
罗西的思考5 分钟前
【Agentic RL / 强化学习 / OPD】OpenClaw-RL 源码阅读笔记 --- (4)--- 系统架构
人工智能·算法·机器学习
2601_957888566 分钟前
从关键词到语义网络:生成式引擎优化(GEO)的技术原理解析与工程实践
人工智能·大模型
2501_934440237 分钟前
简申的服务哲学中,“专业”从来不是冰冷的技术名词,而是一种设身处地的责任担当
人工智能
慧一居士12 分钟前
OpenAI API 协议、 Chat Completions API、Responses API 协议 对比和联系,适用场景以及还有哪些其他协议详解
人工智能
TAOCARTS00120 分钟前
反向海淘旺季运营技巧,借助独立站快速拉升店铺单量
大数据·人工智能
lqqjuly25 分钟前
知识蒸馏:理论、算法与可运行实现
人工智能·深度学习·算法
小丶舟26 分钟前
6GB显卡跑Hermes Agent!开源AI自学习编程Agent实测
人工智能·学习·开源