【RL】ROLL中loss 计算compute_approx_kl

这行代码 ratio = (log_probs - old_log_probs).exp() 是 PPO 算法中一个非常核心且巧妙的数学计算。它的意思是 计算新策略(π_new)和旧策略(π_old)在给定状态下选择相同动作的概率之比

这个比率,通常在强化学习文献中表示为 r_t(θ)ratio

ratio = π_new(a | s) / π_old(a | s)

让我们来分解一下这个公式是如何通过代码实现的。

数学原理

  1. 对数(Logarithms)的性质:

    • log(x / y) = log(x) - log(y)
    • exp(log(x)) = x
  2. 推导过程:

    • 我们想计算 ratio = π_new / π_old
    • 直接计算除法在数值上可能不稳定,特别是当概率值非常小的时候。
    • 一个更稳健的方法是在对数空间(log-space)中进行操作。
    • 我们先对 ratio 取对数:
      log(ratio) = log(π_new / π_old)
    • 根据对数性质,这等于:
      log(ratio) = log(π_new) - log(π_old)
    • 现在,我们有了 log(ratio) 的值。为了得到原始的 ratio,我们只需要对它取指数(exp):
      ratio = exp(log(ratio)) = exp(log(π_new) - log(π_old))

代码与数学的对应关系

现在我们来看代码:

  • log_probs:

    • 这代表 log(π_new)。它是当前正在训练的模型(新策略)为序列中的每个动作(token)计算出的对数概率。
  • old_log_probs:

    • 这代表 log(π_old)。它是在本轮 PPO 训练开始之前,旧的、固定的 Actor 模型(旧策略)计算出的对数概率。这个值在整个 PPO 的内部训练循环(ppo_epochs)中是保持不变的。
  • (log_probs - old_log_probs):

    • 这完全对应了数学推导中的 log(π_new) - log(π_old),也就是 log(ratio)
  • .exp():

    • 这对应了数学推导中的 exp(...) 操作。它将对数比率转换回原始的概率比率 ratio

为什么这个 ratio 如此重要?

ratio 是 PPO 代理目标函数(Surrogate Objective)的核心组成部分。代理目标函数是:

L_CPI(θ) = E[ (π_θ(a|s) / π_θ_old(a|s)) * A_t ]
L_CPI(θ) = E[ ratio * Advantage ]

  • ratio > 1 : 意味着新策略比旧策略更倾向于选择这个动作。
  • ratio < 1 : 意味着新策略比旧策略更不倾向于选择这个动作。

ratioAdvantage 的相互作用决定了策略更新的方向:

  1. 如果 Advantage > 0 (这是一个好动作):

    • 优化器会试图增大 ratio,从而让 ratio * Advantage 更大,使得损失更小(因为 PPO 损失是 -min(...))。
    • 增大 ratio 就意味着增大 π_new,即增加选择这个好动作的概率。
  2. 如果 Advantage < 0 (这是一个坏动作):

    • 优化器会试图减小 ratio,从而让 ratio * Advantage(一个负数)变得更接近于零,使得损失更小。
    • 减小 ratio 就意味着减小 π_new,即降低选择这个坏动作的概率。

总结

ratio = (log_probs - old_log_probs).exp() 是一个在数值上稳定地计算新旧策略概率比 π_new / π_old 的方法。

这个 ratio 衡量了策略更新的变化方向和幅度,它与优势函数 Advantage 相乘,共同构成了 PPO 算法的核心驱动力,指导模型如何调整其行为以获得更高的累积奖励。在对数空间中进行减法然后取指数,是避免浮点数下溢和保持计算精度的标准技巧。

好的,我们来详细解析 compute_approx_kl 函数。这个函数在 PPO 训练中扮演着一个非常重要的角色:估算两个策略分布之间的 KL 散度(Kullback-Leibler divergence)

KL 散度是衡量两个概率分布差异的指标。在 RLHF 的 PPO 训练中,我们通常关心两种 KL 散度:

  1. approx_kl : 新策略 π_new 和旧策略 π_old 之间的 KL 散度。用来衡量策略更新的步子迈得有多大。如果这个值太大,说明训练可能不稳定。
  2. kl_loss : 新策略 π_new 和参考策略 π_ref (通常是原始的 SFT 模型) 之间的 KL 散度。作为一个正则化项加入到总损失中,防止模型在学习新技能时"忘掉"基础的语言能力。

compute_approx_kl 函数就是用来计算这些值的。

函数定义与参数

我们先假设一个常见的实现,它通常存在于 roll.utils.functionals 中:

python 复制代码
import torch

def compute_approx_kl(
    log_probs: torch.Tensor,
    log_probs_base: torch.Tensor,
    action_mask: torch.Tensor,
    kl_penalty: str = "kl"
) -> torch.Tensor:
    """
    计算 log_probs 和 log_probs_base 之间的近似 KL 散度。

    Args:
        log_probs (torch.Tensor): 新策略的对数概率。
        log_probs_base (torch.Tensor): 基准策略(旧策略或参考策略)的对数概率。
        action_mask (torch.Tensor): 掩码,标记哪些位置需要计算 KL 散度。
        kl_penalty (str): "kl", "mse", or "k3" - 使用哪种近似方法。
    """
    # 1. 计算概率比的对数
    log_ratio = log_probs - log_probs_base

    # 2. 计算概率比
    ratio = torch.exp(log_ratio)

    # 3. 根据指定的惩罚类型计算近似 KL
    if kl_penalty == "kl":
        # 这是 D_KL(P || Q) 的一个稳定近似,在 ratio=1 附近表现良好
        kl_div = ratio - 1 - log_ratio
    elif kl_penalty == "mse":
        # 使用均方误差近似,是 "kl" 公式在 log_ratio 接近 0 时的二阶泰勒展开
        kl_div = 0.5 * log_ratio.square()
    elif kl_penalty == "k3":
        # 一个更高阶的近似 (来自 Schulman 的 GAE 论文)
        kl_div = (ratio.square() - 1 - 2 * log_ratio) / 2
    else:
        raise NotImplementedError

    # 返回的是一个和输入形状相同的张量,包含了每个 token 位置的 KL 散度值。
    # 掩码和聚合(求平均)在函数外部进行。
    return kl_div

参数与输出的 Shape

让我们假设一个具体的场景:

  • batch_size = 4
  • sequence_length = 10 (截断对齐后的长度)

那么,输入和输出参数的 shape 通常是:

  • log_probs (输入)

    • Shape : [4, 10]
    • 含义 : 一个二维张量。log_probs[i, j] 表示在第 i 个样本的第 j 个时间步,新策略 π_new 采取 input_ids 中对应动作(token)的对数概率。
  • log_probs_base (输入)

    • Shape : [4, 10]
    • 含义 : 一个二维张量。log_probs_base[i, j] 表示在第 i 个样本的第 j 个时间步,基准策略π_oldπ_ref)采取 input_ids 中对应动作(token)的对数概率。
  • action_mask (输入)

    • Shape : [4, 10]
    • 含义 : 一个二维掩码张量(通常是 0 和 1)。action_mask[i, j] = 1 表示第 i 个样本的第 j 个位置属于 "response" 部分,需要计算 KL 散度;如果为 0,则表示是 "prompt" 或 "padding",不需要计算。
  • kl_div (输出)

    • Shape : [4, 10]
    • 含义 : 一个二维张量。kl_div[i, j] 表示在第 i 个样本的第 j 个时间步,新旧策略之间的近似 KL 散度值。这个值尚未经过掩码和平均。

举例说明执行过程

假设我们只看一个 batch_size=1, sequence_length=3 的例子。

  • 输入 :
    • log_probs: torch.tensor([[-2.0, -1.1, -1.5]])
    • log_probs_base: torch.tensor([[-1.8, -1.0, -1.6]])
    • action_mask: torch.tensor([[0, 1, 1]]) (假设第一个是 prompt,后两个是 response)
    • kl_penalty = "kl"

compute_approx_kl 函数内部执行步骤:

  1. log_ratio = log_probs - log_probs_base

    • log_ratio = [-2.0 - (-1.8), -1.1 - (-1.0), -1.5 - (-1.6)]
    • log_ratio 结果: torch.tensor([[-0.2, -0.1, 0.1]])
  2. ratio = torch.exp(log_ratio)

    • ratio = [exp(-0.2), exp(-0.1), exp(0.1)]
    • ratio 结果: torch.tensor([[0.8187, 0.9048, 1.1052]]) (近似值)
  3. kl_div = ratio - 1 - log_ratio

    • 这一步是逐元素计算的。
    • kl_div[0, 0] = 0.8187 - 1 - (-0.2) = 0.0187
    • kl_div[0, 1] = 0.9048 - 1 - (-0.1) = 0.0048
    • kl_div[0, 2] = 1.1052 - 1 - (0.1) = 0.0052
    • kl_div 结果: torch.tensor([[0.0187, 0.0048, 0.0052]])

函数返回 : torch.tensor([[0.0187, 0.0048, 0.0052]])。这是一个形状为 [1, 3] 的张量。


函数外部的操作 (在 loss_func 中):

loss_func 中,这个返回的张量会被 agg_loss 函数进一步处理。

  1. kl_loss_matrix 接收到 compute_approx_kl 的返回值 [[0.0187, 0.0048, 0.0052]]
  2. agg_loss 会使用 action_mask [[0, 1, 1]] 来过滤和聚合。
  3. 掩码操作 : kl_loss_matrixaction_mask 逐元素相乘。
    [0.0187, 0.0048, 0.0052] * [0, 1, 1] = [0, 0.0048, 0.0052]
  4. 聚合操作 : 对掩码后的结果求平均值(只计算非零元素)。
    (0.0048 + 0.0052) / 2 = 0.005
  5. 最终,传递给总损失计算的 kl_loss 是一个标量值,即 0.005

kl_penalty 的不同选择

  • "kl" : ratio - 1 - log_ratio。这是最常用的一种近似,它在 ratio=1 附近表现良好,数值稳定。
  • "mse" : 0.5 * log_ratio.square()。这是 "kl" 公式在 log_ratio 接近 0 时的二阶泰勒展开近似。计算更简单,效果也常常不错。
  • "k3" : (ratio.square() - 1 - 2*log_ratio) / 2。一个更高阶的近似,理论上可能更准确。

在您的代码中,actor/kl_loss 使用了 "k3",而 actor/approxklactor/policykl 分别使用了 "mse""kl",这是为了监控 不同近似方法计算出的 KL 散度值,以更好地了解训练动态。而实际用于计算总损失kl_loss 本身使用 "k3",可能是因为实验发现它作为正则化项效果更好。

相关推荐
陈广亮1 天前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬1 天前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia1 天前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区1 天前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两1 天前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪1 天前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain
strayCat232551 天前
Clawdbot 源码解读 7: 扩展机制
人工智能·开源
王鑫星1 天前
SWE-bench 首次突破 80%:Claude Opus 4.5 发布,Anthropic 的野心不止于写代码
人工智能
lnix1 天前
当“大龙虾”养在本地:我们离“反SaaS”的AI未来还有多远?
人工智能·aigc