聊聊GRPO算法——从Open R1来看如何训练DeepSeek R1模型

概述

首发自个人公众号:阿郎小哥的随笔驿站

DeepSeek R1系列建议阅读之前的系列文章:

聊聊DeepSeek R1的一些总结

聊聊DeepSeek R1的开源复现库------Open R1之合成数据

聊聊DeepSeek R1的知识蒸馏与应用思考

简介

GRPO 是一种在线学习算法,这意味着它通过在训练期间使用受训模型自身生成的数据来迭代改进。GRPO 目标背后的直觉是最大化生成补全的优势,同时确保模型保持接近参考策略。

GRPO 的四个主要步骤:生成补全计算优势估计 KL 散度计算损失

与传统的RL方法不同,后者通常依赖外部评估者(批评者)来引导学习,GRPO通过评估一组响应之间的相对关系来优化模型。这种方法提高了训练效率,使GRPO在需要复杂问题解决和长链思维的推理任务中表现尤为出色。

步骤分解

步骤1:选择查询

• 从训练数据集$ P(Q) 中选择一个查询 (q) $。

• 示例:假设查询是"8 + 5的和是多少?"

步骤2:生成一组响应

• 模型针对该查询生成一组$ G $个响应。

• 示例:模型生成以下响应:

• o1:"答案是13。"

• o2:"十三。"

• o3:"是12。"

• o4:"和是13。"

步骤3:计算每个响应的奖励

• 什么是奖励?奖励通过量化响应的质量来引导模型的学习。

• GRPO中的奖励类型:

• 准确性奖励:基于响应的正确性(例如,解答数学题)。

• 格式奖励:确保响应符合结构化要求(例如,推理过程需要包含在标签中)。

• 语言一致性奖励:惩罚语言混杂或格式不一致的响应。

• 根据每个响应的好坏,赋予一个奖励($ r_i $)。

例如,奖励可能取决于:

• 准确性:答案是否正确?

• 格式:响应是否结构良好?

示例:

• r1 = 1.0(正确且格式良好)

• r2 = 0.9(正确但较不正式)

• r3 = 0.0(错误答案)

• r4 = 1.0(正确且格式良好)

步骤4:比较响应(群体优势)

• 计算每个响应相对于群体的优势$ (A_i) $,paper中相关术语如下:

用简单的方式理解,就是这样:

• 比较结果优于群体平均水平的响应会获得正分,而表现较差的响应会得到负分。

• 这种方式在群体内部激发竞争,推动模型生成更好的响应。

步骤5:使用裁剪更新策略

示例:如果新策略开始给o1分配过高的概率,裁剪机制确保不会过度强调这个响应。

这种方式保证了即使在像推理这样复杂的任务中,策略优化也能保持稳定和可靠。

步骤6:通过KL散度惩罚偏差

GRPO实现

Open R1

在Open R1的复现路径中

实现了基于GRPO算法的训练,脚本如下

bash 复制代码
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml --num_processes=7 src/open_r1/grpo.py --config recipes/qwen/Qwen2.5-1.5B-Instruct/grpo/confg_full.yaml

confg_full.yaml

yaml 复制代码
# 基座模型
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16

# 训练数据集
dataset_name: AI-MO/NuminaMath-TIR
dataset_configs:
- all
# Num processes is less by 1 as vLLM is using 1 GPU
num_processes: 7

# GRPO训练器参数
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: true
eval_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
hub_model_id: Qwen2.5-1.5B-Open-R1-GRPO
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 10
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: -1
num_train_epochs: 1
output_dir: data/Qwen2.5-1.5B-Open-R1-GRPO
overwrite_output_dir: true
per_device_eval_batch_size: 4   
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- wandb
save_strategy: "no"
seed: 42
warmup_ratio: 0.1

Open R1提供了grpo算法的实现------grpo.py,删减了部分无关代码,关键的程序逻辑如下:

bash 复制代码
@dataclass
class GRPOScriptArguments(ScriptArguments):
    reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format"],
        metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
    )


def accuracy_reward(completions, solution, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol in zip(contents, solution):
        gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
        if len(gold_parsed) != 0:
            # We require the answer to be provided in correct latex (no malformed operators)
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed=True,
                            units=True,
                        ),
                        # Ensures that boxed is tried first
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )
            # Reward 1 if the content is the same as the ground truth, 0 otherwise
            reward = float(verify(answer_parsed, gold_parsed))
        else:
            # If the gold solution is not parseable, we reward 1 to skip this example
            reward = 1.0
            print("Failed to parse gold solution: ", sol)
        rewards.append(reward)

    return rewards


def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]


reward_funcs_registry = {
    "accuracy": accuracy_reward,
    "format": format_reward,
}

SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)


def main(script_args, training_args, model_args):
   
    # Load the dataset
    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

    # Get reward functions
    reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]

    # Format into conversation
    def make_conversation(example):
        return {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": example["problem"]},
            ],
        }

    dataset = dataset.map(make_conversation)
    for split in dataset:
        if "messages" in dataset[split].column_names:
            dataset[split] = dataset[split].remove_columns("messages")

    logger.info("*** Initializing model kwargs ***")
    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )
    model_kwargs = dict(
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
    )
    training_args.model_init_kwargs = model_kwargs

    #############################
    # Initialize the GRPO trainer
    #############################
    trainer = GRPOTrainer(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=dataset[script_args.dataset_train_split],
        eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
        peft_config=get_peft_config(model_args),
        callbacks=get_callbacks(training_args, model_args),
    )

    ###############
    # Training loop
    ###############
    logger.info("*** Train ***")
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
    metrics = train_result.metrics
    metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

    ##################################
    # Save model and create model card
    ##################################
    trainer.save_model(training_args.output_dir)

    # Save everything else on main process
    kwargs = {
        "dataset_name": script_args.dataset_name,
        "tags": ["open-r1"],
    }
    if trainer.accelerator.is_main_process:
        trainer.create_model_card(**kwargs)
        # Restore k,v cache for fast inference
        trainer.model.config.use_cache = True
        trainer.model.config.save_pretrained(training_args.output_dir)

    ##########
    # Evaluate
    ##########
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate()
        metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    #############
    # push to hub
    #############
    if training_args.push_to_hub:
        logger.info("Pushing to hub...")
        trainer.push_to_hub(**kwargs)


if __name__ == "__main__":
    parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args)

代码分析如下:

首先就是加载数据集,但数据集在加载时,会有指定的提示词,即代码中的make_conversation函数,该函数构造指定的prompt引导模型的输出,格式如下:

bash 复制代码
{
    "prompt": [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": example["problem"]},
    ],
}

对于SYSTEM_PROMPT,描述如下:

bash 复制代码
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"

总的来说就是,引导模型先思考推理过程 ,再按格式将推理过程与回复 放入指定标签<think>、<answer>内。

接下来是reward函数,grpo算法有两种奖励:准确性奖励与格式正确奖励;如下

python 复制代码
def accuracy_reward(completions, solution, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol in zip(contents, solution):
        gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
        if len(gold_parsed) != 0:
            # We require the answer to be provided in correct latex (no malformed operators)
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed=True,
                            units=True,
                        ),
                        # Ensures that boxed is tried first
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )
            # Reward 1 if the content is the same as the ground truth, 0 otherwise
            reward = float(verify(answer_parsed, gold_parsed))
        else:
            # If the gold solution is not parseable, we reward 1 to skip this example
            reward = 1.0
            print("Failed to parse gold solution: ", sol)
        rewards.append(reward)

    return rewards


def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]


reward_funcs_registry = {
    "accuracy": accuracy_reward,
    "format": format_reward,
}

最后就是训练,GRPOTrainer是transformers库提供的基于Trainer的训练类,传入指定的参数即可实现基于GRPO算法的实现;其中比较关键的是reward、train_dataset。

bash 复制代码
#############################
# Initialize the GRPO trainer
#############################
trainer = GRPOTrainer(
    model=model_args.model_name_or_path,
    reward_funcs=reward_funcs,
    args=training_args,
    train_dataset=dataset[script_args.dataset_train_split],
    eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
    peft_config=get_peft_config(model_args),
    callbacks=get_callbacks(training_args, model_args),
)

计算训练的checkpoint与循环周期,则会在Trainer类中通过gradient_accumulation_steps(梯度累积步数)、num_train_epochs(训练轮数)以及 per_device_train_batch_size(每个设备的训练批次大小)这些参数计算训练周期。

bash 复制代码
###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
    checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
    checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)

小结

总的来说,Open R1的GRPO训练,是基于GRPOTrainer指定prompt/datasetreward等参数实现GRPO的训练。也就是说,在指定的训练数据集下,通过prompt引导模型的输出,然后基于grpo算法及其reward对 模型的输出与训练数据集的output 做奖惩打分(通过KL散度比较),计算loss,再反向传播。循环反复;最终完成模型的RL训练,达到让模型能做到CoT式的回复,即生成补全计算优势估计 KL 散度计算损失的步骤,如最开始的图所示。

对于GRPOTrainer类的源码及文档可参考:

首发自个人公众号:阿郎小哥的随笔驿站

相关推荐
慕白Lee11 分钟前
【DeepSeek】私有化本地部署图文(Win+Mac)
ai·deepseek
一颗小树x1 小时前
DeepSeek-R1 本地电脑部署 Windows系统 【轻松简易】
windows·部署·deepseek·r1
Qubernet2 小时前
通过Ollama本地部署DeepSeek R1以及简单使用
ai·ollama·deepseek
麻雀小妖2 小时前
如何在WPS和Word/Excel中直接使用DeepSeek功能
ai·wps·office·deepseek
gzpingesoft2 小时前
DeepSeek RAGFlow构建本地知识库系统
ai·知识库·ragflow·deepseek
火星大能猫2 小时前
使用C# 调用deepseek api接口,来实现正常访问
ai·c#·webapi·deepseek
追风20194 小时前
安装和使用 Ollama(实验环境windows)
windows·ai
云边有个稻草人5 小时前
深度学习与搜索引擎优化的结合:DeepSeek的创新与探索
人工智能·深度学习·搜索引擎·deepseek
视觉&物联智能6 小时前
【杂谈】-文明的量子跃迁:AI时代人类物种的自我重构
人工智能·深度学习·aigc·openai·agi·deepseek