OPD Reverse KL

一、 OPD

在线策略蒸馏(On-Policy Distillation, 简称 OPD)。

在语言模型(LLM)的知识蒸馏中,通常有两个模型:

  • 学生模型(Student):策略为 ,参数为 (我们要优化的对象)。
  • 教师模型(Teacher):策略为 (参数固定,提供监督信号)。

逆向 KL 散度(Reverse KL, RKL)的公式为 。与前向 KL(Forward KL, )相比,逆向 KL 具有 模式寻求(Mode-seeking)的特性。它会让学生模型倾向于只在教师模型概率较高的地方生成文本,从而减少模型产生幻觉(Hallucination)或语无伦次的概率,使生成的文本更加确定和精准。

二、 公式 1:Full-vocab Reverse KL 损失函数

在生成第 t 个 token 时的损失函数:

符号解释:

* V:整个词表(Vocabulary)。

* x:输入提示词(Prompt)。

* :在 t 时刻之前已经生成的历史 token 序列。

* :教师模型可能额外享有的输入信息(例如更丰富的上下文、思维链提示或参考答案)。

* :学生模型在当前上下文下,预测词表中每个词 v 的概率分布。

* :教师模型预测的概率分布。

物理意义:

这个公式计算的是在当前步骤 t,学生模型分布与教师模型分布在全词表(Full-vocab)上的逆向 KL 散度。因为求期望的权重项是学生模型的概率 ,所以它是一种 在线/在策略(On-policy) 的评估方式------它关注的是"站在学生模型自己的视角下,其当前输出与教师的偏差"。

三、 公式 2:梯度的推导与化简

对上述损失函数关于学生模型参数 求导。

为了书写简便,我们将 简记为 ,将 简记为

目标是计算

1. 梯度推导过程

损失函数为:

由于只有学生模型 包含参数 ,教师模型 无关。利用导数的乘积法则,对 求导:

我们分别处理这两项:

  • 第二项:

因为 ,对其求导得:

将这一结果代回第二项中,与外面的 相乘消去分母,得到:

  • 第一项:

使用对数导数技巧(Log-derivative trick,常用于强化学习),即

将这两项重新整合:

利用对数导数技巧,把第二项也写成含有 的形式:

为了与图片中的形式完全一致,我们将括号里的项取负号倒过来:

代入后即得到公式:

四、 恒等式的消去作用

里面的 "+1" 在全词表下会被 这条恒等式抵消。

  • 为什么该恒等式成立?

因为概率分布在全词表上的和恒等于 1(即 ),对其两边求导:

再利用对数导数技巧 ,即可得到:

  • 消除后的简化梯度:

这意味着公式中括号里的常数 -1(展开后与外面的负号结合变成 +1)在对整个词表求和时,其贡献为 0。因此,实际计算时梯度可以简化为:

五、 物理意义与直观理解

如果我们把简化后的梯度写成策略梯度(Policy Gradient)中常见的形式(考虑最小化损失函数,参数更新方向为负梯度 ):

这相当于一种自带基线(Baseline)的策略梯度算法:

  1. 动作空间:在当前步骤,学生模型在全词表 V 上进行探索。

  2. 权重项(Reward):对于词表中的每一个词 v,其受到的奖励/惩罚因子为

  • 时:说明学生模型低估了该词的概率。此时 ,则 。梯度更新会提高该词的生成概率。

* 当 时:说明学生模型过度自信(高估了该词)。此时 。梯度更新会压低该词的生成概率。

  1. 全词表覆盖:因为是 Full-vocab,算法不仅对采样到的单个词进行更新,而是同时对词表中的所有词进行推拉(Push-Pull)。这使得训练过程比单样本采样的策略梯度更加平滑和稳定。
相关推荐
Lihua奏3 天前
# 机器学习:机器是怎么从数据里学出规则的
机器学习
饼干哥哥3 天前
用AI全自动剪辑,日更 100条爆款视频——HyperFrames、Remotion、Git使用入门
人工智能·机器学习·ai编程
魏祖潇4 天前
我在飞书里养了个“分身”——私聊喊它办事,群里 @ 它干活,还能替我传话
人工智能·机器学习
哥布林学者10 天前
深度学习进阶(三十一)FlashAttention:IO 感知的精确注意力
机器学习·ai