1. 解决的问题
SPO vs GRPO 核心问题解决
先明确基础前提:SPO 和 GRPO 都是大模型强化学习微调的策略优化方法,GRPO 是早期偏向 "奖励硬引导" 的策略优化方式,SPO 是针对 GRPO 的核心痛点做的针对性改进,核心改进围绕「策略稳定性、优势估计准确性、训练效率」三大维度展开:
1. 策略更新稳定性维度
SPO 用「动态 KL 散度约束(token 级 + 自适应 rho 调整)」解决了 GRPO「策略更新幅度过大、易偏离参考模型导致生成内容失控」的问题,因为 GRPO 存在「奖励信号直接主导策略更新,缺乏对策略偏离的有效约束,容易生成和预训练模型风格 / 逻辑完全脱节的内容」的问题,因为 GRPO 的「核心逻辑是 "奖励越高越好",仅依赖奖励引导策略更新,没有引入参考模型的 KL 约束机制,策略更新无边界」。
通俗解释:GRPO 像 "只看分数不看规矩" 的老师,只要 AI 答案得分高,不管和原本学的知识差多远都鼓励;SPO 则加了 "纪律委员(参考模型)",用 KL 散度限制 AI 不能偏离基础知识太远,还通过 rho 动态调整约束强度(就像代码里的compute_rho),既保证 AI 拿高分,又不跑偏。
2. 优势估计准确性维度
SPO 用「自适应价值基线(AutoAdaptiveValueTracker)」解决了 GRPO「优势估计偏差大、训练过程震荡(忽好忽坏)」的问题,因为 GRPO 存在「固定基线 / 简单滑动平均基线无法适配奖励分布变化,单次奖励波动会严重误导策略更新」的问题,因为 GRPO 的「基线是静态的(比如固定取历史平均奖励),无法根据 AI 生成风格(logprob)动态调整,优势计算受偶然高 / 低奖励影响大」。
通俗解释:GRPO 的 "平均分记录本" 是死的,比如一直用 80 分当平均分,哪怕 AI 最近答题风格变了(比如从写步骤变成只写答案),还是按 80 分判断好坏;SPO 的智能记录本会根据 AI 答题风格(KL)调整 "信任旧平均分的比例(rho)",让平均分跟着风格变,优势估计更准,训练不震荡(对应代码里的update方法)。
3. 损失计算精准度维度
SPO 用「token 级的损失计算 + 有效掩码(completion_mask)」解决了 GRPO「整体序列级损失粗糙、局部 token 生成质量差」的问题,因为 GRPO 存在「仅对整个生成序列计算损失,忽略无效 padding/token,导致局部错误 token(比如格式错误、错别字)无法被精准惩罚」的问题,因为 GRPO 的「损失计算是 "一刀切" 的,没有区分有效生成 token 和无效 padding token,局部错误被平均掩盖」。
通俗解释:GRPO 批改作业只看总分,哪怕 AI 答案里有一半是乱码(padding),也按整体算分;SPO 则会先划掉乱码(completion_mask),只针对有效答案部分算损失,哪个 token 错了就罚哪个(对应代码里的per_token_loss),纠错更精准。
4. 训练效率与显存适配维度
SPO 用「梯度累积 + 损失归一化(args.accumulation_steps)」解决了 GRPO「大批次训练显存溢出、小批次训练不稳定」的问题,因为 GRPO 存在「一次性计算全批次梯度,显存占用高,只能用小批次训练,导致梯度噪声大」的问题,因为 GRPO 的「损失计算和梯度更新是绑定的,没有拆分梯度累积步骤,无法适配大模型的显存限制」。
通俗解释:GRPO 像 "一次性搬完所有砖",显存不够就只能少搬(小批次),但少搬又容易出错;SPO 则是 "分几次搬,最后一起算"(梯度累积),既适配显存,又保证批次大小足够,训练更稳定(对应代码里loss = (policy_loss + aux_loss) / args.accumulation_steps和梯度累积更新逻辑)。
总结
SPO 针对 GRPO 的核心痛点做了四大改进,核心逻辑可总结为:
1.加「动态 KL 约束」解决 GRPO 策略无边界偏离的问题;
2.用「自适应基线」解决 GRPO 优势估计不准、训练震荡的问题;
3.做「token 级精准损失」解决 GRPO 整体损失粗糙的问题;
4.靠「梯度累积」解决 GRPO 显存适配差、训练效率低的问题。
本质上,SPO 是 GRPO 的 "精细化升级版"------ 既保留了 GRPO "奖励引导策略优化" 的核心,又补上了 GRPO 缺乏的「约束、适配、精准性」三大短板,让大模型强化学习微调更稳定、更精准、更适配工程落地。
咱们就按「智能记录本→批改老师→总教练」的顺序,一步拆到底,每个角色先讲「核心目标」→「具体怎么做(生活例子)」→「对应代码的关键逻辑」,全程不碰复杂公式,只讲人话。
第一部分:智能记录本(AutoAdaptiveValueTracker)
🔴 核心目标
记准AI的「平均答题水平」,用"平均分"判断AI单次答题是进步还是退步,避免单次运气分(比如偶尔蒙对)误导训练。
🟡 生活例子(完全对应代码逻辑)
假设你是课代表,要记班里同学的数学平均分:
- 开学初,你先预估一个初始平均分(比如80分);
- 每次考完试,你不是直接替换平均分,而是"加权更新"------比如信任旧平均分90%,信任新分数10%,避免一次考砸就把平均分拉太低;
- 如果发现班里答题风格突然变了(比如都开始乱蒙),你就降低旧平均分的权重(比如只信70%),让新分数更快修正平均分;
- 最终用"当前平均分"判断:这次考90分=比平均好10分,考70分=比平均差10分。
🟢 代码里的具体操作(对应例子拆)
1. 初始化小本本(init)
python
def __init__(self, rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96):
self.rho_mode = 'kl' # 怎么调整"信任旧分/新分"的比例
self.rho_const = 0.9 # 默认信任旧平均分90%
self.clip_lower/upper = 0.5/0.96 # 信任比例不能太低(50%)或太高(96%)
# 初始化平均分的"计算参数"(不用懂公式,就当是开学初预估的80分)
N_init = 1.0 / (1.0 - self.clip_lower)
self.alpha = 0.5 * N_init
self.beta = 0.5 * N_init
self.old_mean_logprob = None # 记上一次的答题风格(后面讲)
- 关键:
rho= 信任旧平均分的比例(比如0.9=90%信任旧分,10%信任新分)。
2. 查当前平均分(get_baselines)
python
def get_baselines(self, batch_size):
baseline = self.alpha / (self.alpha + self.beta) # 算当前平均分(0~1之间)
return torch.full((batch_size,), baseline) # 给全班(批次)每个人都用这个平均分
- 比如
alpha=9,beta=1→ 平均分=0.9(对应90分);alpha=5,beta=5→ 平均分=0.5(对应50分); - 为什么返回和批次大小一样的数?因为全班同学用同一个平均分判断进步/退步。
3. 调整信任比例(compute_rho)
python
def compute_rho(self, cur_mean_logprob):
if self.rho_mode == 'constant':
return 0.9 # 固定信任旧分90%
if self.old_mean_logprob is None:
return 0.9 # 第一次考试,没历史数据,先用固定值
# KL = 答题风格变化程度(比如上次都写步骤,这次都只写答案=KL大)
kl = abs(self.old_mean_logprob - cur_mean_logprob)
rho = 2 ** (-kl / self.D_half) # 风格变化越大,rho越小(越不信任旧分)
return max(min(rho, 0.96), 0.5) # 限制rho在50%~96%之间
- 人话翻译:
- AI答题风格没变(KL小)→ rho≈0.9 → 继续信任旧平均分;
- AI答题风格突变(KL大)→ rho≈0.5 → 多听新分数的,赶紧修正平均分。
4. 更新平均分(update)
python
def update(self, rewards, cur_logprobs=None, response_masks=None):
# 先算rho(信任比例)
if cur_logprobs is not None:
mean_logprob = 算当前答题风格
rho = self.compute_rho(mean_logprob)
self.old_mean_logprob = mean_logprob
else:
rho = 0.9 # 没风格数据,用固定值
# 把分数归一化(比如-3~3分 → 0~1分,和平均分范围匹配)
scale = 3.0
normalized_rewards = (rewards + scale) / (2 * scale) # 例子:3分→1,0分→0.5,-3分→0
avg_normalized_reward = normalized_rewards.mean().item() # 这次考试的平均分
# 加权更新:旧分*rho + 新分*(1-rho)(代码写法不同,但逻辑一样)
self.alpha = rho * self.alpha + avg_normalized_reward
self.beta = rho * self.beta + (1 - avg_normalized_reward)
return rho
- 关键:不是直接替换平均分,而是"慢慢更" → 避免单次考砸/蒙对让平均分乱跳。
📌 智能记录本小结
只干3件事:
- 算当前平均分(基线);
- 根据答题风格调整"信任比例";
- 加权更新平均分,保持稳定。
第二部分:批改老师(calculate_rewards)
🔴 核心目标
给 AI 生成的每一个答案打「综合总分」------ 既查格式是否符合要求(推理题专属),又评内容质量(奖励模型打分),最终分数直接决定 AI 这个答案 "好不好",是后续训练的核心依据。
🟡 生活例子(完全对应代码逻辑)
假设你是数学老师,批改「应用题 + 推理题」的作业,评分规则如下:
-
基础分:所有作业初始 0 分;
-
格式分(仅推理题):
规则 1:作业必须按 「思考过程 → 答案」 规范写,也就是:
我是怎么想的...
最终答案是...只要整体格式对得上,就给 +0.5 分。
规则 2:标签不能多也不能少:
- 有且只有 1 个 `` → +0.25
- 有且只有 1 个 `` → +0.25
- 有且只有 1 个 `` → +0.25
- 有且只有 1 个 `` → +0.25
这一项最多 +1 分。
格式分加起来:最高 1.5 分。
-
内容分(所有题目都算):
请一个"专业评分老师"(奖励模型),看完题目和答案,直接给出一个内容质量分。
分数太高太低会被限制在 -3 ~ +3 之间,防止极端分数乱带节奏。
-
特别加分(推理题):
不只看整段话,只看 `` 里的答案本身 ,再让评分老师打一次分。
最后按:
整段得分 × 0.4 + 答案内容得分 × 0.6更看重"答案对不对"。
-
最终总分:
总分 = 格式分 + 内容质量分这个总分,就是给 AI 这次回答的最终评价。
🟢 对应代码逐段讲解 + 注释
python
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""
【批改老师总入口】
输入:题目(prompts)、AI写的答案(responses)、评分老师(reward_model)、分词器(reward_tokenizer)
输出:每个答案的综合分数(rewards)
"""
1. 推理题格式评分(内部小老师)
python
def reasoning_model_reward(rewards):
"""
只给推理题打分:检查格式规不规范
"""
# 两种合法格式:要么直接 think+answer,要么中间空一行
pattern = r"^\n.*?\n\n\n.*?\n$"
pattern2 = r"^\n.*?\n\n\n\n.*?\n$"
# 批量检查每个回答是否符合整体格式
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
format_rewards = []
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
if match_pattern or match_pattern2:
format_rewards.append(0.5) # 格式对 → +0.5
else:
format_rewards.append(0.0) # 格式错 → 0 分
# 把格式分加到总分里
rewards += torch.tensor(format_rewards, device=args.device)
# 再检查标签数量:4个关键标签必须各出现 1 次
def mark_num(text):
reward = 0
if text.count("") == 1: reward += 0.25
if text.count("") == 1: reward += 0.25
if text.count("") == 1: reward += 0.25
if text.count("") == 1: reward += 0.25
return reward
mark_rewards = [mark_num(response) for response in responses]
rewards += torch.tensor(mark_rewards, device=args.device)
return rewards
2. 开始打分
python
# 初始所有答案都是 0 分
rewards = torch.zeros(len(responses), device=args.device)
# 如果是推理题,先算格式分
if args.reasoning == 1:
rewards = reasoning_model_reward(rewards)
3. 奖励模型评内容质量(真正的"专业老师")
python
with torch.no_grad(): # 评分老师不参与训练,只打分
reward_model_scores = []
batch_size = len(prompts)
scale = 3.0 # 分数限制在 -3 ~ +3
# 一个题目、对应 N 个生成答案,逐个打分
for i in range(batch_size):
for j in range(args.num_generations):
response_idx = i * args.num_generations + j
response = responses[response_idx]
prompt = prompts[i]
# 把题目解析成对话格式:system / user / assistant
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
matches = re.findall(pattern, prompt, re.DOTALL)
messages = [{"role": role, "content": content.strip()} for role, content in matches]
# 拼完整对话:题目 + 模型给出的回答
tmp_chat = messages + [{"role": "assistant", "content": response}]
# 奖励模型打分
score = reward_model.get_score(reward_tokenizer, tmp_chat)
score = max(min(score, scale), -scale) # 限制范围
# 如果是推理题:单独看 里的内容,再打一次分,加权融合
if args.reasoning == 1:
answer_match = re.search(r'(.*?)', response, re.DOTALL)
if answer_match:
answer_content = answer_match.group(1).strip()
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
answer_score = max(min(answer_score, scale), -scale)
# 更看重答案本身:0.4 整段 + 0.6 答案内容
score = score * 0.4 + answer_score * 0.6
reward_model_scores.append(score)
# 把内容分也加进去
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
rewards += reward_model_scores
# 最终返回:格式分 + 内容分 = 综合奖励
return rewards
📌 一句话总结这段代码在干嘛
calculate_rewards 就是一个严格的批改老师:
先看格式对不对(推理题专属),再看内容好不好(奖励模型),最后给一个能直接用来训练AI的总分。
加粗样式
第三部分:GRPO 训练主循环(grpo_train_epoch)
🔴 核心目标
把"批改老师"打的分数转化为 AI 的"成长动力"------ 让 AI 记住"哪些答案得分高、该多写","哪些得分低、要改掉",同时控制 AI 的"变化幅度"(避免改得太离谱),最终让 AI 越练越会写高分答案。
🟡 生活例子(完全对应代码逻辑)
假设你是教练,训练一个"答题选手"(AI),训练流程如下:
- 出题:给选手一批数学题(prompt);
- 选手答题:每道题让选手写 N 个不同答案(num_generations);
- 批改打分:请之前的"批改老师"给每个答案打分(rewards);
- 分析优劣:对比同一道题的 N 个答案,算出"相对优势"(比如:这道题的 3 个答案里,这个答案比平均分高多少);
- 针对性纠错 :
- 对得分高的答案:告诉选手"这种写法要多保留,下次还这么写";
- 对得分低的答案:告诉选手"这种写法要改掉,别再这么写";
- 控制纠错幅度:不能让选手"改太猛"(比如原本会写基础题,改完只会写难题),用"参考版本"(ref_model)约束,确保改完还是"会写基础题";
- 逐步优化:每练一批题就记录进步(日志),练到一定程度就保存"最佳状态"( checkpoint),最终选手越练越会写高分答案。
🟢 对应代码逐段讲解 + 注释
python
def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
"""
【GRPO 训练主循环】
输入:
epoch:当前训练轮次(比如第 3 轮)
loader:题目数据集(一批批出题)
iters:本轮要练多少道题(总步数)
ref_model:参考模型(选手的"基础版本",用来约束改题幅度)
reward_model/reward_tokenizer:批改老师
start_step:起始步数(断点续训,比如上次练到 500 步,这次从 501 开始)
wandb:日志工具(记录训练过程)
输出:无(核心是更新 AI 模型参数)
"""
# 1. 一批批出题,逐个训练
for step, batch in enumerate(loader, start=start_step + 1):
# ===== 第一步:出题 + 预处理 =====
prompts = batch['prompt'] # 本次训练的题目列表 [B],B=批次大小(比如一次练 32 道题)
# 把题目转成模型能看懂的"数字格式"(tokenize):左对齐填充、不额外加特殊符号
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device) # [B, P],P=题目长度
# 限制题目长度(防止太长超出模型能力)
if args.max_seq_len:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
# ===== 第二步:让选手(AI)答题 =====
with torch.no_grad(): # 答题阶段不更新模型,只生成答案
# 分布式训练时,要通过 .module 访问模型的"答题功能"
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
# 每道题生成 N 个答案(num_generations),用"随机采样"(temperature=0.8)避免答案千篇一律
outputs = model_for_gen.generate(
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R],R=答案长度
# 提取答案部分(去掉题目,只留选手写的答案)
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B*num_gen, R]
# ===== 第三步:核心工具函数:计算"答题思路的概率" =====
def get_per_token_logps(mdl, input_ids, n_keep):
"""
计算选手写每个字(token)时的"思路概率":
比如写"1+1=2",模型写"1"的概率、写"+"的概率...... 反映模型"为什么这么写"
输入:
mdl:模型(当前选手 / 参考版本)
input_ids:完整的"题目+答案" [N, P+R],N=B*num_gen
n_keep:只算答案部分的概率(R 个token)
输出:
per_token_logps:每个字的对数概率 [N, R]
"""
# 推理时克隆数据,避免改乱原数据
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
# 让模型"复盘"答题过程,输出每个字的概率(logits)
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :] # [N, R, V],V=所有可选字的数量
per_token_logps = []
# 逐个分析每个答案的每个字
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
# 把概率转成"对数概率"(避免数值溢出),再提取每个字对应的概率
token_logp = torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_logp)
return torch.stack(per_token_logps) # 整理成张量
# ===== 第四步:计算当前选手的"答题思路概率" =====
with autocast_ctx: # 混合精度训练,提速+省内存
# 算选手写答案时,每个字的思路概率
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1))
# 如果是 MoE 模型,额外计算"负载均衡损失"(避免模型偏科)
res = model(outputs) if lm_config.use_moe else None
aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device)
# ===== 第五步:计算参考版本的"答题思路概率"(固定不变,用来约束) =====
with torch.no_grad(): # 参考版本不参与训练,只做对比
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
# ===== 第六步:批改打分(调用之前的 calculate_rewards) =====
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) # 把数字答案转回文字
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B*num_gen]
# ===== 第七步:分析"相对优势"(Advantage) =====
# 按题目分组:把同一道题的 N 个答案放一起对比 [B, num_gen]
grouped_rewards = rewards.view(-1, args.num_generations)
# 算每道题 N 个答案的平均分、标准差
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
# 计算"相对优势":当前答案比同题平均分高多少(标准化,避免极端值)
advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # [B*num_gen]
# ===== 第八步:处理"答案结束标记"(只算有效答案的损失) =====
# 找到答案里"结束符(eos)"的位置,只计算结束符前的内容(忽略后面的无效字符)
is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B*num_gen, R]
# ===== 第九步:计算"纠错损失"(核心:让选手往高分方向改) =====
# 1. 计算"思路偏差"(KL散度):当前选手 vs 参考版本,避免改太猛
kl_div = ref_per_token_logps - per_token_logps
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
# 2. 总损失 = 优势引导(高分多写) - KL约束(别改太猛) + 辅助损失
per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl) # [B*num_gen, R]
# 3. 只算有效答案的损失,求平均
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
loss = (policy_loss + aux_loss) / args.accumulation_steps # 梯度累积,适配小显存
# ===== 第十步:更新选手(AI)的参数(纠错) =====
loss.backward() # 计算"该怎么改"(梯度)
# 累积一定步数后,统一更新参数(梯度累积)
if (step + 1) % args.accumulation_steps == 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) # 限制梯度,避免改太猛
optimizer.step() # 执行更新(纠错)
scheduler.step() # 调整学习率(越练越慢,精细打磨)
optimizer.zero_grad() # 清空梯度,准备下一轮
# ===== 第十一步:记录训练过程(日志) =====
if step % args.log_interval == 0 or step == iters:
# 提取关键指标:损失、奖励、答案长度、学习率
policy_loss_val = loss.item() * args.accumulation_steps
current_aux_loss = aux_loss.item()
avg_reward_val = rewards.mean().item()
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
current_lr = optimizer.param_groups[0]['lr']
# 打印日志
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), '
f'Actor Loss: {policy_loss_val:.4f}, Aux Loss: {current_aux_loss:.4f}, Reward: {avg_reward_val:.4f}, '
f'Avg Response Len: {avg_len_val:.2f}, Learning Rate: {current_lr:.8f}')
# 可视化日志(wandb)
if wandb and is_main_process():
wandb.log({
"policy_loss": policy_loss_val,
"aux_loss": current_aux_loss,
"reward": avg_reward_val,
"avg_response_len": avg_len_val,
"advantages_mean": advantages.mean().item(),
"learning_rate": current_lr
})
# ===== 第十二步:保存最佳状态(checkpoint) =====
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval() # 切换到评估模式
# 生成保存文件名(区分 MoE 模型)
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
# 分布式训练时,提取原始模型
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
# 保存模型参数(转半精度,省空间)
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
# 保存训练状态(断点续训用)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train() # 切回训练模式
del state_dict # 释放内存
# ===== 清理内存 =====
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
📌 一句话总结这段代码在干嘛
grpo_train_epoch 就是 AI 的"训练教练":
让 AI 答题 → 批改打分 → 分析"哪些写法好" → 用损失函数引导 AI 往高分方向改,同时用参考模型控制改题幅度,练完还记录进步、保存最佳状态,最终让 AI 越练越会写高分答案。
核心要点回顾
- 奖励到损失的转化:通过"相对优势"把"批改分数"变成模型能理解的"纠错信号",高分答案对应低损失(鼓励保留),低分答案对应高损失(强制改掉);
- KL 散度约束:用参考模型限制模型更新幅度,避免"改得太猛"导致模型遗忘基础能力;
- 工程优化:通过梯度累积、混合精度、内存清理等手段,适配大模型训练的显存/性能需求。