07 Trainer 模块详解
【总】开篇概述
Trainer 模块的核心价值
Trainer 模块是 verl 训练流程的控制中心,负责编排完整的 PPO(Proximal Policy Optimization)训练循环。它将数据加载、模型推理、奖励计算、优势估计、策略更新等环节串联成一个端到端的训练流水线,同时协调多个分布式 Worker 之间的协作。
verl 的 Trainer 并非简单的训练脚本,而是一个精心设计的分布式编排器:驱动进程(Driver)仅执行轻量级的优势计算与数据调度,所有重计算(前向推理、反向传播)均通过 Ray RPC 委托给远端 Worker 执行,从而实现了控制逻辑与计算逻辑的彻底解耦。
核心问题
verl 如何编排完整的 PPO 训练循环? 具体而言:
- 如何从入口脚本启动分布式训练?
- 如何注册和初始化各类 Worker(Actor、Critic、RefPolicy 等)?
- 如何在单控制器上协调 Rollout → Reward → Advantage → Update 的完整数据流?
- 如何支持 GAE、GRPO、REINFORCE++ 等多种优势估计器?
- 如何处理 KL 惩罚、Rollout 修正等高级算法特性?
全局概览图
#mermaid-svg-KsFRbPK1GX6EV6oR{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-KsFRbPK1GX6EV6oR .error-icon{fill:#552222;}#mermaid-svg-KsFRbPK1GX6EV6oR .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-KsFRbPK1GX6EV6oR .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-KsFRbPK1GX6EV6oR .marker{fill:#333333;stroke:#333333;}#mermaid-svg-KsFRbPK1GX6EV6oR .marker.cross{stroke:#333333;}#mermaid-svg-KsFRbPK1GX6EV6oR svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-KsFRbPK1GX6EV6oR p{margin:0;}#mermaid-svg-KsFRbPK1GX6EV6oR .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR .cluster-label text{fill:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR .cluster-label span{color:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR .cluster-label span p{background-color:transparent;}#mermaid-svg-KsFRbPK1GX6EV6oR .label text,#mermaid-svg-KsFRbPK1GX6EV6oR span{fill:#333;color:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR .node rect,#mermaid-svg-KsFRbPK1GX6EV6oR .node circle,#mermaid-svg-KsFRbPK1GX6EV6oR .node ellipse,#mermaid-svg-KsFRbPK1GX6EV6oR .node polygon,#mermaid-svg-KsFRbPK1GX6EV6oR .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-KsFRbPK1GX6EV6oR .rough-node .label text,#mermaid-svg-KsFRbPK1GX6EV6oR .node .label text,#mermaid-svg-KsFRbPK1GX6EV6oR .image-shape .label,#mermaid-svg-KsFRbPK1GX6EV6oR .icon-shape .label{text-anchor:middle;}#mermaid-svg-KsFRbPK1GX6EV6oR .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-KsFRbPK1GX6EV6oR .rough-node .label,#mermaid-svg-KsFRbPK1GX6EV6oR .node .label,#mermaid-svg-KsFRbPK1GX6EV6oR .image-shape .label,#mermaid-svg-KsFRbPK1GX6EV6oR .icon-shape .label{text-align:center;}#mermaid-svg-KsFRbPK1GX6EV6oR .node.clickable{cursor:pointer;}#mermaid-svg-KsFRbPK1GX6EV6oR .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-KsFRbPK1GX6EV6oR .arrowheadPath{fill:#333333;}#mermaid-svg-KsFRbPK1GX6EV6oR .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-KsFRbPK1GX6EV6oR .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-KsFRbPK1GX6EV6oR .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-KsFRbPK1GX6EV6oR .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-KsFRbPK1GX6EV6oR .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-KsFRbPK1GX6EV6oR .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-KsFRbPK1GX6EV6oR .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-KsFRbPK1GX6EV6oR .cluster text{fill:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR .cluster span{color:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-KsFRbPK1GX6EV6oR .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-KsFRbPK1GX6EV6oR rect.text{fill:none;stroke-width:0;}#mermaid-svg-KsFRbPK1GX6EV6oR .icon-shape,#mermaid-svg-KsFRbPK1GX6EV6oR .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-KsFRbPK1GX6EV6oR .icon-shape p,#mermaid-svg-KsFRbPK1GX6EV6oR .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-KsFRbPK1GX6EV6oR .icon-shape .label rect,#mermaid-svg-KsFRbPK1GX6EV6oR .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-KsFRbPK1GX6EV6oR .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-KsFRbPK1GX6EV6oR .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-KsFRbPK1GX6EV6oR :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} Worker 层
辅助模块
算法层
编排层
入口层
main_ppo.py / main_ppo_sync.py
Hydra 配置加载
Ray 集群初始化
TaskRunner.run
RayPPOTrainer / PPOTrainer
init_workers
fit - 主训练循环
core_algos.py
GAE 优势估计
GRPO 优势估计
REINFORCE++ 优势估计
PPO 裁剪损失
KL 惩罚控制
reward.py - 奖励提取
metric_utils.py - 指标计算
rollout_corr_helper.py - Rollout 修正
SkipManager - 条件跳过
padding_utils.py - 批次填充
ActorRolloutRefWorker
TrainingWorker - Critic
RewardLoopManager
AgentLoopManager
关键结论预览
- 单控制器编排模式:Trainer 运行在驱动进程上,通过 Ray RPC 调度 Worker,自身仅执行轻量计算(优势估计、KL 惩罚),实现了控制与计算的分离。
- 可插拔算法架构 :通过
@register_adv_est和@register_policy_loss装饰器注册机制,支持 GAE、GRPO、RLOO、ReMax 等 10+ 种优势估计器和 vanilla、DPPO、GSPO、SAPO 等 10+ 种策略损失函数。 - 双模式训练 :同时支持异步 PPO(
RayPPOTrainer,基于 DataProto)和同步 PPO(PPOTrainer,基于 TransferQueue),后者是未来主推方向。 - Rollout 修正体系:通过重要性采样(IS)权重和拒绝采样(RS)修正 rollout 与训练策略之间的分布偏移,支持 Bypass 模式(跳过 old_log_prob 重计算)。
- Critic 暖启动 :支持
critic_warmup机制,在指定步数内仅更新 Critic 而冻结 Actor,稳定训练初期。
【分】逐层展开
1. 入口与启动流程
1.1 两种入口脚本
verl 提供两种 PPO 训练入口:
| 脚本 | Trainer 类 | 数据传输 | 状态 |
|---|---|---|---|
main_ppo.py |
RayPPOTrainer |
DataProto | 已废弃(v0.8.0 移除) |
main_ppo_sync.py |
PPOTrainer |
TransferQueue | 推荐使用 |
两者共享 run_ppo() 函数进行 Ray 初始化和 TaskRunner 启动,差异在于数据传输机制和 Worker 协作方式。
1.2 启动时序图
RayPPOTrainer TaskRunner Ray 集群 Hydra 配置 main_ppo.py 用户 RayPPOTrainer TaskRunner Ray 集群 Hydra 配置 main_ppo.py 用户 #mermaid-svg-KPh0357k8oKN2Dio{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-KPh0357k8oKN2Dio .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-KPh0357k8oKN2Dio .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-KPh0357k8oKN2Dio .error-icon{fill:#552222;}#mermaid-svg-KPh0357k8oKN2Dio .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-KPh0357k8oKN2Dio .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-KPh0357k8oKN2Dio .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-KPh0357k8oKN2Dio .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-KPh0357k8oKN2Dio .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-KPh0357k8oKN2Dio .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-KPh0357k8oKN2Dio .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-KPh0357k8oKN2Dio .marker{fill:#333333;stroke:#333333;}#mermaid-svg-KPh0357k8oKN2Dio .marker.cross{stroke:#333333;}#mermaid-svg-KPh0357k8oKN2Dio svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-KPh0357k8oKN2Dio p{margin:0;}#mermaid-svg-KPh0357k8oKN2Dio .actor{stroke:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);fill:#ECECFF;}#mermaid-svg-KPh0357k8oKN2Dio text.actor>tspan{fill:black;stroke:none;}#mermaid-svg-KPh0357k8oKN2Dio .actor-line{stroke:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);}#mermaid-svg-KPh0357k8oKN2Dio .innerArc{stroke-width:1.5;stroke-dasharray:none;}#mermaid-svg-KPh0357k8oKN2Dio .messageLine0{stroke-width:1.5;stroke-dasharray:none;stroke:#333;}#mermaid-svg-KPh0357k8oKN2Dio .messageLine1{stroke-width:1.5;stroke-dasharray:2,2;stroke:#333;}#mermaid-svg-KPh0357k8oKN2Dio #arrowhead path{fill:#333;stroke:#333;}#mermaid-svg-KPh0357k8oKN2Dio .sequenceNumber{fill:white;}#mermaid-svg-KPh0357k8oKN2Dio #sequencenumber{fill:#333;}#mermaid-svg-KPh0357k8oKN2Dio #crosshead path{fill:#333;stroke:#333;}#mermaid-svg-KPh0357k8oKN2Dio .messageText{fill:#333;stroke:none;}#mermaid-svg-KPh0357k8oKN2Dio .labelBox{stroke:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);fill:#ECECFF;}#mermaid-svg-KPh0357k8oKN2Dio .labelText,#mermaid-svg-KPh0357k8oKN2Dio .labelText>tspan{fill:black;stroke:none;}#mermaid-svg-KPh0357k8oKN2Dio .loopText,#mermaid-svg-KPh0357k8oKN2Dio .loopText>tspan{fill:black;stroke:none;}#mermaid-svg-KPh0357k8oKN2Dio .loopLine{stroke-width:2px;stroke-dasharray:2,2;stroke:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);fill:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);}#mermaid-svg-KPh0357k8oKN2Dio .note{stroke:#aaaa33;fill:#fff5ad;}#mermaid-svg-KPh0357k8oKN2Dio .noteText,#mermaid-svg-KPh0357k8oKN2Dio .noteText>tspan{fill:black;stroke:none;}#mermaid-svg-KPh0357k8oKN2Dio .activation0{fill:#f4f4f4;stroke:#666;}#mermaid-svg-KPh0357k8oKN2Dio .activation1{fill:#f4f4f4;stroke:#666;}#mermaid-svg-KPh0357k8oKN2Dio .activation2{fill:#f4f4f4;stroke:#666;}#mermaid-svg-KPh0357k8oKN2Dio .actorPopupMenu{position:absolute;}#mermaid-svg-KPh0357k8oKN2Dio .actorPopupMenuPanel{position:absolute;fill:#ECECFF;box-shadow:0px 8px 16px 0px rgba(0,0,0,0.2);filter:drop-shadow(3px 5px 2px rgb(0 0 0 / 0.4));}#mermaid-svg-KPh0357k8oKN2Dio .actor-man line{stroke:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);fill:#ECECFF;}#mermaid-svg-KPh0357k8oKN2Dio .actor-man circle,#mermaid-svg-KPh0357k8oKN2Dio line{stroke:hsl(259.6261682243, 59.7765363128%, 87.9019607843%);fill:#ECECFF;stroke-width:2px;}#mermaid-svg-KPh0357k8oKN2Dio :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} python main_ppo.py加载 ppo_trainer.yaml返回 configauto_set_device(config)migrate_legacy_reward_impl(config)ray.init(runtime_env)集群就绪ray.get(runner.run.remote(config))add_actor_rollout_worker(config)add_critic_worker(config)init_resource_pool_mgr(config)create_rl_dataset / create_rl_samplerRayPPOTrainer(config, ...)trainer.init_workers()trainer.fit()训练完成返回结果
1.3 Hydra 配置加载
通过 @hydra.main(config_path="config", config_name="ppo_trainer") 装饰器自动加载 YAML 配置。关键配置项包括:
actor_rollout_ref:Actor/Rollout/RefPolicy 共享配置critic:Critic 模型配置algorithm:算法参数(adv_estimator、gamma、lam 等)reward:奖励配置(函数奖励/模型奖励)trainer:训练控制参数(总步数、保存频率等)data:数据集配置
1.4 Ray 集群初始化
run_ppo() 函数负责 Ray 集群初始化:
python
# 关键步骤
default_runtime_env = get_ppo_ray_runtime_env() # 设置 TOKENIZERS_PARALLELISM、NCCL_DEBUG 等
ray.init(**OmegaConf.to_container(ray_init_kwargs))
constants_ppo.py 中定义了 Ray 运行时环境变量,包括 tokenizer 并行化、NCCL 调试级别、vLLM 日志级别等。
1.5 TaskRunner
TaskRunner 是一个 Ray Remote 类,作为训练任务的远程执行器:
python
@ray.remote(num_cpus=1)
class TaskRunner:
def run(self, config):
# 1. 注册 Worker
self.add_actor_rollout_worker(config)
self.add_critic_worker(config)
# 2. 初始化资源池
self.init_resource_pool_mgr(config)
# 3. 创建数据集
train_dataset = create_rl_dataset(...)
val_dataset = create_rl_dataset(...)
train_sampler = create_rl_sampler(...)
# 4. 实例化 Trainer
trainer = RayPPOTrainer(config, ...)
trainer.init_workers()
trainer.fit()
Worker 注册逻辑:
add_actor_rollout_worker():注册ActorRolloutRefWorker,若使用 LoRA 则 RefPolicy 融合在 Actor 中(Role.ActorRollout),否则独立注册(Role.ActorRolloutRef)add_critic_worker():注册TrainingWorker作为 Critic,仅在need_critic(config)为 True 时注册add_reward_model_resource_pool():若启用奖励模型,注册独立资源池或共享全局池
ResourcePool 初始化:
python
resource_pool_spec = {
"global_pool": [n_gpus_per_node] * nnodes, # Actor/Critic/Ref 共享
"reward_pool": [...], # 可选:独立奖励模型池
"teacher_pool": [...], # 可选:独立教师模型池
}
数据集创建:
create_rl_dataset():根据data_config动态选择数据集类(通过get_dataset_class),支持 JSONL、Parquet 等格式create_rl_sampler():创建采样器,支持随机采样(RandomSampler,可恢复种子)和顺序采样(SequentialSampler)
2. RayPPOTrainer 详解
2.1 初始化
RayPPOTrainer.__init__() 完成以下初始化:
- 配置解析:解析 hybrid_engine、use_reference_policy、use_critic 等标志
- LoRA 判断 :根据
lora_rank > 0或lora_adapter_path判断 RefPolicy 是否融合在 Actor 中 - KL 控制器 :若
use_kl_in_reward=True,创建AdaptiveKLController或FixedKLController - 数据加载器 :创建训练/验证 DataLoader(
StatefulDataLoader,支持断点恢复)
2.2 init_workers() - Worker 初始化流程
init_workers() 是 Trainer 最复杂的初始化方法,负责创建所有分布式 Worker:
#mermaid-svg-RkWruDcd9ZGLibAn{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-RkWruDcd9ZGLibAn .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-RkWruDcd9ZGLibAn .error-icon{fill:#552222;}#mermaid-svg-RkWruDcd9ZGLibAn .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-RkWruDcd9ZGLibAn .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-RkWruDcd9ZGLibAn .marker{fill:#333333;stroke:#333333;}#mermaid-svg-RkWruDcd9ZGLibAn .marker.cross{stroke:#333333;}#mermaid-svg-RkWruDcd9ZGLibAn svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-RkWruDcd9ZGLibAn p{margin:0;}#mermaid-svg-RkWruDcd9ZGLibAn .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-RkWruDcd9ZGLibAn .cluster-label text{fill:#333;}#mermaid-svg-RkWruDcd9ZGLibAn .cluster-label span{color:#333;}#mermaid-svg-RkWruDcd9ZGLibAn .cluster-label span p{background-color:transparent;}#mermaid-svg-RkWruDcd9ZGLibAn .label text,#mermaid-svg-RkWruDcd9ZGLibAn span{fill:#333;color:#333;}#mermaid-svg-RkWruDcd9ZGLibAn .node rect,#mermaid-svg-RkWruDcd9ZGLibAn .node circle,#mermaid-svg-RkWruDcd9ZGLibAn .node ellipse,#mermaid-svg-RkWruDcd9ZGLibAn .node polygon,#mermaid-svg-RkWruDcd9ZGLibAn .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-RkWruDcd9ZGLibAn .rough-node .label text,#mermaid-svg-RkWruDcd9ZGLibAn .node .label text,#mermaid-svg-RkWruDcd9ZGLibAn .image-shape .label,#mermaid-svg-RkWruDcd9ZGLibAn .icon-shape .label{text-anchor:middle;}#mermaid-svg-RkWruDcd9ZGLibAn .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-RkWruDcd9ZGLibAn .rough-node .label,#mermaid-svg-RkWruDcd9ZGLibAn .node .label,#mermaid-svg-RkWruDcd9ZGLibAn .image-shape .label,#mermaid-svg-RkWruDcd9ZGLibAn .icon-shape .label{text-align:center;}#mermaid-svg-RkWruDcd9ZGLibAn .node.clickable{cursor:pointer;}#mermaid-svg-RkWruDcd9ZGLibAn .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-RkWruDcd9ZGLibAn .arrowheadPath{fill:#333333;}#mermaid-svg-RkWruDcd9ZGLibAn .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-RkWruDcd9ZGLibAn .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-RkWruDcd9ZGLibAn .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-RkWruDcd9ZGLibAn .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-RkWruDcd9ZGLibAn .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-RkWruDcd9ZGLibAn .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-RkWruDcd9ZGLibAn .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-RkWruDcd9ZGLibAn .cluster text{fill:#333;}#mermaid-svg-RkWruDcd9ZGLibAn .cluster span{color:#333;}#mermaid-svg-RkWruDcd9ZGLibAn div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-RkWruDcd9ZGLibAn .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-RkWruDcd9ZGLibAn rect.text{fill:none;stroke-width:0;}#mermaid-svg-RkWruDcd9ZGLibAn .icon-shape,#mermaid-svg-RkWruDcd9ZGLibAn .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-RkWruDcd9ZGLibAn .icon-shape p,#mermaid-svg-RkWruDcd9ZGLibAn .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-RkWruDcd9ZGLibAn .icon-shape .label rect,#mermaid-svg-RkWruDcd9ZGLibAn .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-RkWruDcd9ZGLibAn .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-RkWruDcd9ZGLibAn .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-RkWruDcd9ZGLibAn :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} init_workers
创建资源池
注册 ActorRolloutRefWorker
注册 Critic TrainingWorker
注册 RefPolicy Worker
create_colocated_worker_cls
RayWorkerGroup 初始化
spawn Worker 子组
初始化 Critic
critic_wg.reset
set_loss_fn - value_loss
初始化 ActorRollout
actor_rollout_wg.init_model
初始化 RefPolicy
ref_policy_wg.init_model
创建 RewardLoopManager
创建 AgentLoopManager
创建 LLMServerManager
创建 CheckpointEngineManager
sleep_replicas - 休眠副本以加载检查点
关键设计要点:
- 共置 Worker(Colocated Worker) :Actor、Critic、RefPolicy 通过
create_colocated_worker_cls创建为共置 Worker,共享同一组 GPU 资源,通过spawn拆分为独立子组 - Critic 配置转换 :将
CriticConfig转换为TrainingWorkerConfig,统一由TrainingWorker处理 - AgentLoopManager:管理异步 Rollout,支持流式奖励计算(Agent-Reward Loop)
- CheckpointEngineManager:管理检查点的休眠/唤醒,实现训练权重到 Rollout 引擎的增量同步
2.3 fit() - 主训练循环详解
fit() 是 Trainer 的核心方法,实现了完整的 PPO 训练循环:
#mermaid-svg-LwplTXaqt6x9z99l{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-LwplTXaqt6x9z99l .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-LwplTXaqt6x9z99l .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-LwplTXaqt6x9z99l .error-icon{fill:#552222;}#mermaid-svg-LwplTXaqt6x9z99l .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-LwplTXaqt6x9z99l .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-LwplTXaqt6x9z99l .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-LwplTXaqt6x9z99l .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-LwplTXaqt6x9z99l .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-LwplTXaqt6x9z99l .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-LwplTXaqt6x9z99l .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-LwplTXaqt6x9z99l .marker{fill:#333333;stroke:#333333;}#mermaid-svg-LwplTXaqt6x9z99l .marker.cross{stroke:#333333;}#mermaid-svg-LwplTXaqt6x9z99l svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-LwplTXaqt6x9z99l p{margin:0;}#mermaid-svg-LwplTXaqt6x9z99l .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-LwplTXaqt6x9z99l .cluster-label text{fill:#333;}#mermaid-svg-LwplTXaqt6x9z99l .cluster-label span{color:#333;}#mermaid-svg-LwplTXaqt6x9z99l .cluster-label span p{background-color:transparent;}#mermaid-svg-LwplTXaqt6x9z99l .label text,#mermaid-svg-LwplTXaqt6x9z99l span{fill:#333;color:#333;}#mermaid-svg-LwplTXaqt6x9z99l .node rect,#mermaid-svg-LwplTXaqt6x9z99l .node circle,#mermaid-svg-LwplTXaqt6x9z99l .node ellipse,#mermaid-svg-LwplTXaqt6x9z99l .node polygon,#mermaid-svg-LwplTXaqt6x9z99l .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-LwplTXaqt6x9z99l .rough-node .label text,#mermaid-svg-LwplTXaqt6x9z99l .node .label text,#mermaid-svg-LwplTXaqt6x9z99l .image-shape .label,#mermaid-svg-LwplTXaqt6x9z99l .icon-shape .label{text-anchor:middle;}#mermaid-svg-LwplTXaqt6x9z99l .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-LwplTXaqt6x9z99l .rough-node .label,#mermaid-svg-LwplTXaqt6x9z99l .node .label,#mermaid-svg-LwplTXaqt6x9z99l .image-shape .label,#mermaid-svg-LwplTXaqt6x9z99l .icon-shape .label{text-align:center;}#mermaid-svg-LwplTXaqt6x9z99l .node.clickable{cursor:pointer;}#mermaid-svg-LwplTXaqt6x9z99l .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-LwplTXaqt6x9z99l .arrowheadPath{fill:#333333;}#mermaid-svg-LwplTXaqt6x9z99l .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-LwplTXaqt6x9z99l .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-LwplTXaqt6x9z99l .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-LwplTXaqt6x9z99l .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-LwplTXaqt6x9z99l .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-LwplTXaqt6x9z99l .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-LwplTXaqt6x9z99l .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-LwplTXaqt6x9z99l .cluster text{fill:#333;}#mermaid-svg-LwplTXaqt6x9z99l .cluster span{color:#333;}#mermaid-svg-LwplTXaqt6x9z99l div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-LwplTXaqt6x9z99l .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-LwplTXaqt6x9z99l rect.text{fill:none;stroke-width:0;}#mermaid-svg-LwplTXaqt6x9z99l .icon-shape,#mermaid-svg-LwplTXaqt6x9z99l .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-LwplTXaqt6x9z99l .icon-shape p,#mermaid-svg-LwplTXaqt6x9z99l .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-LwplTXaqt6x9z99l .icon-shape .label rect,#mermaid-svg-LwplTXaqt6x9z99l .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-LwplTXaqt6x9z99l .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-LwplTXaqt6x9z99l .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-LwplTXaqt6x9z99l :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} Yes
No
Yes
No
Yes
No
Yes
No
Yes
No
Yes
No
未完成
已完成
Yes
No
fit 开始
_load_checkpoint
checkpoint_manager.update_weights
val_before_train?
_validate
epoch 循环
batch 循环
- 生成 Rollout
- 计算奖励
- 平衡批次
- 计算 old_log_prob
- use_ref_policy?
计算 ref_log_prob - use_critic?
计算 values - 计算优势估计
use_kl_in_reward?
apply_kl_penalty
rollout_correction?
compute_rollout_correction
compute_advantage - use_critic?
update_critic
critic_warmup?
仅 update_weights - update_actor
需要验证? - save_checkpoint
update_weights
_validate - 计算指标
logger.log
global_steps++
各步骤详解:
a. 生成 Rollout
python
gen_batch_output = gen_batch.repeat(repeat_times=rollout_n, interleave=True)
combined_gen_output = self.async_rollout_manager.generate_sequences(combined_gen_batch)
- 将 prompt 重复
rollout.n次(交错排列),用于 GRPO 等需要多样本的优势估计 - 若使用 ReMax,额外生成一条贪心基线
- 通过
AgentLoopManager异步调度生成任务
b. 计算奖励
奖励计算分为两条路径:
- 函数奖励 :通过
RewardLoopManager调用用户自定义的compute_score函数 - 模型奖励 :通过
RewardLoopManager调用奖励模型(可共置或独立资源池)
c. KL 惩罚(apply_kl_penalty)
python
def apply_kl_penalty(data, kl_ctrl, kl_penalty="kl"):
kld = core_algos.kl_penalty(old_log_probs, ref_log_prob, kl_penalty)
token_level_rewards = token_level_scores - beta * kld
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
KL 惩罚从 token 级别的奖励中减去 β × KL 散度,β 由 AdaptiveKLController 动态调整。
d. 优势估计
详见第 4 节 core_algos.py 分析。
e. PPO 策略更新(update_actor)
python
actor_output = self.actor_rollout_wg.update_actor(batch_td)
将数据转换为 no-padding 格式后发送给 Actor Worker,Worker 内部执行 PPO 裁剪损失计算和梯度更新。
f. Critic 更新(update_critic)
python
output = self.critic_wg.train_mini_batch(batch_td)
Critic 使用 value_loss 作为损失函数,支持 mini-batch 训练。
g. 检查点保存
保存 Actor、Critic 和 DataLoader 状态,支持本地和 HDFS 远程存储,可配置最大保留检查点数量。
h. 验证循环
_validate() 方法遍历验证集,生成响应并计算奖励,支持 majority voting、best-of-N 等评估指标。
3. 核心算法 core_algos.py
3.1 算法调用关系图
#mermaid-svg-aDyzNLGvJf4N39mG{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-aDyzNLGvJf4N39mG .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-aDyzNLGvJf4N39mG .error-icon{fill:#552222;}#mermaid-svg-aDyzNLGvJf4N39mG .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-aDyzNLGvJf4N39mG .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-aDyzNLGvJf4N39mG .marker{fill:#333333;stroke:#333333;}#mermaid-svg-aDyzNLGvJf4N39mG .marker.cross{stroke:#333333;}#mermaid-svg-aDyzNLGvJf4N39mG svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-aDyzNLGvJf4N39mG p{margin:0;}#mermaid-svg-aDyzNLGvJf4N39mG .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-aDyzNLGvJf4N39mG .cluster-label text{fill:#333;}#mermaid-svg-aDyzNLGvJf4N39mG .cluster-label span{color:#333;}#mermaid-svg-aDyzNLGvJf4N39mG .cluster-label span p{background-color:transparent;}#mermaid-svg-aDyzNLGvJf4N39mG .label text,#mermaid-svg-aDyzNLGvJf4N39mG span{fill:#333;color:#333;}#mermaid-svg-aDyzNLGvJf4N39mG .node rect,#mermaid-svg-aDyzNLGvJf4N39mG .node circle,#mermaid-svg-aDyzNLGvJf4N39mG .node ellipse,#mermaid-svg-aDyzNLGvJf4N39mG .node polygon,#mermaid-svg-aDyzNLGvJf4N39mG .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-aDyzNLGvJf4N39mG .rough-node .label text,#mermaid-svg-aDyzNLGvJf4N39mG .node .label text,#mermaid-svg-aDyzNLGvJf4N39mG .image-shape .label,#mermaid-svg-aDyzNLGvJf4N39mG .icon-shape .label{text-anchor:middle;}#mermaid-svg-aDyzNLGvJf4N39mG .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-aDyzNLGvJf4N39mG .rough-node .label,#mermaid-svg-aDyzNLGvJf4N39mG .node .label,#mermaid-svg-aDyzNLGvJf4N39mG .image-shape .label,#mermaid-svg-aDyzNLGvJf4N39mG .icon-shape .label{text-align:center;}#mermaid-svg-aDyzNLGvJf4N39mG .node.clickable{cursor:pointer;}#mermaid-svg-aDyzNLGvJf4N39mG .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-aDyzNLGvJf4N39mG .arrowheadPath{fill:#333333;}#mermaid-svg-aDyzNLGvJf4N39mG .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-aDyzNLGvJf4N39mG .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-aDyzNLGvJf4N39mG .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-aDyzNLGvJf4N39mG .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-aDyzNLGvJf4N39mG .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-aDyzNLGvJf4N39mG .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-aDyzNLGvJf4N39mG .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-aDyzNLGvJf4N39mG .cluster text{fill:#333;}#mermaid-svg-aDyzNLGvJf4N39mG .cluster span{color:#333;}#mermaid-svg-aDyzNLGvJf4N39mG div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-aDyzNLGvJf4N39mG .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-aDyzNLGvJf4N39mG rect.text{fill:none;stroke-width:0;}#mermaid-svg-aDyzNLGvJf4N39mG .icon-shape,#mermaid-svg-aDyzNLGvJf4N39mG .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-aDyzNLGvJf4N39mG .icon-shape p,#mermaid-svg-aDyzNLGvJf4N39mG .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-aDyzNLGvJf4N39mG .icon-shape .label rect,#mermaid-svg-aDyzNLGvJf4N39mG .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-aDyzNLGvJf4N39mG .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-aDyzNLGvJf4N39mG .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-aDyzNLGvJf4N39mG :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 辅助函数
策略损失注册表
优势估计器注册表
@register_adv_est 装饰器
compute_gae_advantage_return
compute_grpo_outcome_advantage
compute_grpo_vectorized_outcome_advantage
compute_grpo_passk_outcome_advantage
compute_rloo_outcome_advantage
compute_rloo_vectorized_outcome_advantage
compute_reinforce_plus_plus_outcome_advantage
compute_reinforce_plus_plus_baseline_outcome_advantage
compute_remax_outcome_advantage
compute_opo_outcome_advantage
compute_gpg_outcome_advantage
compute_gdpo_outcome_advantage
compute_optimal_token_baseline_advantage
compute_multi_turn_optimal_token_baseline_advantage
@register_policy_loss 装饰器
compute_policy_loss_vanilla
compute_policy_loss_dppo_tv
compute_policy_loss_dppo_kl
compute_policy_loss_gspo
compute_policy_loss_sapo
compute_policy_loss_gpg
compute_policy_loss_clip_cov
compute_policy_loss_kl_cov
compute_policy_loss_geo_mean
compute_policy_loss_cispo
compute_policy_loss_bypass_mode
agg_loss - 损失聚合
kl_penalty - KL 惩罚计算
AdaptiveKLController / FixedKLController
compute_value_loss - 价值损失
3.2 优势估计器详解
GAE(Generalized Advantage Estimation)
python
def compute_gae_advantage_return(token_level_rewards, values, response_mask, gamma, lam):
# 从后向前递推
for t in reversed(range(gen_len)):
delta = rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam = delta + gamma * lam * lastgaelam
# 在 EOS 后跳过值和 TD 误差
advantages = masked_whiten(advantages, response_mask)
returns = advantages + values
GAE 是唯一需要 Critic 的优势估计器,通过 λ 参数平衡偏差与方差。
GRPO(Group Relative Policy Optimization)
python
def compute_grpo_outcome_advantage(token_level_rewards, response_mask, index, ...):
# 按 uid 分组
for i in range(bsz):
id2score[index[i]].append(scores[i])
# 组内标准化
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
# 广播到 token 级别
scores = scores.unsqueeze(-1) * response_mask
GRPO 不需要 Critic,通过同一 prompt 的多个响应进行组内相对比较。norm_adv_by_std_in_grpo 控制是否除以标准差(Dr.GRPO 建议设为 False)。
REINFORCE++
python
def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards, response_mask, config):
# 从后向前累积回报
for t in reversed(range(response_length)):
running_return = rewards[:, t] + gamma * running_return
returns[:, t] = running_return
running_return *= response_mask[:, t] # EOS 后重置
advantages = masked_whiten(returns, response_mask)
REINFORCE++ 是 REINFORCE 的改进版,使用折扣累积回报而非单步奖励。
其他估计器:
- RLOO :Leave-One-Out 基线,
advantage = n/(n-1) * score - n/(n-1) * mean - ReMax :使用贪心基线,
advantage = returns - greedy_baseline - GDPO:组奖励解耦归一化,各奖励维度独立标准化后加权聚合
- OTB:最优 Token 基线,使用累积路径方差代理作为权重
3.3 KL 惩罚控制
python
class AdaptiveKLController:
def update(self, current_kl, n_steps):
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / horizon
self.value *= mult
KL 惩罚支持多种估计方法(k1/k2/k3),k3+ 后缀表示使用 straight-through 梯度估计。
3.4 PPO 裁剪损失
python
def compute_policy_loss_vanilla(old_log_prob, log_prob, advantages, response_mask, ...):
ratio = exp(log_prob - old_log_prob)
pg_losses1 = -advantages * ratio
pg_losses2 = -advantages * clamp(ratio, 1-clip_ratio, 1+clip_ratio)
# 双裁剪 PPO
pg_losses3 = -advantages * clip_ratio_c
pg_losses = where(advantages < 0, min(pg_losses3, max(pg_losses1, pg_losses2)), max(pg_losses1, pg_losses2))
支持非对称裁剪(clip_ratio_low / clip_ratio_high)和双裁剪 PPO(clip_ratio_c)。
3.5 损失聚合(agg_loss)
支持四种聚合模式:
| 模式 | 公式 | 说明 |
|---|---|---|
token-mean |
Σ(loss × mask) / total_tokens × dp_size | 全局 token 均值 |
seq-mean-token-sum |
Σ_seq(Σ_token(loss × mask)) / batch_size × dp_size | 序列均值 |
seq-mean-token-sum-norm |
seq-mean-token-sum / scale_factor | 带归一化 |
seq-mean-token-mean |
Σ_seq(Σ_token(loss × mask) / token_count) / batch_size | 序列内 token 均值 |
4. 奖励处理 reward.py
4.1 奖励函数加载
reward.py 提供灵活的奖励函数加载机制:
python
def get_custom_reward_fn(config):
# 从外部文件动态加载奖励函数
raw_fn = load_extern_object(module_path=module_path, object_name=fn_name)
# 合并额外参数
return partial(_call_with_kwargs, raw_fn, reward_kwargs)
支持同步和异步奖励函数(_call_with_kwargs / _call_with_kwargs_async)。
4.2 RewardManager 加载
python
def load_reward_manager(config, tokenizer, **reward_kwargs):
compute_score = get_custom_reward_fn(config) or get_default_compute_score(...)
reward_manager_cls = resolve_reward_manager_cls(config)
return reward_manager_cls(config=config, tokenizer=tokenizer, compute_score=compute_score, ...)
RewardManager 支持两种来源:
register:从注册表获取importlib:从外部模块动态导入
4.3 extract_reward 函数
python
def extract_reward(batch: DataProto):
reward_tensor = batch.batch["rm_scores"]
reward_extra_keys = batch.meta_info.get("reward_extra_keys", [])
reward_extra_infos_dict = {key: batch.non_tensor_batch[key] for key in reward_extra_keys}
return reward_tensor, reward_extra_infos_dict
从批次数据中提取奖励张量和额外信息(如各维度奖励分数),用于 GDPO 等需要细粒度奖励的算法。
4.4 函数奖励 vs 模型奖励
| 特性 | 函数奖励 | 模型奖励 |
|---|---|---|
| 来源 | 用户自定义 compute_score |
奖励模型推理 |
| 执行位置 | RewardLoopManager | RewardLoopManager |
| 资源需求 | CPU 即可 | 需要 GPU |
| 独立资源池 | 不需要 | 可选(enable_resource_pool) |
| 典型场景 | 数学题验证、代码执行 | 对话质量评估 |
5. 指标系统 metric_utils.py
5.1 compute_data_metrics
计算训练数据的核心统计指标:
- 分数指标 :
critic/score/mean|max|min--- 序列级分数统计 - 奖励指标 :
critic/rewards/mean|max|min--- 序列级奖励统计 - 优势指标 :
critic/advantages/mean|max|min--- 优势值统计 - 回报指标 :
critic/returns/mean|max|min--- 回报值统计 - 价值指标 (需 Critic):
critic/values/mean|max|min、critic/vf_explained_var - 长度指标 :
response_length/mean|max|min|clip_ratio、prompt_length/... - 中止率 :
response/aborted_ratio--- 响应长度为 0 的比例 - 多轮指标 :
num_turns/mean|max|min(多轮对话场景)
特别注意:非中止样本的响应长度单独统计(response_length_non_aborted/*),避免零长度样本扭曲统计。
5.2 compute_throughout_metrics
计算吞吐量指标:
python
{
"perf/total_num_tokens": total_num_tokens,
"perf/time_per_step": time,
"perf/throughput": total_num_tokens / (time * n_gpus),
}
5.3 compute_timing_metrics
计算各阶段耗时指标,包括原始耗时(timing_s/{name})和每 token 耗时(timing_per_token_ms/{name}):
gen:Rollout 生成(仅计算响应 token)ref/values/adv/update_critic/update_actor:其他阶段(计算全部 token)
5.4 compute_variance_proxy_metrics
计算方差代理指标,用于监控梯度方差:
variance_proxy/proxy1_signal_strength:||ḡ||²(信号强度)variance_proxy/proxy2_total_power:E\|\|ĝ_τ\|\|²(总功率)variance_proxy/proxy3_pure_noise:方差估计
5.5 process_validation_metrics
处理验证指标,支持丰富的统计方法:
mean@N/std@N:N 个样本的均值/标准差best@N/mean|std:Best-of-N 的 bootstrap 统计worst@N/mean|std:Worst-of-N 的 bootstrap 统计maj@N/mean|std:Majority voting 的 bootstrap 统计
6. Skip 机制
6.1 SkipManager 概述
SkipManager 是一个条件跳过训练步骤的管理器,通过装饰器模式实现:
python
class SkipManager:
config: SkipManagerConfig | None = None
step: int = -1
skip_instances: dict = {}
@classmethod
def annotate(cls, role: str, **kwargs_outer) -> Callable:
# 装饰器:根据条件跳过或替换函数执行
6.2 工作流程
-
SkipManager.init(config)--- 从配置初始化,创建所有注册的 Skip 实例 -
SkipManager.set_step(step)--- 设置当前全局步数 -
@SkipManager.annotate(role)--- 装饰器,在指定步骤上:- 若满足前置条件(
meet_precondition),执行替代函数(warp_function) - 否则正常执行,但收集数据(
prepare_data)供后续步骤使用
- 若满足前置条件(
-
验证阶段自动跳过(
_should_bypass_for_validation)
6.3 使用场景
- 跳过某些步骤的 Rollout 生成,复用历史数据
- 在特定步骤使用缓存的优势估计
- 条件性地跳过奖励计算
7. Rollout 修正(rollout_corr_helper.py)
7.1 问题背景
在 PPO 训练中,Rollout 策略(如 vLLM BF16)与训练策略(如 FSDP FP32)之间存在精度差异,导致 off-policy 问题。rollout_corr_helper.py 提供完整的修正方案:
7.2 核心能力
-
重要性采样(IS)权重:
- Token 级别:
w_t = π_train(y_t) / π_rollout(y_t) - 序列级别:
w_seq = Π_t w_t - 截断 IS(TIS):
clamp(w, max=threshold) - IcePop:
w ∈ [lower, upper]范围外置零
- Token 级别:
-
拒绝采样(RS):
- 基于 KL 散度的硬信任区域
- 支持 token 级别和序列级别
- 多种 KL 估计器:k1(直接估计)、k2(MSE)、k3(低方差估计)
-
Off-policy 诊断指标:
- KL 散度、困惑度(PPL)、χ² 散度
- 有效样本量(ESS)
7.3 Bypass 模式
python
def apply_bypass_mode(batch, rollout_corr_config, policy_loss_config):
# 跳过 old_log_prob 重计算,直接使用 rollout_log_probs
batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"]
policy_loss_config["loss_mode"] = "bypass_mode"
Bypass 模式将三策略(π_rollout, π_old, π_θ)简化为两策略(π_rollout, π_θ),节省一次 Actor 前向传播,同时通过 IS 权重和 RS 修正分布偏移。
7.4 集成方式
在 ray_trainer.py 的 fit() 中:
python
# 解耦模式:重新计算 old_log_prob 作为近端锚点
if not bypass_recomputing_logprobs:
old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch)
# 计算 Rollout 修正
if rollout_corr_config is not None and "rollout_log_probs" in batch.batch:
batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
8. 辅助工具
8.1 padding_utils.py
upsample_batch_to_divisible_size() 函数用于 TransferQueue 模式下,当批次大小不能被 DP 大小或 mini-batch 大小整除时,填充合成样本:
- 使用最小序列(1 prompt token + 1 response token)作为模板
- 填充样本的
rm_scores、rollout_log_probs等置零,不影响 PPO 损失 - 在 tag 中标记
is_padding=True,供指标计算时过滤
8.2 utils.py(Role 与条件判断)
Role 枚举定义了所有 Worker 角色:
| Role | 值 | 说明 |
|---|---|---|
| Actor | 0 | 策略模型 |
| Rollout | 1 | 生成引擎 |
| ActorRollout | 2 | Actor+Rollout 融合 |
| Critic | 3 | 价值模型 |
| RefPolicy | 4 | 参考策略 |
| RewardModel | 5 | 奖励模型 |
| ActorRolloutRef | 6 | Actor+Rollout+Ref 融合 |
| TeacherModel | 8 | 教师模型(蒸馏) |
条件判断函数:
need_reference_policy():use_kl_in_reward或use_kl_loss为 Trueneed_critic():critic.enable=True或adv_estimator=GAEneed_reward_model():reward.reward_model.enable=Trueneed_teacher_policy():蒸馏启用时
【总】总结升华
核心设计要点回顾
-
单控制器编排:Trainer 在驱动进程上编排整个 PPO 循环,通过 Ray RPC 调度 Worker,自身仅执行轻量计算。这种设计使得训练逻辑集中、易于调试,同时充分利用分布式计算资源。
-
可插拔算法注册表 :
@register_adv_est和@register_policy_loss装饰器实现了算法组件的开放-封闭原则------新增优势估计器或策略损失函数无需修改 Trainer 代码,只需注册即可。 -
双模式数据传输:DataProto 模式(异步 PPO)和 TransferQueue 模式(同步 PPO)分别适用于不同场景,后者支持零拷贝数据传输和流式奖励计算,是未来主推方向。
-
Rollout 修正体系:IS 权重 + RS 拒绝采样的组合,系统性地解决了 rollout 与训练策略之间的分布偏移问题,Bypass 模式进一步优化了计算效率。
-
资源池管理 :通过
ResourcePoolManager统一管理 GPU 资源分配,支持共置(Actor/Critic/Ref 共享 GPU)和独立(奖励模型独占 GPU)两种部署模式。
设计亮点与权衡
亮点:
- Critic 暖启动:初期仅更新 Critic,避免不准确的 Value 估计导致 Actor 更新不稳定
- 序列长度平衡 :
_balance_batch()按 token 负载均衡分配到各 DP rank,减少流水线气泡 - 增量权重同步 :
CheckpointEngineManager通过休眠/唤醒机制实现训练权重到 Rollout 引擎的高效同步 - 全面的指标体系:从数据指标、时序指标、吞吐指标到方差代理指标,提供全方位的训练监控
权衡:
- 驱动进程瓶颈:所有优势估计在驱动进程上执行,对于极大批次可能成为瓶颈(但实际中优势计算相比 GPU 计算很轻量)
- 两套 Trainer 的维护成本 :
RayPPOTrainer和PPOTrainer存在大量重复逻辑,未来需统一 - 配置复杂度:Hydra 配置项众多,新用户上手门槛较高
与其他 RL 框架 Trainer 的对比
| 特性 | verl | TRL (PPOTrainer) | OpenRLHF |
|---|---|---|---|
| 分布式后端 | Ray | Accelerate/DeepSpeed | Ray |
| Rollout 引擎 | vLLM/SGLang | HuggingFace generate | vLLM |
| 优势估计器 | 10+ 种(注册表) | GAE | GAE/GRPO |
| 策略损失 | 10+ 种(注册表) | PPO-clip | PPO-clip |
| Rollout 修正 | IS + RS + Bypass | 无 | 无 |
| 资源管理 | ResourcePoolManager | 单 GPU/单节点 | Ray 资源池 |
| 多轮对话 | AgentLoop | 不支持 | 不支持 |
| 蒸馏支持 | TeacherModel | 不支持 | 不支持 |
verl 的 Trainer 在算法丰富度、分布式能力和工程可扩展性上具有明显优势,尤其适合大规模 LLM 的 RL 训练场景。其注册表机制和 Rollout 修正体系是区别于其他框架的核心差异化特性。