【RL】importance_sampling Ratio的计算

好的,我们来详细解释 clipfrac (clipping fraction) 在代码中是如何计算的。这个指标衡量了在PPO损失计算中,有多大比例的token因为ratio超出范围而被裁剪。

clipfrac 的计算通常在 ActorWorker.loss_func 方法内部,紧随着PPO损失的核心计算步骤。


ActorWorker.loss_func 中的计算逻辑

让我们回顾一下 loss_func 的核心部分,并定位 clipfrac 的计算来源。

python 复制代码
# in roll/pipeline/base_worker.py, class ActorWorker, method loss_func

def loss_func(self, data: DataProto, output_tensor: torch.Tensor):
    # ... (获取 advantages, old_log_probs 等)
    
    # 1. 计算 ratio
    if self.pipeline_config.importance_sampling == "token":
        ratio = (log_probs - old_log_probs).exp()
    elif self.pipeline_config.importance_sampling == "seq":
        # ... (计算序列级别的 ratio)
    
    # 2. 定义裁剪边界
    pg_clip_low = self.pipeline_config.pg_clip_low if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip
    pg_clip_high = self.pipeline_config.pg_clip_high if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip
    # 假设 pg_clip = 0.2, 那么 low = 0.2, high = 0.2
    # 裁剪区间为 [1 - 0.2, 1 + 0.2] = [0.8, 1.2]

    # 3. 计算PPO损失
    surr1 = ratio * advantages
    surr2 = ratio.clamp(1 - pg_clip_low, 1 + pg_clip_high) * advantages
    loss = -torch.min(surr1, surr2)
    
    # ...
    
    # -------------------------------------------------------------
    # 4. 计算 clipfrac
    # -------------------------------------------------------------
    # 判断 ratio 是否小于下界
    clipped_low = (ratio < 1 - pg_clip_low).float()
    
    # 判断 ratio 是否大于上界
    clipped_high = (ratio > 1 + pg_clip_high).float()
    
    # 将两种裁剪情况相加,得到一个表示是否被裁剪的0-1矩阵
    # 因为 ratio 不可能同时小于下界又大于上界,所以直接相加即可
    clipped = (clipped_low + clipped_high).float()

    # 5. 聚合和记录
    # ...
    loss_metric = {
        "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(),
        "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(),
        # 这里的 clipped.mean() 就是我们通常所说的总 clipfrac
        "actor/ppo_ratio_clipfrac": clipped.mean().detach().item(),
        # ...
    } 
    # ...
    
    return total_loss, pg_metrics

详细步骤与举例

我们用一个简单的例子来说明这个过程。假设我们有一个批次(batch_size=2),序列长度为 4,pg_clip (ε) 设置为 0.2

1. 输入:ratio 矩阵

经过计算,我们得到了 ratio 矩阵,形状为 (2, 4):

python 复制代码
import torch

ratio = torch.tensor([
    [1.10, 0.95, 1.25, 1.05],  # 样本 1
    [0.75, 1.15, 0.82, 0.99]   # 样本 2
])
2. 定义裁剪边界

裁剪区间是 [1 - 0.2, 1 + 0.2],即 [0.8, 1.2]

3. 计算 clipped_low

判断 ratio 矩阵中的每个元素是否小于下界 0.8

python 复制代码
clipped_low = (ratio < 0.8).float()

# clipped_low 的值:
# tensor([
#     [0., 0., 0., 0.],  # 样本 1: 没有元素小于0.8
#     [1., 0., 0., 0.]   # 样本 2: 第一个元素 0.75 < 0.8
# ])
4. 计算 clipped_high

判断 ratio 矩阵中的每个元素是否大于上界 1.2

python 复制代码
clipped_high = (ratio > 1.2).float()

# clipped_high 的值:
# tensor([
#     [0., 0., 1., 0.],  # 样本 1: 第三个元素 1.25 > 1.2
#     [0., 0., 0., 0.]   # 样本 2: 没有元素大于1.2
# ])
5. 计算 clipped

clipped_lowclipped_high 相加。

python 复制代码
clipped = clipped_low + clipped_high

# clipped 的值:
# tensor([
#     [0., 0., 1., 0.],
#     [1., 0., 0., 0.]
# ])

这个 clipped 矩阵是一个0-1矩阵,1 所在的位置就代表该token的ratio被裁剪了。

6. 计算最终的 clipfrac

clipped 矩阵求平均值。这会计算出被裁剪的token占所有token总数的比例。

python 复制代码
clipfrac = clipped.mean().item()
  • clipped 矩阵中元素的总数是 2 * 4 = 8
  • clipped 矩阵中 1 的数量是 2
  • clipfrac = 2 / 8 = 0.25

所以,在这个例子中,clipfrac0.25 或 25%。

总结

clipfrac 的计算过程如下:

  1. 计算出每个token的**ratio**(新旧策略概率比)。
  2. 定义PPO的裁剪区间 [1 - ε, 1 + ε]
  3. 创建一个布尔掩码(boolean mask),标记出所有ratio小于 1 - ε大于 1 + ε 的位置。
  4. 计算这个掩码中 True(即被裁剪的token)的数量,然后除以总的token数量

这个最终的比例就是clipfrac,它是一个介于0和1之间的标量,直观地反映了在当前训练步中,策略更新的"激进"程度。

好的,我们来详细讲解ratio的两种计算方式(token级别和seq级别),并分别举例说明。

ratio,即重要性采样比率(Importance Sampling Ratio),是PPO算法的核心,用于修正使用旧策略采集的数据来评估新策略时产生的分布差异。

它的基本定义是:
ratio = π_θ(a|s) / π_θ_old(a|s)

其中:

  • π_θ(a|s)当前策略 (新策略)在状态s下采取动作a的概率。
  • π_θ_old(a|s)采样策略 (旧策略)在状态s下采取动作a的概率。

在实际计算中,为了避免数值下溢和提高计算稳定性,我们通常使用对数概率(log probabilities)来计算:
log(ratio) = log(π_θ(a|s)) - log(π_θ_old(a|s))
ratio = exp( log(π_θ(a|s)) - log(π_θ_old(a|s)) )


1. importance_sampling == "token" (Token级Ratio)

这是最直接、最经典的方式。它为每一个token 都独立计算一个ratio值。

计算逻辑
  1. log_probs : 使用当前模型 (新策略)对批次中的序列进行前向传播,得到每个token位置上,实际采取的那个token的对数概率。结果是一个形状为 (batch_size, seq_len) 的张量。
  2. old_log_probs : 这是在数据采集阶段 ,使用当时的模型(旧策略)计算并存储下来的对数概率。它也具有相同的形状 (batch_size, seq_len)
  3. 计算ratio : ratio = (log_probs - old_log_probs).exp()。这个操作是逐元素进行的。
举例说明

假设 batch_size=2, seq_len=4

  • log_probs (新策略的对数概率):

    python 复制代码
    log_probs = torch.tensor([
        [-0.5, -1.2, -0.8, -2.0],  # 样本 1
        [-1.0, -0.4, -1.5, -0.9]   # 样本 2
    ])
  • old_log_probs (旧策略的对数概率):

    python 复制代码
    old_log_probs = torch.tensor([
        [-0.6, -1.1, -1.1, -2.1],  # 样本 1
        [-1.3, -0.5, -1.4, -0.9]   # 样本 2
    ])
  • 计算过程:

    1. 计算对数差 log_probs - old_log_probs:

      复制代码
      tensor([
          [ 0.1, -0.1,  0.3,  0.1],  # -0.5 - (-0.6) = 0.1, etc.
          [ 0.3,  0.1, -0.1,  0.0]
      ])
    2. 取指数 exp(...):

      python 复制代码
      ratio = (log_probs - old_log_probs).exp()
      
      # ratio 的值:
      # tensor([
      #     [1.1052, 0.9048, 1.3499, 1.1052],  # exp(0.1), exp(-0.1), etc.
      #     [1.3499, 1.1052, 0.9048, 1.0000]
      # ])
  • 结果 : 我们得到了一个形状为 (2, 4)ratio矩阵。每个元素 ratio[i, j] 代表第i个样本的第j个token的新旧策略概率比。例如,1.3499 意味着新策略生成这个token的概率是旧策略的1.35倍。


2. importance_sampling == "seq" (序列级Ratio)

这种方式不为每个token计算独立的ratio,而是为整个序列 计算一个统一的ratio,然后将这个值广播到该序列的所有token上。

计算逻辑
  1. 计算逐token的对数差 : log_ratio_token_level = log_probs - old_log_probs。这一步和token级方法一样。
  2. 计算序列的平均对数差 : 使用 masked_mean 函数,计算每个序列(行)的有效token的平均对数差。
    masked_log_ratio_seq_level = masked_mean(log_ratio_token_level, final_response_mask, dim=-1)
  3. 计算序列的ratio : 对上一步得到的序列平均对数差取指数。
    ratio_seq_level = masked_log_ratio_seq_level.exp()
  4. 广播 : 将序列级别的ratio扩展(unsqueeze and expand)成与原始log_probs相同的形状 (batch_size, seq_len)
举例说明

我们使用和上面相同的log_probsold_log_probs,并假设一个mask

  • log_probsold_log_probs (同上)

  • mask (假设样本1长度为3,样本2长度为4):

    python 复制代码
    mask = torch.tensor([
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]
    ])
  • 计算过程 :

    1. 计算逐token的对数差 log_ratio_token_level:

      复制代码
      tensor([
          [ 0.1, -0.1,  0.3,  0.1],
          [ 0.3,  0.1, -0.1,  0.0]
      ])
    2. 计算每个序列的平均对数差:

      • 样本1: (0.1 + (-0.1) + 0.3) / 3 = 0.1
      • 样本2: (0.3 + 0.1 + (-0.1) + 0.0) / 4 = 0.075

      masked_log_ratio_seq_level = tensor([0.1, 0.075])

    3. 对序列平均对数差取指数,得到序列级ratio:

      • 样本1: exp(0.1) = 1.1052
      • 样本2: exp(0.075) = 1.0779

      ratio_seq_level = tensor([1.1052, 1.0779])

    4. 广播 这个序列级ratio到整个矩阵:

      python 复制代码
      # ratio 的最终值:
      # tensor([
      #     [1.1052, 1.1052, 1.1052, 1.1052],  # 样本1的所有token共享同一个ratio
      #     [1.0779, 1.0779, 1.0779, 1.0779]   # 样本2的所有token共享同一个ratio
      # ])

对比与总结

特性 importance_sampling="token" importance_sampling="seq"
计算粒度 逐个token 整个序列
Ratio值 每个token都有自己独特的ratio值。 同一个序列中的所有token共享一个相同的ratio值。
物理意义 评估新策略在每个决策点上的改变。 评估新策略对于生成整个完整序列的概率的整体改变。
方差 理论上方差更高 ,因为每个token的ratio都会引入随机性。 理论上方差更低,因为对数比率被平均了,平滑了单个token的极端变化。
使用场景 标准的PPO实现。 在某些情况下,可能使训练更稳定,特别是在处理长序列时,可以防止单个token的极端ratio值对梯度产生过大影响。

在您提供的代码中,self.pipeline_config.importance_sampling 这个配置项就是用来选择这两种不同计算方式的"开关"。

相关推荐
Cx330❀1 天前
脉脉平台深度测评:【AI创作者xAMA】从职场社交到AI创作赋能
数据库·人工智能·脉脉
攻城狮7号1 天前
通用 GUI 智能体基座 MAI-UI 开源:告别“人工智障”?
人工智能·mai-ui·tongyi-mai·阿里通义实验室·gui智能体
寻星探路1 天前
【深度长文】深入理解网络原理:TCP/IP 协议栈核心实战与性能调优
java·网络·人工智能·python·网络协议·tcp/ip·ai
轻竹办公PPT1 天前
实测多款 AI:2026 年工作计划 PPT 哪种更好修改
人工智能·python·powerpoint
AIHubPro未来百科1 天前
三天用AI开发完成开源WordPress导航主题:要哇棱镜主题详解 + 完整部署教程
人工智能·开源
执笔论英雄1 天前
【RL】advantages 与 ratio之间的关系
人工智能
切糕师学AI1 天前
AI 领域中的 Prompt(提示词/提示)是什么?
人工智能·prompt
HZZD_HZZD1 天前
喜讯|合众致达成功中标宁夏宝丰集团水电表计量结算管理平台项目
大数据·人工智能
AI_56781 天前
基于职业发展的Python与Java深度对比分析
java·人工智能·python·信息可视化