面试-GRPO强化学习

1 GRPO

第一部分:奖励函数

代码分为两个部分,其一是 奖励函数 ,计算模型 response 的奖励,其中分为 形式上的奖励内容上的奖励

python 复制代码
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
    """整合所有奖励函数计算总奖励"""
    def reasoning_model_reward(rewards):
        pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
        pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
        matches_pattern = [re.match(pattern, response, re.S) for response in responses]
        matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]

        format_rewards = []
        for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
            if match_pattern or match_pattern2:
                format_rewards.append(0.5)
            else:
                format_rewards.append(0.0)
        rewards += torch.tensor(format_rewards, device=args.device)

        def mark_num(text):
            reward = 0
            if text.count("<think>") == 1: reward += 0.25
            if text.count("</think>") == 1: reward += 0.25
            if text.count("<answer>") == 1: reward += 0.25
            if text.count("</answer>") == 1: reward += 0.25
            return reward

        mark_rewards = [mark_num(response) for response in responses]
        rewards += torch.tensor(mark_rewards, device=args.device)
        return rewards

    rewards = torch.zeros(len(responses), device=args.device)
    if args.reasoning == 1:
        rewards = reasoning_model_reward(rewards)

    with torch.no_grad():
        reward_model_scores = []
        batch_size = len(prompts)
        scale = 3.0

        for i in range(batch_size):
            for j in range(args.num_generations):
                response_idx = i * args.num_generations + j
                response = responses[response_idx]
                prompt = prompts[i]

                pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
                matches = re.findall(pattern, prompt, re.DOTALL)
                messages = [{"role": role, "content": content.strip()} for role, content in matches]

                tmp_chat = messages + [{"role": "assistant", "content": response}]
                score = reward_model.get_score(reward_tokenizer, tmp_chat)
                score = max(min(score, scale), -scale)

                if args.reasoning == 1:
                    answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
                    if answer_match:
                        answer_content = answer_match.group(1).strip()
                        tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
                        answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
                        answer_score = max(min(answer_score, scale), -scale)
                        score = score * 0.4 + answer_score * 0.6

                reward_model_scores.append(score)

        reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
        rewards += reward_model_scores

    return rewards

第二部分:训练

其二,是关于 GRPO 训练的核心代码,分为四个阶段:

Group Generation :GRPO 的核心是 "组",对于同一个问题(Prompt),模型一口气生成 GGG 个不同的回答。所以 Responses 的 Shape 为 [batch * group, seqlen, dim]。

Raw Rewards :分为规则奖励和模型奖励两种。前者通过正则表达式的方式对 responses 进行硬匹配,判断格式是否符合要求。后者通过 Reward Model 对语义内容进行判断。一般后者的得分区间 > 前者的得分区间,然后总分相加就是 Raw Rewards 的奖励得分。

Relative Advantage :计算每组 GGG 中的 response 之间的相对优势, A=Total Reward−μσA = \frac{\text{Total Reward} - \mu}{\sigma}A=σTotal Reward−μ,其中 μ\muμ 为均值,σ\sigmaσ 为标准差(反映这组得分的波动程度)。

Policy Update :最后,模型根据优势值 AAA 去调整自己的参数。比如,对于 回答 A,模型会想:"这个回答比同组其他回答都好,我以后要多生成这种!",对于 回答 C,模型会想:"虽然我努力写了标签,但被同组 A 吊打了,我以后要少写这种错答案。"价值函数如下:
JGRPO(θ)=E[1G∑i=1G(πθ(oi∣q)πθold(oi∣q)A^i−βDKL(πθ∣∣πref))]J_{GRPO}(\theta) = \mathbb{E} \left[ \frac{1}{G} \sum_{i=1}^{G} \left( \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)} \hat{A}i - \beta \mathbb{D}{KL}(\pi_{\theta} || \pi_{ref}) \right) \right]JGRPO(θ)=E[G1i=1∑G(πθold(oi∣q)πθ(oi∣q)A^i−βDKL(πθ∣∣πref))]

损失函数如下:

L(θ)=−J(θ)L(\theta) = -J(\theta)L(θ)=−J(θ)

  • πθ(oi∣q)\pi_{\theta}(o_i|q)πθ(oi∣q):在给定查询 qqq 的前提下,策略 πθπ_θπθ 基于观测 oio_ioi 输出动作的概率。
  • πθold(oi∣q)\pi_{\theta_{old}}(o_i|q)πθold(oi∣q):在给定查询 qqq 的前提下,策略 πθπ_θπθ 基于观测 oio_ioi 输出动作的概率。
  • 策略比率 (Importance Sampling Ratio)πθ(oi∣q)πθold(oi∣q)\frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)}πθold(oi∣q)πθ(oi∣q):它告诉模型,如果这个动作(Token)的优势很大,就通过增加这个比率来提高该动作出现的概率。
  • A^i\hat{A}iA^i:这就是 GRPO 名字里 "Group Relative" 的由来。它不看你拿了多少分,看你比同组平均分高多少。如果 Advantage 是正的,Loss 就会引导模型增加产生这个序列的概率 πθ(oi∣q)πθold(oi∣q)\frac{\pi{\theta}(o_i|q)}{\pi_{\theta_{old}}(o_i|q)}πθold(oi∣q)πθ(oi∣q)。
  • KL 散度惩罚 (The Anchor) :这是计算当前模型和 Reference Model(基准模型) 之间的差异。目的是防止模型为了拿高分而"走捷径",变成一个只会吐正确答案但失去语言能力的"复读机"。它强制模型不要偏离原始模型太远。

代码:

python 复制代码
def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
    # 遍历数据加载器,step 是当前迭代次数,batch 包含原始 Prompt 文本
    for step, batch in enumerate(loader, start=start_step + 1):
        # 1. 获取原始提示词列表,B 是 Batch Size
        prompts = batch['prompt']  # list[str], 长度为 B
        
        # 2. 将提示词转为 Token ID。padding_side="left" 是生成任务的标准做法,确保 Prompt 在左,生成的空间在右
        prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
                                  padding_side="left", add_special_tokens=False).to(args.device)  
        
        # 3. 强制截断过长的 Prompt,防止显存爆炸
        if args.max_seq_len:
            prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
            prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]

        # --- 第一阶段:采样生成 ---
        with torch.no_grad():
            # 如果是分布式训练 (DDP),需要通过 .module 调用底层的 generate 方法
            model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
            # 核心:每个 Prompt 生成 num_generations (G) 个回答。输出形状为 [B*G, P+R]
            outputs = model_for_gen.generate(
                **prompt_inputs, 
                max_new_tokens=args.max_gen_len, 
                do_sample=True,          # 开启采样,保证生成的 G 个回答各不相同
                temperature=0.8,         # 控制随机性
                num_return_sequences=args.num_generations, 
                pad_token_id=tokenizer.pad_token_id
            )  

        # 4. 只截取模型生成的部分(Completion),去掉前面的 Prompt 部分
        completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):]  # [B*G, R]
        
        # --- 第二阶段:概率计算辅助函数 ---
        def get_per_token_logps(mdl, input_ids, n_keep):
            """计算序列中每个 token 的对数概率 log(π(a|s))"""
            # 克隆数据防止推理模式下的张量修改错误
            input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
            # 获取所有 token 的 logits。logits_to_keep 优化计算量,只保留生成部分的 logits
            logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]  # [B*G, R, Vocab]
            
            per_token_logps = []
            # 遍历这组回答,找到模型实际选中的那个 token 的概率
            for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
                ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
                # 使用 gather 从词表维度提取对应 token 的 log_softmax 概率值
                per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
            return torch.stack(per_token_logps)  # 返回 [B*G, R]
        
        # --- 第三阶段:计算新旧策略概率 ---
        with autocast_ctx:
            # 计算当前正在训练的模型(Actor)生成这些 token 的对数概率
            per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1))  
            # 如果是混合专家模型 (MoE),计算辅助损失以保持负载均衡
            res = model(outputs) if lm_config.use_moe else None
            aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device)
        
        with torch.no_grad():
            # 计算参考模型(Ref Model,通常是冻结的 SFT 模型)的对数概率,用于后续计算 KL 散度
            ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1))

        # --- 第四阶段:计算奖励与组内优势 ---
        # 5. 解码回答并计算奖励分数(包含正则规则评分和 Reward Model 评分)
        completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
        rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device)  # [B*G]

        # 6. GRPO 核心:组内归一化。将奖励按 Prompt 分组 [B, G]
        grouped_rewards = rewards.view(-1, args.num_generations)  
        # 计算每组(每个问题)的平均奖励
        mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations)  
        # 计算每组奖励的标准差
        std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations)  
        # 计算优势函数:(当前奖励 - 组内平均) / 标准差。限制范围在 [-10, 10] 防止梯度爆炸
        advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
        # 全局优势归一化,进一步稳定训练
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)  # [B*G]

        # --- 第五阶段:计算 Mask 与 KL 散度 ---
        # 7. 识别结束符 EOS,生成 Mask,忽略 EOS 之后的填充部分
        is_eos = completion_ids == tokenizer.eos_token_id  
        eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
        eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
        completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int()

        # 8. 计算 KL 散度惩罚,防止模型偏离原模型太远(KL 近似公式:exp(x)-x-1)
        kl_div = ref_per_token_logps - per_token_logps
        per_token_kl = torch.exp(kl_div) - kl_div - 1  # [B*G, R]

        # --- 第六阶段:计算最终 Loss 并反向传播 ---
        # 9. 结合策略梯度与 KL 惩罚。注意:这里使用了重要性采样的 Ratio (exp(新-旧))
        # 这里的旧概率就是采样时的概率,用 detach 断开梯度,只更新当前模型
        per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl)  
        
        # 10. 只对有效 token (mask=1) 求平均 Loss
        policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
        # 加上 MoE 的辅助 Loss,并按梯度累积步数缩放
        loss = (policy_loss + aux_loss) / args.accumulation_steps  
        loss.backward()

        # --- 第七阶段:优化器更新 ---
        if (step + 1) % args.accumulation_steps == 0:
            # 11. 梯度裁剪,防止梯度过大导致参数崩坏
            if args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            optimizer.step()      # 更新参数
            scheduler.step()      # 更新学习率
            optimizer.zero_grad() # 清空梯度

        # --- 第八阶段:日志监控 ---
        if step % args.log_interval == 0 or step == iters:
            policy_loss_val = loss.item() * args.accumulation_steps
            current_aux_loss = aux_loss.item()
            avg_reward_val = rewards.mean().item()
            avg_len_val = completion_mask.sum(dim=1).float().mean().item()
            current_lr = optimizer.param_groups[0]['lr']

            # 控制台输出
            Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), '
                   f'Actor Loss: {policy_loss_val:.4f}, Aux Loss: {current_aux_loss:.4f}, Reward: {avg_reward_val:.4f}, '
                   f'Avg Response Len: {avg_len_val:.2f}, Learning Rate: {current_lr:.8f}')

            # WandB 在线同步
            if wandb and is_main_process():
                wandb.log({
                    "policy_loss": policy_loss_val,
                    "aux_loss": current_aux_loss,
                    "reward": avg_reward_val,
                    "avg_response_len": avg_len_val,
                    "advantages_mean": advantages.mean().item(),
                    "learning_rate": current_lr
                })

        # --- 第九阶段:模型保存 ---
        if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
            model.eval()
            moe_suffix = '_moe' if lm_config.use_moe else ''
            ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
            # 提取原始模型状态(处理 DDP 或 Compile 包装)
            raw_model = model.module if isinstance(model, DistributedDataParallel) else model
            raw_model = getattr(raw_model, '_orig_mod', raw_model)
            state_dict = raw_model.state_dict()
            # 保存为半精度 CPU 张量,节省空间
            torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
            lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, 
                          epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
            model.train()
            del state_dict

        # 12. 显式删除大张量并清理缓存,防止显存随步数增加而缓慢增长
        del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
        del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
相关推荐
NAGNIP19 小时前
轻松搞懂全连接神经网络结构!
人工智能·算法·面试
moshuying20 小时前
别让AI焦虑,偷走你本该有的底气
前端·人工智能
董董灿是个攻城狮21 小时前
零基础带你用 AI 搞定命令行
人工智能
喝拿铁写前端1 天前
Dify 构建 FE 工作流:前端团队可复用 AI 工作流实战
前端·人工智能
阿里云大数据AI技术1 天前
阿里云 EMR Serverless Spark + DataWorks 技术实践:引领企业 Data+AI 一体化转型
人工智能
billhan20161 天前
MCP 深入理解:协议原理与自定义开发
人工智能
Jahzo1 天前
openclaw桌面端体验--ClawX
人工智能·github
billhan20161 天前
Agent 开发全流程:从概念到生产
人工智能
threerocks1 天前
过了个年,AI 圈变天了?但没人告诉你为什么
人工智能