【RL】 kl loss

http://joschu.net/blog/kl-approx.html

好的,这篇文章出自 John Schulman(PPO 和 TRPO 算法的作者之一),是一篇非常经典和重要的技术博客。它解释了在强化学习中,我们是如何以及为什么使用某种特定形式来近似计算 KL 散度的。

我会为你进行翻译,并在关键部分加上注释,帮助你理解它与你遇到的 kl_loss 问题的关系。


博客标题:近似 KL 散度 (Approximating KL Divergence)

核心思想(简单概括)

这篇文章的核心思想是:在像 PPO 这样的算法中,我们需要在每一步计算新旧策略之间的 KL 散度,以确保更新步子不会迈得太大。但精确计算 KL 散度很困难,所以我们用采样数据来估算它。然而,有多种估算公式,有些公式看起来很直观,但实际使用时方差很大(会导致数值跳动剧烈),而另一些稍作修改的公式则表现得更稳定、偏差更小。

作者推荐的"好"的估算公式是 (r-1) - log(r) ,其中 r 是新旧策略的概率比。


全文翻译与解读

第一部分:我们为什么要在乎 KL 散度?

策略梯度(Policy Gradient)方法通过下面的公式来更新策略参数 θ
g = E[ A(s,a) * ∇_θ log π_θ(a|s) ]

一个棘手的问题是:我们应该用多大的学习率(步长)呢?如果步子迈得太大,整个策略可能会彻底崩溃。

信任区域(Trust Region)方法通过限制策略的改变量来解决这个问题。具体来说,它要求新策略 π_new 和旧策略 π_old 之间的 KL 散度不能超过一个很小的常数 δ
KL(π_old || π_new) ≤ δ

【解读】 :这正是我们在 PPO 训练中看到 kl_loss 的原因。它作为一个"安全绳",防止模型为了追求奖励而"放飞自我",忘记了基本的语言能力。

第二部分:数学推导

KL 散度的定义如下:
KL(π_old || π_new) = E_{a ~ π_old} [ log( π_old(a|s) / π_new(a|s) ) ]

【解读】 :这个期望 E 是在旧策略 π_old 的动作分布下计算的。在实际操作中,我们无法遍历所有可能的动作,所以只能通过采样(sampling)来估算这个期望值。也就是说,我们从旧策略中采样一批动作,然后计算这些样本的 log(π_old / π_new) 的平均值。

我们把概率比(ratio)定义为 r = π_new / π_old。那么 KL 散度就是 E[ -log(r) ]

现在,让我们对 -log(r)r=1 附近进行泰勒展开。(为什么是 r=1?因为在训练初期,新旧策略非常接近,所以 r 的值就在 1 附近)。

-log(r) ≈ -(r-1) + (r-1)² / 2

把这个近似式代入到 KL 散度的期望公式中:
KL(π_old || π_new) = E[ -log(r) ] ≈ E[ -(r-1) + (r-1)² / 2 ]
KL(π_old || π_new) ≈ -E[r-1] + E[(r-1)² / 2]

这里有一个重要的技巧:E[r-1] = E[π_new / π_old - 1] = E[π_new / π_old] - 1。根据重要性采样的性质,E_{a ~ π_old}[π_new / π_old] 应该等于 1。所以 E[r-1] = 1 - 1 = 0

因此,上面那个近似式的第一项 -E[r-1] 就等于 0 了。我们得到一个更简洁的 KL 近似:
KL(π_old || π_new) ≈ (1/2) * E[ (r-1)² ],其中 r = π_new / π_old

【解读】 :这个公式 (1/2) * E[ (r-1)² ] 看起来非常简单直观,很多代码库可能就直接用它来估算 KL 散度。但作者接下来指出,这个公式有问题。

第三部分:最终结论(The Punchline)

我们刚刚推导出了一个近似公式:
KL_approx_1 = (1/2) * E[ (π_new/π_old - 1)² ]

然而,有一个不同的、并且更好 的近似公式,它直接来自于 -log(r) 的另一个近似,即 r-1-log(r)
KL_approx_2 = E[ (π_new/π_old - 1) - log(π_new/π_old) ]

【解读】 :这里的 KL_approx_2 就是许多高质量 RLHF 代码库(比如 TRL)中实际使用的 KL 散度计算方式。它通常被记为 (r-1) - log(r)

为什么 KL_approx_2 更好?

作者指出,KL_approx_1 虽然在数学上看起来是 KL_approx_2 的二阶近似,但在实际采样估算时,KL_approx_1方差(variance)非常大

r(概率比)变得很大时,(r-1)² 会爆炸性增长,导致单一样本对整体估算值产生巨大影响,从而使 kl_loss 的曲线剧烈跳动。而 (r-1) - log(r) 这个表达式对于大的 r 值增长得更慢,表现得更稳定。

所以,作者给出了几个估算 KL 散度的备选方案:

  1. E[ (r-1)² ] (不好,高方差)
  2. E[ (r-1) - log(r) ] (很好,低方差,低偏差)

和你的问题有什么关系?

  1. 解释了 KL Loss 的波动性来源 :你的 kl_loss 曲线变化很明显(剧烈波动),其中一个潜在原因可能就是你使用的 RL 框架在计算 KL 散度时,采用了方差较大的估算器(比如 (r-1)²)。当你的模型在 Rollout 阶段采样到一些"极端"的数据(导致 ratio 值很大或很小)时,高方差的估算器就会让 kl_loss 的值上蹿下跳。

  2. 提供了诊断方向 :你可以检查一下你所使用的代码库(例如 Megatron-DeepSpeedtrl)中计算 KL 散度的具体实现。看看它使用的是哪个公式。高质量的实现通常会使用 (ratio - 1) - log_ratio 这种形式。如果你发现它用的是 (ratio - 1)²,那么曲线的剧烈波动就有了理论解释。

  3. 强调了随机性的本质 :即使使用了最好的估算器,kl_loss 仍然会波动。因为它是基于采样 的估算,每一批(batch)的数据都不同,导致 ratio 的分布也不同,所以 kl_loss 自然会变化。你的那张"剧烈变化"的图,正是在数据差异和估算器特性的双重作用下产生的。

总而言之,这篇博客深刻揭示了在强化学习实践中,一个看似微小的数学公式选择,会对训练的稳定性和观测指标的平滑度产生巨大的影响。

相关推荐
BitaHub20248 小时前
深度推理力量:用 DeepSeek V3.2 Speciale 打造自动数据分析系统
人工智能·deepseek
开放知识图谱9 小时前
论文浅尝 | 图上生成:将大语言模型视为智能体与知识图谱以解决不完整知识图谱问答(EMNLP2024)
人工智能·语言模型·自然语言处理·知识图谱
珂朵莉MM9 小时前
2025年睿抗机器人开发者大赛CAIP-编程技能赛-本科组(国赛)解题报告 | 珂学家
java·人工智能·算法·机器人·无人机
果粒蹬i9 小时前
当CNN遇见Transformer:混合模型的特征可视化与融合攻略
人工智能·cnn·transformer
悟道心9 小时前
8. 自然语言处理NLP -GPT
人工智能·gpt·自然语言处理
乐迪信息9 小时前
乐迪信息:船体AI烟火检测,24小时火灾自动预警
人工智能·物联网·算法·目标检测·语音识别
且去填词9 小时前
DeepSeek :基于 AST 与 AI 的遗留系统“手术刀”式治理方案
人工智能·自动化·llm·ast·agent·策略模式·deepseek
llilian_169 小时前
相位差测量仪 高精度相位计相位差测量仪的应用 相位计
大数据·人工智能·功能测试·单片机
云雾J视界9 小时前
从Boost的设计哲学到工业实践:解锁下一代AI中间件架构的密码
c++·人工智能·中间件·架构·stackoverflow·boost