好的,我们来详细解释 clipfrac (clipping fraction) 在代码中是如何计算的。这个指标衡量了在PPO损失计算中,有多大比例的token因为ratio超出范围而被裁剪。
clipfrac 的计算通常在 ActorWorker.loss_func 方法内部,紧随着PPO损失的核心计算步骤。
ActorWorker.loss_func 中的计算逻辑
让我们回顾一下 loss_func 的核心部分,并定位 clipfrac 的计算来源。
python
# in roll/pipeline/base_worker.py, class ActorWorker, method loss_func
def loss_func(self, data: DataProto, output_tensor: torch.Tensor):
# ... (获取 advantages, old_log_probs 等)
# 1. 计算 ratio
if self.pipeline_config.importance_sampling == "token":
ratio = (log_probs - old_log_probs).exp()
elif self.pipeline_config.importance_sampling == "seq":
# ... (计算序列级别的 ratio)
# 2. 定义裁剪边界
pg_clip_low = self.pipeline_config.pg_clip_low if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip
pg_clip_high = self.pipeline_config.pg_clip_high if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip
# 假设 pg_clip = 0.2, 那么 low = 0.2, high = 0.2
# 裁剪区间为 [1 - 0.2, 1 + 0.2] = [0.8, 1.2]
# 3. 计算PPO损失
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - pg_clip_low, 1 + pg_clip_high) * advantages
loss = -torch.min(surr1, surr2)
# ...
# -------------------------------------------------------------
# 4. 计算 clipfrac
# -------------------------------------------------------------
# 判断 ratio 是否小于下界
clipped_low = (ratio < 1 - pg_clip_low).float()
# 判断 ratio 是否大于上界
clipped_high = (ratio > 1 + pg_clip_high).float()
# 将两种裁剪情况相加,得到一个表示是否被裁剪的0-1矩阵
# 因为 ratio 不可能同时小于下界又大于上界,所以直接相加即可
clipped = (clipped_low + clipped_high).float()
# 5. 聚合和记录
# ...
loss_metric = {
"actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(),
"actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(),
# 这里的 clipped.mean() 就是我们通常所说的总 clipfrac
"actor/ppo_ratio_clipfrac": clipped.mean().detach().item(),
# ...
}
# ...
return total_loss, pg_metrics
详细步骤与举例
我们用一个简单的例子来说明这个过程。假设我们有一个批次(batch_size=2),序列长度为 4,pg_clip (ε) 设置为 0.2。
1. 输入:ratio 矩阵
经过计算,我们得到了 ratio 矩阵,形状为 (2, 4):
python
import torch
ratio = torch.tensor([
[1.10, 0.95, 1.25, 1.05], # 样本 1
[0.75, 1.15, 0.82, 0.99] # 样本 2
])
2. 定义裁剪边界
裁剪区间是 [1 - 0.2, 1 + 0.2],即 [0.8, 1.2]。
3. 计算 clipped_low
判断 ratio 矩阵中的每个元素是否小于下界 0.8。
python
clipped_low = (ratio < 0.8).float()
# clipped_low 的值:
# tensor([
# [0., 0., 0., 0.], # 样本 1: 没有元素小于0.8
# [1., 0., 0., 0.] # 样本 2: 第一个元素 0.75 < 0.8
# ])
4. 计算 clipped_high
判断 ratio 矩阵中的每个元素是否大于上界 1.2。
python
clipped_high = (ratio > 1.2).float()
# clipped_high 的值:
# tensor([
# [0., 0., 1., 0.], # 样本 1: 第三个元素 1.25 > 1.2
# [0., 0., 0., 0.] # 样本 2: 没有元素大于1.2
# ])
5. 计算 clipped
将 clipped_low 和 clipped_high 相加。
python
clipped = clipped_low + clipped_high
# clipped 的值:
# tensor([
# [0., 0., 1., 0.],
# [1., 0., 0., 0.]
# ])
这个 clipped 矩阵是一个0-1矩阵,1 所在的位置就代表该token的ratio被裁剪了。
6. 计算最终的 clipfrac
对 clipped 矩阵求平均值。这会计算出被裁剪的token占所有token总数的比例。
python
clipfrac = clipped.mean().item()
clipped矩阵中元素的总数是2 * 4 = 8。clipped矩阵中1的数量是2。clipfrac = 2 / 8 = 0.25
所以,在这个例子中,clipfrac 是 0.25 或 25%。
总结
clipfrac 的计算过程如下:
- 计算出每个token的**
ratio**(新旧策略概率比)。 - 定义PPO的裁剪区间
[1 - ε, 1 + ε]。 - 创建一个布尔掩码(boolean mask),标记出所有
ratio值小于1 - ε或 大于1 + ε的位置。 - 计算这个掩码中
True(即被裁剪的token)的数量,然后除以总的token数量。
这个最终的比例就是clipfrac,它是一个介于0和1之间的标量,直观地反映了在当前训练步中,策略更新的"激进"程度。
好的,我们来详细讲解ratio的两种计算方式(token级别和seq级别),并分别举例说明。
ratio,即重要性采样比率(Importance Sampling Ratio),是PPO算法的核心,用于修正使用旧策略采集的数据来评估新策略时产生的分布差异。
它的基本定义是:
ratio = π_θ(a|s) / π_θ_old(a|s)
其中:
π_θ(a|s)是当前策略 (新策略)在状态s下采取动作a的概率。π_θ_old(a|s)是采样策略 (旧策略)在状态s下采取动作a的概率。
在实际计算中,为了避免数值下溢和提高计算稳定性,我们通常使用对数概率(log probabilities)来计算:
log(ratio) = log(π_θ(a|s)) - log(π_θ_old(a|s))
ratio = exp( log(π_θ(a|s)) - log(π_θ_old(a|s)) )
1. importance_sampling == "token" (Token级Ratio)
这是最直接、最经典的方式。它为每一个token 都独立计算一个ratio值。
计算逻辑
log_probs: 使用当前模型 (新策略)对批次中的序列进行前向传播,得到每个token位置上,实际采取的那个token的对数概率。结果是一个形状为(batch_size, seq_len)的张量。old_log_probs: 这是在数据采集阶段 ,使用当时的模型(旧策略)计算并存储下来的对数概率。它也具有相同的形状(batch_size, seq_len)。- 计算
ratio:ratio = (log_probs - old_log_probs).exp()。这个操作是逐元素进行的。
举例说明
假设 batch_size=2, seq_len=4。
-
log_probs(新策略的对数概率):pythonlog_probs = torch.tensor([ [-0.5, -1.2, -0.8, -2.0], # 样本 1 [-1.0, -0.4, -1.5, -0.9] # 样本 2 ]) -
old_log_probs(旧策略的对数概率):pythonold_log_probs = torch.tensor([ [-0.6, -1.1, -1.1, -2.1], # 样本 1 [-1.3, -0.5, -1.4, -0.9] # 样本 2 ]) -
计算过程:
-
计算对数差
log_probs - old_log_probs:tensor([ [ 0.1, -0.1, 0.3, 0.1], # -0.5 - (-0.6) = 0.1, etc. [ 0.3, 0.1, -0.1, 0.0] ]) -
取指数
exp(...):pythonratio = (log_probs - old_log_probs).exp() # ratio 的值: # tensor([ # [1.1052, 0.9048, 1.3499, 1.1052], # exp(0.1), exp(-0.1), etc. # [1.3499, 1.1052, 0.9048, 1.0000] # ])
-
-
结果 : 我们得到了一个形状为
(2, 4)的ratio矩阵。每个元素ratio[i, j]代表第i个样本的第j个token的新旧策略概率比。例如,1.3499意味着新策略生成这个token的概率是旧策略的1.35倍。
2. importance_sampling == "seq" (序列级Ratio)
这种方式不为每个token计算独立的ratio,而是为整个序列 计算一个统一的ratio,然后将这个值广播到该序列的所有token上。
计算逻辑
- 计算逐token的对数差 :
log_ratio_token_level = log_probs - old_log_probs。这一步和token级方法一样。 - 计算序列的平均对数差 : 使用
masked_mean函数,计算每个序列(行)的有效token的平均对数差。
masked_log_ratio_seq_level = masked_mean(log_ratio_token_level, final_response_mask, dim=-1) - 计算序列的
ratio: 对上一步得到的序列平均对数差取指数。
ratio_seq_level = masked_log_ratio_seq_level.exp() - 广播 : 将序列级别的
ratio扩展(unsqueeze and expand)成与原始log_probs相同的形状(batch_size, seq_len)。
举例说明
我们使用和上面相同的log_probs和old_log_probs,并假设一个mask。
-
log_probs和old_log_probs(同上) -
mask(假设样本1长度为3,样本2长度为4):pythonmask = torch.tensor([ [1., 1., 1., 0.], [1., 1., 1., 1.] ]) -
计算过程 :
-
计算逐token的对数差
log_ratio_token_level:tensor([ [ 0.1, -0.1, 0.3, 0.1], [ 0.3, 0.1, -0.1, 0.0] ]) -
计算每个序列的平均对数差:
- 样本1:
(0.1 + (-0.1) + 0.3) / 3 = 0.1 - 样本2:
(0.3 + 0.1 + (-0.1) + 0.0) / 4 = 0.075
masked_log_ratio_seq_level = tensor([0.1, 0.075]) - 样本1:
-
对序列平均对数差取指数,得到序列级
ratio:- 样本1:
exp(0.1) = 1.1052 - 样本2:
exp(0.075) = 1.0779
ratio_seq_level = tensor([1.1052, 1.0779]) - 样本1:
-
广播 这个序列级
ratio到整个矩阵:python# ratio 的最终值: # tensor([ # [1.1052, 1.1052, 1.1052, 1.1052], # 样本1的所有token共享同一个ratio # [1.0779, 1.0779, 1.0779, 1.0779] # 样本2的所有token共享同一个ratio # ])
-
对比与总结
| 特性 | importance_sampling="token" |
importance_sampling="seq" |
|---|---|---|
| 计算粒度 | 逐个token | 整个序列 |
| Ratio值 | 每个token都有自己独特的ratio值。 |
同一个序列中的所有token共享一个相同的ratio值。 |
| 物理意义 | 评估新策略在每个决策点上的改变。 | 评估新策略对于生成整个完整序列的概率的整体改变。 |
| 方差 | 理论上方差更高 ,因为每个token的ratio都会引入随机性。 |
理论上方差更低,因为对数比率被平均了,平滑了单个token的极端变化。 |
| 使用场景 | 标准的PPO实现。 | 在某些情况下,可能使训练更稳定,特别是在处理长序列时,可以防止单个token的极端ratio值对梯度产生过大影响。 |
在您提供的代码中,self.pipeline_config.importance_sampling 这个配置项就是用来选择这两种不同计算方式的"开关"。