【RL】importance_sampling Ratio的计算

好的,我们来详细解释 clipfrac (clipping fraction) 在代码中是如何计算的。这个指标衡量了在PPO损失计算中,有多大比例的token因为ratio超出范围而被裁剪。

clipfrac 的计算通常在 ActorWorker.loss_func 方法内部,紧随着PPO损失的核心计算步骤。


ActorWorker.loss_func 中的计算逻辑

让我们回顾一下 loss_func 的核心部分,并定位 clipfrac 的计算来源。

python 复制代码
# in roll/pipeline/base_worker.py, class ActorWorker, method loss_func

def loss_func(self, data: DataProto, output_tensor: torch.Tensor):
    # ... (获取 advantages, old_log_probs 等)
    
    # 1. 计算 ratio
    if self.pipeline_config.importance_sampling == "token":
        ratio = (log_probs - old_log_probs).exp()
    elif self.pipeline_config.importance_sampling == "seq":
        # ... (计算序列级别的 ratio)
    
    # 2. 定义裁剪边界
    pg_clip_low = self.pipeline_config.pg_clip_low if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip
    pg_clip_high = self.pipeline_config.pg_clip_high if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip
    # 假设 pg_clip = 0.2, 那么 low = 0.2, high = 0.2
    # 裁剪区间为 [1 - 0.2, 1 + 0.2] = [0.8, 1.2]

    # 3. 计算PPO损失
    surr1 = ratio * advantages
    surr2 = ratio.clamp(1 - pg_clip_low, 1 + pg_clip_high) * advantages
    loss = -torch.min(surr1, surr2)
    
    # ...
    
    # -------------------------------------------------------------
    # 4. 计算 clipfrac
    # -------------------------------------------------------------
    # 判断 ratio 是否小于下界
    clipped_low = (ratio < 1 - pg_clip_low).float()
    
    # 判断 ratio 是否大于上界
    clipped_high = (ratio > 1 + pg_clip_high).float()
    
    # 将两种裁剪情况相加,得到一个表示是否被裁剪的0-1矩阵
    # 因为 ratio 不可能同时小于下界又大于上界,所以直接相加即可
    clipped = (clipped_low + clipped_high).float()

    # 5. 聚合和记录
    # ...
    loss_metric = {
        "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(),
        "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(),
        # 这里的 clipped.mean() 就是我们通常所说的总 clipfrac
        "actor/ppo_ratio_clipfrac": clipped.mean().detach().item(),
        # ...
    } 
    # ...
    
    return total_loss, pg_metrics

详细步骤与举例

我们用一个简单的例子来说明这个过程。假设我们有一个批次(batch_size=2),序列长度为 4,pg_clip (ε) 设置为 0.2

1. 输入:ratio 矩阵

经过计算,我们得到了 ratio 矩阵,形状为 (2, 4):

python 复制代码
import torch

ratio = torch.tensor([
    [1.10, 0.95, 1.25, 1.05],  # 样本 1
    [0.75, 1.15, 0.82, 0.99]   # 样本 2
])
2. 定义裁剪边界

裁剪区间是 [1 - 0.2, 1 + 0.2],即 [0.8, 1.2]

3. 计算 clipped_low

判断 ratio 矩阵中的每个元素是否小于下界 0.8

python 复制代码
clipped_low = (ratio < 0.8).float()

# clipped_low 的值:
# tensor([
#     [0., 0., 0., 0.],  # 样本 1: 没有元素小于0.8
#     [1., 0., 0., 0.]   # 样本 2: 第一个元素 0.75 < 0.8
# ])
4. 计算 clipped_high

判断 ratio 矩阵中的每个元素是否大于上界 1.2

python 复制代码
clipped_high = (ratio > 1.2).float()

# clipped_high 的值:
# tensor([
#     [0., 0., 1., 0.],  # 样本 1: 第三个元素 1.25 > 1.2
#     [0., 0., 0., 0.]   # 样本 2: 没有元素大于1.2
# ])
5. 计算 clipped

clipped_lowclipped_high 相加。

python 复制代码
clipped = clipped_low + clipped_high

# clipped 的值:
# tensor([
#     [0., 0., 1., 0.],
#     [1., 0., 0., 0.]
# ])

这个 clipped 矩阵是一个0-1矩阵,1 所在的位置就代表该token的ratio被裁剪了。

6. 计算最终的 clipfrac

clipped 矩阵求平均值。这会计算出被裁剪的token占所有token总数的比例。

python 复制代码
clipfrac = clipped.mean().item()
  • clipped 矩阵中元素的总数是 2 * 4 = 8
  • clipped 矩阵中 1 的数量是 2
  • clipfrac = 2 / 8 = 0.25

所以,在这个例子中,clipfrac0.25 或 25%。

总结

clipfrac 的计算过程如下:

  1. 计算出每个token的**ratio**(新旧策略概率比)。
  2. 定义PPO的裁剪区间 [1 - ε, 1 + ε]
  3. 创建一个布尔掩码(boolean mask),标记出所有ratio小于 1 - ε大于 1 + ε 的位置。
  4. 计算这个掩码中 True(即被裁剪的token)的数量,然后除以总的token数量

这个最终的比例就是clipfrac,它是一个介于0和1之间的标量,直观地反映了在当前训练步中,策略更新的"激进"程度。

好的,我们来详细讲解ratio的两种计算方式(token级别和seq级别),并分别举例说明。

ratio,即重要性采样比率(Importance Sampling Ratio),是PPO算法的核心,用于修正使用旧策略采集的数据来评估新策略时产生的分布差异。

它的基本定义是:
ratio = π_θ(a|s) / π_θ_old(a|s)

其中:

  • π_θ(a|s)当前策略 (新策略)在状态s下采取动作a的概率。
  • π_θ_old(a|s)采样策略 (旧策略)在状态s下采取动作a的概率。

在实际计算中,为了避免数值下溢和提高计算稳定性,我们通常使用对数概率(log probabilities)来计算:
log(ratio) = log(π_θ(a|s)) - log(π_θ_old(a|s))
ratio = exp( log(π_θ(a|s)) - log(π_θ_old(a|s)) )


1. importance_sampling == "token" (Token级Ratio)

这是最直接、最经典的方式。它为每一个token 都独立计算一个ratio值。

计算逻辑
  1. log_probs : 使用当前模型 (新策略)对批次中的序列进行前向传播,得到每个token位置上,实际采取的那个token的对数概率。结果是一个形状为 (batch_size, seq_len) 的张量。
  2. old_log_probs : 这是在数据采集阶段 ,使用当时的模型(旧策略)计算并存储下来的对数概率。它也具有相同的形状 (batch_size, seq_len)
  3. 计算ratio : ratio = (log_probs - old_log_probs).exp()。这个操作是逐元素进行的。
举例说明

假设 batch_size=2, seq_len=4

  • log_probs (新策略的对数概率):

    python 复制代码
    log_probs = torch.tensor([
        [-0.5, -1.2, -0.8, -2.0],  # 样本 1
        [-1.0, -0.4, -1.5, -0.9]   # 样本 2
    ])
  • old_log_probs (旧策略的对数概率):

    python 复制代码
    old_log_probs = torch.tensor([
        [-0.6, -1.1, -1.1, -2.1],  # 样本 1
        [-1.3, -0.5, -1.4, -0.9]   # 样本 2
    ])
  • 计算过程:

    1. 计算对数差 log_probs - old_log_probs:

      复制代码
      tensor([
          [ 0.1, -0.1,  0.3,  0.1],  # -0.5 - (-0.6) = 0.1, etc.
          [ 0.3,  0.1, -0.1,  0.0]
      ])
    2. 取指数 exp(...):

      python 复制代码
      ratio = (log_probs - old_log_probs).exp()
      
      # ratio 的值:
      # tensor([
      #     [1.1052, 0.9048, 1.3499, 1.1052],  # exp(0.1), exp(-0.1), etc.
      #     [1.3499, 1.1052, 0.9048, 1.0000]
      # ])
  • 结果 : 我们得到了一个形状为 (2, 4)ratio矩阵。每个元素 ratio[i, j] 代表第i个样本的第j个token的新旧策略概率比。例如,1.3499 意味着新策略生成这个token的概率是旧策略的1.35倍。


2. importance_sampling == "seq" (序列级Ratio)

这种方式不为每个token计算独立的ratio,而是为整个序列 计算一个统一的ratio,然后将这个值广播到该序列的所有token上。

计算逻辑
  1. 计算逐token的对数差 : log_ratio_token_level = log_probs - old_log_probs。这一步和token级方法一样。
  2. 计算序列的平均对数差 : 使用 masked_mean 函数,计算每个序列(行)的有效token的平均对数差。
    masked_log_ratio_seq_level = masked_mean(log_ratio_token_level, final_response_mask, dim=-1)
  3. 计算序列的ratio : 对上一步得到的序列平均对数差取指数。
    ratio_seq_level = masked_log_ratio_seq_level.exp()
  4. 广播 : 将序列级别的ratio扩展(unsqueeze and expand)成与原始log_probs相同的形状 (batch_size, seq_len)
举例说明

我们使用和上面相同的log_probsold_log_probs,并假设一个mask

  • log_probsold_log_probs (同上)

  • mask (假设样本1长度为3,样本2长度为4):

    python 复制代码
    mask = torch.tensor([
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]
    ])
  • 计算过程 :

    1. 计算逐token的对数差 log_ratio_token_level:

      复制代码
      tensor([
          [ 0.1, -0.1,  0.3,  0.1],
          [ 0.3,  0.1, -0.1,  0.0]
      ])
    2. 计算每个序列的平均对数差:

      • 样本1: (0.1 + (-0.1) + 0.3) / 3 = 0.1
      • 样本2: (0.3 + 0.1 + (-0.1) + 0.0) / 4 = 0.075

      masked_log_ratio_seq_level = tensor([0.1, 0.075])

    3. 对序列平均对数差取指数,得到序列级ratio:

      • 样本1: exp(0.1) = 1.1052
      • 样本2: exp(0.075) = 1.0779

      ratio_seq_level = tensor([1.1052, 1.0779])

    4. 广播 这个序列级ratio到整个矩阵:

      python 复制代码
      # ratio 的最终值:
      # tensor([
      #     [1.1052, 1.1052, 1.1052, 1.1052],  # 样本1的所有token共享同一个ratio
      #     [1.0779, 1.0779, 1.0779, 1.0779]   # 样本2的所有token共享同一个ratio
      # ])

对比与总结

特性 importance_sampling="token" importance_sampling="seq"
计算粒度 逐个token 整个序列
Ratio值 每个token都有自己独特的ratio值。 同一个序列中的所有token共享一个相同的ratio值。
物理意义 评估新策略在每个决策点上的改变。 评估新策略对于生成整个完整序列的概率的整体改变。
方差 理论上方差更高 ,因为每个token的ratio都会引入随机性。 理论上方差更低,因为对数比率被平均了,平滑了单个token的极端变化。
使用场景 标准的PPO实现。 在某些情况下,可能使训练更稳定,特别是在处理长序列时,可以防止单个token的极端ratio值对梯度产生过大影响。

在您提供的代码中,self.pipeline_config.importance_sampling 这个配置项就是用来选择这两种不同计算方式的"开关"。

相关推荐
NAGNIP1 天前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab1 天前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab1 天前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP1 天前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年1 天前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼1 天前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS1 天前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区1 天前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈1 天前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang1 天前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx