【RL】ROLL中loss 计算compute_approx_kl

这行代码 ratio = (log_probs - old_log_probs).exp() 是 PPO 算法中一个非常核心且巧妙的数学计算。它的意思是 计算新策略(π_new)和旧策略(π_old)在给定状态下选择相同动作的概率之比

这个比率,通常在强化学习文献中表示为 r_t(θ)ratio

ratio = π_new(a | s) / π_old(a | s)

让我们来分解一下这个公式是如何通过代码实现的。

数学原理

  1. 对数(Logarithms)的性质:

    • log(x / y) = log(x) - log(y)
    • exp(log(x)) = x
  2. 推导过程:

    • 我们想计算 ratio = π_new / π_old
    • 直接计算除法在数值上可能不稳定,特别是当概率值非常小的时候。
    • 一个更稳健的方法是在对数空间(log-space)中进行操作。
    • 我们先对 ratio 取对数:
      log(ratio) = log(π_new / π_old)
    • 根据对数性质,这等于:
      log(ratio) = log(π_new) - log(π_old)
    • 现在,我们有了 log(ratio) 的值。为了得到原始的 ratio,我们只需要对它取指数(exp):
      ratio = exp(log(ratio)) = exp(log(π_new) - log(π_old))

代码与数学的对应关系

现在我们来看代码:

  • log_probs:

    • 这代表 log(π_new)。它是当前正在训练的模型(新策略)为序列中的每个动作(token)计算出的对数概率。
  • old_log_probs:

    • 这代表 log(π_old)。它是在本轮 PPO 训练开始之前,旧的、固定的 Actor 模型(旧策略)计算出的对数概率。这个值在整个 PPO 的内部训练循环(ppo_epochs)中是保持不变的。
  • (log_probs - old_log_probs):

    • 这完全对应了数学推导中的 log(π_new) - log(π_old),也就是 log(ratio)
  • .exp():

    • 这对应了数学推导中的 exp(...) 操作。它将对数比率转换回原始的概率比率 ratio

为什么这个 ratio 如此重要?

ratio 是 PPO 代理目标函数(Surrogate Objective)的核心组成部分。代理目标函数是:

L_CPI(θ) = E[ (π_θ(a|s) / π_θ_old(a|s)) * A_t ]
L_CPI(θ) = E[ ratio * Advantage ]

  • ratio > 1 : 意味着新策略比旧策略更倾向于选择这个动作。
  • ratio < 1 : 意味着新策略比旧策略更不倾向于选择这个动作。

ratioAdvantage 的相互作用决定了策略更新的方向:

  1. 如果 Advantage > 0 (这是一个好动作):

    • 优化器会试图增大 ratio,从而让 ratio * Advantage 更大,使得损失更小(因为 PPO 损失是 -min(...))。
    • 增大 ratio 就意味着增大 π_new,即增加选择这个好动作的概率。
  2. 如果 Advantage < 0 (这是一个坏动作):

    • 优化器会试图减小 ratio,从而让 ratio * Advantage(一个负数)变得更接近于零,使得损失更小。
    • 减小 ratio 就意味着减小 π_new,即降低选择这个坏动作的概率。

总结

ratio = (log_probs - old_log_probs).exp() 是一个在数值上稳定地计算新旧策略概率比 π_new / π_old 的方法。

这个 ratio 衡量了策略更新的变化方向和幅度,它与优势函数 Advantage 相乘,共同构成了 PPO 算法的核心驱动力,指导模型如何调整其行为以获得更高的累积奖励。在对数空间中进行减法然后取指数,是避免浮点数下溢和保持计算精度的标准技巧。

好的,我们来详细解析 compute_approx_kl 函数。这个函数在 PPO 训练中扮演着一个非常重要的角色:估算两个策略分布之间的 KL 散度(Kullback-Leibler divergence)

KL 散度是衡量两个概率分布差异的指标。在 RLHF 的 PPO 训练中,我们通常关心两种 KL 散度:

  1. approx_kl : 新策略 π_new 和旧策略 π_old 之间的 KL 散度。用来衡量策略更新的步子迈得有多大。如果这个值太大,说明训练可能不稳定。
  2. kl_loss : 新策略 π_new 和参考策略 π_ref (通常是原始的 SFT 模型) 之间的 KL 散度。作为一个正则化项加入到总损失中,防止模型在学习新技能时"忘掉"基础的语言能力。

compute_approx_kl 函数就是用来计算这些值的。

函数定义与参数

我们先假设一个常见的实现,它通常存在于 roll.utils.functionals 中:

python 复制代码
import torch

def compute_approx_kl(
    log_probs: torch.Tensor,
    log_probs_base: torch.Tensor,
    action_mask: torch.Tensor,
    kl_penalty: str = "kl"
) -> torch.Tensor:
    """
    计算 log_probs 和 log_probs_base 之间的近似 KL 散度。

    Args:
        log_probs (torch.Tensor): 新策略的对数概率。
        log_probs_base (torch.Tensor): 基准策略(旧策略或参考策略)的对数概率。
        action_mask (torch.Tensor): 掩码,标记哪些位置需要计算 KL 散度。
        kl_penalty (str): "kl", "mse", or "k3" - 使用哪种近似方法。
    """
    # 1. 计算概率比的对数
    log_ratio = log_probs - log_probs_base

    # 2. 计算概率比
    ratio = torch.exp(log_ratio)

    # 3. 根据指定的惩罚类型计算近似 KL
    if kl_penalty == "kl":
        # 这是 D_KL(P || Q) 的一个稳定近似,在 ratio=1 附近表现良好
        kl_div = ratio - 1 - log_ratio
    elif kl_penalty == "mse":
        # 使用均方误差近似,是 "kl" 公式在 log_ratio 接近 0 时的二阶泰勒展开
        kl_div = 0.5 * log_ratio.square()
    elif kl_penalty == "k3":
        # 一个更高阶的近似 (来自 Schulman 的 GAE 论文)
        kl_div = (ratio.square() - 1 - 2 * log_ratio) / 2
    else:
        raise NotImplementedError

    # 返回的是一个和输入形状相同的张量,包含了每个 token 位置的 KL 散度值。
    # 掩码和聚合(求平均)在函数外部进行。
    return kl_div

参数与输出的 Shape

让我们假设一个具体的场景:

  • batch_size = 4
  • sequence_length = 10 (截断对齐后的长度)

那么,输入和输出参数的 shape 通常是:

  • log_probs (输入)

    • Shape : [4, 10]
    • 含义 : 一个二维张量。log_probs[i, j] 表示在第 i 个样本的第 j 个时间步,新策略 π_new 采取 input_ids 中对应动作(token)的对数概率。
  • log_probs_base (输入)

    • Shape : [4, 10]
    • 含义 : 一个二维张量。log_probs_base[i, j] 表示在第 i 个样本的第 j 个时间步,基准策略π_oldπ_ref)采取 input_ids 中对应动作(token)的对数概率。
  • action_mask (输入)

    • Shape : [4, 10]
    • 含义 : 一个二维掩码张量(通常是 0 和 1)。action_mask[i, j] = 1 表示第 i 个样本的第 j 个位置属于 "response" 部分,需要计算 KL 散度;如果为 0,则表示是 "prompt" 或 "padding",不需要计算。
  • kl_div (输出)

    • Shape : [4, 10]
    • 含义 : 一个二维张量。kl_div[i, j] 表示在第 i 个样本的第 j 个时间步,新旧策略之间的近似 KL 散度值。这个值尚未经过掩码和平均。

举例说明执行过程

假设我们只看一个 batch_size=1, sequence_length=3 的例子。

  • 输入 :
    • log_probs: torch.tensor([[-2.0, -1.1, -1.5]])
    • log_probs_base: torch.tensor([[-1.8, -1.0, -1.6]])
    • action_mask: torch.tensor([[0, 1, 1]]) (假设第一个是 prompt,后两个是 response)
    • kl_penalty = "kl"

compute_approx_kl 函数内部执行步骤:

  1. log_ratio = log_probs - log_probs_base

    • log_ratio = [-2.0 - (-1.8), -1.1 - (-1.0), -1.5 - (-1.6)]
    • log_ratio 结果: torch.tensor([[-0.2, -0.1, 0.1]])
  2. ratio = torch.exp(log_ratio)

    • ratio = [exp(-0.2), exp(-0.1), exp(0.1)]
    • ratio 结果: torch.tensor([[0.8187, 0.9048, 1.1052]]) (近似值)
  3. kl_div = ratio - 1 - log_ratio

    • 这一步是逐元素计算的。
    • kl_div[0, 0] = 0.8187 - 1 - (-0.2) = 0.0187
    • kl_div[0, 1] = 0.9048 - 1 - (-0.1) = 0.0048
    • kl_div[0, 2] = 1.1052 - 1 - (0.1) = 0.0052
    • kl_div 结果: torch.tensor([[0.0187, 0.0048, 0.0052]])

函数返回 : torch.tensor([[0.0187, 0.0048, 0.0052]])。这是一个形状为 [1, 3] 的张量。


函数外部的操作 (在 loss_func 中):

loss_func 中,这个返回的张量会被 agg_loss 函数进一步处理。

  1. kl_loss_matrix 接收到 compute_approx_kl 的返回值 [[0.0187, 0.0048, 0.0052]]
  2. agg_loss 会使用 action_mask [[0, 1, 1]] 来过滤和聚合。
  3. 掩码操作 : kl_loss_matrixaction_mask 逐元素相乘。
    [0.0187, 0.0048, 0.0052] * [0, 1, 1] = [0, 0.0048, 0.0052]
  4. 聚合操作 : 对掩码后的结果求平均值(只计算非零元素)。
    (0.0048 + 0.0052) / 2 = 0.005
  5. 最终,传递给总损失计算的 kl_loss 是一个标量值,即 0.005

kl_penalty 的不同选择

  • "kl" : ratio - 1 - log_ratio。这是最常用的一种近似,它在 ratio=1 附近表现良好,数值稳定。
  • "mse" : 0.5 * log_ratio.square()。这是 "kl" 公式在 log_ratio 接近 0 时的二阶泰勒展开近似。计算更简单,效果也常常不错。
  • "k3" : (ratio.square() - 1 - 2*log_ratio) / 2。一个更高阶的近似,理论上可能更准确。

在您的代码中,actor/kl_loss 使用了 "k3",而 actor/approxklactor/policykl 分别使用了 "mse""kl",这是为了监控 不同近似方法计算出的 KL 散度值,以更好地了解训练动态。而实际用于计算总损失kl_loss 本身使用 "k3",可能是因为实验发现它作为正则化项效果更好。

相关推荐
sealaugh322 小时前
AI(学习笔记第十七课)langchain v1.0(SQL Agent)
人工智能·笔记·学习
zbguolei2 小时前
使用VBA将EXCEL生成PPT
人工智能·opencv·计算机视觉
易百纳2 小时前
易百纳携多模态AI桌面机器人——Kubee Robot亮相2025火山引擎冬季FORCE大会
人工智能·火山引擎
zhengfei6112 小时前
AI渗透工具——自主进攻性安全人工智能,用于指导渗透测试流程(EVA)
人工智能·安全
IT_陈寒2 小时前
React 18 性能优化实战:5个被低估的Hooks用法让你的应用快30%
前端·人工智能·后端
戴西软件2 小时前
戴西软件3DViz Convert:解锁三维数据流动,驱动一体化协同设计
大数据·人工智能·安全·3d·华为云·云计算
haiyu_y2 小时前
Day 51 在预训练 ResNet18 中注入 CBAM 注意力
人工智能·pytorch·深度学习
拉拉拉拉拉拉拉马2 小时前
感知机(Perceptron)算法详解
人工智能·python·深度学习·算法·机器学习
万邦科技Lafite2 小时前
淘宝开放API获取订单信息教程(2025年最新版)
java·开发语言·数据库·人工智能·python·开放api·电商开放平台