1 GRPO
第一部分:奖励函数
代码分为两个部分,其一是 奖励函数 ,计算模型 response 的奖励,其中分为 形式上的奖励 和 内容上的奖励 :
python
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
def reasoning_model_reward(rewards):
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
format_rewards = []
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
if match_pattern or match_pattern2:
format_rewards.append(0.5)
else:
format_rewards.append(0.0)
rewards += torch.tensor(format_rewards, device=args.device)
def mark_num(text):
reward = 0
if text.count("<think>") == 1: reward += 0.25
if text.count("</think>") == 1: reward += 0.25
if text.count("<answer>") == 1: reward += 0.25
if text.count("</answer>") == 1: reward += 0.25
return reward
mark_rewards = [mark_num(response) for response in responses]
rewards += torch.tensor(mark_rewards, device=args.device)
return rewards
rewards = torch.zeros(len(responses), device=args.device)
if args.reasoning == 1:
rewards = reasoning_model_reward(rewards)
with torch.no_grad():
reward_model_scores = []
batch_size = len(prompts)
scale = 3.0
for i in range(batch_size):
for j in range(args.num_generations):
response_idx = i * args.num_generations + j
response = responses[response_idx]
prompt = prompts[i]
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
matches = re.findall(pattern, prompt, re.DOTALL)
messages = [{"role": role, "content": content.strip()} for role, content in matches]
tmp_chat = messages + [{"role": "assistant", "content": response}]
score = reward_model.get_score(reward_tokenizer, tmp_chat)
score = max(min(score, scale), -scale)
if args.reasoning == 1:
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
if answer_match:
answer_content = answer_match.group(1).strip()
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
answer_score = max(min(answer_score, scale), -scale)
score = score * 0.4 + answer_score * 0.6
reward_model_scores.append(score)
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
rewards += reward_model_scores
return rewards
第二部分:训练
其二,是关于 GRPO 训练的核心代码,分为四个阶段:
Group Generation :GRPO 的核心是 "组",对于同一个问题(Prompt),模型一口气生成 GGG 个不同的回答。所以 Responses 的 Shape 为 [batch * group, seqlen, dim]。
Raw Rewards :分为规则奖励和模型奖励两种。前者通过正则表达式的方式对 responses 进行硬匹配,判断格式是否符合要求。后者通过 Reward Model 对语义内容进行判断。一般后者的得分区间 > 前者的得分区间,然后总分相加就是 Raw Rewards 的奖励得分。
Relative Advantage :计算每组 GGG 中的 response 之间的相对优势, A=Total Reward−μσA = \frac{\text{Total Reward} - \mu}{\sigma}A=σTotal Reward−μ,其中 μ\muμ 为均值,σ\sigmaσ 为标准差(反映这组得分的波动程度)。
Policy Update :最后,模型根据优势值 AAA 去调整自己的参数。比如,对于 回答 A,模型会想:"这个回答比同组其他回答都好,我以后要多生成这种!",对于 回答 C,模型会想:"虽然我努力写了标签,但被同组 A 吊打了,我以后要少写这种错答案。"价值函数如下:
JGRPO(θ)=E[1G∑i=1G(πθ(oi∣q)πθold(oi∣q)A^i−βDKL(πθ∣∣πref))]J_{GRPO}(\theta) = \mathbb{E} \left[ \frac{1}{G} \sum_{i=1}^{G} \left( \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)} \hat{A}i - \beta \mathbb{D}{KL}(\pi_{\theta} || \pi_{ref}) \right) \right]JGRPO(θ)=E[G1i=1∑G(πθold(oi∣q)πθ(oi∣q)A^i−βDKL(πθ∣∣πref))]
损失函数如下:
L(θ)=−J(θ)L(\theta) = -J(\theta)L(θ)=−J(θ)
- πθ(oi∣q)\pi_{\theta}(o_i|q)πθ(oi∣q):在给定查询 qqq 的前提下,策略 πθπ_θπθ 基于观测 oio_ioi 输出动作的概率。
- πθold(oi∣q)\pi_{\theta_{old}}(o_i|q)πθold(oi∣q):在给定查询 qqq 的前提下,策略 πθπ_θπθ 基于观测 oio_ioi 输出动作的概率。
- 策略比率 (Importance Sampling Ratio)πθ(oi∣q)πθold(oi∣q)\frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)}πθold(oi∣q)πθ(oi∣q):它告诉模型,如果这个动作(Token)的优势很大,就通过增加这个比率来提高该动作出现的概率。
- A^i\hat{A}iA^i:这就是 GRPO 名字里 "Group Relative" 的由来。它不看你拿了多少分,看你比同组平均分高多少。如果 Advantage 是正的,Loss 就会引导模型增加产生这个序列的概率 πθ(oi∣q)πθold(oi∣q)\frac{\pi{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)}πθold(oi∣q)πθ(oi∣q)。
- KL 散度惩罚 (The Anchor) :这是计算当前模型和 Reference Model(基准模型) 之间的差异。目的是防止模型为了拿高分而"走捷径",变成一个只会吐正确答案但失去语言能力的"复读机"。它强制模型不要偏离原始模型太远。
代码:
python
def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
# 遍历数据加载器,step 是当前迭代次数,batch 包含原始 Prompt 文本
for step, batch in enumerate(loader, start=start_step + 1):
# 1. 获取原始提示词列表,B 是 Batch Size
prompts = batch['prompt'] # list[str], 长度为 B
# 2. 将提示词转为 Token ID。padding_side="left" 是生成任务的标准做法,确保 Prompt 在左,生成的空间在右
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device)
# 3. 强制截断过长的 Prompt,防止显存爆炸
if args.max_seq_len:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
# --- 第一阶段:采样生成 ---
with torch.no_grad():
# 如果是分布式训练 (DDP),需要通过 .module 调用底层的 generate 方法
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
# 核心:每个 Prompt 生成 num_generations (G) 个回答。输出形状为 [B*G, P+R]
outputs = model_for_gen.generate(
**prompt_inputs,
max_new_tokens=args.max_gen_len,
do_sample=True, # 开启采样,保证生成的 G 个回答各不相同
temperature=0.8, # 控制随机性
num_return_sequences=args.num_generations,
pad_token_id=tokenizer.pad_token_id
)
# 4. 只截取模型生成的部分(Completion),去掉前面的 Prompt 部分
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B*G, R]
# --- 第二阶段:概率计算辅助函数 ---
def get_per_token_logps(mdl, input_ids, n_keep):
"""计算序列中每个 token 的对数概率 log(π(a|s))"""
# 克隆数据防止推理模式下的张量修改错误
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
# 获取所有 token 的 logits。logits_to_keep 优化计算量,只保留生成部分的 logits
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :] # [B*G, R, Vocab]
per_token_logps = []
# 遍历这组回答,找到模型实际选中的那个 token 的概率
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
# 使用 gather 从词表维度提取对应 token 的 log_softmax 概率值
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
return torch.stack(per_token_logps) # 返回 [B*G, R]
# --- 第三阶段:计算新旧策略概率 ---
with autocast_ctx:
# 计算当前正在训练的模型(Actor)生成这些 token 的对数概率
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1))
# 如果是混合专家模型 (MoE),计算辅助损失以保持负载均衡
res = model(outputs) if lm_config.use_moe else None
aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device)
with torch.no_grad():
# 计算参考模型(Ref Model,通常是冻结的 SFT 模型)的对数概率,用于后续计算 KL 散度
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1))
# --- 第四阶段:计算奖励与组内优势 ---
# 5. 解码回答并计算奖励分数(包含正则规则评分和 Reward Model 评分)
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B*G]
# 6. GRPO 核心:组内归一化。将奖励按 Prompt 分组 [B, G]
grouped_rewards = rewards.view(-1, args.num_generations)
# 计算每组(每个问题)的平均奖励
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations)
# 计算每组奖励的标准差
std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations)
# 计算优势函数:(当前奖励 - 组内平均) / 标准差。限制范围在 [-10, 10] 防止梯度爆炸
advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
# 全局优势归一化,进一步稳定训练
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # [B*G]
# --- 第五阶段:计算 Mask 与 KL 散度 ---
# 7. 识别结束符 EOS,生成 Mask,忽略 EOS 之后的填充部分
is_eos = completion_ids == tokenizer.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int()
# 8. 计算 KL 散度惩罚,防止模型偏离原模型太远(KL 近似公式:exp(x)-x-1)
kl_div = ref_per_token_logps - per_token_logps
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*G, R]
# --- 第六阶段:计算最终 Loss 并反向传播 ---
# 9. 结合策略梯度与 KL 惩罚。注意:这里使用了重要性采样的 Ratio (exp(新-旧))
# 这里的旧概率就是采样时的概率,用 detach 断开梯度,只更新当前模型
per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl)
# 10. 只对有效 token (mask=1) 求平均 Loss
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# 加上 MoE 的辅助 Loss,并按梯度累积步数缩放
loss = (policy_loss + aux_loss) / args.accumulation_steps
loss.backward()
# --- 第七阶段:优化器更新 ---
if (step + 1) % args.accumulation_steps == 0:
# 11. 梯度裁剪,防止梯度过大导致参数崩坏
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step() # 更新参数
scheduler.step() # 更新学习率
optimizer.zero_grad() # 清空梯度
# --- 第八阶段:日志监控 ---
if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item() * args.accumulation_steps
current_aux_loss = aux_loss.item()
avg_reward_val = rewards.mean().item()
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
current_lr = optimizer.param_groups[0]['lr']
# 控制台输出
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), '
f'Actor Loss: {policy_loss_val:.4f}, Aux Loss: {current_aux_loss:.4f}, Reward: {avg_reward_val:.4f}, '
f'Avg Response Len: {avg_len_val:.2f}, Learning Rate: {current_lr:.8f}')
# WandB 在线同步
if wandb and is_main_process():
wandb.log({
"policy_loss": policy_loss_val,
"aux_loss": current_aux_loss,
"reward": avg_reward_val,
"avg_response_len": avg_len_val,
"advantages_mean": advantages.mean().item(),
"learning_rate": current_lr
})
# --- 第九阶段:模型保存 ---
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
# 提取原始模型状态(处理 DDP 或 Compile 包装)
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
# 保存为半精度 CPU 张量,节省空间
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del state_dict
# 12. 显式删除大张量并清理缓存,防止显存随步数增加而缓慢增长
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask