GSPO(Group Sequence Policy Optimization)的核心思想是将 PPO 中 token 级别的重要性比率替换为序列级别的重要性比率,从而获得更稳定的策略优化信号。
设定具体输入参数
python
import torch
import torch.nn.functional as F
# ============================================================
# 场景设定:batch_size=2, response_length=4
# 第1条样本:4个token全部有效(mask全1)
# 第2条样本:4个token中只有3个有效(最后一个padding)
# ============================================================
torch.manual_seed(42)
# 旧策略下每个token的log概率
old_log_prob = torch.tensor([
[-0.5, -0.8, -0.6, -0.9], # 样本1
[-0.4, -0.7, -0.5, -0.3], # 样本2
])
# 当前策略下每个token的log概率(训练过程中更新后)
log_prob = torch.tensor([
[-0.4, -0.7, -0.5, -0.8], # 样本1:略有改善
[-0.3, -0.6, -0.8, -0.3], # 样本2:token3变差
])
# 优势函数(来自GRPO/GAE等)
advantages = torch.tensor([
[ 0.8, 0.8, 0.8, 0.8], # 样本1:正优势(好动作,鼓励)
[-0.5, -0.5, -0.5, -0.5], # 样本2:负优势(坏动作,抑制)
])
# 响应mask(1=有效token,0=padding)
response_mask = torch.tensor([
[1, 1, 1, 1], # 样本1:4个有效token
[1, 1, 1, 0], # 样本2:3个有效token,最后是padding
], dtype=torch.float)
# Config参数
clip_ratio_low = 0.2 # 下裁剪边界
clip_ratio_high = 0.2 # 上裁剪边界(symmetric)
Step1: 计算token级别的log比率
python
negative_approx_kl = log_prob - old_log_prob
python
old_log_prob = [[-0.5, -0.8, -0.6, -0.9],
[-0.4, -0.7, -0.5, -0.3]]
log_prob = [[-0.4, -0.7, -0.5, -0.8],
[-0.3, -0.6, -0.8, -0.3]]
negative_approx_kl = log_prob - old_log_prob
= [[-0.4-(-0.5), -0.7-(-0.8), -0.5-(-0.6), -0.8-(-0.9)],
[-0.3-(-0.4), -0.6-(-0.7), -0.8-(-0.5), -0.3-(-0.3)]]
= [[ 0.1, 0.1, 0.1, 0.1], ← 样本1每个token概率都略微上升
[ 0.1, 0.1, -0.3, 0.0]] ← 样本2 token3概率下降
Step2 : 计算序列级别的重要性比率(GSPO核心)
python
seq_lengths = torch.sum(response_mask, dim=-1).clamp(min=1)
python
response_mask.sum(dim=-1) = [4, 3]
clamp(min=1) → seq_lengths = [4.0, 3.0]
样本1 有效token数 = 4
样本2 有效token数 = 3(最后一个是padding,不算)
python
negative_approx_kl_seq = torch.sum(negative_approx_kl * response_mask, dim=-1) / seq_lengths
每条序列的平均对数概率比
python
step1: negative_approx_kl * response_mask
= [[ 0.1, 0.1, 0.1, 0.1], ← 样本1乘以[1,1,1,1]
[ 0.1, 0.1, -0.3, 0.0]] ← 样本2乘以[1,1,1,0]
step2: sum(dim=-1)
= [0.1+0.1+0.1+0.1, 0.1+0.1-0.3+0.0]
= [0.4, -0.1]
step3: /seq_lengths
= [0.4/4, -0.1/3]
= [0.1, -0.0333]
negative_approx_kl_seq = [0.1000, -0.0333]
↑ ↑
样本1序列 样本2序列
整体变好 整体略变差
Step3: 构造token级别的联合重要性比率
python
log_seq_importance_ratio = log_prob - log_prob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1)
python
部分A: log_prob - log_prob.detach()
梯度意义:值为0,但保留了 log_prob 的梯度!
这让 loss 对 θ 求导时,梯度从 log_prob 流回网络
部分B: negative_approx_kl_seq.detach().unsqueeze(-1)
.detach() → 这是一个常数,不参与梯度计算
.unsqueeze(-1) → shape [2] → [2,1],广播到每个token位置
组合效果:
∂/∂θ [log_seq_importance_ratio]_i,t = ∂log_prob_i,t/∂θ
数值上 = 0 + negative_approx_kl_seq_i = 序列平均对数比
python
negative_approx_kl_seq.detach() = [0.1000, -0.0333]
.unsqueeze(-1) = [[0.1000],
[-0.0333]] # shape [2,1]
广播后 = [[0.1000, 0.1000, 0.1000, 0.1000],
[-0.0333,-0.0333,-0.0333,-0.0333]]
log_prob - log_prob.detach() = 全零矩阵(数值为0,但有梯度)
log_seq_importance_ratio = 0 + 以上广播结果
= [[ 0.1000, 0.1000, 0.1000, 0.1000],
[-0.0333, -0.0333, -0.0333, -0.0333]]
💡 理解这个设计:
梯度路径 → 通过 log_prob(当前网络参数)
数值 → 使用序列平均比率(稳定)
停止梯度 → 避免序列比率自我循环影响梯度
python
log_seq_importance_ratio = torch.clamp(log_seq_importance_ratio, max=10.0)
python
当前最大值是 0.1,远小于 10.0,数值裁剪不生效
→ 结果不变(防止 exp 爆炸的安全措施)
Step4: 计算序列重要性比率
python
seq_importance_ratio = torch.exp(log_seq_importance_ratio)
python
exp([[ 0.1000, 0.1000, 0.1000, 0.1000],
[-0.0333, -0.0333, -0.0333, -0.0333]])
= [[e^0.1, e^0.1, e^0.1, e^0.1 ],
[e^-0.033,e^-0.033,e^-0.033,e^-0.033]]
= [[1.1052, 1.1052, 1.1052, 1.1052], ← 样本1: ratio>1,当前策略选更多
[0.9672, 0.9672, 0.9672, 0.9672]] ← 样本2: ratio<1,当前策略选更少
🔑 关键性质:同一序列内所有 token 共享同一个重要性比率!(与PPO不同,PPO每个token有独立的ratio)
Step 5:计算裁剪前后的策略损失
python
pg_losses1 = -advantages * seq_importance_ratio
pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
python
# clip_ratio_low = clip_ratio_high = 0.2
# 裁剪范围:[0.8, 1.2]
# 样本1: ratio=1.1052,在[0.8,1.2]内,clamp不生效
# 样本2: ratio=0.9672,在[0.8,1.2]内,clamp不生效
seq_importance_ratio_clamped = seq_importance_ratio(不变)
# advantages:
# 样本1: +0.8(正,好动作)
# 样本2: -0.5(负,坏动作)
pg_losses1 = -advantages * seq_importance_ratio
= -[[ 0.8, 0.8, 0.8, 0.8], * [[1.1052,1.1052,1.1052,1.1052],
[-0.5, -0.5, -0.5, -0.5]] [0.9672,0.9672,0.9672,0.9672]]
= [[-0.8842, -0.8842, -0.8842, -0.8842], ← 负值→对loss有贡献→鼓励好动作✓
[ 0.4836, 0.4836, 0.4836, 0.4836]] ← 正值→惩罚坏动作✓
pg_losses2 = -advantages * clamp(ratio, 0.8, 1.2)
# 本例clamp不生效,pg_losses2 = pg_losses1
验证裁剪生效的极端情况(假设 ratio=1.5 > 1.2):
python
# 好动作(adv=+0.8), ratio超出上界1.2
pg_losses1 = -0.8 * 1.5 = -1.20 (未裁剪,更大收益)
pg_losses2 = -0.8 * 1.2 = -0.96 (裁剪后,收益上限)
pg_loss = max(-1.20, -0.96) = -0.96 ← 取保守的较大值!
# 防止策略更新过激!
Step6 :取最大值(裁剪后保守估计)
python
pg_losses = torch.maximum(pg_losses1, pg_losses2)
python
本例 pg_losses1 == pg_losses2(ratio在合法范围内)
pg_losses = [[-0.8842, -0.8842, -0.8842, -0.8842],
[ 0.4836, 0.4836, 0.4836, 0.4836]]
Step7: 聚合损失
python
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask,
loss_agg_mode="seq-mean-token-mean")
"seq-mean-token-mean" 含义:先在每条序列内做 token 平均,再在 batch 维度做序列平均
python
step1: 每条序列内token平均(只算有效token)
样本1: (-0.8842*4)/4 = -0.8842
样本2: ( 0.4836*3)/3 = 0.4836 (第4个token被mask掉)
per_seq_loss = [-0.8842, 0.4836]
step2: batch维度取均值
pg_loss = (-0.8842 + 0.4836) / 2 = -0.2003
💡 为什么GSPO推荐 seq-mean-token-mean:与序列级别的重要性比率对齐,避免长序列主导梯度
Step 8:计算监控指标
python
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
python
# gt(pg_losses2, pg_losses1):哪些位置裁剪生效了?
# 本例 pg_losses2 == pg_losses1,所以全是 False
torch.gt(pg_losses2, pg_losses1) = [[False,False,False,False],
[False,False,False,False]]
masked_mean(...) = 0.0 ← 裁剪发生率 0%(本例ratio都在合法范围)
python
pg_clipfrac_lower = torch.tensor(0.0) # GSPO无下界裁剪统计
python
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
python
-negative_approx_kl = [[-0.1, -0.1, -0.1, -0.1],
[-0.1, -0.1, 0.3, 0.0]]
有效token处取均值(排除样本2的第4个token):
= (-0.1*4 + (-0.1+(-0.1)+0.3)) / (4+3)
= (-0.4 + 0.1) / 7
= -0.3/7 ≈ -0.0429
ppo_kl ≈ -0.0429 ← 负值说明当前策略整体概率比旧策略高(KL反向)
完整流程总结
python
┌─────────────────────────────────────────────────────────────────────┐
│ GSPO 计算流程图 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ old_log_prob, log_prob │
│ │ │
│ ▼ │
│ token级 log差值 (negative_approx_kl) │
│ [[ 0.1, 0.1, 0.1, 0.1], │
│ [ 0.1, 0.1, -0.3, 0.0]] │
│ │ │
│ ▼ (×mask, sum, ÷seq_len) │
│ 序列级平均 log比率 (negative_approx_kl_seq) │
│ [0.1000, -0.0333] ← 每条序列一个标量 │
│ │ │
│ ▼ (广播+detach技巧) │
│ token级联合 log比率 (log_seq_importance_ratio) │
│ [[ 0.1, 0.1, 0.1, 0.1], ← 同序列内所有token相同 │
│ [-0.033,-0.033,-0.033,-0.033]] │
│ │ │
│ ▼ exp() │
│ 序列重要性比率 (seq_importance_ratio) │
│ [[1.1052, 1.1052, 1.1052, 1.1052], │
│ [0.9672, 0.9672, 0.9672, 0.9672]] │
│ │ │
│ ▼ ×(-advantages), clamp[0.8,1.2], maximum │
│ pg_losses = max(未裁剪损失, 裁剪后损失) │
│ [[-0.8842,-0.8842,-0.8842,-0.8842], │
│ [ 0.4836, 0.4836, 0.4836, 0.4836]] │
│ │ │
│ ▼ seq-mean-token-mean │
│ pg_loss = -0.2003 ← 反向传播更新策略网络 │
│ │
└─────────────────────────────────────────────────────────────────────┘