GRPO 深度解析 (TRL 源码视角)

GRPO 深度解析 (TRL 源码视角)

参考:huggingface/trl

TRL 是一个面向后训练(post-training)的前沿库,支持 SFT、GRPO、DPO 等技术,构建于 🤗 Transformers 生态之上,可跨多种硬件规模扩展。


一、核心思想

GRPO(Group Relative Policy Optimization) 是一种专为大语言模型设计的强化学习算法。

它最大的创新在于:去掉了传统 PPO 中庞大的 Critic 模型

取而代之的方式是:

  1. 对同一个 Prompt,一次性生成 G 条回复
  2. 用 Reward 函数对这 G 条回复分别打分
  3. 以这组分数的组内均值作为基线,计算每条回复的相对优势(Advantage)
  4. 用 Advantage 指导参数更新:高于均值的 → 增大生成概率(鼓励),低于均值的 → 减小概率(惩罚)

对比 PPO:PPO 需要一个额外的 Critic 网络来估计当前状态的价值函数 V(s),而 GRPO 直接用组内统计量替代,大幅降低了显存和计算开销。


二、整体框架

核心骨架围绕三个方面展开:模型体系奖励机制(Reward)生成引擎(vLLM 等)

2.1 Reference Model(参考模型)

Reference Model 用于计算 KL 散度,防止策略模型训练偏离太远(由超参 beta 控制)。

python 复制代码
self.beta = args.beta
if self.beta == 0.0:
    # beta=0 时不需要 KL 约束,Reference Model 置空
    self.ref_model = None
elif is_peft_model(model):
    # 使用 PEFT(如 LoRA)时,关闭 adapter 即可还原为初始模型,无需单独维护
    self.ref_model = None
else:
    # 全量微调或 DeepSpeed/FSDP 场景:从头创建一份 Reference Model
    self.ref_model = create_model_from_path(...)
   

一种需要new model instance , 一种不需要

2.2 Reward Model(奖励模型)

Reward Model 有两种形式:

  • 预训练神经网络AutoModelForSequenceClassification
  • 基于规则的函数 (如 if-else 长度判断、正则匹配输出格式、调用外部代码编译器)
python 复制代码
# 统一包装成列表,支持多个 Reward 函数组合
if not isinstance(reward_funcs, list):
    reward_funcs = [reward_funcs]

for i, reward_func in enumerate(reward_funcs):
    if isinstance(reward_func, str):
        # 传入模型路径字符串 → 加载预训练的序列分类模型
        reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(...)

当传入多个 Reward 函数(如 [length_reward, format_reward, accuracy_reward])时,GRPOTrainer 会将多位"裁判"的分数合并。


三、生成与打分:_generate_and_score_completions

这是训练循环的核心函数,负责:

  1. 为当前 Batch 的 Prompts 生成多条补全(completions)
  2. 对补全进行 Padding 和张量化
  3. 计算每条补全的 per-token log 概率(用于后续 loss 计算)
  4. 调用 Reward 函数打分,计算 Advantage

3.1 生成与 Padding

python 复制代码
# 生成回答(支持 vLLM 和 HF generate 两种引擎)
(
    prompt_ids_list,
    completion_ids_list,
    ...
) = self._generate(prompts)

# 将变长序列 Padding 成等长张量
# Prompt 左 Padding,Completion 右 Padding
prompt_ids  = pad(prompt_ids,  padding_side="left",  ...).to(device)
completion_ids = pad(completion_ids, padding_side="right", ...).to(device)

3.2 Advantage 计算(多目标 Reward 合并)

当存在多个 Reward 函数时,TRL 提供两种合并策略:

策略一:sum_then_normalize(先加权求和,再组内归一化)

python 复制代码
# 1. 对各 reward 函数的输出加权求和
rewards = (rewards_per_func * self.reward_weights).nansum(dim=1)

# 2. 组内均值作为基线(GRPO 核心:对比 PPO 的优势函数)
mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1)
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(num_generations, dim=0)

# 3. 计算优势值,并用组内标准差归一化
advantages = rewards - mean_grouped_rewards
advantages = advantages / (std_rewards + 1e-4)

策略二:normalize_then_sum(先各自归一化,再加权求和)

python 复制代码
# 先对每个 reward 函数的输出单独做组内归一化
grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs))
reward_k = (grouped - mean_k) / (std_k + 1e-4)

# 再加权求和,得到最终 advantages
rewards = (reward_k * self.reward_weights).nansum(dim=1)
advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4)

四、损失函数计算:_compute_loss

4.1 概率比(重要性权重)

python 复制代码
# log_ratio = log(π_θ / π_old)
log_ratio = per_token_logps - old_per_token_logps

# coef_1 = π_θ / π_old(新旧策略的概率比)
coef_1 = torch.exp(log_ratio)

4.2 KL 散度惩罚项

python 复制代码
if self.beta != 0.0:
    # 用近似公式计算 token 级别的 KL(ref || model)
    per_token_kl = (
        torch.exp(ref_per_token_logps - per_token_logps)
        - (ref_per_token_logps - per_token_logps)
        - 1
    )

4.3 各 Loss 类型

TRL 支持多种 loss 变体,核心是 GRPO:

python 复制代码
if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]:
    # coef_2:对概率比进行 Clipping,防止单步更新过大(类比 PPO 的 clip trick)
    coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)

    per_token_loss1 = coef_1 * advantages
    per_token_loss2 = coef_2 * advantages
    # 取两者中较保守的一个,防止策略飞车
    per_token_loss = -torch.min(per_token_loss1, per_token_loss2)

Clipping 的作用 :把新旧策略的概率比强制截断在 [1-ε, 1+ε] 之间,防止一次梯度步走得太远,是 PPO-style 算法的核心稳定性机制。

4.4 最终 Loss 聚合

python 复制代码
# GRPO:对每条序列内取 token 平均,再对 batch 取均值
if self.loss_type == "grpo":
    loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean()

# 加上 KL 惩罚
if self.beta != 0.0:
    per_token_loss = per_token_loss + self.beta * per_token_kl

五、分组采样实现:_get_train_sampler

GRPO 的关键工程细节:同一个 Prompt 必须在同一个 GPU 上被复制 G 次,才能在组内计算 Advantage。

复制代码
                              |   GPU 0  |   GPU 1  |

          global_step   step    <──>  num_generations=2
                                <─────>  per_device_train_batch_size=3
grad_accum=2  ▲  ▲  0     0    0  0  1  1  2  2   ← 生成 prompts 0-11;存缓存;取第1切片算 loss
              ▼  |  0     1    3  3  4  4  5  5   ← 从缓存取第2切片
                 |
                 |  1     2    6  6  7  7  8  8   ← 从缓存取第3切片
steps_per_gen=4  ▼  1     3    9  9 10 10 11 11   ← 从缓存取第4切片

                    2     4   12 12 13 13 14 14   ← 再次生成(第二轮)
python 复制代码
return RepeatSampler(
    data_source=dataset,
    mini_repeat_count=self.num_generations,        # 每个 prompt 复制 G 次
    batch_size=self.args.generation_batch_size // self.num_generations,
    repeat_count=self.num_iterations * self.args.steps_per_generation,
    shuffle=self.shuffle_dataset,
    seed=self.args.seed,
)

六、双引擎架构:生成与训练解耦

6.1 为什么要解耦?

文本生成(自回归 Inference)和梯度更新(Forward/Backward)在计算特性上有本质差异:

阶段 瓶颈 优化手段
推理(生成) 显存带宽(KV Cache 频繁读写) vLLM PagedAttention、FlashAttention
训练(梯度) 显存容量(激活值 + 梯度 + 优化器状态) 混合精度、梯度累积、ZeRO

两种优化策略互不兼容,混在一起只会两头受损。RL 训练的标准做法:攒一个超级大 Batch 做生成 → 切片 → 多次梯度更新

代码逻辑:

python 复制代码
generate_every = self.args.steps_per_generation * self.num_iterations

if self._step % generate_every == 0 or self._buffered_inputs is None:
    # 1. 一次性生成大 Batch 并打分
    generation_batch = self._generate_and_score_completions(generation_batch)
    # 2. 切分成小份,存入缓存
    generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation)
    self._buffered_inputs = generation_batches

# 3. 每次训练只取缓存里的一小份算 loss 和梯度
inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]

6.2 权重同步问题

双引擎架构带来一个工程挑战:vLLM 的推理权重和 PyTorch 的训练权重需要保持同步

python 复制代码
# 每次生成前,检查 step 是否变化,若变化则同步权重
if self.state.global_step != self._last_loaded_step:
    self.vllm_generation.sync_weights()
    self._last_loaded_step = self.state.global_step

七、重要性采样修正(Importance Sampling)

7.1 问题来源

设想:vLLM 用"旧权重"生成了 1000 条回答,切成 10 份,PyTorch 做 10 次梯度更新。第 2~10 次更新时,训练数据仍是旧权重生成的,但模型已经被更新过了 → 数据分布偏移(off-policy)。

7.2 数学修正

E x ∼ π sample [ π θ ( x ) π sample ( x ) × Loss ( x ) ] \mathbb{E}{x \sim \pi{\text{sample}}} \left[ \frac{\pi_{\theta}(x)}{\pi_{\text{sample}}(x)} \times \text{Loss}(x) \right] Ex∼πsample[πsample(x)πθ(x)×Loss(x)]

在对数空间下,概率除法变减法,避免浮点下溢:

python 复制代码
# 计算重要性采样比值
per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask
vllm_importance_sampling_ratio = torch.exp(per_token_logps_diff)

# 用 clamp 截断过大的比值,防止数值爆炸
vllm_importance_sampling_ratio = torch.clamp(
    vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap
)

# 作用到最终 loss 上
per_token_loss = per_token_loss * inputs["importance_sampling_ratio"]

直觉理解:这是一种坐标系修正。就像在运动学中,从地面参考系切换到运动参考系时需要做速度变换一样,这里是把"从旧策略角度看到的数据"修正为"从当前策略角度看到的期望"。


八、从零手写极简训练循环(伪代码)

思考题:脱离 TRL 的工程封装,从 promptsloss.backward() + optimizer.step(),完整流水线是什么?

python 复制代码
def train_step(prompts, model, ref_model, reward_fn, optimizer):
    G = num_generations  # 每个 prompt 生成的回复数

    # ── Step 1: 生成 ──────────────────────────────────
    # 对每个 prompt,采样 G 条回复
    completions = model.generate(prompts, num_return_sequences=G)
    # completions.shape: (B*G, seq_len)

    # ── Step 2: 打分 ──────────────────────────────────
    rewards = reward_fn(prompts, completions)  # shape: (B*G,)

    # ── Step 3: 计算 Advantage(GRPO 核心)────────────
    rewards_grouped = rewards.view(B, G)
    mean_rewards = rewards_grouped.mean(dim=1, keepdim=True)   # 组内均值作为基线
    std_rewards  = rewards_grouped.std(dim=1, keepdim=True)
    advantages = ((rewards_grouped - mean_rewards) / (std_rewards + 1e-4)).view(B*G)

    # ── Step 4: 计算 per-token log prob ───────────────
    per_token_logps     = model.get_logps(prompts, completions)
    ref_per_token_logps = ref_model.get_logps(prompts, completions)  # 如果 beta > 0

    # ── Step 5: 计算 Loss ─────────────────────────────
    # Clipped surrogate loss(类 PPO)
    old_logps = per_token_logps.detach()
    ratio = torch.exp(per_token_logps - old_logps)          # coef_1
    ratio_clipped = torch.clamp(ratio, 1 - eps, 1 + eps)    # coef_2
    policy_loss = -torch.min(ratio * advantages, ratio_clipped * advantages)

    # KL 惩罚(防止偏离参考模型太远)
    kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
    per_token_loss = policy_loss + beta * kl

    # 对 token 和 batch 维度取均值
    loss = (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(1)
    loss = loss.mean()

    # ── Step 6: 反向传播 ──────────────────────────────
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return loss.item()
相关推荐
guslegend2 小时前
3月29日(openclaw安全保障)
人工智能·安全·大模型
最初的↘那颗心2 小时前
企业级 AI Agent 工程方法论:从原型到生产的完整指南(上)
大模型·rag·ai agent·mcp·企业级ai
简简单单做算法3 小时前
基于Q-Learning强化学习的小车倒立摆平衡控制系统matlab性能仿真
算法·matlab·强化学习·qlearning·小车倒立摆平衡控制
加斯顿工程师3 小时前
国产开源大模型发布时间线
大模型
光仔December3 小时前
【从0学习Spring AI Alibaba】2、Spring AI Alibaba版本选型及环境搭建
人工智能·大模型·saa·spring ai·ai alibaba
张彦峰ZYF4 小时前
大模型LLM ACA - ACP认证考试模拟试卷六
大模型·llm·aca - acp
guslegend15 小时前
大模型RAG进阶多格式文档解析
人工智能·大模型
靴子学长16 小时前
Decoder only 架构下 - KV cache 的理解
pytorch·深度学习·算法·大模型·kv