详解 KL 散度的反向传播计算:以三分类神经网络为例

一、先明确反向传播的核心前提

1. 模型结构简化(聚焦关键层)

为了清晰展示,我们简化神经网络结构(实际模型可能有多个隐藏层,但核心逻辑一致):

  • 输入层:手写数字图像的特征向量(如扁平化后的28×28=784维);
  • 隐藏层:1个全连接层(假设输出维度为100,激活函数用ReLU);
  • 输出层 :全连接层(输出维度=类别数3,无激活函数,输出为「logits」:z0,z1,z2z_0, z_1, z_2z0,z1,z2);
  • 激活层 :Softmax函数(将logits转换为预测概率分布 Q(x)=[Q0,Q1,Q2]Q(x) = [Q_0, Q_1, Q_2]Q(x)=[Q0,Q1,Q2])。
2. 损失函数定义

由于我们的目标是对齐 PPP 和 QQQ,整体损失函数 L\mathcal{L}L 直接等于KL散度 (实际中可能会叠加其他损失项,但此处聚焦KL散度的反向传播):
L=DKL(P∥Q)=∑c=02Pc⋅ln⁡(PcQc)\mathcal{L} = D_{KL}(P \parallel Q) = \sum_{c=0}^2 P_c \cdot \ln\left( \frac{P_c}{Q_c} \right)L=DKL(P∥Q)=c=0∑2Pc⋅ln(QcPc)

  • 其中:PcP_cPc 是真实分布的固定值(如 P0=0.3,P1=0.5,P2=0.2P_0=0.3, P_1=0.5, P_2=0.2P0=0.3,P1=0.5,P2=0.2),Qc=Softmax(zc)=ezc∑k=02ezkQ_c = \text{Softmax}(z_c) = \frac{e^{z_c}}{\sum_{k=0}^2 e^{z_k}}Qc=Softmax(zc)=∑k=02ezkezc(zcz_czc 是输出层logits,是模型参数的函数)。
3. 反向传播目标

计算损失 L\mathcal{L}L 对模型所有可训练参数(隐藏层权重 W1W_1W1、偏置 b1b_1b1;输出层权重 W2W_2W2、偏置 b2b_2b2)的梯度,然后用梯度下降更新参数: θ=θ−η⋅∂L∂θ\theta = \theta - \eta \cdot \frac{\partial \mathcal{L}}{\partial \theta}θ=θ−η⋅∂θ∂L (θ\thetaθ 表示任意可训练参数,η\etaη 是学习率,如0.001)。

二、关键步骤:推导KL散度对logits的梯度(核心桥梁)

反向传播的核心是"链式法则",而 KL散度对输出层logits zcz_czc 的梯度 ∂L∂zc\frac{\partial \mathcal{L}}{\partial z_c}∂zc∂L 是连接损失和模型参数的关键(因为logits直接由输出层参数计算得到,再往前传播到隐藏层即可)。

1. 简化KL散度公式(方便求导)

先拆分KL散度的表达式: L=∑c=02Pc⋅ln⁡Pc−∑c=02Pc⋅ln⁡Qc\mathcal{L} = \sum_{c=0}^2 P_c \cdot \ln P_c - \sum_{c=0}^2 P_c \cdot \ln Q_c L=c=0∑2Pc⋅lnPc−c=0∑2Pc⋅lnQc

  • 第一项 ∑c=02Pc⋅ln⁡Pc\sum_{c=0}^2 P_c \cdot \ln P_c∑c=02Pc⋅lnPc:是"真实分布的熵",PcP_cPc 是固定值(从数据中统计得到),因此对任何模型参数(包括 zcz_czc)的导数都为 0
  • 第二项是关键:L=−∑c=02Pc⋅ln⁡Qc\mathcal{L} = - \sum_{c=0}^2 P_c \cdot \ln Q_cL=−∑c=02Pc⋅lnQc(求导时只需关注这一项)。
2. 代入Softmax公式,

求 ∂L∂zc\frac{\partial \mathcal{L}}{\partial z_c}∂zc∂L 已知 Qc=ezc∑k=02ezk=ezcZQ_c = \frac{e^{z_c}}{\sum_{k=0}^2 e^{z_k}} = \frac{e^{z_c}}{Z}Qc=∑k=02ezkezc=Zezc(其中 Z=∑k=02ezkZ = \sum_{k=0}^2 e^{z_k}Z=∑k=02ezk 是归一化常数)。 根据链式法则:∂L∂zc=−∑j=02Pj⋅∂ln⁡Qj∂zc\frac{\partial \mathcal{L}}{\partial z_c} = - \sum_{j=0}^2 P_j \cdot \frac{\partial \ln Q_j}{\partial z_c}∂zc∂L=−j=0∑2Pj⋅∂zc∂lnQj(对每个logit zcz_czc,需要考虑它对所有 QjQ_jQj 的影响,因为 ZZZ 包含所有 zkz_kzk)。

进一步展开: ∂ln⁡Qj∂zc=1Qj⋅∂Qj∂zc \frac{\partial \ln Q_j}{\partial z_c} = \frac{1}{Q_j} \cdot \frac{\partial Q_j}{\partial z_c} ∂zc∂lnQj=Qj1⋅∂zc∂Qj

因此: ∂L∂zc=−∑j=02Pj⋅1Qj⋅∂Qj∂zc\frac{\partial \mathcal{L}}{\partial z_c} = - \sum_{j=0}^2 P_j \cdot \frac{1}{Q_j} \cdot \frac{\partial Q_j}{\partial z_c} ∂zc∂L=−j=0∑2Pj⋅Qj1⋅∂zc∂Qj

3. 利用Softmax的梯度性质(关键简化)

Softmax函数有一个核心梯度性质(必须记住,推导略):∂Qj∂zc=Qj⋅(δc,j−Qc)\frac{\partial Q_j}{\partial z_c} = Q_j \cdot (\delta_{c,j} - Q_c)∂zc∂Qj=Qj⋅(δc,j−Qc)

  • 其中 δc,j\delta_{c,j}δc,j 是「克罗内克函数」:当 c=jc = jc=j 时,δc,j=1\delta_{c,j} = 1δc,j=1;当 c≠jc \neq jc=j 时,δc,j=0\delta_{c,j} = 0δc,j=0。 将这个性质代入上式: ∂L∂zc=−∑j=02Pj⋅1Qj⋅Qj⋅(δc,j−Qc)\frac{\partial \mathcal{L}}{\partial z_c} = - \sum_{j=0}^2 P_j \cdot \frac{1}{Q_j} \cdot Q_j \cdot (\delta_{c,j} - Q_c)∂zc∂L=−j=0∑2Pj⋅Qj1⋅Qj⋅(δc,j−Qc)
  • QjQ_jQj 约分后简化为: ∂L∂zc=−∑j=02Pj⋅(δc,j−Qc)\frac{\partial \mathcal{L}}{\partial z_c} = - \sum_{j=0}^2 P_j \cdot (\delta_{c,j} - Q_c) ∂zc∂L=−j=0∑2Pj⋅(δc,j−Qc)
4. 拆分求和项,最终化简

将求和项拆分为 j=cj=cj=c 和 j≠cj \neq cj=c 两部分:

  • 当 j=cj = cj=c 时:δc,j=1\delta_{c,j} = 1δc,j=1,该项为 Pc⋅(1−Qc)P_c \cdot (1 - Q_c)Pc⋅(1−Qc);
    • 当 j≠cj \neq cj=c 时:δc,j=0\delta_{c,j} = 0δc,j=0,该项为 Pj⋅(0−Qc)=−Pj⋅QcP_j \cdot (0 - Q_c) = - P_j \cdot Q_cPj⋅(0−Qc)=−Pj⋅Qc。
  • 因此: ∂L∂zc=−[Pc⋅(1−Qc)−Qc⋅∑j≠cPj]\frac{\partial \mathcal{L}}{\partial z_c} = - \left[ P_c \cdot (1 - Q_c) - Q_c \cdot \sum_{j \neq c} P_j \right]∂zc∂L=− Pc⋅(1−Qc)−Qc⋅j=c∑Pj
  • 又因为 ∑j=02Pj=1\sum_{j=0}^2 P_j = 1∑j=02Pj=1,所以 ∑j≠cPj=1−Pc\sum_{j \neq c} P_j = 1 - P_c∑j=cPj=1−Pc,代入后: ∂L∂zc=−[Pc−PcQc−Qc(1−Pc)]\frac{\partial \mathcal{L}}{\partial z_c} = - \left[ P_c - P_c Q_c - Q_c (1 - P_c) \right]∂zc∂L=−[Pc−PcQc−Qc(1−Pc)]
  • 展开括号并化简:∂L∂zc=−[Pc−PcQc−Qc+PcQc]=−[Pc−Qc]=Qc−Pc \frac{\partial \mathcal{L}}{\partial z_c} = - \left[ P_c - P_c Q_c - Q_c + P_c Q_c \right] = - \left[ P_c - Q_c \right] = Q_c - P_c∂zc∂L=−[Pc−PcQc−Qc+PcQc]=−[Pc−Qc]=Qc−Pc
最终核心结论(震惊的简化!)

KL散度对输出层logits zcz_czc 的梯度,竟然简化为: ∂L∂zc=Qc−Pc\boxed{\frac{\partial \mathcal{L}}{\partial z_c} = Q_c - P_c}∂zc∂L=Qc−Pc

  • 这意味着:每个logit zcz_czc 的梯度 = 对应类别的预测概率 - 真实概率。

三、结合例子计算梯度(衔接之前的数值)

假设例子中:

  • 真实分布 P=[P0,P1,P2]=[0.3,0.5,0.2]P = [P_0, P_1, P_2] = [0.3, 0.5, 0.2]P=[P0,P1,P2]=[0.3,0.5,0.2]
  • 预测分布 Q=[Q0,Q1,Q2]=[0.35,0.45,0.2]Q = [Q_0, Q_1, Q_2] = [0.35, 0.45, 0.2]Q=[Q0,Q1,Q2]=[0.35,0.45,0.2]
    代入梯度公式,计算每个logit的梯度:
  • 对 z0z_0z0 的梯度:∂L∂z0=Q0−P0=0.35−0.3=0.05\frac{\partial \mathcal{L}}{\partial z_0} = Q_0 - P_0 = 0.35 - 0.3 = 0.05∂z0∂L=Q0−P0=0.35−0.3=0.05
  • 对 z1z_1z1 的梯度:∂L∂z1=Q1−P1=0.45−0.5=−0.05\frac{\partial \mathcal{L}}{\partial z_1} = Q_1 - P_1 = 0.45 - 0.5 = -0.05∂z1∂L=Q1−P1=0.45−0.5=−0.05
  • 对 z2z_2z2 的梯度:∂L∂z2=Q2−P2=0.2−0.2=0\frac{\partial \mathcal{L}}{\partial z_2} = Q_2 - P_2 = 0.2 - 0.2 = 0∂z2∂L=Q2−P2=0.2−0.2=0

梯度结果解读:

  • ∂L∂z0=0.05>0\frac{\partial \mathcal{L}}{\partial z_0} = 0.05 > 0∂z0∂L=0.05>0:损失 L\mathcal{L}L 随 z0z_0z0 增大而增大,因此参数更新时要减小 与 z0z_0z0 相关的权重(让 z0z_0z0 变小,进而让 Q0Q_0Q0 从0.35下降到0.3,贴近 P0P_0P0);
  • ∂L∂z1=−0.05<0\frac{\partial \mathcal{L}}{\partial z_1} = -0.05 < 0∂z1∂L=−0.05<0:损失 L\mathcal{L}L 随 z1z_1z1 增大而减小,因此参数更新时要增大 与 z1z_1z1 相关的权重(让 z1z_1z1 变大,进而让 Q1Q_1Q1 从0.45上升到0.5,贴近 P1P_1P1);
  • ∂L∂z2=0\frac{\partial \mathcal{L}}{\partial z_2} = 0∂z2∂L=0:Q2Q_2Q2 已完全贴近 P2P_2P2,无需调整与 z2z_2z2 相关的参数。

四、梯度反向传播到前层参数(完整流程)

有了logits的梯度 ∂L∂zc\frac{\partial \mathcal{L}}{\partial z_c}∂zc∂L,接下来通过链式法则反向传播到隐藏层和输入层的参数。我们以输出层权重 W2W_2W2 和偏置 b2b_2b2 为例(隐藏层参数同理)。

1. 输出层的线性计算关系

输出层的logits zzz 是隐藏层输出 hhh 与权重 W2W_2W2、偏置 b2b_2b2 的线性组合: z=h⋅W2+b2z = h \cdot W_2 + b_2z=h⋅W2+b2

  • 维度说明(假设隐藏层输出 hhh 是1×100的向量):
    -- hhh:1×100(批量大小=1时的隐藏层输出);
    -- W2W_2W2:100×3(隐藏层到输出层的权重矩阵,每行对应隐藏层一个神经元,每列对应一个类别);
    -- b2b_2b2:1×3(输出层偏置);
    -- zzz:1×3(输出层logits)。
2. 计算对输出层权重 W2W_2W2 的梯度

根据矩阵求导规则: ∂L∂W2=hT⋅∂L∂z\frac{\partial \mathcal{L}}{\partial W_2} = h^T \cdot \frac{\partial \mathcal{L}}{\partial z} ∂W2∂L=hT⋅∂z∂L

  • hTh^ThT:100×1(隐藏层输出的转置);
  • ∂L∂z\frac{\partial \mathcal{L}}{\partial z}∂z∂L:1×3(logits的梯度向量,即 [0.05, -0.05, 0]);
  • 结果 ∂L∂W2\frac{\partial \mathcal{L}}{\partial W_2}∂W2∂L:100×3(与 W2W_2W2 维度一致,可直接用于更新)。
3. 计算对输出层偏置 b2b_2b2 的梯度

偏置 b2b_2b2 的梯度直接等于logits的梯度(因为偏置对每个logit的贡献是线性的,导数为1): ∂L∂b2=∂L∂z=[0.05,−0.05,0]\frac{\partial \mathcal{L}}{\partial b_2} = \frac{\partial \mathcal{L}}{\partial z} = [0.05, -0.05, 0]∂b2∂L=∂z∂L=[0.05,−0.05,0]

4. 反向传播到隐藏层

隐藏层的梯度计算同理,利用链式法则: ∂L∂h=∂L∂z⋅W2T\frac{\partial \mathcal{L}}{\partial h} = \frac{\partial \mathcal{L}}{\partial z} \cdot W_2^T ∂h∂L=∂z∂L⋅W2T

  • W2TW_2^TW2T:3×100(输出层权重的转置);
  • 结果 ∂L∂h\frac{\partial \mathcal{L}}{\partial h}∂h∂L:1×100(与隐藏层输出 hhh 维度一致)。

再结合隐藏层的激活函数(如ReLU)的导数,可计算出对隐藏层权重 W1W_1W1 和偏置 b1b_1b1 的梯度,最终完成所有参数的梯度计算。

五、参数更新(梯度下降执行)

得到所有参数的梯度后,用梯度下降法更新参数(以输出层权重 W2W_2W2 和偏置 b2b_2b2 为例):

  • 权重更新:W2=W2−η⋅∂L∂W2W_2 = W_2 - \eta \cdot \frac{\partial \mathcal{L}}{\partial W_2}W2=W2−η⋅∂W2∂L
  • 偏置更新:b2=b2−η⋅∂L∂b2b_2 = b_2 - \eta \cdot \frac{\partial \mathcal{L}}{\partial b_2}b2=b2−η⋅∂b2∂L 假设学习率 η=0.001\eta = 0.001η=0.001,则:
    -- 偏置 b0b_0b0(对应 z0z_0z0 的偏置)更新:b0=b0−0.001×0.05=b0−0.00005b_0 = b_0 - 0.001 \times 0.05 = b_0 - 0.00005b0=b0−0.001×0.05=b0−0.00005(减小偏置,让 z0z_0z0 变小,Q0Q_0Q0 下降);
    -- 偏置 b1b_1b1(对应 z1z_1z1 的偏置)更新:b1=b1−0.001×(−0.05)=b1+0.00005b_1 = b_1 - 0.001 \times (-0.05) = b_1 + 0.00005b1=b1−0.001×(−0.05)=b1+0.00005(增大偏置,让 z1z_1z1 变大,Q1Q_1Q1 上升);
    -- 偏置 b2b_2b2 无更新(梯度为0)。

六、迭代优化效果

经过一次反向传播和参数更新后,模型的logits z0z_0z0 会略微减小,z1z_1z1 会略微增大,导致:

  • 预测分布 Q0Q_0Q0 从0.35 → 0.33(更贴近 P0=0.3P_0=0.3P0=0.3);
  • Q1Q_1Q1 从0.45 → 0.47(更贴近 P1=0.5P_1=0.5P1=0.5);
  • KL散度从0.0063 → 更小的值(如0.004)。

重复这个过程(迭代训练),直到KL散度收敛到极小值(如0.001以下),模型的预测分布就会完全贴近真实分布。

核心总结:

KL散度反向传播的关键

  1. 梯度简化奇迹 :KL散度对logits的梯度最终简化为 Qc−PcQ_c - P_cQc−Pc,无需复杂计算,这也是它在深度学习中广泛使用的原因;
  2. 反向传播逻辑:损失(KL散度)→ logits梯度 → 输出层参数梯度 → 隐藏层参数梯度 → 梯度下降更新;
  3. 例子呼应:通过具体数值展示了梯度的正负和大小如何指导参数调整,让预测分布逐步贴近真实分布;
  4. 实际代码启示:在TensorFlow/PyTorch中,无需手动推导梯度,框架会自动计算,但理解这个过程能帮你调优模型(如学习率选择、损失函数设计)。

简单来说,KL散度用于反向传播的本质是:通过"预测概率与真实概率的差值"指导参数调整,让模型的输出分布越来越接近目标分布

相关推荐
自然语1 小时前
数字生已经进化到一个分水岭面临选择?先实现“动态识别“还是先实现“特征信息归纳分类“,文中给出以给出答案,大家选哪个方向?
人工智能·分类·数据挖掘
高洁011 小时前
卷积神经网络(CNN)详细介绍及其原理详解(3)
python·神经网络·机器学习·transformer
肖邦德夜曲1 小时前
1.强化学习基本概念
机器学习·强化学习
iiiiii111 小时前
【论文阅读笔记】FOCAL 离线元强化学习,从静态数据中快速适应新任务
论文阅读·人工智能·笔记·学习·机器学习·学习方法·具身智能
荒野火狐2 小时前
【强化学习】关于PPO收敛问题
python·深度学习·机器学习·强化学习
江上鹤.1482 小时前
Day 28 复习日
人工智能·python·机器学习
进阶的小蜉蝣2 小时前
[Machine Learning] 机器学习中的Collate
人工智能·机器学习
火山引擎开发者社区2 小时前
Vector Bucket:云原生向量存储新范式
人工智能·机器学习·云原生