极客说|强化学习(RL)与有监督微调(SFT)的选择以及奖励函数的优化

「极客说」 是一档专注 AI 时代开发者分享的专栏,我们邀请来自微软以及技术社区专家,带来最前沿的技术干货与实践经验。在这里,您将看到深度教程、最佳实践和创新解决方案。关注「极客说」,与行业顶尖专家一起探索科技的无限可能!投稿请联系:17278094563(微信号)

本文首先将阐述强化学习(RL)和监督微调(SFT)在实现方式上的区别,然后通过一个具体案例,详细说明如何对奖励函数进行优化。

从简单例子入手理解 SFT 和 RL

监督微调(SFT)- 像老师教学生

监督微调(Supervised Fine-Tuning,简称 SFT)相当于作为老师,自己先列出很多问题,再告诉模型标准的回答,比如用数据(训练集)教它:

我们让模型一遍又一遍模仿训练语料中的标准答案,直到我们符合要求。

SFT 具体步骤(算法的介绍)

  1. 我们拿出一个问题:苹果什么颜色?
  2. 模型自己尝试回答:比如它乱回答成 蓝色
  3. 我们就立马纠正,告诉它正确的答案应该是红色,给它一个明确的误差信号: [ 误差 = - log P("红色") ]
  4. 然后模型用这个误差信号帮助它更新自己说法,让下次"红色"概率增加。

所以,监督学习过程如下:

css 复制代码
for 问题, 标准答案 in 数据集:
    模型答案 = 模型生成(问题)
    误差 = 计算交叉熵Loss(模型答案, 标准答案)
    模型更新(误差)

优点:安全、稳定

缺点:模型永远只能模仿,不太能创造性地发现新答案。

强化学习(RL)-- 让模型自己摸索

强化不直接教标准答案,而是用"鼓励"和"惩罚"引导模型。

我们问模型:"1加1等于?"

  • 它如果乱说了:"香蕉!",我们立刻给个负面奖励(-1);
  • 如果它说对了:"2",我们给它正面奖励(+2)。

模型得到这些奖励和惩罚之后,会慢慢去摸索和记忆,知道怎么才能得到更多奖励(而不是直接告诉它标准答案)。

强化学习大致算法:

bash 复制代码
# RL过程:
for 问题 in 数据集:
    # 让鹦鹉自由生成多个答案(探索)
    多个答案 = 模型生成多个可行答案(问题) 

    # 每个答案给奖励
    for 每个答案 in 多个答案:
        奖励 = 奖励函数(每个答案)
        更新策略(奖励 * log(生成该答案概率))

优势:模型能够自己发现最优策略,能主动"探索",学得更主动;

危险:但探索过猛容易产生 KL 爆冲、梯度爆炸、最终模型崩盘。

SFT 和 RL 选择

大多数情况下训练模型先 SFT 再 RL 更安全、更高效,尤其是对能力尚弱的小模型或需要严格格式输出的任务。不过这并不是绝对法则,下面补充几点可作为快速校验的要点。

为什么"先 SFT 后 RL"通常更好

训练稳定性

  • 直接 RL(尤其是小模型)容易出现 KL 爆冲、梯度爆炸,模型甚至崩盘。
  • SFT 先把策略锚定在"基本正确、格式合规"的空间,再让 RL 微调,KL 跳变小很多,收敛更稳。

数据利用效率

  • SFT 等价于"先喂答案教基础功";RL 更像"在掌握基础后练举一反三"。
  • 如果一开始就 RL,模型会在大量无意义探索上浪费步数。

人工标注成本

  • SFT 阶段可用少量高质量标注(或合成高质量标注)直接模仿;
  • RL 阶段只用奖励信号即可继续放大效果,二者配合能节省标注量。

直接 RL 的合理场景

  1. 几乎没有标注数据、但可以自动计算奖励,例如:解数独、玩 Atari 游戏,环境本身给出分数。
  2. 大模型已具备强基础能力 GPT-4、Claude 3-Sonnet 这一级别,格式和基本推理已比较稳,直接 RL(或 RLAIF)效果也可接受。
  3. 任务鼓励高多样性、无法提供单一"标准答案" 如创意写作、对话风格优化,仅用偏好打分即可训练。

实践经验速查表

  1. 我们的奖励函数是不是完全依赖"答案==标准答案"? 如果是,说明我们已经有明确标注;SFT 通常先做更划算。
  2. 我们有多大 GPU/TPU 预算? RL(尤其 GRPO/PPO)往往需要比 SFT 高 2-4 倍的算力。
  3. 任务对"推理链"可解释性要求高吗? 先 SFT(教会标签格式)再 RL(提升正确率)更容易满足可解释输出。

结论

"先 SFT 再 RL"并非硬性规定,但在绝大多数需要结构化输出、且有可用标注的场景下是最省心、最稳妥的路径。只有当标注极少或任务天然提供可计算奖励时,才会优先考虑"直接 RL"。

RL 常见问题

前文提到的 RL 常见的 KL 爆冲、梯度爆炸、模型崩盘问题,本小节详细介绍。

一般情况下,这三个问题会组成一条「连锁反应」:

rust 复制代码
奖励函数设计不佳或超参错误
      ↓↓导致↓↓
   KL爆冲 --> 梯度爆炸 --> 模型参数剧烈变化或NaN
      ↓↓进一步导致↓↓
   模型崩盘 (输出单一、低质)

KL 爆冲

KL 散度(Kullback--Leibler Divergence)本质上衡量的确实是两个概率分布之间的差距 。在 DPO(Direct Preference Optimization)方法 中,参考模型(reference model)和 训练中模型(policy model )之间计算的就是 KL 散度

用简单例子解释一下:

假设默认模型只会讲三句话:"我们好"、"谢谢"、"再见"。

它现在的"说话概率"(也可以叫"原始概率分布")是:

我们心目中理想的"模型应该说话的概率分布"(目标概率分布)是:

我们希望模型朝着目标概率(Q 分布 )学习,但它原本的习惯是当前概率(P 分布)。

这时候,为了知道我们的鹦鹉目前的概率分布 P目标概率分布 Q 差距有多远。

  • KL 散度越小 = 两个概率越接近
  • KL 散度越大 = 两个概率分布的差距越明显

在例子中,如果原来模型会说:"我们好(Hello)",但我们想教它说:"谢谢(Thank you)",那么就有了:

  • 一个原始模型的分布(Original distribution):擅长说"我们好";
  • 一个目标模型的分布(Target distribution):我们希望它能学会说"谢谢"。

假设我们给了模型过分高的奖励,比如只要提到"谢谢",我们奖励20分。模型会在几步内学得太猛,突然所有问题只回复:"谢谢谢谢!"这就是 KL 距离瞬间爆发。

KL 爆冲发生以后,需要用算法调整 KL 惩罚系数(β)

Loss 总 = 奖励损失 + β × KL 散度

提高 β,比如 0.01 → 0.1,约束模型变化的幅度。

梯度爆炸

深度学习中很常见的梯度爆炸问题主要是指:

  • 网络在训练过程中 因为某次更新的梯度过大 ,导致模型参数突然变化过大,从而网络可能变得不稳定甚至崩溃。

最常见导致梯度爆炸的情况,很少是简单的代码 Bug;事实上更多是算法超参设置不当或数值计算不稳定导致的:

  • 学习率(LR)过大:如原本建议的学习率是 1e-5,但使用了过高学习率(如 1e-2 或者更高),一次参数更新迈步过大,造成梯度过大。
  • 奖励信号设计不合理(尺度过大):有时设计奖励信号时,没有进行归一化处理,例如我们奖励的值过大(比如正常奖励是 ±1,却给了数百甚至上万),导致更新步幅过猛,产生极大的梯度数值。
  • 网络结构本身设计或优化器配置不好:比如神经网络某些层的初始化不合理,或梯度累计出现了数值问题,使得运动过程中梯度持续放大。

未使用梯度裁剪或裁剪设置值过大:如果训练过程中未用梯度裁剪方法,或梯度裁剪的上限值设置过大(如10以上),一旦梯度猛增就不能约束,即可引发梯度爆炸。

算法表现为梯度值剧烈变大甚至 NaN。

模型崩盘

模型崩盘的本质含义是:

  • 模型的参数被 "过度优化" 到单一或极少数的策略上(也称为 Mode Collapse);
  • 策略分布发生严重的退化,模型无法再生成丰富、多样化的内容。

模型崩盘有典型的指标,例如:

  • 输出的熵大幅降低(Entropy↓),表示语言多样性消失;
  • 生成内容变得单一固定,重复度极高;
  • 在训练数据以外的泛化能力和稳健性大幅下降。

算法上,熵的定义是:

scss 复制代码
熵值 = -sum( p(X_i)*log(p(X_i)) )
# 熵越低,表示模型生成的语言越单调单一,越接近崩盘

一种典型的模型崩盘的表现是:

  • 训练前语言多样性熵值 ≈ 8 到 10;
  • 训练后模型崩盘,语言熵值 下降至 1~2 左右。

模型崩盘最常见的直接原因是源于强化学习训练过程本身的一系列内在问题(尤其是强化学习),例如:

  • 奖励函数过于单一和简单:导致模型倾向走极端,重复一种行为;
  • 长时间训练、KL 问题持续未解决:模型能力持续退化,最终彻底丢失多样性;
  • **连续出现梯度爆炸但未干预:参数持续异常更新,模型能力根本不能正常保留;
  • 数据质量较低或过拟合于一种模式:模型长时间反复学习有限模式,无法泛化。

如果出现上述问题我们还继续训练,鹦鹉最后脑袋就真的弄坏了。比如它彻底只会一招,一问就吐出"苹果苹果"或彻底傻掉不回话,再训练也没用(模型崩溃)。

SFT 与 GRPO 的两阶段训练

接下来,参考 repo 中 code 目录下的训练代码,我们详细介绍 SFT 和 GRPO 的区别。

说明

  • SFT 阶段脚本里会对 GAIR/LIMO.select(range(1600)) 之类抽样;原始仓库约 817 条(train),938 条(dev+test)。
  • GRPO 阶段在 openai/gsm8k"main" 配置上取 train split,再 select(range(3500)) 抽子集做 RL;test split 用于离线评测。

要在 Hugging Face Hub 搜索 "GAIR/LIMO" 和 "openai/gsm8k" 即可查看与下载完整数据。

两阶段各自"训练了什么"?

阶段①:SFT(Supervised Fine-Tuning)

训练信号

  • 交叉熵(Cross-Entropy),对教师答案逐 token 强制对齐。

学到内容

  • XML 模板必须完整闭合。
  • 里如何写链式思考(First ... Therefore ...)。
  • 在 标签里只出现一个纯数字。
  • LoRA 参数被拉近"正确格式 + 基本推理"的低损失区。

不学/很少学到

  • GSM8K 真值数字(因为数据集不同)。
  • 高阶数学技巧(量太少、只有 1 epoch)。

阶段②:GRPO(Reinforcement Learning, KL-regularized

训练信号

  • 数值奖励 cor_reward:完全命中 +2;其余 0。
  • 格式奖励 fmt_reward:模板满足 +1;否则 0。
  • 惩罚项 KL:防止行为过度偏离基座。

学到内容

  • 如何把 数字精确等于真值(Exact-Match)。
  • 在保持模板的同时优化上一步数字。
  • 探索 ‑> 投票 ‑> 精修的策略(num_generations=8 + 众数投票)。

不再关注

  • 语言流畅度/用词:奖励里没有相应项。
  • 训练集 LIMO 里的叙述风格(如果在奖励里没加 BLEU/Rouge)。

在两阶段训练中:

  • SFT 主要把模型往"格式正确 + 推理语气自然"方向拉; 在真值层面,由于 LIMO 的答案和 GSM8K 不重叠,加的数值知识有限。
  • GRPO 不仅训练数学,还继续用 fmt_reward 维持格式; 如果把格式奖励权重调成 0,格式率会显著下降。
  • SFT 阶段也会略提升数学(因为 LIMO 题目是算数题),只是提升幅度小; GRPO 阶段才用 3 500 条 GSM8K + 180 步强化专门优化数字。
  • 最终格式 90 %+ 依然是两阶段共同作用的结果------SFT 给起点,GRPO 用奖励守住。

两阶段原始数据集字段如下:

训练脚本脚本 map 之后变成如下格式:

这样便于核对:

  • LIMO 的 solution 被嵌入 prompt → 模型在 SFT 时学习;
  • GSM8K 的 answer 纯数字保留,供 GRPO 奖励使用;
  • LIMO 的 answer 在 SFT 时只是模板演示,不参与 RL。

SFT 中的训练格式解释

如上一段内容解释,SFT map 后的训练语料并没有 completion 字段。

在 SFTConfig 训练代码里设置了

completion_only_loss=False

这表示"不要只对 completion 计算损失,而是对整条 prompt 进行 teacher-forcing"。在这种模式下,SFTTrainer 并不需要单独的 completion 字段------只要有一列 prompt 含完整参考答案即可。

  1. 但 SFTTrainer 源码要求数据集中必须存在 completion 这一列(无论用不用)。为了省事就补了空 字符串占位,使得字段齐全、代码不报错。
  2. 为什么不把 answer 放进 completion? 如果我们设 completion_only_loss=True,那就需要把 <answer>25</answer> 部分挪到 completion,让 prompt 只包含系统提示 + question + <reasoning>...</reasoning>。 当前脚本选用整串 CE 方式,所以 completion 留空即可。

简而言之:

  • completion="" 是占位;
  • 真正的教师文本(含 solution 和 answer)已经在 prompt 里,交叉熵对整串计算,所以不会损失任何监督信息。

SFT 训练损失函数的构建

三种构建 SFT 损失函数方案

把 COT + answer 全放到 completion(方案 C)会发生什么?

1、prompt 只剩 "系统提示 + 题干",长度变短 → 同批显存更低; 2、model 在训练时只要"读题干 → 预测 reasoning+answer", 形成经典的 Instruction → Target 教师强制结构; 3、优势

  • 数字与 reasoning token 都在 loss 中,权重不被系统提示稀释;
  • prompt 更短,长题目不易溢出 max_seq_length

4、可能副作用

  • 如果 <reasoning> 很长,占用了 90 % 的 loss,数字又被稀释;
  • 需要保证 <reasoning> 首 token 可由题干直接预测, 否则梯度稀疏(题干→ 标签 gap);
  • 格式标签 <answer> 仍在 completion 内,CE 会学到,但如果生成时 Temperature>0,模型还是可能漏标签,需要 RL 或格式奖励二次约束。

如何选择?

在我们当前"小数据、 1 epoch"的设置下,整串 CE 提供最稠密梯度;如果未来扩充 LIMO 到数万条并跑多 epoch,可以考虑方案 C,并在 RL 阶段继续用格式奖励守护模板,以获得更高数值准确率且不过拟合冗长 COT。

如果改成方案 C------把整个

xml 复制代码
<reasoning>......</reasoning><answer>......</answer>

都放进 completion,只让交叉熵监督这段文本,但增加一个格式类奖励仍然是最稳妥的做法。理由与操作要点如下。

  • 方案 C 把 COT+答案放在 completion 后,模型有潜力更关注数值,但仍可能在生成时漏标签;
  • 保留一个(或低权重)的 fmt_reward 作为安全带是最保险的配置;
  • 可根据任务需要把格式奖励权重动态调低或改成惩罚式,以兼顾准确率与模板稳定性。

设计欠佳奖励函数(优化前的奖励函数)

在强化学习训练中,答案正确性的判断 通常通过自动化脚本实现,而非依赖人工标注的表格。以下是具体实现逻辑。

格式奖励函数(format_reward_func

目标

确保模型输出符合预设的 XML 标签结构 <reasoning>...</reasoning><answer>...</answer>

代码实现

python 复制代码
import re

def format_reward_func(completions, **kwargs):
    """检查输出是否符合XML标签格式"""
    pattern = r"^<reasoning>[\s\S]*?<\/reasoning>\s*<answer>[\s\S]*?<\/answer>$"
    responses = [completion[0]["content"] for completion in completions]
    rewards = [1.0if re.match(pattern, response) else0.0for response in responses]
    return rewards

逻辑解析

  • 正则表达式匹配: 使用正则表达式 r"^<reasoning>[\s\S]*?<\/reasoning>\s*<answer>[\s\S]*?<\/answer>$" 严格检查输出是否包含完整的 <reasoning><answer> 标签,且顺序正确。
  • 奖励分配: 符合格式则奖励 1.0 分,否则 0.0 分。

正确性奖励函数(correctness_reward_func

目标

验证模型输出的数值答案是否与标准答案一致。

代码实现

python 复制代码
def correctness_reward_func(completions, answer, **kwargs):
    """检查答案是否正确"""
    responses = [completion[0]["content"] for completion in completions]
    extracted_responses = [extract_last_xml_answer(response) for response in responses]
    rewards = [
        2.0if extracted == correct else0.0
        for extracted, correct inzip(extracted_responses, answer)
    ]
    return rewards

依赖函数 extract_last_xml_answer

python 复制代码
def extract_last_xml_answer(response):
    """从XML标签中提取答案(若格式错误,则取最后一个数字)"""
    try:
        # 尝试解析XML标签
        answer = re.search(r"<answer>(.*?)</answer>", response).group(1).strip()
        return answer
    except:
        # 格式错误时,提取最后一个数字
        numbers = re.findall(r"\d+\.?\d*", response)
        return numbers[-1] if numbers else""

逻辑解析

  • 答案提取 :优先从 <answer> 标签中提取答案;若标签缺失或格式错误,则提取输出中的最后一个数字。
  • 奖励分配:答案与标准答案一致则奖励 2.0 分,否则 0.0 分。

总奖励计算

  • 总分范围0.0 ~ 3.0 总奖励(1.0) + 正确性奖励(2.0)。
  • 归一化处理: GRPO 算法会对组内奖励进行相对归一化(组内个体奖励减去组平均奖励),以平衡探索与利用。

关键设计考量

  • 格式与正确性的权重 正确性奖励(2.0)权重高于格式奖励(1.0),体现"答案正确性优先于格式"的设计原则。
  • 容错机制 即使格式错误,仍尝试提取最后一个数字作为答案,避免因格式问题完全丢弃有效答案。
  • 则表达式严格性 格式检查使用严格匹配(^...$),确保标签闭合且无多余内容,强制模型学习结构化输出。

奖励函数优化思路

奖励函数优化主要包含:

  1. 细化奖励分值
  2. 增加群组投票

数字奖励 cor_reward (0 / 1 / 2 分)

python 复制代码
XML_RE  = re.compile(r"<answer>(.*?)</answer>", re.S)
_num    = lambda x: re.sub(r"[%$,]", "", x).strip()

def _extract_nums(text: str):
    return [_num(m) for m in XML_RE.findall(text)]

def cor_reward(completions, **kw):
    answers = kw.get("answer") or kw.get("answers") or []
    rewards = []

    for cand_list, gt inzip(completions, answers):
        # 1) 收集 8 条回答里的所有 <answer>...</answer> 数字
        nums = [
            n
            for c in cand_list
            for n in _extract_nums(c["content"])
        ]

        # 2) 若一个数字都没抓到 → 直接 0 分
        ifnot nums:
            rewards.append(0.0)
            continue

        # 3) 群组投票:出现次数最多的数字
        vote = Counter(nums).most_common(1)[0][0]

        # 4) 评分:完全对 +2,差 1 +1,其余 0
        diff = abs(int(vote) - int(gt)) if vote.isdigit() and gt.isdigit() else999
        if   diff == 0: rewards.append(2.0)
        elif diff == 1: rewards.append(1.0)
        else:           rewards.append(0.0)

    return rewards

详细步骤

1、_extract_nums()

  • 用正则在单条回答文本里找所有 ...;
  • _num() 去掉 $ % , 等符号,得纯数字字符串。

2、组内投票(majority vote)

scss 复制代码
vote = Counter(nums).most_common(1)[0][0]
  • 把 8 条回答汇总得到的数字列表 nums 做统计;
  • 选出现频率最高的那一个(若并列,Counter 取第一出现)。
  • 投票的好处:抑制偶然的随机数;让模型有动力让多个回答趋向一致=正确数值。

3、分级奖励 diff = |vote - ground_truth| -- diff == 0 → +2 (完全正确) -- diff == 1 → +1 (只差 1,也给部分梯度) -- else     → 0 (远离真值) 这样 early 训练阶段更容易拿到非零 reward,梯度稠密,KL 更平滑。

输出示例

ini 复制代码
batch_size = 8
cor_reward → [2,1,0,2,0,1,0,2]
fmt_reward → [1,1,0,1,1,1,0,1]
total_reward → [3,2,0,3,1,2,0,3]

新旧奖励函数对比:

结果:

  • fmt_reward_mean 更快爬到 0.9;
  • cor_reward_mean 抬到 ~1.2(≈30% 完全对 + 35% 差 1);
  • KL 控制在 <0.2,训练稳定;
  • 总 reward 1.8→2.1 左右,比旧版提升约 10 %。

群组 vote 的合理性研究

先投票再对真值"合不合理,要看我们希望奖励函数起什么作用。

合理方面

1、自洽性(Self-Consistency)的经验规律 OpenAI、Google 论文都表明: "同一 prompt 让模型多生成几条推理,用众数/平均值作为最终答案, 准确率往往高于单条输出。" 投票奖励把这个经验直接注入 RL:

  • 如果 8 条里 ≥4 条写 42,那 42 很可能就是正确答案;
  • 早期即使 8 条回答各不相同,也能把出现次数最多的那个作为 "模型当前最确信" 的猜测。

2、梯度密度更高

  • 纯 0/2 模式:完全错 = 0,很容易 reward 全 0;
  • 投票 + 差 1 给 1 分:早期也能拿到非零 reward,梯度方向更连续。

3、利用并行生成的计算成本 既然我们已经花显存一次性生成了 8 条回答,把它们全都用来评奖要 比只看第一条更物超所值。

4、格式门控 + 数值投票分离 先用格式奖励约束输出形状,再用投票奖励评数值;两部分可独立调 权重,互不干扰。

局限性

  1. "集体跑偏" 如果模型内部存在系统性错误(8 条都写 41,但真值 42),投票仍会 选错。此时 reward 仍给 0 / 1,梯度作用有限。
  2. 并列众数的歧义 Counter.most_common(1) 默认返回先出现的数字; 若票数打平,选择具有随机性,可能带来噪声。 → 可以设阈值:只有票数 ≥4 才用众数,否则 reward=0。
  3. 差值阈值的 trade-off ‑ 差 1 给 1 分能 densify 梯度; ‑ 但如果阈值太宽(差 5 也给分)会削弱"完全正确"的驱动力。
  4. 生成条数与开销 num_generations=8 对 A100 2B 模型还算轻;如果用更大的模型或者 更长 completion,生成 8 条会拖慢训练。

如何让投票更棒(进一步优化的可能性)

结论

  • 小批量、有限步数的 RL 微调而言,"投票→对真值" 的奖励能显著缓解 0/2 稀疏问题,让 early reward 更快爬升,并在格式合规率上带来明显增益,是一个合理且常用的技巧。
  • 如果我们更在意"对/错的严格区分",可以保持 diff==0 才给分; 若更在意收敛速度和平滑梯度,保留差 1 + 部分分会更友好。

因此,是否保留投票机制取决于:

  • 我们能否接受多生成几条回答的时间 / 显存成本;
  • 我们更关注最终极限正确率(可考虑后期关闭差 1 奖励), 还是关注训练效率和稳定性(保留投票 + 部分分)。

训练结果指标解读

SFTTrainer 日志里出现字段:

SFT、GRPO 通用字段:

GRPOTrainer 特有字段:

备注:

  1. fmt_reward/mean ≥ 0.9 → 模板输出稳定。
  2. cor_reward/mean ≥ 1.2 → 30 % 以上完全正确(好)。
  3. kl < 0.3 → 更新稳定;若突涨,需减小学习率 / β。
  4. frac_reward_zero_std < 0.3 → 奖励信号足够密集。
  5. completions/clipped_ratio > 0.4 → 说明 128 token 不够,可调大。

奖励函数优化对训练效果实测

奖励函数优化前

训练

ruby 复制代码
source .venv/bin/activate
root@a100vm:~/Gemma-2-2B-IT-GRPO# pwd
/root/Gemma-2-2B-IT-GRPO
root@a100vm:~/Gemma-2-2B-IT-GRPO# python gemma-grpo2.py 
root@a100vm:~/Gemma-2-2B-IT-GRPO# python  gemma-instruct-grpo2.py

训练中的资源利用率

ruby 复制代码
(Gemma-2-2B-IT-GRPO) root@a100vm:~/Gemma-2-2B-IT-GRPO# nvidia-smi
Sun Jun 1511:44:162025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.05              Driver Version: 560.35.05      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100 80GB PCIe          Off |   00000001:00:00.0 Off |                    0 |
| N/A   49C    P0            109W /  300W |   80793MiB /  81920MiB |     48%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A    449318      C   python                                      80780MiB |
+-----------------------------------------------------------------------------------------+

评估

javascript 复制代码
python3 -m venv ~/eval-env
source ~/eval-env/bin/activate
pip install "torch>=2.1""transformers>=4.49" datasets tqdm
pip install accelerate
python gsm8k-eval-tf2.py --model_dir gemma-grpo-only
python gsm8k-eval-tf2.py --model_dir gemma-sft-grpo

评估脚本执行结果

纯 GRPO

matlab 复制代码
----------------------------------------
Input tokens  avg=140.5  max=269
Output tokens avg=90.9  max=257
Correct format     : 1142/1319 (86.6%)
Plausibly correct  : 566/1319 (42.9%)
Exact correct      : 559/1319 (42.4%)
========================================

SFT+GRPO

ruby 复制代码
----------------------------------------
Input tokens  avg=140.5  max=269
Output tokens avg=74.7  max=257
Correct format     : 1192/1319 (90.4%)
Plausibly correct  : 504/1319 (38.2%)
Exact correct      : 500/1319 (37.9%)
========================================
(eval-env) root@a100vm:~/Gemma-2-2B-IT-GRPO# 

奖励函数优化后

训练

ruby 复制代码
source .venv/bin/activate
root@a100vm:~/Gemma-2-2B-IT-GRPO# pwd
/root/Gemma-2-2B-IT-GRPO
root@a100vm:~/Gemma-2-2B-IT-GRPO# python gemma-grpo3.py 
root@a100vm:~/Gemma-2-2B-IT-GRPO# python  gemma-instruct-grpo3.py

评估

javascript 复制代码
python3 -m venv ~/eval-env
source ~/eval-env/bin/activate
pip install "torch>=2.1""transformers>=4.49" datasets tqdm
pip install accelerate
python gsm8k-eval-tf2.py --model_dir gemma-grpo-only
python gsm8k-eval-tf2.py --model_dir gemma-sft-grpo

评估脚本执行结果

仅 GRPO

ruby 复制代码
----------------------------------------
Input tokens  avg=140.5  max=269
Output tokens avg=92.2  max=257
Correct format     : 1120/1319 (84.9%)
Plausibly correct  : 665/1319 (50.4%)
Exact correct      : 657/1319 (49.8%)
========================================
(eval-env) root@a100vm:~/Gemma-2-2B-IT-GRPO# 

SFT+GRPO

ruby 复制代码
----------------------------------------
Input tokens  avg=140.5  max=269
Output tokens avg=75.5  max=257
Correct format     : 1161/1319 (88.0%)
Plausibly correct  : 506/1319 (38.4%)
Exact correct      : 505/1319 (38.3%)
========================================
(eval-env) root@a100vm:~/Gemma-2-2B-IT-GRPO# 

奖励函数优化前后对比(仅仅对比 GRPO)

结论

  1. 数值准确率 新奖励把完全正确率提升了约 7 个百分点,这是只奖励 exact-match 的直接收益
  2. 格式合规率 基本持平(因为并没有优化格式奖励)
  3. 后续细化奖励规则,增加训练 step 数,准确率有望继续提升。

白皮书推荐

开发者们,别掉队!大语言模型正以前所未有的速度重塑技术格局。微软最新发布《大语言模型(LLM)上手指南》白皮书,涵盖 Microsoft Copilot 副驾驶® 在代码编写、Debug、创意发想等方面的强大功能详细解说。

点击下方链接,在角色一栏中填写"开发者"即可领取专属开发者的技术文档。

info.microsoft.com/GC-DevOps-W...

相关推荐
我爱C编程5 小时前
基于强化学习的5G通信网络基站资源动态分配策略matlab性能仿真
5g·matlab·强化学习·基站资源动态分配
SunStriKE2 天前
veRL代码阅读-1.论文原理
深度学习·强化学习·源码阅读
Listennnn3 天前
强化学习三大分类
人工智能·强化学习
JNU freshman4 天前
强化学习之 DQN、Double DQN、PPO
强化学习
MarkGosling4 天前
【资源合集】强化学习训练LLM Agents的实战资源库:AgentsMeetRL
llm·agent·强化学习
汤姆和佩琦4 天前
LLMs基础学习(八)强化学习专题(4)
学习·强化学习·策略随机探索
Gowi_fly7 天前
从 PPO、DPO 到 GRPO:大语言模型策略优化算法解析
llm·强化学习
我不是小upper7 天前
AReaL-boba²:首个全异步强化学习训练系统它来了!!
人工智能·强化学习
panbaoran9138 天前
【一】零基础--分层强化学习概览
强化学习·hrl