http://joschu.net/blog/kl-approx.html
好的,这篇文章出自 John Schulman(PPO 和 TRPO 算法的作者之一),是一篇非常经典和重要的技术博客。它解释了在强化学习中,我们是如何以及为什么使用某种特定形式来近似计算 KL 散度的。
我会为你进行翻译,并在关键部分加上注释,帮助你理解它与你遇到的 kl_loss 问题的关系。
博客标题:近似 KL 散度 (Approximating KL Divergence)
核心思想(简单概括)
这篇文章的核心思想是:在像 PPO 这样的算法中,我们需要在每一步计算新旧策略之间的 KL 散度,以确保更新步子不会迈得太大。但精确计算 KL 散度很困难,所以我们用采样数据来估算它。然而,有多种估算公式,有些公式看起来很直观,但实际使用时方差很大(会导致数值跳动剧烈),而另一些稍作修改的公式则表现得更稳定、偏差更小。
作者推荐的"好"的估算公式是 (r-1) - log(r) ,其中 r 是新旧策略的概率比。
全文翻译与解读
第一部分:我们为什么要在乎 KL 散度?
策略梯度(Policy Gradient)方法通过下面的公式来更新策略参数 θ:
g = E[ A(s,a) * ∇_θ log π_θ(a|s) ]
一个棘手的问题是:我们应该用多大的学习率(步长)呢?如果步子迈得太大,整个策略可能会彻底崩溃。
信任区域(Trust Region)方法通过限制策略的改变量来解决这个问题。具体来说,它要求新策略 π_new 和旧策略 π_old 之间的 KL 散度不能超过一个很小的常数 δ。
KL(π_old || π_new) ≤ δ
【解读】 :这正是我们在 PPO 训练中看到
kl_loss的原因。它作为一个"安全绳",防止模型为了追求奖励而"放飞自我",忘记了基本的语言能力。
第二部分:数学推导
KL 散度的定义如下:
KL(π_old || π_new) = E_{a ~ π_old} [ log( π_old(a|s) / π_new(a|s) ) ]
【解读】 :这个期望
E是在旧策略π_old的动作分布下计算的。在实际操作中,我们无法遍历所有可能的动作,所以只能通过采样(sampling)来估算这个期望值。也就是说,我们从旧策略中采样一批动作,然后计算这些样本的log(π_old / π_new)的平均值。
我们把概率比(ratio)定义为 r = π_new / π_old。那么 KL 散度就是 E[ -log(r) ]。
现在,让我们对 -log(r) 在 r=1 附近进行泰勒展开。(为什么是 r=1?因为在训练初期,新旧策略非常接近,所以 r 的值就在 1 附近)。
-log(r) ≈ -(r-1) + (r-1)² / 2
把这个近似式代入到 KL 散度的期望公式中:
KL(π_old || π_new) = E[ -log(r) ] ≈ E[ -(r-1) + (r-1)² / 2 ]
KL(π_old || π_new) ≈ -E[r-1] + E[(r-1)² / 2]
这里有一个重要的技巧:E[r-1] = E[π_new / π_old - 1] = E[π_new / π_old] - 1。根据重要性采样的性质,E_{a ~ π_old}[π_new / π_old] 应该等于 1。所以 E[r-1] = 1 - 1 = 0。
因此,上面那个近似式的第一项 -E[r-1] 就等于 0 了。我们得到一个更简洁的 KL 近似:
KL(π_old || π_new) ≈ (1/2) * E[ (r-1)² ],其中 r = π_new / π_old。
【解读】 :这个公式
(1/2) * E[ (r-1)² ]看起来非常简单直观,很多代码库可能就直接用它来估算 KL 散度。但作者接下来指出,这个公式有问题。
第三部分:最终结论(The Punchline)
我们刚刚推导出了一个近似公式:
KL_approx_1 = (1/2) * E[ (π_new/π_old - 1)² ]
然而,有一个不同的、并且更好 的近似公式,它直接来自于 -log(r) 的另一个近似,即 r-1-log(r)。
KL_approx_2 = E[ (π_new/π_old - 1) - log(π_new/π_old) ]
【解读】 :这里的
KL_approx_2就是许多高质量 RLHF 代码库(比如 TRL)中实际使用的 KL 散度计算方式。它通常被记为(r-1) - log(r)。
为什么 KL_approx_2 更好?
作者指出,KL_approx_1 虽然在数学上看起来是 KL_approx_2 的二阶近似,但在实际采样估算时,KL_approx_1 的方差(variance)非常大。
当 r(概率比)变得很大时,(r-1)² 会爆炸性增长,导致单一样本对整体估算值产生巨大影响,从而使 kl_loss 的曲线剧烈跳动。而 (r-1) - log(r) 这个表达式对于大的 r 值增长得更慢,表现得更稳定。
所以,作者给出了几个估算 KL 散度的备选方案:
E[ (r-1)² ](不好,高方差)E[ (r-1) - log(r) ](很好,低方差,低偏差)
和你的问题有什么关系?
-
解释了 KL Loss 的波动性来源 :你的
kl_loss曲线变化很明显(剧烈波动),其中一个潜在原因可能就是你使用的 RL 框架在计算 KL 散度时,采用了方差较大的估算器(比如(r-1)²)。当你的模型在 Rollout 阶段采样到一些"极端"的数据(导致ratio值很大或很小)时,高方差的估算器就会让kl_loss的值上蹿下跳。 -
提供了诊断方向 :你可以检查一下你所使用的代码库(例如
Megatron-DeepSpeed或trl)中计算 KL 散度的具体实现。看看它使用的是哪个公式。高质量的实现通常会使用(ratio - 1) - log_ratio这种形式。如果你发现它用的是(ratio - 1)²,那么曲线的剧烈波动就有了理论解释。 -
强调了随机性的本质 :即使使用了最好的估算器,
kl_loss仍然会波动。因为它是基于采样 的估算,每一批(batch)的数据都不同,导致ratio的分布也不同,所以kl_loss自然会变化。你的那张"剧烈变化"的图,正是在数据差异和估算器特性的双重作用下产生的。
总而言之,这篇博客深刻揭示了在强化学习实践中,一个看似微小的数学公式选择,会对训练的稳定性和观测指标的平滑度产生巨大的影响。