面试-SPO

1. 解决的问题

SPO vs GRPO 核心问题解决

先明确基础前提:SPO 和 GRPO 都是大模型强化学习微调的策略优化方法,GRPO 是早期偏向 "奖励硬引导" 的策略优化方式,SPO 是针对 GRPO 的核心痛点做的针对性改进,核心改进围绕「策略稳定性、优势估计准确性、训练效率」三大维度展开:

1. 策略更新稳定性维度

SPO 用「动态 KL 散度约束(token 级 + 自适应 rho 调整)」解决了 GRPO「策略更新幅度过大、易偏离参考模型导致生成内容失控」的问题,因为 GRPO 存在「奖励信号直接主导策略更新,缺乏对策略偏离的有效约束,容易生成和预训练模型风格 / 逻辑完全脱节的内容」的问题,因为 GRPO 的「核心逻辑是 "奖励越高越好",仅依赖奖励引导策略更新,没有引入参考模型的 KL 约束机制,策略更新无边界」。

通俗解释:GRPO 像 "只看分数不看规矩" 的老师,只要 AI 答案得分高,不管和原本学的知识差多远都鼓励;SPO 则加了 "纪律委员(参考模型)",用 KL 散度限制 AI 不能偏离基础知识太远,还通过 rho 动态调整约束强度(就像代码里的compute_rho),既保证 AI 拿高分,又不跑偏。

2. 优势估计准确性维度

SPO 用「自适应价值基线(AutoAdaptiveValueTracker)」解决了 GRPO「优势估计偏差大、训练过程震荡(忽好忽坏)」的问题,因为 GRPO 存在「固定基线 / 简单滑动平均基线无法适配奖励分布变化,单次奖励波动会严重误导策略更新」的问题,因为 GRPO 的「基线是静态的(比如固定取历史平均奖励),无法根据 AI 生成风格(logprob)动态调整,优势计算受偶然高 / 低奖励影响大」。

通俗解释:GRPO 的 "平均分记录本" 是死的,比如一直用 80 分当平均分,哪怕 AI 最近答题风格变了(比如从写步骤变成只写答案),还是按 80 分判断好坏;SPO 的智能记录本会根据 AI 答题风格(KL)调整 "信任旧平均分的比例(rho)",让平均分跟着风格变,优势估计更准,训练不震荡(对应代码里的update方法)。

3. 损失计算精准度维度

SPO 用「token 级的损失计算 + 有效掩码(completion_mask)」解决了 GRPO「整体序列级损失粗糙、局部 token 生成质量差」的问题,因为 GRPO 存在「仅对整个生成序列计算损失,忽略无效 padding/token,导致局部错误 token(比如格式错误、错别字)无法被精准惩罚」的问题,因为 GRPO 的「损失计算是 "一刀切" 的,没有区分有效生成 token 和无效 padding token,局部错误被平均掩盖」。

通俗解释:GRPO 批改作业只看总分,哪怕 AI 答案里有一半是乱码(padding),也按整体算分;SPO 则会先划掉乱码(completion_mask),只针对有效答案部分算损失,哪个 token 错了就罚哪个(对应代码里的per_token_loss),纠错更精准。

4. 训练效率与显存适配维度

SPO 用「梯度累积 + 损失归一化(args.accumulation_steps)」解决了 GRPO「大批次训练显存溢出、小批次训练不稳定」的问题,因为 GRPO 存在「一次性计算全批次梯度,显存占用高,只能用小批次训练,导致梯度噪声大」的问题,因为 GRPO 的「损失计算和梯度更新是绑定的,没有拆分梯度累积步骤,无法适配大模型的显存限制」。

通俗解释:GRPO 像 "一次性搬完所有砖",显存不够就只能少搬(小批次),但少搬又容易出错;SPO 则是 "分几次搬,最后一起算"(梯度累积),既适配显存,又保证批次大小足够,训练更稳定(对应代码里loss = (policy_loss + aux_loss) / args.accumulation_steps和梯度累积更新逻辑)。

总结

SPO 针对 GRPO 的核心痛点做了四大改进,核心逻辑可总结为:

1.加「动态 KL 约束」解决 GRPO 策略无边界偏离的问题;

2.用「自适应基线」解决 GRPO 优势估计不准、训练震荡的问题;

3.做「token 级精准损失」解决 GRPO 整体损失粗糙的问题;

4.靠「梯度累积」解决 GRPO 显存适配差、训练效率低的问题。

本质上,SPO 是 GRPO 的 "精细化升级版"------ 既保留了 GRPO "奖励引导策略优化" 的核心,又补上了 GRPO 缺乏的「约束、适配、精准性」三大短板,让大模型强化学习微调更稳定、更精准、更适配工程落地。

咱们就按「智能记录本→批改老师→总教练」的顺序,一步拆到底,每个角色先讲「核心目标」→「具体怎么做(生活例子)」→「对应代码的关键逻辑」,全程不碰复杂公式,只讲人话。


第一部分:智能记录本(AutoAdaptiveValueTracker)

🔴 核心目标

记准AI的「平均答题水平」,用"平均分"判断AI单次答题是进步还是退步,避免单次运气分(比如偶尔蒙对)误导训练

🟡 生活例子(完全对应代码逻辑)

假设你是课代表,要记班里同学的数学平均分:

  1. 开学初,你先预估一个初始平均分(比如80分);
  2. 每次考完试,你不是直接替换平均分,而是"加权更新"------比如信任旧平均分90%,信任新分数10%,避免一次考砸就把平均分拉太低;
  3. 如果发现班里答题风格突然变了(比如都开始乱蒙),你就降低旧平均分的权重(比如只信70%),让新分数更快修正平均分;
  4. 最终用"当前平均分"判断:这次考90分=比平均好10分,考70分=比平均差10分。

🟢 代码里的具体操作(对应例子拆)

1. 初始化小本本(init

python 复制代码
def __init__(self, rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96):
    self.rho_mode = 'kl'  # 怎么调整"信任旧分/新分"的比例
    self.rho_const = 0.9  # 默认信任旧平均分90%
    self.clip_lower/upper = 0.5/0.96  # 信任比例不能太低(50%)或太高(96%)
    # 初始化平均分的"计算参数"(不用懂公式,就当是开学初预估的80分)
    N_init = 1.0 / (1.0 - self.clip_lower)
    self.alpha = 0.5 * N_init
    self.beta = 0.5 * N_init
    self.old_mean_logprob = None  # 记上一次的答题风格(后面讲)
  • 关键:rho = 信任旧平均分的比例(比如0.9=90%信任旧分,10%信任新分)。

2. 查当前平均分(get_baselines)

python 复制代码
def get_baselines(self, batch_size):
    baseline = self.alpha / (self.alpha + self.beta)  # 算当前平均分(0~1之间)
    return torch.full((batch_size,), baseline)  # 给全班(批次)每个人都用这个平均分
  • 比如alpha=9,beta=1 → 平均分=0.9(对应90分);alpha=5,beta=5 → 平均分=0.5(对应50分);
  • 为什么返回和批次大小一样的数?因为全班同学用同一个平均分判断进步/退步。

3. 调整信任比例(compute_rho)

python 复制代码
def compute_rho(self, cur_mean_logprob):
    if self.rho_mode == 'constant':
        return 0.9  # 固定信任旧分90%
    if self.old_mean_logprob is None:
        return 0.9  # 第一次考试,没历史数据,先用固定值
    # KL = 答题风格变化程度(比如上次都写步骤,这次都只写答案=KL大)
    kl = abs(self.old_mean_logprob - cur_mean_logprob)
    rho = 2 ** (-kl / self.D_half)  # 风格变化越大,rho越小(越不信任旧分)
    return max(min(rho, 0.96), 0.5)  # 限制rho在50%~96%之间
  • 人话翻译:
    • AI答题风格没变(KL小)→ rho≈0.9 → 继续信任旧平均分;
    • AI答题风格突变(KL大)→ rho≈0.5 → 多听新分数的,赶紧修正平均分。

4. 更新平均分(update)

python 复制代码
def update(self, rewards, cur_logprobs=None, response_masks=None):
    # 先算rho(信任比例)
    if cur_logprobs is not None:
        mean_logprob = 算当前答题风格
        rho = self.compute_rho(mean_logprob)
        self.old_mean_logprob = mean_logprob
    else:
        rho = 0.9  # 没风格数据,用固定值

    # 把分数归一化(比如-3~3分 → 0~1分,和平均分范围匹配)
    scale = 3.0
    normalized_rewards = (rewards + scale) / (2 * scale)  # 例子:3分→1,0分→0.5,-3分→0
    avg_normalized_reward = normalized_rewards.mean().item()  # 这次考试的平均分

    # 加权更新:旧分*rho + 新分*(1-rho)(代码写法不同,但逻辑一样)
    self.alpha = rho * self.alpha + avg_normalized_reward
    self.beta = rho * self.beta + (1 - avg_normalized_reward)
    return rho
  • 关键:不是直接替换平均分,而是"慢慢更" → 避免单次考砸/蒙对让平均分乱跳。

📌 智能记录本小结

只干3件事:

  1. 算当前平均分(基线);
  2. 根据答题风格调整"信任比例";
  3. 加权更新平均分,保持稳定。

第二部分:批改老师(calculate_rewards)

🔴 核心目标

给 AI 生成的每一个答案打「综合总分」------ 既查格式是否符合要求(推理题专属),又评内容质量(奖励模型打分),最终分数直接决定 AI 这个答案 "好不好",是后续训练的核心依据。

🟡 生活例子(完全对应代码逻辑)

假设你是数学老师,批改「应用题 + 推理题」的作业,评分规则如下:

  • 基础分:所有作业初始 0 分;

  • 格式分(仅推理题):

    规则 1:作业必须按 「思考过程 → 答案」 规范写,也就是:
    我是怎么想的...
    最终答案是...

    只要整体格式对得上,就给 +0.5 分

    规则 2:标签不能多也不能少:

    • 有且只有 1 个 `` → +0.25
    • 有且只有 1 个 `` → +0.25
    • 有且只有 1 个 `` → +0.25
    • 有且只有 1 个 `` → +0.25
      这一项最多 +1 分

    格式分加起来:最高 1.5 分

  • 内容分(所有题目都算):

    请一个"专业评分老师"(奖励模型),看完题目和答案,直接给出一个内容质量分。

    分数太高太低会被限制在 -3 ~ +3 之间,防止极端分数乱带节奏。

  • 特别加分(推理题):

    不只看整段话,只看 `` 里的答案本身 ,再让评分老师打一次分。

    最后按:
    整段得分 × 0.4 + 答案内容得分 × 0.6

    更看重"答案对不对"。

  • 最终总分:
    总分 = 格式分 + 内容质量分

    这个总分,就是给 AI 这次回答的最终评价。

🟢 对应代码逐段讲解 + 注释

python 复制代码
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
    """
    【批改老师总入口】
    输入:题目(prompts)、AI写的答案(responses)、评分老师(reward_model)、分词器(reward_tokenizer)
    输出:每个答案的综合分数(rewards)
    """

1. 推理题格式评分(内部小老师)

python 复制代码
    def reasoning_model_reward(rewards):
        """
        只给推理题打分:检查格式规不规范
        """
        # 两种合法格式:要么直接 think+answer,要么中间空一行
        pattern = r"^\n.*?\n\n\n.*?\n$"
        pattern2 = r"^\n.*?\n\n\n\n.*?\n$"

        # 批量检查每个回答是否符合整体格式
        matches_pattern = [re.match(pattern, response, re.S) for response in responses]
        matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]

        format_rewards = []
        for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
            if match_pattern or match_pattern2:
                format_rewards.append(0.5)  # 格式对 → +0.5
            else:
                format_rewards.append(0.0)  # 格式错 → 0 分
        # 把格式分加到总分里
        rewards += torch.tensor(format_rewards, device=args.device)

        # 再检查标签数量:4个关键标签必须各出现 1 次
        def mark_num(text):
            reward = 0
            if text.count("") == 1: reward += 0.25
            if text.count("") == 1: reward += 0.25
            if text.count("") == 1: reward += 0.25
            if text.count("") == 1: reward += 0.25
            return reward

        mark_rewards = [mark_num(response) for response in responses]
        rewards += torch.tensor(mark_rewards, device=args.device)
        return rewards

2. 开始打分

python 复制代码
    # 初始所有答案都是 0 分
    rewards = torch.zeros(len(responses), device=args.device)
    
    # 如果是推理题,先算格式分
    if args.reasoning == 1:
        rewards = reasoning_model_reward(rewards)

3. 奖励模型评内容质量(真正的"专业老师")

python 复制代码
    with torch.no_grad():  # 评分老师不参与训练,只打分
        reward_model_scores = []
        batch_size = len(prompts)
        scale = 3.0  # 分数限制在 -3 ~ +3

        # 一个题目、对应 N 个生成答案,逐个打分
        for i in range(batch_size):
            for j in range(args.num_generations):
                response_idx = i * args.num_generations + j
                response = responses[response_idx]
                prompt = prompts[i]

                # 把题目解析成对话格式:system / user / assistant
                pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
                matches = re.findall(pattern, prompt, re.DOTALL)
                messages = [{"role": role, "content": content.strip()} for role, content in matches]

                # 拼完整对话:题目 + 模型给出的回答
                tmp_chat = messages + [{"role": "assistant", "content": response}]
                # 奖励模型打分
                score = reward_model.get_score(reward_tokenizer, tmp_chat)
                score = max(min(score, scale), -scale)  # 限制范围

                # 如果是推理题:单独看  里的内容,再打一次分,加权融合
                if args.reasoning == 1:
                    answer_match = re.search(r'(.*?)', response, re.DOTALL)
                    if answer_match:
                        answer_content = answer_match.group(1).strip()
                        tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
                        answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
                        answer_score = max(min(answer_score, scale), -scale)
                        # 更看重答案本身:0.4 整段 + 0.6 答案内容
                        score = score * 0.4 + answer_score * 0.6

                reward_model_scores.append(score)

        # 把内容分也加进去
        reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
        rewards += reward_model_scores

    # 最终返回:格式分 + 内容分 = 综合奖励
    return rewards

📌 一句话总结这段代码在干嘛

calculate_rewards 就是一个严格的批改老师:

先看格式对不对(推理题专属),再看内容好不好(奖励模型),最后给一个能直接用来训练AI的总分

加粗样式

第三部分:GRPO 训练主循环(grpo_train_epoch)

🔴 核心目标

把"批改老师"打的分数转化为 AI 的"成长动力"------ 让 AI 记住"哪些答案得分高、该多写","哪些得分低、要改掉",同时控制 AI 的"变化幅度"(避免改得太离谱),最终让 AI 越练越会写高分答案。

🟡 生活例子(完全对应代码逻辑)

假设你是教练,训练一个"答题选手"(AI),训练流程如下:

  1. 出题:给选手一批数学题(prompt);
  2. 选手答题:每道题让选手写 N 个不同答案(num_generations);
  3. 批改打分:请之前的"批改老师"给每个答案打分(rewards);
  4. 分析优劣:对比同一道题的 N 个答案,算出"相对优势"(比如:这道题的 3 个答案里,这个答案比平均分高多少);
  5. 针对性纠错
    • 对得分高的答案:告诉选手"这种写法要多保留,下次还这么写";
    • 对得分低的答案:告诉选手"这种写法要改掉,别再这么写";
    • 控制纠错幅度:不能让选手"改太猛"(比如原本会写基础题,改完只会写难题),用"参考版本"(ref_model)约束,确保改完还是"会写基础题";
  6. 逐步优化:每练一批题就记录进步(日志),练到一定程度就保存"最佳状态"( checkpoint),最终选手越练越会写高分答案。

🟢 对应代码逐段讲解 + 注释

python 复制代码
def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
    """
    【GRPO 训练主循环】
    输入:
        epoch:当前训练轮次(比如第 3 轮)
        loader:题目数据集(一批批出题)
        iters:本轮要练多少道题(总步数)
        ref_model:参考模型(选手的"基础版本",用来约束改题幅度)
        reward_model/reward_tokenizer:批改老师
        start_step:起始步数(断点续训,比如上次练到 500 步,这次从 501 开始)
        wandb:日志工具(记录训练过程)
    输出:无(核心是更新 AI 模型参数)
    """
    # 1. 一批批出题,逐个训练
    for step, batch in enumerate(loader, start=start_step + 1):
        # ===== 第一步:出题 + 预处理 =====
        prompts = batch['prompt']  # 本次训练的题目列表 [B],B=批次大小(比如一次练 32 道题)
        # 把题目转成模型能看懂的"数字格式"(tokenize):左对齐填充、不额外加特殊符号
        prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
                                  padding_side="left", add_special_tokens=False).to(args.device)  # [B, P],P=题目长度
        # 限制题目长度(防止太长超出模型能力)
        if args.max_seq_len:
            prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
            prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]

        # ===== 第二步:让选手(AI)答题 =====
        with torch.no_grad():  # 答题阶段不更新模型,只生成答案
            # 分布式训练时,要通过 .module 访问模型的"答题功能"
            model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
            # 每道题生成 N 个答案(num_generations),用"随机采样"(temperature=0.8)避免答案千篇一律
            outputs = model_for_gen.generate(
                **prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
                num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id)  # [B*num_gen, P+R],R=答案长度
        # 提取答案部分(去掉题目,只留选手写的答案)
        completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):]  # [B*num_gen, R]

        # ===== 第三步:核心工具函数:计算"答题思路的概率" =====
        def get_per_token_logps(mdl, input_ids, n_keep):
            """
            计算选手写每个字(token)时的"思路概率":
            比如写"1+1=2",模型写"1"的概率、写"+"的概率...... 反映模型"为什么这么写"
            输入:
                mdl:模型(当前选手 / 参考版本)
                input_ids:完整的"题目+答案" [N, P+R],N=B*num_gen
                n_keep:只算答案部分的概率(R 个token)
            输出:
                per_token_logps:每个字的对数概率 [N, R]
            """
            # 推理时克隆数据,避免改乱原数据
            input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
            # 让模型"复盘"答题过程,输出每个字的概率(logits)
            logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]  # [N, R, V],V=所有可选字的数量
            per_token_logps = []
            # 逐个分析每个答案的每个字
            for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
                ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
                # 把概率转成"对数概率"(避免数值溢出),再提取每个字对应的概率
                token_logp = torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1)
                per_token_logps.append(token_logp)
            return torch.stack(per_token_logps)  # 整理成张量

        # ===== 第四步:计算当前选手的"答题思路概率" =====
        with autocast_ctx:  # 混合精度训练,提速+省内存
            # 算选手写答案时,每个字的思路概率
            per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1))
            # 如果是 MoE 模型,额外计算"负载均衡损失"(避免模型偏科)
            res = model(outputs) if lm_config.use_moe else None
            aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device)

        # ===== 第五步:计算参考版本的"答题思路概率"(固定不变,用来约束) =====
        with torch.no_grad():  # 参考版本不参与训练,只做对比
            ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1))  # [B*num_gen, R]

        # ===== 第六步:批改打分(调用之前的 calculate_rewards) =====
        completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)  # 把数字答案转回文字
        rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device)  # [B*num_gen]

        # ===== 第七步:分析"相对优势"(Advantage) =====
        # 按题目分组:把同一道题的 N 个答案放一起对比 [B, num_gen]
        grouped_rewards = rewards.view(-1, args.num_generations)
        # 算每道题 N 个答案的平均分、标准差
        mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations)  # [B*num_gen]
        std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations)  # [B*num_gen]
        # 计算"相对优势":当前答案比同题平均分高多少(标准化,避免极端值)
        advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)  # [B*num_gen]

        # ===== 第八步:处理"答案结束标记"(只算有效答案的损失) =====
        # 找到答案里"结束符(eos)"的位置,只计算结束符前的内容(忽略后面的无效字符)
        is_eos = completion_ids == tokenizer.eos_token_id  # [B*num_gen, R]
        eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
        eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
        completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int()  # [B*num_gen, R]

        # ===== 第九步:计算"纠错损失"(核心:让选手往高分方向改) =====
        # 1. 计算"思路偏差"(KL散度):当前选手 vs 参考版本,避免改太猛
        kl_div = ref_per_token_logps - per_token_logps
        per_token_kl = torch.exp(kl_div) - kl_div - 1  # [B*num_gen, R]
        # 2. 总损失 = 优势引导(高分多写) - KL约束(别改太猛) + 辅助损失
        per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl)  # [B*num_gen, R]
        # 3. 只算有效答案的损失,求平均
        policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
        loss = (policy_loss + aux_loss) / args.accumulation_steps  # 梯度累积,适配小显存

        # ===== 第十步:更新选手(AI)的参数(纠错) =====
        loss.backward()  # 计算"该怎么改"(梯度)
        # 累积一定步数后,统一更新参数(梯度累积)
        if (step + 1) % args.accumulation_steps == 0:
            if args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)  # 限制梯度,避免改太猛
            optimizer.step()  # 执行更新(纠错)
            scheduler.step()  # 调整学习率(越练越慢,精细打磨)
            optimizer.zero_grad()  # 清空梯度,准备下一轮

        # ===== 第十一步:记录训练过程(日志) =====
        if step % args.log_interval == 0 or step == iters:
            # 提取关键指标:损失、奖励、答案长度、学习率
            policy_loss_val = loss.item() * args.accumulation_steps
            current_aux_loss = aux_loss.item()
            avg_reward_val = rewards.mean().item()
            avg_len_val = completion_mask.sum(dim=1).float().mean().item()
            current_lr = optimizer.param_groups[0]['lr']
            # 打印日志
            Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), '
                   f'Actor Loss: {policy_loss_val:.4f}, Aux Loss: {current_aux_loss:.4f}, Reward: {avg_reward_val:.4f}, '
                   f'Avg Response Len: {avg_len_val:.2f}, Learning Rate: {current_lr:.8f}')
            # 可视化日志(wandb)
            if wandb and is_main_process():
                wandb.log({
                    "policy_loss": policy_loss_val,
                    "aux_loss": current_aux_loss,
                    "reward": avg_reward_val,
                    "avg_response_len": avg_len_val,
                    "advantages_mean": advantages.mean().item(),
                    "learning_rate": current_lr
                })

        # ===== 第十二步:保存最佳状态(checkpoint) =====
        if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
            model.eval()  # 切换到评估模式
            # 生成保存文件名(区分 MoE 模型)
            moe_suffix = '_moe' if lm_config.use_moe else ''
            ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
            # 分布式训练时,提取原始模型
            raw_model = model.module if isinstance(model, DistributedDataParallel) else model
            raw_model = getattr(raw_model, '_orig_mod', raw_model)
            # 保存模型参数(转半精度,省空间)
            state_dict = raw_model.state_dict()
            torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
            # 保存训练状态(断点续训用)
            lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, 
                         epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
            model.train()  # 切回训练模式
            del state_dict  # 释放内存

        # ===== 清理内存 =====
        del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
        del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask

📌 一句话总结这段代码在干嘛

grpo_train_epoch 就是 AI 的"训练教练":

让 AI 答题 → 批改打分 → 分析"哪些写法好" → 用损失函数引导 AI 往高分方向改,同时用参考模型控制改题幅度,练完还记录进步、保存最佳状态,最终让 AI 越练越会写高分答案。

核心要点回顾

  1. 奖励到损失的转化:通过"相对优势"把"批改分数"变成模型能理解的"纠错信号",高分答案对应低损失(鼓励保留),低分答案对应高损失(强制改掉);
  2. KL 散度约束:用参考模型限制模型更新幅度,避免"改得太猛"导致模型遗忘基础能力;
  3. 工程优化:通过梯度累积、混合精度、内存清理等手段,适配大模型训练的显存/性能需求。
相关推荐
十步杀一人_千里不留行2 小时前
当代码评审成为手工业遗风
人工智能
AI浩2 小时前
SSVP:用于工业零样本异常检测的协同语义-视觉提示
人工智能·机器学习
Maynor9962 小时前
王煜全前哨分析框架③:如何构建产业预测方法论?
人工智能
njsgcs2 小时前
memU怎么处理记忆的
人工智能
开开心心就好2 小时前
实用PDF批量加马赛克,抹除敏感信息绿色版
java·linux·开发语言·网络·人工智能·pdf·word2vec
沐曦股份MetaX2 小时前
【智算芯闻】具身智能的新范式:利用AI智能体加速机器人学习技能
人工智能·机器人
乾元2 小时前
模型提取:黑盒环境下如何窃取对手的 AI 模型参数
网络·人工智能·安全·web安全·机器学习·架构·系统架构
志栋智能2 小时前
智能巡检自动化解决方案:从“人海战术”到“AI智巡”的效能革命
大数据·运维·人工智能·网络安全·云原生·自动化
志栋智能2 小时前
AI驱动的带内自动化巡检:编织IT生态的“智慧神经网络”
大数据·运维·网络·人工智能·神经网络·自动化