GSPO策略损失完整演示

GSPO(Group Sequence Policy Optimization)的核心思想是将 PPO 中 token 级别的重要性比率替换为序列级别的重要性比率,从而获得更稳定的策略优化信号。

设定具体输入参数

python 复制代码
import torch
import torch.nn.functional as F

# ============================================================
# 场景设定:batch_size=2, response_length=4
# 第1条样本:4个token全部有效(mask全1)
# 第2条样本:4个token中只有3个有效(最后一个padding)
# ============================================================

torch.manual_seed(42)

# 旧策略下每个token的log概率
old_log_prob = torch.tensor([
    [-0.5, -0.8, -0.6, -0.9],   # 样本1
    [-0.4, -0.7, -0.5, -0.3],   # 样本2
])

# 当前策略下每个token的log概率(训练过程中更新后)
log_prob = torch.tensor([
    [-0.4, -0.7, -0.5, -0.8],   # 样本1:略有改善
    [-0.3, -0.6, -0.8, -0.3],   # 样本2:token3变差
])

# 优势函数(来自GRPO/GAE等)
advantages = torch.tensor([
    [ 0.8,  0.8,  0.8,  0.8],   # 样本1:正优势(好动作,鼓励)
    [-0.5, -0.5, -0.5, -0.5],   # 样本2:负优势(坏动作,抑制)
])

# 响应mask(1=有效token,0=padding)
response_mask = torch.tensor([
    [1, 1, 1, 1],   # 样本1:4个有效token
    [1, 1, 1, 0],   # 样本2:3个有效token,最后是padding
], dtype=torch.float)

# Config参数
clip_ratio_low  = 0.2   # 下裁剪边界
clip_ratio_high = 0.2   # 上裁剪边界(symmetric)

Step1: 计算token级别的log比率

python 复制代码
negative_approx_kl = log_prob - old_log_prob
python 复制代码
old_log_prob = [[-0.5, -0.8, -0.6, -0.9],
                [-0.4, -0.7, -0.5, -0.3]]

log_prob     = [[-0.4, -0.7, -0.5, -0.8],
                [-0.3, -0.6, -0.8, -0.3]]

negative_approx_kl = log_prob - old_log_prob
= [[-0.4-(-0.5), -0.7-(-0.8), -0.5-(-0.6), -0.8-(-0.9)],
   [-0.3-(-0.4), -0.6-(-0.7), -0.8-(-0.5), -0.3-(-0.3)]]

= [[ 0.1,  0.1,  0.1,  0.1],   ← 样本1每个token概率都略微上升
   [ 0.1,  0.1, -0.3,  0.0]]   ← 样本2 token3概率下降

Step2 : 计算序列级别的重要性比率(GSPO核心)

python 复制代码
seq_lengths = torch.sum(response_mask, dim=-1).clamp(min=1)
python 复制代码
response_mask.sum(dim=-1) = [4, 3]
clamp(min=1) → seq_lengths = [4.0, 3.0]

样本1 有效token数 = 4
样本2 有效token数 = 3(最后一个是padding,不算)
python 复制代码
negative_approx_kl_seq = torch.sum(negative_approx_kl * response_mask, dim=-1) / seq_lengths

每条序列的平均对数概率比

python 复制代码
step1: negative_approx_kl * response_mask

= [[ 0.1,  0.1,  0.1,  0.1],    ← 样本1乘以[1,1,1,1]
   [ 0.1,  0.1, -0.3,  0.0]]    ← 样本2乘以[1,1,1,0]

step2: sum(dim=-1)
= [0.1+0.1+0.1+0.1,  0.1+0.1-0.3+0.0]
= [0.4,  -0.1]

step3: /seq_lengths
= [0.4/4,  -0.1/3]
= [0.1,  -0.0333]

negative_approx_kl_seq = [0.1000, -0.0333]
                              ↑           ↑
                         样本1序列     样本2序列
                         整体变好      整体略变差

Step3: 构造token级别的联合重要性比率

python 复制代码
log_seq_importance_ratio = log_prob - log_prob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1)
python 复制代码
部分A: log_prob - log_prob.detach()
梯度意义:值为0,但保留了 log_prob 的梯度!
这让 loss 对 θ 求导时,梯度从 log_prob 流回网络

部分B: negative_approx_kl_seq.detach().unsqueeze(-1)
.detach() → 这是一个常数,不参与梯度计算
.unsqueeze(-1) → shape [2] → [2,1],广播到每个token位置

组合效果:
∂/∂θ [log_seq_importance_ratio]_i,t = ∂log_prob_i,t/∂θ
数值上 = 0 + negative_approx_kl_seq_i = 序列平均对数比
python 复制代码
negative_approx_kl_seq.detach() = [0.1000, -0.0333]
                  .unsqueeze(-1) = [[0.1000],
                                    [-0.0333]]  # shape [2,1]

广播后 = [[0.1000, 0.1000, 0.1000, 0.1000],
          [-0.0333,-0.0333,-0.0333,-0.0333]]

log_prob - log_prob.detach() = 全零矩阵(数值为0,但有梯度)

log_seq_importance_ratio = 0 + 以上广播结果
= [[ 0.1000,  0.1000,  0.1000,  0.1000],
   [-0.0333, -0.0333, -0.0333, -0.0333]]

💡 理解这个设计:

梯度路径 → 通过 log_prob(当前网络参数)

数值 → 使用序列平均比率(稳定)

停止梯度 → 避免序列比率自我循环影响梯度

python 复制代码
log_seq_importance_ratio = torch.clamp(log_seq_importance_ratio, max=10.0)
python 复制代码
当前最大值是 0.1,远小于 10.0,数值裁剪不生效
→ 结果不变(防止 exp 爆炸的安全措施)

Step4: 计算序列重要性比率

python 复制代码
seq_importance_ratio = torch.exp(log_seq_importance_ratio)
python 复制代码
exp([[ 0.1000,  0.1000,  0.1000,  0.1000],
     [-0.0333, -0.0333, -0.0333, -0.0333]])

= [[e^0.1,   e^0.1,   e^0.1,   e^0.1  ],
   [e^-0.033,e^-0.033,e^-0.033,e^-0.033]]

= [[1.1052, 1.1052, 1.1052, 1.1052],   ← 样本1: ratio>1,当前策略选更多
   [0.9672, 0.9672, 0.9672, 0.9672]]   ← 样本2: ratio<1,当前策略选更少

🔑 关键性质:同一序列内所有 token 共享同一个重要性比率!(与PPO不同,PPO每个token有独立的ratio)

Step 5:计算裁剪前后的策略损失

python 复制代码
pg_losses1 = -advantages * seq_importance_ratio
pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
python 复制代码
# clip_ratio_low = clip_ratio_high = 0.2
# 裁剪范围:[0.8, 1.2]

# 样本1: ratio=1.1052,在[0.8,1.2]内,clamp不生效
# 样本2: ratio=0.9672,在[0.8,1.2]内,clamp不生效

seq_importance_ratio_clamped = seq_importance_ratio(不变)

# advantages:
# 样本1: +0.8(正,好动作)
# 样本2: -0.5(负,坏动作)

pg_losses1 = -advantages * seq_importance_ratio
= -[[ 0.8,  0.8,  0.8,  0.8],   * [[1.1052,1.1052,1.1052,1.1052],
    [-0.5, -0.5, -0.5, -0.5]]      [0.9672,0.9672,0.9672,0.9672]]

= [[-0.8842, -0.8842, -0.8842, -0.8842],   ← 负值→对loss有贡献→鼓励好动作✓
   [ 0.4836,  0.4836,  0.4836,  0.4836]]   ← 正值→惩罚坏动作✓

pg_losses2 = -advantages * clamp(ratio, 0.8, 1.2)
# 本例clamp不生效,pg_losses2 = pg_losses1

验证裁剪生效的极端情况(假设 ratio=1.5 > 1.2):

python 复制代码
# 好动作(adv=+0.8), ratio超出上界1.2
pg_losses1 = -0.8 * 1.5  = -1.20  (未裁剪,更大收益)
pg_losses2 = -0.8 * 1.2  = -0.96  (裁剪后,收益上限)
pg_loss = max(-1.20, -0.96) = -0.96  ← 取保守的较大值!
# 防止策略更新过激!

Step6 :取最大值(裁剪后保守估计)

python 复制代码
pg_losses = torch.maximum(pg_losses1, pg_losses2)
python 复制代码
本例 pg_losses1 == pg_losses2(ratio在合法范围内)

pg_losses = [[-0.8842, -0.8842, -0.8842, -0.8842],
             [ 0.4836,  0.4836,  0.4836,  0.4836]]

Step7: 聚合损失

python 复制代码
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, 
                   loss_agg_mode="seq-mean-token-mean")

"seq-mean-token-mean" 含义:先在每条序列内做 token 平均,再在 batch 维度做序列平均

python 复制代码
step1: 每条序列内token平均(只算有效token)

样本1: (-0.8842*4)/4 = -0.8842
样本2: ( 0.4836*3)/3 =  0.4836   (第4个token被mask掉)

per_seq_loss = [-0.8842, 0.4836]

step2: batch维度取均值

pg_loss = (-0.8842 + 0.4836) / 2 = -0.2003

💡 为什么GSPO推荐 seq-mean-token-mean:与序列级别的重要性比率对齐,避免长序列主导梯度

Step 8:计算监控指标

python 复制代码
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
python 复制代码
# gt(pg_losses2, pg_losses1):哪些位置裁剪生效了?
# 本例 pg_losses2 == pg_losses1,所以全是 False

torch.gt(pg_losses2, pg_losses1) = [[False,False,False,False],
                                     [False,False,False,False]]

masked_mean(...) = 0.0   ← 裁剪发生率 0%(本例ratio都在合法范围)
python 复制代码
pg_clipfrac_lower = torch.tensor(0.0)  # GSPO无下界裁剪统计
python 复制代码
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
python 复制代码
-negative_approx_kl = [[-0.1, -0.1, -0.1, -0.1],
                        [-0.1, -0.1,  0.3,  0.0]]

有效token处取均值(排除样本2的第4个token):
= (-0.1*4 + (-0.1+(-0.1)+0.3)) / (4+3)
= (-0.4 + 0.1) / 7
= -0.3/7 ≈ -0.0429

ppo_kl ≈ -0.0429   ← 负值说明当前策略整体概率比旧策略高(KL反向)

完整流程总结

python 复制代码
┌─────────────────────────────────────────────────────────────────────┐
│                         GSPO 计算流程图                              │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  old_log_prob, log_prob                                             │
│         │                                                           │
│         ▼                                                           │
│  token级 log差值 (negative_approx_kl)                               │
│  [[ 0.1,  0.1,  0.1,  0.1],                                        │
│   [ 0.1,  0.1, -0.3,  0.0]]                                        │
│         │                                                           │
│         ▼ (×mask, sum, ÷seq_len)                                    │
│  序列级平均 log比率 (negative_approx_kl_seq)                         │
│  [0.1000, -0.0333]    ← 每条序列一个标量                            │
│         │                                                           │
│         ▼ (广播+detach技巧)                                         │
│  token级联合 log比率 (log_seq_importance_ratio)                      │
│  [[ 0.1, 0.1, 0.1, 0.1],  ← 同序列内所有token相同                  │
│   [-0.033,-0.033,-0.033,-0.033]]                                    │
│         │                                                           │
│         ▼ exp()                                                     │
│  序列重要性比率 (seq_importance_ratio)                               │
│  [[1.1052, 1.1052, 1.1052, 1.1052],                                │
│   [0.9672, 0.9672, 0.9672, 0.9672]]                                │
│         │                                                           │
│         ▼ ×(-advantages), clamp[0.8,1.2], maximum                  │
│  pg_losses = max(未裁剪损失, 裁剪后损失)                            │
│  [[-0.8842,-0.8842,-0.8842,-0.8842],                               │
│   [ 0.4836, 0.4836, 0.4836, 0.4836]]                               │
│         │                                                           │
│         ▼ seq-mean-token-mean                                       │
│  pg_loss = -0.2003   ← 反向传播更新策略网络                         │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘
相关推荐
&星痕&2 小时前
从零开始手搓 (1)计算图 (c++,python语言实现)
c++·python·深度学习·机器学习
坚持学习前端日记2 小时前
python对接comfyui的过程
开发语言·网络·python
竹林8182 小时前
从数据混乱到丝滑管理:我在Python项目中重构SQLite数据库的实战记录
python·sqlite
今儿敲了吗2 小时前
python基础学习笔记第四章
c++·笔记·python·学习
电商API&Tina2 小时前
淘宝商品视频的采集需要注意哪些问题||item_video-获得淘宝商品视频
大数据·网络·数据库·人工智能·python·音视频
唐叔在学习2 小时前
Python桌面端应用消息提醒功能开发实践
后端·python·程序员
程序员小远2 小时前
单元测试知识详解
自动化测试·软件测试·python·测试工具·职场和发展·单元测试·测试用例
数形长夏2 小时前
一心多用的艺术:Python多任务处理模式
python·架构