07_verl-Trainer模块详解

07 Trainer 模块详解

【总】开篇概述

Trainer 模块的核心价值

Trainer 模块是 verl 训练流程的控制中心,负责编排完整的 PPO(Proximal Policy Optimization)训练循环。它将数据加载、模型推理、奖励计算、优势估计、策略更新等环节串联成一个端到端的训练流水线,同时协调多个分布式 Worker 之间的协作。

verl 的 Trainer 并非简单的训练脚本,而是一个精心设计的分布式编排器:驱动进程(Driver)仅执行轻量级的优势计算与数据调度,所有重计算(前向推理、反向传播)均通过 Ray RPC 委托给远端 Worker 执行,从而实现了控制逻辑与计算逻辑的彻底解耦。

核心问题

verl 如何编排完整的 PPO 训练循环? 具体而言:

  • 如何从入口脚本启动分布式训练?
  • 如何注册和初始化各类 Worker(Actor、Critic、RefPolicy 等)?
  • 如何在单控制器上协调 Rollout → Reward → Advantage → Update 的完整数据流?
  • 如何支持 GAE、GRPO、REINFORCE++ 等多种优势估计器?
  • 如何处理 KL 惩罚、Rollout 修正等高级算法特性?

全局概览图

#mermaid-svg-KsFRbPK1GX6EV6oR{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-KsFRbPK1GX6EV6oR .error-icon{fill:#552222;}#mermaid-svg-KsFRbPK1GX6EV6oR .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-KsFRbPK1GX6EV6oR .marker{fill:#333333;stroke:#333333;}#mermaid-svg-KsFRbPK1GX6EV6oR .marker.cross{stroke:#333333;}#mermaid-svg-KsFRbPK1GX6EV6oR svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-KsFRbPK1GX6EV6oR p{margin:0;}#mermaid-svg-KsFRbPK1GX6EV6oR .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR .cluster-label text{fill:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR .cluster-label span{color:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR .cluster-label span p{background-color:transparent;}#mermaid-svg-KsFRbPK1GX6EV6oR .label text,#mermaid-svg-KsFRbPK1GX6EV6oR span{fill:#333;color:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR .node rect,#mermaid-svg-KsFRbPK1GX6EV6oR .node circle,#mermaid-svg-KsFRbPK1GX6EV6oR .node ellipse,#mermaid-svg-KsFRbPK1GX6EV6oR .node polygon,#mermaid-svg-KsFRbPK1GX6EV6oR .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-KsFRbPK1GX6EV6oR .rough-node .label text,#mermaid-svg-KsFRbPK1GX6EV6oR .node .label text,#mermaid-svg-KsFRbPK1GX6EV6oR .image-shape .label,#mermaid-svg-KsFRbPK1GX6EV6oR .icon-shape .label{text-anchor:middle;}#mermaid-svg-KsFRbPK1GX6EV6oR .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-KsFRbPK1GX6EV6oR .rough-node .label,#mermaid-svg-KsFRbPK1GX6EV6oR .node .label,#mermaid-svg-KsFRbPK1GX6EV6oR .image-shape .label,#mermaid-svg-KsFRbPK1GX6EV6oR .icon-shape .label{text-align:center;}#mermaid-svg-KsFRbPK1GX6EV6oR .node.clickable{cursor:pointer;}#mermaid-svg-KsFRbPK1GX6EV6oR .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-KsFRbPK1GX6EV6oR .arrowheadPath{fill:#333333;}#mermaid-svg-KsFRbPK1GX6EV6oR .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-KsFRbPK1GX6EV6oR .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-KsFRbPK1GX6EV6oR .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-KsFRbPK1GX6EV6oR .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-KsFRbPK1GX6EV6oR .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-KsFRbPK1GX6EV6oR .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-KsFRbPK1GX6EV6oR .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-KsFRbPK1GX6EV6oR .cluster text{fill:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR .cluster span{color:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-KsFRbPK1GX6EV6oR .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR rect.text{fill:none;stroke-width:0;}#mermaid-svg-KsFRbPK1GX6EV6oR .icon-shape,#mermaid-svg-KsFRbPK1GX6EV6oR .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-KsFRbPK1GX6EV6oR .icon-shape p,#mermaid-svg-KsFRbPK1GX6EV6oR .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-KsFRbPK1GX6EV6oR .icon-shape .label rect,#mermaid-svg-KsFRbPK1GX6EV6oR .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-KsFRbPK1GX6EV6oR .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-KsFRbPK1GX6EV6oR .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-KsFRbPK1GX6EV6oR :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} Worker 层
辅助模块
算法层
编排层
入口层
main_ppo.py / main_ppo_sync.py
Hydra 配置加载
Ray 集群初始化
TaskRunner.run
RayPPOTrainer / PPOTrainer
init_workers
fit - 主训练循环
core_algos.py
GAE 优势估计
GRPO 优势估计
REINFORCE++ 优势估计
PPO 裁剪损失
KL 惩罚控制
reward.py - 奖励提取
metric_utils.py - 指标计算
rollout_corr_helper.py - Rollout 修正
SkipManager - 条件跳过
padding_utils.py - 批次填充
ActorRolloutRefWorker
TrainingWorker - Critic
RewardLoopManager
AgentLoopManager

关键结论预览

  1. 单控制器编排模式:Trainer 运行在驱动进程上,通过 Ray RPC 调度 Worker,自身仅执行轻量计算(优势估计、KL 惩罚),实现了控制与计算的分离。
  2. 可插拔算法架构 :通过 @register_adv_est@register_policy_loss 装饰器注册机制,支持 GAE、GRPO、RLOO、ReMax 等 10+ 种优势估计器和 vanilla、DPPO、GSPO、SAPO 等 10+ 种策略损失函数。
  3. 双模式训练 :同时支持异步 PPO(RayPPOTrainer,基于 DataProto)和同步 PPO(PPOTrainer,基于 TransferQueue),后者是未来主推方向。
  4. Rollout 修正体系:通过重要性采样(IS)权重和拒绝采样(RS)修正 rollout 与训练策略之间的分布偏移,支持 Bypass 模式(跳过 old_log_prob 重计算)。
  5. Critic 暖启动 :支持 critic_warmup 机制,在指定步数内仅更新 Critic 而冻结 Actor,稳定训练初期。

【分】逐层展开

1. 入口与启动流程

1.1 两种入口脚本

verl 提供两种 PPO 训练入口:

脚本 Trainer 类 数据传输 状态
main_ppo.py RayPPOTrainer DataProto 已废弃(v0.8.0 移除)
main_ppo_sync.py PPOTrainer TransferQueue 推荐使用

两者共享 run_ppo() 函数进行 Ray 初始化和 TaskRunner 启动,差异在于数据传输机制和 Worker 协作方式。

1.2 启动时序图

RayPPOTrainer TaskRunner Ray 集群 Hydra 配置 main_ppo.py 用户 RayPPOTrainer TaskRunner Ray 集群 Hydra 配置 main_ppo.py 用户 #mermaid-svg-KPh0357k8oKN2Dio{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-KPh0357k8oKN2Dio .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-KPh0357k8oKN2Dio .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-KPh0357k8oKN2Dio .error-icon{fill:#552222;}#mermaid-svg-KPh0357k8oKN2Dio .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-KPh0357k8oKN2Dio .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-KPh0357k8oKN2Dio .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-KPh0357k8oKN2Dio .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-KPh0357k8oKN2Dio .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-KPh0357k8oKN2Dio .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-KPh0357k8oKN2Dio .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-KPh0357k8oKN2Dio .marker{fill:#333333;stroke:#333333;}#mermaid-svg-KPh0357k8oKN2Dio .marker.cross{stroke:#333333;}#mermaid-svg-KPh0357k8oKN2Dio svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-KPh0357k8oKN2Dio p{margin:0;}#mermaid-svg-KPh0357k8oKN2Dio .actor{stroke:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);fill:#ECECFF;}#mermaid-svg-KPh0357k8oKN2Dio text.actor>tspan{fill:black;stroke:none;}#mermaid-svg-KPh0357k8oKN2Dio .actor-line{stroke:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);}#mermaid-svg-KPh0357k8oKN2Dio .innerArc{stroke-width:1.5;stroke-dasharray:none;}#mermaid-svg-KPh0357k8oKN2Dio .messageLine0{stroke-width:1.5;stroke-dasharray:none;stroke:#333;}#mermaid-svg-KPh0357k8oKN2Dio .messageLine1{stroke-width:1.5;stroke-dasharray:2,2;stroke:#333;}#mermaid-svg-KPh0357k8oKN2Dio #arrowhead path{fill:#333;stroke:#333;}#mermaid-svg-KPh0357k8oKN2Dio .sequenceNumber{fill:white;}#mermaid-svg-KPh0357k8oKN2Dio #sequencenumber{fill:#333;}#mermaid-svg-KPh0357k8oKN2Dio #crosshead path{fill:#333;stroke:#333;}#mermaid-svg-KPh0357k8oKN2Dio .messageText{fill:#333;stroke:none;}#mermaid-svg-KPh0357k8oKN2Dio .labelBox{stroke:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);fill:#ECECFF;}#mermaid-svg-KPh0357k8oKN2Dio .labelText,#mermaid-svg-KPh0357k8oKN2Dio .labelText>tspan{fill:black;stroke:none;}#mermaid-svg-KPh0357k8oKN2Dio .loopText,#mermaid-svg-KPh0357k8oKN2Dio .loopText>tspan{fill:black;stroke:none;}#mermaid-svg-KPh0357k8oKN2Dio .loopLine{stroke-width:2px;stroke-dasharray:2,2;stroke:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);fill:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);}#mermaid-svg-KPh0357k8oKN2Dio .note{stroke:#aaaa33;fill:#fff5ad;}#mermaid-svg-KPh0357k8oKN2Dio .noteText,#mermaid-svg-KPh0357k8oKN2Dio .noteText>tspan{fill:black;stroke:none;}#mermaid-svg-KPh0357k8oKN2Dio .activation0{fill:#f4f4f4;stroke:#666;}#mermaid-svg-KPh0357k8oKN2Dio .activation1{fill:#f4f4f4;stroke:#666;}#mermaid-svg-KPh0357k8oKN2Dio .activation2{fill:#f4f4f4;stroke:#666;}#mermaid-svg-KPh0357k8oKN2Dio .actorPopupMenu{position:absolute;}#mermaid-svg-KPh0357k8oKN2Dio .actorPopupMenuPanel{position:absolute;fill:#ECECFF;box-shadow:0px 8px 16px 0px rgba(0,0,0,0.2);filter:drop-shadow(3px 5px 2px rgb(0 0 0 / 0.4));}#mermaid-svg-KPh0357k8oKN2Dio .actor-man line{stroke:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);fill:#ECECFF;}#mermaid-svg-KPh0357k8oKN2Dio .actor-man circle,#mermaid-svg-KPh0357k8oKN2Dio line{stroke:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);fill:#ECECFF;stroke-width:2px;}#mermaid-svg-KPh0357k8oKN2Dio :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} python main_ppo.py加载 ppo_trainer.yaml返回 configauto_set_device(config)migrate_legacy_reward_impl(config)ray.init(runtime_env)集群就绪ray.get(runner.run.remote(config))add_actor_rollout_worker(config)add_critic_worker(config)init_resource_pool_mgr(config)create_rl_dataset / create_rl_samplerRayPPOTrainer(config, ...)trainer.init_workers()trainer.fit()训练完成返回结果

1.3 Hydra 配置加载

通过 @hydra.main(config_path="config", config_name="ppo_trainer") 装饰器自动加载 YAML 配置。关键配置项包括:

  • actor_rollout_ref:Actor/Rollout/RefPolicy 共享配置
  • critic:Critic 模型配置
  • algorithm:算法参数(adv_estimator、gamma、lam 等)
  • reward:奖励配置(函数奖励/模型奖励)
  • trainer:训练控制参数(总步数、保存频率等)
  • data:数据集配置
1.4 Ray 集群初始化

run_ppo() 函数负责 Ray 集群初始化:

python 复制代码
# 关键步骤
default_runtime_env = get_ppo_ray_runtime_env()  # 设置 TOKENIZERS_PARALLELISM、NCCL_DEBUG 等
ray.init(**OmegaConf.to_container(ray_init_kwargs))

constants_ppo.py 中定义了 Ray 运行时环境变量,包括 tokenizer 并行化、NCCL 调试级别、vLLM 日志级别等。

1.5 TaskRunner

TaskRunner 是一个 Ray Remote 类,作为训练任务的远程执行器:

python 复制代码
@ray.remote(num_cpus=1)
class TaskRunner:
    def run(self, config):
        # 1. 注册 Worker
        self.add_actor_rollout_worker(config)
        self.add_critic_worker(config)
        # 2. 初始化资源池
        self.init_resource_pool_mgr(config)
        # 3. 创建数据集
        train_dataset = create_rl_dataset(...)
        val_dataset = create_rl_dataset(...)
        train_sampler = create_rl_sampler(...)
        # 4. 实例化 Trainer
        trainer = RayPPOTrainer(config, ...)
        trainer.init_workers()
        trainer.fit()

Worker 注册逻辑

  • add_actor_rollout_worker():注册 ActorRolloutRefWorker,若使用 LoRA 则 RefPolicy 融合在 Actor 中(Role.ActorRollout),否则独立注册(Role.ActorRolloutRef
  • add_critic_worker():注册 TrainingWorker 作为 Critic,仅在 need_critic(config) 为 True 时注册
  • add_reward_model_resource_pool():若启用奖励模型,注册独立资源池或共享全局池

ResourcePool 初始化

python 复制代码
resource_pool_spec = {
    "global_pool": [n_gpus_per_node] * nnodes,  # Actor/Critic/Ref 共享
    "reward_pool": [...],  # 可选:独立奖励模型池
    "teacher_pool": [...],  # 可选:独立教师模型池
}

数据集创建

  • create_rl_dataset():根据 data_config 动态选择数据集类(通过 get_dataset_class),支持 JSONL、Parquet 等格式
  • create_rl_sampler():创建采样器,支持随机采样(RandomSampler,可恢复种子)和顺序采样(SequentialSampler

2. RayPPOTrainer 详解

2.1 初始化

RayPPOTrainer.__init__() 完成以下初始化:

  1. 配置解析:解析 hybrid_engine、use_reference_policy、use_critic 等标志
  2. LoRA 判断 :根据 lora_rank > 0lora_adapter_path 判断 RefPolicy 是否融合在 Actor 中
  3. KL 控制器 :若 use_kl_in_reward=True,创建 AdaptiveKLControllerFixedKLController
  4. 数据加载器 :创建训练/验证 DataLoader(StatefulDataLoader,支持断点恢复)
2.2 init_workers() - Worker 初始化流程

init_workers() 是 Trainer 最复杂的初始化方法,负责创建所有分布式 Worker:
#mermaid-svg-RkWruDcd9ZGLibAn{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-RkWruDcd9ZGLibAn .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-RkWruDcd9ZGLibAn .error-icon{fill:#552222;}#mermaid-svg-RkWruDcd9ZGLibAn .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-RkWruDcd9ZGLibAn .marker{fill:#333333;stroke:#333333;}#mermaid-svg-RkWruDcd9ZGLibAn .marker.cross{stroke:#333333;}#mermaid-svg-RkWruDcd9ZGLibAn svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-RkWruDcd9ZGLibAn p{margin:0;}#mermaid-svg-RkWruDcd9ZGLibAn .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-RkWruDcd9ZGLibAn .cluster-label text{fill:#333;}#mermaid-svg-RkWruDcd9ZGLibAn .cluster-label span{color:#333;}#mermaid-svg-RkWruDcd9ZGLibAn .cluster-label span p{background-color:transparent;}#mermaid-svg-RkWruDcd9ZGLibAn .label text,#mermaid-svg-RkWruDcd9ZGLibAn span{fill:#333;color:#333;}#mermaid-svg-RkWruDcd9ZGLibAn .node rect,#mermaid-svg-RkWruDcd9ZGLibAn .node circle,#mermaid-svg-RkWruDcd9ZGLibAn .node ellipse,#mermaid-svg-RkWruDcd9ZGLibAn .node polygon,#mermaid-svg-RkWruDcd9ZGLibAn .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-RkWruDcd9ZGLibAn .rough-node .label text,#mermaid-svg-RkWruDcd9ZGLibAn .node .label text,#mermaid-svg-RkWruDcd9ZGLibAn .image-shape .label,#mermaid-svg-RkWruDcd9ZGLibAn .icon-shape .label{text-anchor:middle;}#mermaid-svg-RkWruDcd9ZGLibAn .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-RkWruDcd9ZGLibAn .rough-node .label,#mermaid-svg-RkWruDcd9ZGLibAn .node .label,#mermaid-svg-RkWruDcd9ZGLibAn .image-shape .label,#mermaid-svg-RkWruDcd9ZGLibAn .icon-shape .label{text-align:center;}#mermaid-svg-RkWruDcd9ZGLibAn .node.clickable{cursor:pointer;}#mermaid-svg-RkWruDcd9ZGLibAn .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-RkWruDcd9ZGLibAn .arrowheadPath{fill:#333333;}#mermaid-svg-RkWruDcd9ZGLibAn .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-RkWruDcd9ZGLibAn .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-RkWruDcd9ZGLibAn .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-RkWruDcd9ZGLibAn .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-RkWruDcd9ZGLibAn .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-RkWruDcd9ZGLibAn .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-RkWruDcd9ZGLibAn .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-RkWruDcd9ZGLibAn .cluster text{fill:#333;}#mermaid-svg-RkWruDcd9ZGLibAn .cluster span{color:#333;}#mermaid-svg-RkWruDcd9ZGLibAn div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-RkWruDcd9ZGLibAn .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-RkWruDcd9ZGLibAn rect.text{fill:none;stroke-width:0;}#mermaid-svg-RkWruDcd9ZGLibAn .icon-shape,#mermaid-svg-RkWruDcd9ZGLibAn .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-RkWruDcd9ZGLibAn .icon-shape p,#mermaid-svg-RkWruDcd9ZGLibAn .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-RkWruDcd9ZGLibAn .icon-shape .label rect,#mermaid-svg-RkWruDcd9ZGLibAn .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-RkWruDcd9ZGLibAn .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-RkWruDcd9ZGLibAn .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-RkWruDcd9ZGLibAn :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} init_workers
创建资源池
注册 ActorRolloutRefWorker
注册 Critic TrainingWorker
注册 RefPolicy Worker
create_colocated_worker_cls
RayWorkerGroup 初始化
spawn Worker 子组
初始化 Critic
critic_wg.reset
set_loss_fn - value_loss
初始化 ActorRollout
actor_rollout_wg.init_model
初始化 RefPolicy
ref_policy_wg.init_model
创建 RewardLoopManager
创建 AgentLoopManager
创建 LLMServerManager
创建 CheckpointEngineManager
sleep_replicas - 休眠副本以加载检查点

关键设计要点:

  • 共置 Worker(Colocated Worker) :Actor、Critic、RefPolicy 通过 create_colocated_worker_cls 创建为共置 Worker,共享同一组 GPU 资源,通过 spawn 拆分为独立子组
  • Critic 配置转换 :将 CriticConfig 转换为 TrainingWorkerConfig,统一由 TrainingWorker 处理
  • AgentLoopManager:管理异步 Rollout,支持流式奖励计算(Agent-Reward Loop)
  • CheckpointEngineManager:管理检查点的休眠/唤醒,实现训练权重到 Rollout 引擎的增量同步
2.3 fit() - 主训练循环详解

fit() 是 Trainer 的核心方法,实现了完整的 PPO 训练循环:
#mermaid-svg-LwplTXaqt6x9z99l{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-LwplTXaqt6x9z99l .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-LwplTXaqt6x9z99l .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-LwplTXaqt6x9z99l .error-icon{fill:#552222;}#mermaid-svg-LwplTXaqt6x9z99l .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-LwplTXaqt6x9z99l .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-LwplTXaqt6x9z99l .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-LwplTXaqt6x9z99l .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-LwplTXaqt6x9z99l .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-LwplTXaqt6x9z99l .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-LwplTXaqt6x9z99l .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-LwplTXaqt6x9z99l .marker{fill:#333333;stroke:#333333;}#mermaid-svg-LwplTXaqt6x9z99l .marker.cross{stroke:#333333;}#mermaid-svg-LwplTXaqt6x9z99l svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-LwplTXaqt6x9z99l p{margin:0;}#mermaid-svg-LwplTXaqt6x9z99l .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-LwplTXaqt6x9z99l .cluster-label text{fill:#333;}#mermaid-svg-LwplTXaqt6x9z99l .cluster-label span{color:#333;}#mermaid-svg-LwplTXaqt6x9z99l .cluster-label span p{background-color:transparent;}#mermaid-svg-LwplTXaqt6x9z99l .label text,#mermaid-svg-LwplTXaqt6x9z99l span{fill:#333;color:#333;}#mermaid-svg-LwplTXaqt6x9z99l .node rect,#mermaid-svg-LwplTXaqt6x9z99l .node circle,#mermaid-svg-LwplTXaqt6x9z99l .node ellipse,#mermaid-svg-LwplTXaqt6x9z99l .node polygon,#mermaid-svg-LwplTXaqt6x9z99l .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-LwplTXaqt6x9z99l .rough-node .label text,#mermaid-svg-LwplTXaqt6x9z99l .node .label text,#mermaid-svg-LwplTXaqt6x9z99l .image-shape .label,#mermaid-svg-LwplTXaqt6x9z99l .icon-shape .label{text-anchor:middle;}#mermaid-svg-LwplTXaqt6x9z99l .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-LwplTXaqt6x9z99l .rough-node .label,#mermaid-svg-LwplTXaqt6x9z99l .node .label,#mermaid-svg-LwplTXaqt6x9z99l .image-shape .label,#mermaid-svg-LwplTXaqt6x9z99l .icon-shape .label{text-align:center;}#mermaid-svg-LwplTXaqt6x9z99l .node.clickable{cursor:pointer;}#mermaid-svg-LwplTXaqt6x9z99l .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-LwplTXaqt6x9z99l .arrowheadPath{fill:#333333;}#mermaid-svg-LwplTXaqt6x9z99l .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-LwplTXaqt6x9z99l .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-LwplTXaqt6x9z99l .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-LwplTXaqt6x9z99l .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-LwplTXaqt6x9z99l .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-LwplTXaqt6x9z99l .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-LwplTXaqt6x9z99l .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-LwplTXaqt6x9z99l .cluster text{fill:#333;}#mermaid-svg-LwplTXaqt6x9z99l .cluster span{color:#333;}#mermaid-svg-LwplTXaqt6x9z99l div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-LwplTXaqt6x9z99l .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-LwplTXaqt6x9z99l rect.text{fill:none;stroke-width:0;}#mermaid-svg-LwplTXaqt6x9z99l .icon-shape,#mermaid-svg-LwplTXaqt6x9z99l .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-LwplTXaqt6x9z99l .icon-shape p,#mermaid-svg-LwplTXaqt6x9z99l .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-LwplTXaqt6x9z99l .icon-shape .label rect,#mermaid-svg-LwplTXaqt6x9z99l .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-LwplTXaqt6x9z99l .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-LwplTXaqt6x9z99l .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-LwplTXaqt6x9z99l :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} Yes
No
Yes
No
Yes
No
Yes
No
Yes
No
Yes
No
未完成
已完成
Yes
No
fit 开始
_load_checkpoint
checkpoint_manager.update_weights
val_before_train?
_validate
epoch 循环
batch 循环

  1. 生成 Rollout
  2. 计算奖励
  3. 平衡批次
  4. 计算 old_log_prob
  5. use_ref_policy?
    计算 ref_log_prob
  6. use_critic?
    计算 values
  7. 计算优势估计
    use_kl_in_reward?
    apply_kl_penalty
    rollout_correction?
    compute_rollout_correction
    compute_advantage
  8. use_critic?
    update_critic
    critic_warmup?
    仅 update_weights
  9. update_actor
    需要验证?
  10. save_checkpoint
    update_weights
    _validate
  11. 计算指标
    logger.log
    global_steps++

各步骤详解

a. 生成 Rollout

python 复制代码
gen_batch_output = gen_batch.repeat(repeat_times=rollout_n, interleave=True)
combined_gen_output = self.async_rollout_manager.generate_sequences(combined_gen_batch)
  • 将 prompt 重复 rollout.n 次(交错排列),用于 GRPO 等需要多样本的优势估计
  • 若使用 ReMax,额外生成一条贪心基线
  • 通过 AgentLoopManager 异步调度生成任务

b. 计算奖励

奖励计算分为两条路径:

  • 函数奖励 :通过 RewardLoopManager 调用用户自定义的 compute_score 函数
  • 模型奖励 :通过 RewardLoopManager 调用奖励模型(可共置或独立资源池)

c. KL 惩罚(apply_kl_penalty)

python 复制代码
def apply_kl_penalty(data, kl_ctrl, kl_penalty="kl"):
    kld = core_algos.kl_penalty(old_log_probs, ref_log_prob, kl_penalty)
    token_level_rewards = token_level_scores - beta * kld
    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)

KL 惩罚从 token 级别的奖励中减去 β × KL 散度,β 由 AdaptiveKLController 动态调整。

d. 优势估计

详见第 4 节 core_algos.py 分析。

e. PPO 策略更新(update_actor)

python 复制代码
actor_output = self.actor_rollout_wg.update_actor(batch_td)

将数据转换为 no-padding 格式后发送给 Actor Worker,Worker 内部执行 PPO 裁剪损失计算和梯度更新。

f. Critic 更新(update_critic)

python 复制代码
output = self.critic_wg.train_mini_batch(batch_td)

Critic 使用 value_loss 作为损失函数,支持 mini-batch 训练。

g. 检查点保存

保存 Actor、Critic 和 DataLoader 状态,支持本地和 HDFS 远程存储,可配置最大保留检查点数量。

h. 验证循环

_validate() 方法遍历验证集,生成响应并计算奖励,支持 majority voting、best-of-N 等评估指标。


3. 核心算法 core_algos.py

3.1 算法调用关系图

#mermaid-svg-aDyzNLGvJf4N39mG{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-aDyzNLGvJf4N39mG .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-aDyzNLGvJf4N39mG .error-icon{fill:#552222;}#mermaid-svg-aDyzNLGvJf4N39mG .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-aDyzNLGvJf4N39mG .marker{fill:#333333;stroke:#333333;}#mermaid-svg-aDyzNLGvJf4N39mG .marker.cross{stroke:#333333;}#mermaid-svg-aDyzNLGvJf4N39mG svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-aDyzNLGvJf4N39mG p{margin:0;}#mermaid-svg-aDyzNLGvJf4N39mG .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-aDyzNLGvJf4N39mG .cluster-label text{fill:#333;}#mermaid-svg-aDyzNLGvJf4N39mG .cluster-label span{color:#333;}#mermaid-svg-aDyzNLGvJf4N39mG .cluster-label span p{background-color:transparent;}#mermaid-svg-aDyzNLGvJf4N39mG .label text,#mermaid-svg-aDyzNLGvJf4N39mG span{fill:#333;color:#333;}#mermaid-svg-aDyzNLGvJf4N39mG .node rect,#mermaid-svg-aDyzNLGvJf4N39mG .node circle,#mermaid-svg-aDyzNLGvJf4N39mG .node ellipse,#mermaid-svg-aDyzNLGvJf4N39mG .node polygon,#mermaid-svg-aDyzNLGvJf4N39mG .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-aDyzNLGvJf4N39mG .rough-node .label text,#mermaid-svg-aDyzNLGvJf4N39mG .node .label text,#mermaid-svg-aDyzNLGvJf4N39mG .image-shape .label,#mermaid-svg-aDyzNLGvJf4N39mG .icon-shape .label{text-anchor:middle;}#mermaid-svg-aDyzNLGvJf4N39mG .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-aDyzNLGvJf4N39mG .rough-node .label,#mermaid-svg-aDyzNLGvJf4N39mG .node .label,#mermaid-svg-aDyzNLGvJf4N39mG .image-shape .label,#mermaid-svg-aDyzNLGvJf4N39mG .icon-shape .label{text-align:center;}#mermaid-svg-aDyzNLGvJf4N39mG .node.clickable{cursor:pointer;}#mermaid-svg-aDyzNLGvJf4N39mG .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-aDyzNLGvJf4N39mG .arrowheadPath{fill:#333333;}#mermaid-svg-aDyzNLGvJf4N39mG .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-aDyzNLGvJf4N39mG .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-aDyzNLGvJf4N39mG .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-aDyzNLGvJf4N39mG .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-aDyzNLGvJf4N39mG .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-aDyzNLGvJf4N39mG .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-aDyzNLGvJf4N39mG .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-aDyzNLGvJf4N39mG .cluster text{fill:#333;}#mermaid-svg-aDyzNLGvJf4N39mG .cluster span{color:#333;}#mermaid-svg-aDyzNLGvJf4N39mG div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-aDyzNLGvJf4N39mG .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-aDyzNLGvJf4N39mG rect.text{fill:none;stroke-width:0;}#mermaid-svg-aDyzNLGvJf4N39mG .icon-shape,#mermaid-svg-aDyzNLGvJf4N39mG .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-aDyzNLGvJf4N39mG .icon-shape p,#mermaid-svg-aDyzNLGvJf4N39mG .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-aDyzNLGvJf4N39mG .icon-shape .label rect,#mermaid-svg-aDyzNLGvJf4N39mG .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-aDyzNLGvJf4N39mG .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-aDyzNLGvJf4N39mG .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-aDyzNLGvJf4N39mG :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 辅助函数
策略损失注册表
优势估计器注册表
@register_adv_est 装饰器
compute_gae_advantage_return
compute_grpo_outcome_advantage
compute_grpo_vectorized_outcome_advantage
compute_grpo_passk_outcome_advantage
compute_rloo_outcome_advantage
compute_rloo_vectorized_outcome_advantage
compute_reinforce_plus_plus_outcome_advantage
compute_reinforce_plus_plus_baseline_outcome_advantage
compute_remax_outcome_advantage
compute_opo_outcome_advantage
compute_gpg_outcome_advantage
compute_gdpo_outcome_advantage
compute_optimal_token_baseline_advantage
compute_multi_turn_optimal_token_baseline_advantage
@register_policy_loss 装饰器
compute_policy_loss_vanilla
compute_policy_loss_dppo_tv
compute_policy_loss_dppo_kl
compute_policy_loss_gspo
compute_policy_loss_sapo
compute_policy_loss_gpg
compute_policy_loss_clip_cov
compute_policy_loss_kl_cov
compute_policy_loss_geo_mean
compute_policy_loss_cispo
compute_policy_loss_bypass_mode
agg_loss - 损失聚合
kl_penalty - KL 惩罚计算
AdaptiveKLController / FixedKLController
compute_value_loss - 价值损失

3.2 优势估计器详解

GAE(Generalized Advantage Estimation)

python 复制代码
def compute_gae_advantage_return(token_level_rewards, values, response_mask, gamma, lam):
    # 从后向前递推
    for t in reversed(range(gen_len)):
        delta = rewards[:, t] + gamma * nextvalues - values[:, t]
        lastgaelam = delta + gamma * lam * lastgaelam
        # 在 EOS 后跳过值和 TD 误差
    advantages = masked_whiten(advantages, response_mask)
    returns = advantages + values

GAE 是唯一需要 Critic 的优势估计器,通过 λ 参数平衡偏差与方差。

GRPO(Group Relative Policy Optimization)

python 复制代码
def compute_grpo_outcome_advantage(token_level_rewards, response_mask, index, ...):
    # 按 uid 分组
    for i in range(bsz):
        id2score[index[i]].append(scores[i])
    # 组内标准化
    scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
    # 广播到 token 级别
    scores = scores.unsqueeze(-1) * response_mask

GRPO 不需要 Critic,通过同一 prompt 的多个响应进行组内相对比较。norm_adv_by_std_in_grpo 控制是否除以标准差(Dr.GRPO 建议设为 False)。

REINFORCE++

python 复制代码
def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards, response_mask, config):
    # 从后向前累积回报
    for t in reversed(range(response_length)):
        running_return = rewards[:, t] + gamma * running_return
        returns[:, t] = running_return
        running_return *= response_mask[:, t]  # EOS 后重置
    advantages = masked_whiten(returns, response_mask)

REINFORCE++ 是 REINFORCE 的改进版,使用折扣累积回报而非单步奖励。

其他估计器

  • RLOO :Leave-One-Out 基线,advantage = n/(n-1) * score - n/(n-1) * mean
  • ReMax :使用贪心基线,advantage = returns - greedy_baseline
  • GDPO:组奖励解耦归一化,各奖励维度独立标准化后加权聚合
  • OTB:最优 Token 基线,使用累积路径方差代理作为权重
3.3 KL 惩罚控制
python 复制代码
class AdaptiveKLController:
    def update(self, current_kl, n_steps):
        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / horizon
        self.value *= mult

KL 惩罚支持多种估计方法(k1/k2/k3),k3+ 后缀表示使用 straight-through 梯度估计。

3.4 PPO 裁剪损失
python 复制代码
def compute_policy_loss_vanilla(old_log_prob, log_prob, advantages, response_mask, ...):
    ratio = exp(log_prob - old_log_prob)
    pg_losses1 = -advantages * ratio
    pg_losses2 = -advantages * clamp(ratio, 1-clip_ratio, 1+clip_ratio)
    # 双裁剪 PPO
    pg_losses3 = -advantages * clip_ratio_c
    pg_losses = where(advantages < 0, min(pg_losses3, max(pg_losses1, pg_losses2)), max(pg_losses1, pg_losses2))

支持非对称裁剪(clip_ratio_low / clip_ratio_high)和双裁剪 PPO(clip_ratio_c)。

3.5 损失聚合(agg_loss)

支持四种聚合模式:

模式 公式 说明
token-mean Σ(loss × mask) / total_tokens × dp_size 全局 token 均值
seq-mean-token-sum Σ_seq(Σ_token(loss × mask)) / batch_size × dp_size 序列均值
seq-mean-token-sum-norm seq-mean-token-sum / scale_factor 带归一化
seq-mean-token-mean Σ_seq(Σ_token(loss × mask) / token_count) / batch_size 序列内 token 均值

4. 奖励处理 reward.py

4.1 奖励函数加载

reward.py 提供灵活的奖励函数加载机制:

python 复制代码
def get_custom_reward_fn(config):
    # 从外部文件动态加载奖励函数
    raw_fn = load_extern_object(module_path=module_path, object_name=fn_name)
    # 合并额外参数
    return partial(_call_with_kwargs, raw_fn, reward_kwargs)

支持同步和异步奖励函数(_call_with_kwargs / _call_with_kwargs_async)。

4.2 RewardManager 加载
python 复制代码
def load_reward_manager(config, tokenizer, **reward_kwargs):
    compute_score = get_custom_reward_fn(config) or get_default_compute_score(...)
    reward_manager_cls = resolve_reward_manager_cls(config)
    return reward_manager_cls(config=config, tokenizer=tokenizer, compute_score=compute_score, ...)

RewardManager 支持两种来源:

  • register:从注册表获取
  • importlib:从外部模块动态导入
4.3 extract_reward 函数
python 复制代码
def extract_reward(batch: DataProto):
    reward_tensor = batch.batch["rm_scores"]
    reward_extra_keys = batch.meta_info.get("reward_extra_keys", [])
    reward_extra_infos_dict = {key: batch.non_tensor_batch[key] for key in reward_extra_keys}
    return reward_tensor, reward_extra_infos_dict

从批次数据中提取奖励张量和额外信息(如各维度奖励分数),用于 GDPO 等需要细粒度奖励的算法。

4.4 函数奖励 vs 模型奖励
特性 函数奖励 模型奖励
来源 用户自定义 compute_score 奖励模型推理
执行位置 RewardLoopManager RewardLoopManager
资源需求 CPU 即可 需要 GPU
独立资源池 不需要 可选(enable_resource_pool
典型场景 数学题验证、代码执行 对话质量评估

5. 指标系统 metric_utils.py

5.1 compute_data_metrics

计算训练数据的核心统计指标:

  • 分数指标critic/score/mean|max|min --- 序列级分数统计
  • 奖励指标critic/rewards/mean|max|min --- 序列级奖励统计
  • 优势指标critic/advantages/mean|max|min --- 优势值统计
  • 回报指标critic/returns/mean|max|min --- 回报值统计
  • 价值指标 (需 Critic):critic/values/mean|max|mincritic/vf_explained_var
  • 长度指标response_length/mean|max|min|clip_ratioprompt_length/...
  • 中止率response/aborted_ratio --- 响应长度为 0 的比例
  • 多轮指标num_turns/mean|max|min(多轮对话场景)

特别注意:非中止样本的响应长度单独统计(response_length_non_aborted/*),避免零长度样本扭曲统计。

5.2 compute_throughout_metrics

计算吞吐量指标:

python 复制代码
{
    "perf/total_num_tokens": total_num_tokens,
    "perf/time_per_step": time,
    "perf/throughput": total_num_tokens / (time * n_gpus),
}
5.3 compute_timing_metrics

计算各阶段耗时指标,包括原始耗时(timing_s/{name})和每 token 耗时(timing_per_token_ms/{name}):

  • gen:Rollout 生成(仅计算响应 token)
  • ref/values/adv/update_critic/update_actor:其他阶段(计算全部 token)
5.4 compute_variance_proxy_metrics

计算方差代理指标,用于监控梯度方差:

  • variance_proxy/proxy1_signal_strength:||ḡ||²(信号强度)
  • variance_proxy/proxy2_total_power:E\|\|ĝ_τ\|\|²(总功率)
  • variance_proxy/proxy3_pure_noise:方差估计
5.5 process_validation_metrics

处理验证指标,支持丰富的统计方法:

  • mean@N / std@N:N 个样本的均值/标准差
  • best@N/mean|std:Best-of-N 的 bootstrap 统计
  • worst@N/mean|std:Worst-of-N 的 bootstrap 统计
  • maj@N/mean|std:Majority voting 的 bootstrap 统计

6. Skip 机制

6.1 SkipManager 概述

SkipManager 是一个条件跳过训练步骤的管理器,通过装饰器模式实现:

python 复制代码
class SkipManager:
    config: SkipManagerConfig | None = None
    step: int = -1
    skip_instances: dict = {}

    @classmethod
    def annotate(cls, role: str, **kwargs_outer) -> Callable:
        # 装饰器:根据条件跳过或替换函数执行
6.2 工作流程
  1. SkipManager.init(config) --- 从配置初始化,创建所有注册的 Skip 实例

  2. SkipManager.set_step(step) --- 设置当前全局步数

  3. @SkipManager.annotate(role) --- 装饰器,在指定步骤上:

    • 若满足前置条件(meet_precondition),执行替代函数(warp_function
    • 否则正常执行,但收集数据(prepare_data)供后续步骤使用
  4. 验证阶段自动跳过(_should_bypass_for_validation

6.3 使用场景
  • 跳过某些步骤的 Rollout 生成,复用历史数据
  • 在特定步骤使用缓存的优势估计
  • 条件性地跳过奖励计算

7. Rollout 修正(rollout_corr_helper.py)

7.1 问题背景

在 PPO 训练中,Rollout 策略(如 vLLM BF16)与训练策略(如 FSDP FP32)之间存在精度差异,导致 off-policy 问题。rollout_corr_helper.py 提供完整的修正方案:

7.2 核心能力
  1. 重要性采样(IS)权重

    • Token 级别:w_t = π_train(y_t) / π_rollout(y_t)
    • 序列级别:w_seq = Π_t w_t
    • 截断 IS(TIS):clamp(w, max=threshold)
    • IcePop:w ∈ [lower, upper] 范围外置零
  2. 拒绝采样(RS)

    • 基于 KL 散度的硬信任区域
    • 支持 token 级别和序列级别
    • 多种 KL 估计器:k1(直接估计)、k2(MSE)、k3(低方差估计)
  3. Off-policy 诊断指标

    • KL 散度、困惑度(PPL)、χ² 散度
    • 有效样本量(ESS)
7.3 Bypass 模式
python 复制代码
def apply_bypass_mode(batch, rollout_corr_config, policy_loss_config):
    # 跳过 old_log_prob 重计算,直接使用 rollout_log_probs
    batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"]
    policy_loss_config["loss_mode"] = "bypass_mode"

Bypass 模式将三策略(π_rollout, π_old, π_θ)简化为两策略(π_rollout, π_θ),节省一次 Actor 前向传播,同时通过 IS 权重和 RS 修正分布偏移。

7.4 集成方式

ray_trainer.pyfit() 中:

python 复制代码
# 解耦模式:重新计算 old_log_prob 作为近端锚点
if not bypass_recomputing_logprobs:
    old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch)

# 计算 Rollout 修正
if rollout_corr_config is not None and "rollout_log_probs" in batch.batch:
    batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)

8. 辅助工具

8.1 padding_utils.py

upsample_batch_to_divisible_size() 函数用于 TransferQueue 模式下,当批次大小不能被 DP 大小或 mini-batch 大小整除时,填充合成样本:

  • 使用最小序列(1 prompt token + 1 response token)作为模板
  • 填充样本的 rm_scoresrollout_log_probs 等置零,不影响 PPO 损失
  • 在 tag 中标记 is_padding=True,供指标计算时过滤
8.2 utils.py(Role 与条件判断)

Role 枚举定义了所有 Worker 角色:

Role 说明
Actor 0 策略模型
Rollout 1 生成引擎
ActorRollout 2 Actor+Rollout 融合
Critic 3 价值模型
RefPolicy 4 参考策略
RewardModel 5 奖励模型
ActorRolloutRef 6 Actor+Rollout+Ref 融合
TeacherModel 8 教师模型(蒸馏)

条件判断函数:

  • need_reference_policy()use_kl_in_rewarduse_kl_loss 为 True
  • need_critic()critic.enable=Trueadv_estimator=GAE
  • need_reward_model()reward.reward_model.enable=True
  • need_teacher_policy():蒸馏启用时

【总】总结升华

核心设计要点回顾

  1. 单控制器编排:Trainer 在驱动进程上编排整个 PPO 循环,通过 Ray RPC 调度 Worker,自身仅执行轻量计算。这种设计使得训练逻辑集中、易于调试,同时充分利用分布式计算资源。

  2. 可插拔算法注册表@register_adv_est@register_policy_loss 装饰器实现了算法组件的开放-封闭原则------新增优势估计器或策略损失函数无需修改 Trainer 代码,只需注册即可。

  3. 双模式数据传输:DataProto 模式(异步 PPO)和 TransferQueue 模式(同步 PPO)分别适用于不同场景,后者支持零拷贝数据传输和流式奖励计算,是未来主推方向。

  4. Rollout 修正体系:IS 权重 + RS 拒绝采样的组合,系统性地解决了 rollout 与训练策略之间的分布偏移问题,Bypass 模式进一步优化了计算效率。

  5. 资源池管理 :通过 ResourcePoolManager 统一管理 GPU 资源分配,支持共置(Actor/Critic/Ref 共享 GPU)和独立(奖励模型独占 GPU)两种部署模式。

设计亮点与权衡

亮点

  • Critic 暖启动:初期仅更新 Critic,避免不准确的 Value 估计导致 Actor 更新不稳定
  • 序列长度平衡_balance_batch() 按 token 负载均衡分配到各 DP rank,减少流水线气泡
  • 增量权重同步CheckpointEngineManager 通过休眠/唤醒机制实现训练权重到 Rollout 引擎的高效同步
  • 全面的指标体系:从数据指标、时序指标、吞吐指标到方差代理指标,提供全方位的训练监控

权衡

  • 驱动进程瓶颈:所有优势估计在驱动进程上执行,对于极大批次可能成为瓶颈(但实际中优势计算相比 GPU 计算很轻量)
  • 两套 Trainer 的维护成本RayPPOTrainerPPOTrainer 存在大量重复逻辑,未来需统一
  • 配置复杂度:Hydra 配置项众多,新用户上手门槛较高

与其他 RL 框架 Trainer 的对比

特性 verl TRL (PPOTrainer) OpenRLHF
分布式后端 Ray Accelerate/DeepSpeed Ray
Rollout 引擎 vLLM/SGLang HuggingFace generate vLLM
优势估计器 10+ 种(注册表) GAE GAE/GRPO
策略损失 10+ 种(注册表) PPO-clip PPO-clip
Rollout 修正 IS + RS + Bypass
资源管理 ResourcePoolManager 单 GPU/单节点 Ray 资源池
多轮对话 AgentLoop 不支持 不支持
蒸馏支持 TeacherModel 不支持 不支持

verl 的 Trainer 在算法丰富度、分布式能力和工程可扩展性上具有明显优势,尤其适合大规模 LLM 的 RL 训练场景。其注册表机制和 Rollout 修正体系是区别于其他框架的核心差异化特性。

相关推荐
花骨朵轻创1 小时前
基于WeChatBot框架 API 封装的 Python SDK,提供简洁易用的接口调用方式
人工智能
deepdata_cn1 小时前
面向AI Agent标准化工作环境构建的驾驭工程(Harness Engineering)
人工智能·harness engine
沪漂阿龙1 小时前
Embedding:文本怎么变成向量?语义检索为什么能工作?
人工智能·python·embedding
me8321 小时前
【AI面试】大模型面试60问(面试速记+详解)
人工智能·学习·ai
星辰_mya1 小时前
autowired和resource区别
java·后端·spring·架构·原理
来自于狂人1 小时前
第5章 记忆管理——让Agent记住事情
人工智能·算法·语言模型·自然语言处理
生信碱移1 小时前
Vscode 连接 ipynb 选择内核无法自动显示 conda 环境对应的 python
服务器·人工智能·经验分享·vscode·python
lazy_ma1 小时前
大模型实操-Spring Boot集成LangChain4j
人工智能·后端
Cloud_Shy6182 小时前
解读《Effective Python 3rd Edition》:从练气到老魔(第七章 Item 48 - 50)
开发语言·人工智能·笔记·python·microsoft·学习方法