这行代码 ratio = (log_probs - old_log_probs).exp() 是 PPO 算法中一个非常核心且巧妙的数学计算。它的意思是 计算新策略(π_new)和旧策略(π_old)在给定状态下选择相同动作的概率之比。
这个比率,通常在强化学习文献中表示为 r_t(θ) 或 ratio。
ratio = π_new(a | s) / π_old(a | s)
让我们来分解一下这个公式是如何通过代码实现的。
数学原理
-
对数(Logarithms)的性质:
log(x / y) = log(x) - log(y)exp(log(x)) = x
-
推导过程:
- 我们想计算
ratio = π_new / π_old。 - 直接计算除法在数值上可能不稳定,特别是当概率值非常小的时候。
- 一个更稳健的方法是在对数空间(log-space)中进行操作。
- 我们先对
ratio取对数:
log(ratio) = log(π_new / π_old) - 根据对数性质,这等于:
log(ratio) = log(π_new) - log(π_old) - 现在,我们有了
log(ratio)的值。为了得到原始的ratio,我们只需要对它取指数(exp):
ratio = exp(log(ratio)) = exp(log(π_new) - log(π_old))
- 我们想计算
代码与数学的对应关系
现在我们来看代码:
-
log_probs:- 这代表
log(π_new)。它是当前正在训练的模型(新策略)为序列中的每个动作(token)计算出的对数概率。
- 这代表
-
old_log_probs:- 这代表
log(π_old)。它是在本轮 PPO 训练开始之前,旧的、固定的 Actor 模型(旧策略)计算出的对数概率。这个值在整个 PPO 的内部训练循环(ppo_epochs)中是保持不变的。
- 这代表
-
(log_probs - old_log_probs):- 这完全对应了数学推导中的
log(π_new) - log(π_old),也就是log(ratio)。
- 这完全对应了数学推导中的
-
.exp():- 这对应了数学推导中的
exp(...)操作。它将对数比率转换回原始的概率比率ratio。
- 这对应了数学推导中的
为什么这个 ratio 如此重要?
ratio 是 PPO 代理目标函数(Surrogate Objective)的核心组成部分。代理目标函数是:
L_CPI(θ) = E[ (π_θ(a|s) / π_θ_old(a|s)) * A_t ]
L_CPI(θ) = E[ ratio * Advantage ]
ratio > 1: 意味着新策略比旧策略更倾向于选择这个动作。ratio < 1: 意味着新策略比旧策略更不倾向于选择这个动作。
ratio 和 Advantage 的相互作用决定了策略更新的方向:
-
如果
Advantage > 0(这是一个好动作):- 优化器会试图增大
ratio,从而让ratio * Advantage更大,使得损失更小(因为 PPO 损失是-min(...))。 - 增大
ratio就意味着增大π_new,即增加选择这个好动作的概率。
- 优化器会试图增大
-
如果
Advantage < 0(这是一个坏动作):- 优化器会试图减小
ratio,从而让ratio * Advantage(一个负数)变得更接近于零,使得损失更小。 - 减小
ratio就意味着减小π_new,即降低选择这个坏动作的概率。
- 优化器会试图减小
总结
ratio = (log_probs - old_log_probs).exp() 是一个在数值上稳定地计算新旧策略概率比 π_new / π_old 的方法。
这个 ratio 衡量了策略更新的变化方向和幅度,它与优势函数 Advantage 相乘,共同构成了 PPO 算法的核心驱动力,指导模型如何调整其行为以获得更高的累积奖励。在对数空间中进行减法然后取指数,是避免浮点数下溢和保持计算精度的标准技巧。
好的,我们来详细解析 compute_approx_kl 函数。这个函数在 PPO 训练中扮演着一个非常重要的角色:估算两个策略分布之间的 KL 散度(Kullback-Leibler divergence)。
KL 散度是衡量两个概率分布差异的指标。在 RLHF 的 PPO 训练中,我们通常关心两种 KL 散度:
approx_kl: 新策略π_new和旧策略π_old之间的 KL 散度。用来衡量策略更新的步子迈得有多大。如果这个值太大,说明训练可能不稳定。kl_loss: 新策略π_new和参考策略π_ref(通常是原始的 SFT 模型) 之间的 KL 散度。作为一个正则化项加入到总损失中,防止模型在学习新技能时"忘掉"基础的语言能力。
compute_approx_kl 函数就是用来计算这些值的。
函数定义与参数
我们先假设一个常见的实现,它通常存在于 roll.utils.functionals 中:
python
import torch
def compute_approx_kl(
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: torch.Tensor,
kl_penalty: str = "kl"
) -> torch.Tensor:
"""
计算 log_probs 和 log_probs_base 之间的近似 KL 散度。
Args:
log_probs (torch.Tensor): 新策略的对数概率。
log_probs_base (torch.Tensor): 基准策略(旧策略或参考策略)的对数概率。
action_mask (torch.Tensor): 掩码,标记哪些位置需要计算 KL 散度。
kl_penalty (str): "kl", "mse", or "k3" - 使用哪种近似方法。
"""
# 1. 计算概率比的对数
log_ratio = log_probs - log_probs_base
# 2. 计算概率比
ratio = torch.exp(log_ratio)
# 3. 根据指定的惩罚类型计算近似 KL
if kl_penalty == "kl":
# 这是 D_KL(P || Q) 的一个稳定近似,在 ratio=1 附近表现良好
kl_div = ratio - 1 - log_ratio
elif kl_penalty == "mse":
# 使用均方误差近似,是 "kl" 公式在 log_ratio 接近 0 时的二阶泰勒展开
kl_div = 0.5 * log_ratio.square()
elif kl_penalty == "k3":
# 一个更高阶的近似 (来自 Schulman 的 GAE 论文)
kl_div = (ratio.square() - 1 - 2 * log_ratio) / 2
else:
raise NotImplementedError
# 返回的是一个和输入形状相同的张量,包含了每个 token 位置的 KL 散度值。
# 掩码和聚合(求平均)在函数外部进行。
return kl_div
参数与输出的 Shape
让我们假设一个具体的场景:
batch_size = 4sequence_length = 10(截断对齐后的长度)
那么,输入和输出参数的 shape 通常是:
-
log_probs(输入)- Shape :
[4, 10] - 含义 : 一个二维张量。
log_probs[i, j]表示在第i个样本的第j个时间步,新策略π_new采取input_ids中对应动作(token)的对数概率。
- Shape :
-
log_probs_base(输入)- Shape :
[4, 10] - 含义 : 一个二维张量。
log_probs_base[i, j]表示在第i个样本的第j个时间步,基准策略 (π_old或π_ref)采取input_ids中对应动作(token)的对数概率。
- Shape :
-
action_mask(输入)- Shape :
[4, 10] - 含义 : 一个二维掩码张量(通常是 0 和 1)。
action_mask[i, j] = 1表示第i个样本的第j个位置属于 "response" 部分,需要计算 KL 散度;如果为 0,则表示是 "prompt" 或 "padding",不需要计算。
- Shape :
-
kl_div(输出)- Shape :
[4, 10] - 含义 : 一个二维张量。
kl_div[i, j]表示在第i个样本的第j个时间步,新旧策略之间的近似 KL 散度值。这个值尚未经过掩码和平均。
- Shape :
举例说明执行过程
假设我们只看一个 batch_size=1, sequence_length=3 的例子。
- 输入 :
log_probs:torch.tensor([[-2.0, -1.1, -1.5]])log_probs_base:torch.tensor([[-1.8, -1.0, -1.6]])action_mask:torch.tensor([[0, 1, 1]])(假设第一个是 prompt,后两个是 response)kl_penalty = "kl"
compute_approx_kl 函数内部执行步骤:
-
log_ratio = log_probs - log_probs_baselog_ratio = [-2.0 - (-1.8), -1.1 - (-1.0), -1.5 - (-1.6)]log_ratio结果:torch.tensor([[-0.2, -0.1, 0.1]])
-
ratio = torch.exp(log_ratio)ratio = [exp(-0.2), exp(-0.1), exp(0.1)]ratio结果:torch.tensor([[0.8187, 0.9048, 1.1052]])(近似值)
-
kl_div = ratio - 1 - log_ratio- 这一步是逐元素计算的。
kl_div[0, 0] = 0.8187 - 1 - (-0.2) = 0.0187kl_div[0, 1] = 0.9048 - 1 - (-0.1) = 0.0048kl_div[0, 2] = 1.1052 - 1 - (0.1) = 0.0052kl_div结果:torch.tensor([[0.0187, 0.0048, 0.0052]])
函数返回 : torch.tensor([[0.0187, 0.0048, 0.0052]])。这是一个形状为 [1, 3] 的张量。
函数外部的操作 (在 loss_func 中):
在 loss_func 中,这个返回的张量会被 agg_loss 函数进一步处理。
kl_loss_matrix接收到compute_approx_kl的返回值[[0.0187, 0.0048, 0.0052]]。agg_loss会使用action_mask[[0, 1, 1]]来过滤和聚合。- 掩码操作 :
kl_loss_matrix与action_mask逐元素相乘。
[0.0187, 0.0048, 0.0052] * [0, 1, 1] = [0, 0.0048, 0.0052] - 聚合操作 : 对掩码后的结果求平均值(只计算非零元素)。
(0.0048 + 0.0052) / 2 = 0.005 - 最终,传递给总损失计算的
kl_loss是一个标量值,即0.005。
kl_penalty 的不同选择
"kl":ratio - 1 - log_ratio。这是最常用的一种近似,它在ratio=1附近表现良好,数值稳定。"mse":0.5 * log_ratio.square()。这是"kl"公式在log_ratio接近 0 时的二阶泰勒展开近似。计算更简单,效果也常常不错。"k3":(ratio.square() - 1 - 2*log_ratio) / 2。一个更高阶的近似,理论上可能更准确。
在您的代码中,actor/kl_loss 使用了 "k3",而 actor/approxkl 和 actor/policykl 分别使用了 "mse" 和 "kl",这是为了监控 不同近似方法计算出的 KL 散度值,以更好地了解训练动态。而实际用于计算总损失 的 kl_loss 本身使用 "k3",可能是因为实验发现它作为正则化项效果更好。