召莆赘平参考verl对dapo的实现,首先咱们看一下入口.sh和.py文件,在./recipe/dapo/文件夹中有以下目录
.
├── config
│ ├── dapo_megatron_trainer.yaml
│ └── dapo_trainer.yaml
├── dapo_ray_trainer.py
├── main_dapo.py
├── prepare_dapo_data.sh
├── README.md
├── run_dapo_qwen2.5_32b.sh
整体的执行顺序:
main_dapo.py:数据加载初始化、初始化actor_rollout model、rm model,加载reward_manager
dapo_ray_trainer.py:RL训练流程
对batch进行repeate,每个q采样n次
记录每个采样的log,以及对应的reward_score 和 advantage
filter掉一个q的所有sample的score都是1或都是0,继续获取新的q进行采样,直到满足要求的batch的大小达到train_prompt_bsz。(值得注意的是,batch大小是gen_prompt_bsz=3*train_prompt_bsz,通过提高采样q的个数,避免满足要求的q不到train_prompt_bsz)。
每mini_batch的data进行模型更新
每micro_batch的data进行前向传播(token-mean loss)与梯度计算
具体代码实例:
main_dapo.py
Copyright 2024 Bytedance Ltd. and/or its affiliates
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import os
import socket
import hydra
import ray
from omegaconf import OmegaConf
from verl.trainer.ppo.reward import load_reward_manager
from verl.utils.device import is_cuda_available
from .dapo_ray_trainer import RayDAPOTrainer
@hydra.main(config_path="config", config_name="dapo_trainer", version_base=None)
def main(config):
run_ppo(config)
#################################################################
RL训练入口
#################################################################
def run_ppo(config) -> None:
if not ray.is_initialized():
this is for local ray cluster
default_runtime_env = {
"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
}
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
print(f"ray init kwargs: {ray_init_kwargs}")
ray.init(**OmegaConf.to_container(ray_init_kwargs))
try:
if (
is_cuda_available
and config.global_profiler.tool == "nsys"
and OmegaConf.select(config.global_profiler, "steps") is not None
and len(OmegaConf.select(config.global_profiler, "steps")) > 0
):
nsight_options = OmegaConf.to_container(
config.global_profiler.global_tool_config.nsys.controller_nsight_options
)
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
else:
runner = TaskRunner.remote()
ray.get(runner.run.remote(config))
finally:
if ray.is_initialized():
ray.shutdown()
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):
print initial config
from pprint import pprint
from omegaconf import OmegaConf
from verl.utils.fs import copy_to_local
print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)
instantiate tokenizer
from verl.utils import hf_processor, hf_tokenizer
tokenizer = hf_tokenizer(local_path)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
from verl.single_controller.ray import RayWorkerGroup
#################################################################
加载actor worker
#################################################################
define worker classes
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = RayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
}
we should adopt a multi-source reward function here
- for rule-based rm, we directly call a reward score
- for model-based rm, we call a model
- for code related prompt, we send to a sandbox if there are test cases
- finally, we combine all the rewards together
- The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
#################################################################
加载reward manager函数。用于根据data计算对应的reward score
#################################################################
reward_fn = load_reward_manager(
config,
tokenizer,
0,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.reward_model.overlong_buffer,
)
Note that we always use function-based RM for validation
val_reward_fn = load_reward_manager(
config,
tokenizer,
1,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.reward_model.overlong_buffer,
)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
#################################################################
加载主要的DAPO RL训练类,并运行.fit()
#################################################################
trainer = RayDAPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
if name == "main":
main()
我们紧接着来看一下from verl.trainer.ppo.reward import load_reward_manager。
配置文件中verl/recipe/dapo/run_dapo_qwen2.5_32b.sh给出了reward的类型
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4)) # overlong soft
overlong_penalty_factor=1.0
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
def load_reward_manager(
config: DictConfig, tokenizer: Any, num_examine: int, **reward_kwargs: Any
) -> AbstractRewardManager:
"""
Load and initialize a reward manager based on the configuration.
Args:
config: PPO trainer configuration object containing reward_model fields.
tokenizer: Tokenizer object used for processing text.
num_examine: Number of samples to examine.
**reward_kwargs: Additional keyword arguments for the reward manager.
Returns:
An instance of the specified reward manager class.
"""
Try to get a custom reward function based on the configuration
user defined reward manager can be registered in custom_reward_fn
compute_score = get_custom_reward_fn(config)
final_compute_score = compute_score
The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
naive: NaiveRewardManager
prime: PrimeRewardManager
batch: BatchRewardManager
dapo: DAPORewardManager
Note(haibin.lin): For custom reward managers, please make sure they are imported and
registered via `verl.workers.reward_manager.register`
By default reward_manager is set to naive (NaiveRewardManager)
#################################################################
在这里加载具体的reward_manager
#################################################################
reward_manager_name = config.reward_model.get("reward_manager", "naive")
reward_manager_cls = get_reward_manager_cls(reward_manager_name)
if compute_score is None:
sandbox_config = config.reward_model.get("sandbox_fusion")
sandbox_url = sandbox_config.get("url") if sandbox_config else None
memory_limit_mb = sandbox_config.get("memory_limit_mb", 1024)
if sandbox_url:
sandbox_manager = multiprocessing.Manager()
Create a semaphore to control concurrent access to the sandbox
_concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64))
final_compute_score = partial(
default_compute_score,
sandbox_fusion_url=sandbox_url,
concurrent_semaphore=_concurrent_semaphore,
memory_limit_mb=memory_limit_mb,
)
else:
final_compute_score = default_compute_score
#################################################################
这里的reward_manager_cls 其实是DAPO,
#################################################################
Instantiate and return the reward manager with the specified parameters
return reward_manager_cls(
tokenizer=tokenizer,
num_examine=num_examine,
compute_score=final_compute_score,
reward_fn_key=config.data.reward_fn_key,
**reward_kwargs,
)
这里需要知道dapo的reward_manager_cls 具体是什么,因为reward需要batch数据才能计算,因此对于reward manager咱们先按下不表(其实dapo对应的reward_manager_cls是在verl/verl/workers/reward_manager/dapo.py),先去dapo_ray_trainer.py看一下batch是怎么采样的,再回来仔细阅读reward的具体计算方法。
dapo_ray_trainer.py
#################################################################
RayDAPOTrainer继承于RayPPOTrainer
fit()函数:执行dapo的训练,包括(1)动态采样(2)overlong soft reward计算(3)token-level loss
#################################################################
class RayDAPOTrainer(RayPPOTrainer):
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
self.gen_steps = 0
load checkpoint before doing anything
self._load_checkpoint()
perform validation before training
currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
rollout_skip.wrap_generate_sequences()
add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
we start from step 1
self.global_steps += 1
self.gen_steps += 1
last_val_metrics = None
prev_step_profile = False
curr_step_profile = (
self.global_steps in self.config.global_profiler.steps
if self.config.global_profiler.steps is not None
else False
)
next_step_profile = False
timing_raw = defaultdict(float)
batch = None
#################################################################
num_prompt_in_batch:记录filter后,std不等于0的q的个数,当模型更新后重新赋值为0
num_gen_batches: 记录当前使用了多少个gen_batch,当模型更新后重新赋值为0
#################################################################
num_prompt_in_batch = 0
num_gen_batches = 0
#################################################################
正式开始训练,循环每个epoch后,循环每个gen_batch
#################################################################
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
with marked_timer("start_profile", timing_raw):
self._start_profiling(
not prev_step_profile and curr_step_profile
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
)
#################################################################
new_batch 是DataProto类型(具体见verl/verl/protocol.py),
new_batch.batch是TensorDict类型
new_batch中q的数量是可训练batch大小的3倍(增加采样的batch的q的个数)
#################################################################
new_batch: DataProto = DataProto.from_single_dict(batch_dict)
num_gen_batches += 1
pop those keys for generation
if "multi_modal_data" in new_batch.non_tensor_batch.keys():
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
)
else:
从new_batch中提取对应的key,构建gen_batch
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids"],
)
这里为什么要repeate呢,因为每个prompt要采样n次,所以repeat n次。这里的interleave=True
gen_batch: (bsz, response_length),
gen_batch_output: (bsz*n, response_length)
gen_batch_output = gen_batch.repeat(
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
)
is_last_step = self.global_steps >= self.total_training_steps
with marked_timer("step", timing_raw):
generate a batch
with marked_timer("gen", timing_raw, "red"):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)
这个advatange 可以先忽略。RMAX需要先计算 贪心采样的sample的logits作为后序adv计算的baseline
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with marked_timer("gen_max", timing_raw, "red"):
gen_baseline_batch = deepcopy(gen_batch)
这里是贪心采样的baseline,do_sample = False
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
new_batch = new_batch.union(gen_baseline_output)
compute reward model score on new_batch
rm_scores = None
if self.use_rm and "rm_scores" not in new_batch.batch.keys():
rm_scores = self.rm_wg.compute_rm_score(new_batch)
new_batch = new_batch.union(rm_scores)
reward_baseline_tensor, _ = compute_reward(new_batch, self.reward_fn)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
keys_to_pop = set(gen_baseline_output.batch.keys())
if rm_scores is not None:
keys_to_pop.update(rm_scores.batch.keys())
new_batch.pop(batch_keys=list(keys_to_pop))
new_batch.batch["reward_baselines"] = reward_baseline_tensor
del rm_scores, gen_baseline_batch, gen_baseline_output
#################################################################
new_batch的大小是gen_prompt_bsz
对每一个prompt设置一个专属的标识 uid
之所以设置uid,是因为之后对sample计算reward时,需要对同一个q的n个sample的reward标准化
#################################################################
new_batch.non_tensor_batch["uid"] = np.array(
str(uuid.uuid4()) for _ in range(len(new_batch.batch))\], dtype=object ) # 对batch中的每个key进行repeat(这里应该主要是对uid进行repeat) # repeat to align with repeated responses in rollout new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) # 把采样完的放到new_batch中 new_batch = new_batch.union(gen_batch_output) with marked_timer("reward", timing_raw, "yellow"): # compute scores. Support both model and function-based. # We first compute the scores using reward model. Then, we call reward_fn to combine # the results from reward model and rule-based results. if self.use_rm and "rm_scores" not in new_batch.batch.keys(): # we first compute reward model score reward_tensor = self.rm_wg.compute_rm_score(new_batch) new_batch = new_batch.union(reward_tensor) # 计算new_batch各个采样的reward,根据设置好的self.reward_fn # we combine with rule-based rm reward_tensor, reward_extra_infos_dict = compute_reward(new_batch, self.reward_fn) new_batch.batch\["token_level_scores"\] = reward_tensor if reward_extra_infos_dict: new_batch.non_tensor_batch.update( {k: np.array(v) for k, v in reward_extra_infos_dict.items()} ) # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: new_batch, kl_metrics = apply_kl_penalty( new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty ) metrics.update( kl_metrics ) # TODO: This will be cleared if we use multiple genenration batches else: new_batch.batch\["token_level_rewards"\] = new_batch.batch\["token_level_scores"
#################################################################
dapo的filter(dynamic sample)部分
#################################################################
if not self.config.algorithm.filter_groups.enable:
batch = new_batch
else: # NOTE: When prompts after filtering is less than train batch size,
we skip to the next generation batch
metric_name = self.config.algorithm.filter_groups.metric
if metric_name == "seq_final_reward":
Turn to numpy for easier filtering
new_batch.non_tensor_batch["seq_final_reward"] = (
new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
)
elif metric_name == "seq_reward":
new_batch.non_tensor_batch["seq_reward"] = (
new_batch.batch["token_level_scores"].sum(dim=-1).numpy()
)
{uid: [r1,r2,r3,...,rn], uid: [...], ...},记录每个轨迹所有采样的reward
Collect the sequence reward for each trajectory
prompt_uid2metric_vals = defaultdict(list)
for uid, metric_val in zip(
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True
):
prompt_uid2metric_vals[uid].append(metric_val)
每个q的reward的std
prompt_uid2metric_std = {}
for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)
保留reward std不是0的q的uid
kept_prompt_uids = [
uid
for uid, std in prompt_uid2metric_std.items()
if std > 0 or len(prompt_uid2metric_vals[uid]) == 1
]
累积std不是0的q
num_prompt_in_batch += len(kept_prompt_uids)
记录留下来的q的sample的idx
kept_traj_idxs = []
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
if traj_from_prompt_uid in kept_prompt_uids:
kept_traj_idxs.append(idx)
基于traj的id,检索对应的new_batch
new_batch = new_batch[kept_traj_idxs]
batch是留下的traj数据的累积
batch = new_batch if batch is None else DataProto.concat([batch, new_batch])
.sh文件配置的 可以训练的batch的最小大小(q的数量)
prompt_bsz = self.config.data.train_batch_size
如果现有的累积filter出来的q的数量小于 配置的最小数量,则continue继续使用下一个new_batch进行累积
if num_prompt_in_batch < prompt_bsz:
print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
max_num_gen_batches是最多可以使用的gen_batch的个数
如果其小于0的话,即没有限制;若num_gen_batches < max_num_gen_batches则继续continue
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
print(f"{num_gen_batches=}. Keep generating...")
self.gen_steps += 1
is_last_step = self.global_steps >= self.total_training_steps
continue
else:
raise ValueError(
f"{num_gen_batches=} >= {max_num_gen_batches=}."
-
" Generated too many. Please check if your data are too difficult."
-
" You could also try set max_num_gen_batches=0 to enable endless trials."
)
累积的符合的q个个数>=最小的可以训练的batch的大小
else:
Align the batch
traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
#################################################################
对齐一下,多余的轨迹会被抛弃,不知道会不会导致采样的利用效率不高,
会不会导致一些轨迹根本不会被训练到
#################################################################
batch = batch[:traj_bsz]
#################################################################
actor模型更新
#################################################################
=== Updating ===
batch.batch["response_mask"] = compute_response_mask(batch)
Balance the number of valid tokens across DP ranks.
NOTE: This usually changes the order of data in the `batch`,
which won't affect the advantage calculation (since it's based on uid),
but might affect the loss calculation (due to the change of mini-batching).
TODO: Decouple the DP balancing and mini-batching.
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
#################################################################
记录filter后的batch的每个traj的采样时的logtis(token-level)
用于计算重要性采样的比值
#################################################################
recompute old_log_probs
with marked_timer("old_log_prob", timing_raw, "blue"):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
这里dapo的loss_agg_mode是"token_mean"
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)
if self.use_reference_policy:
compute reference log_prob
with marked_timer("ref", timing_raw, "olive"):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
compute values
if self.use_critic:
with marked_timer("values", timing_raw, "cyan"):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
计算token_level的重要性采样
Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
IS and mismatch metrics already have mismatch/ prefix
metrics.update(is_metrics)
#################################################################
计算advantage
#################################################################
with marked_timer("adv", timing_raw, "brown"):
compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
update critic
if self.use_critic:
with marked_timer("update_critic", timing_raw, "pink"):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
#################################################################
更新actor model(batch的大小是train_prompt_size)
每个mini_bsz 更新一次模型(参数-累积梯度)
每个micro_bsz 累积一次梯度
#################################################################
update actor
with marked_timer("update_actor", timing_raw, "red"):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)
validate
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with marked_timer("testing", timing_raw, "green"):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0
):
with marked_timer("save_checkpoint", timing_raw, "green"):
self._save_checkpoint()
with marked_timer("stop_profile", timing_raw):
next_step_profile = (
self.global_steps + 1 in self.config.global_profiler.steps
if self.config.global_profiler.steps is not None
else False
)
self._stop_profiling(
curr_step_profile and not next_step_profile
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
)
prev_step_profile = curr_step_profile
curr_step_profile = next_step_profile
collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
timing_raw = defaultdict(float) # clear timing
metrics["train/num_gen_batches"] = num_gen_batches
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0
TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
if is_last_step:
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return
progress_bar.update(1)
self.global_steps += 1
self.gen_steps += 1
check if last step checkpint exists
checkpoint_dir = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}")
if not os.path.exists(checkpoint_dir):
save last step checkpoint
timing_raw = defaultdict(float)
with marked_timer("save_checkpoint", timing_raw, "green"):
self._save_checkpoint()
metrics = {f"timing/{k}": v for k, v in timing_raw.items()}
logger.log(data=metrics, step=self.global_steps)
这时候咱们再看一下dapo的reward manager实现:主要和ppo的区别在于使用了overlong_buffer,计算长度的reward
verl/verl/workers/reward_manager/dapo.py
#################################################################
这里使用dapo注册了DAPORewardManager,因此可以用
reward_manager_cls = get_reward_manager_cls(reward_manager_name)得到
#################################################################
@register("dapo")
class DAPORewardManager(AbstractRewardManager):
"""The reward manager."""
def init(
self,
tokenizer,
num_examine,
compute_score=None,
reward_fn_key="data_source",
max_resp_len=None,
overlong_buffer_cfg=None,
) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
self.compute_score = compute_score or default_compute_score
self.reward_fn_key = reward_fn_key
self.overlong_buffer_cfg = overlong_buffer_cfg
self.max_resp_len = max_resp_len
if self.overlong_buffer_cfg is not None:
assert self.max_resp_len is not None, (
f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None"
)
assert self.max_resp_len >= self.overlong_buffer_cfg.len, (
"max_resp_len must be larger than overlong_buffer.len"
)
#################################################################
DAPO reward manager的主要函数
#################################################################
def call(self, data: DataProto, return_dict: bool = False):
"""We will expand this function gradually based on the available datasets"""
If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if "rm_scores" in data.batch.keys():
if return_dict:
reward_extra_keys = data.meta_info.get("reward_extra_keys", [])
reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}
return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info}
else:
return data.batch["rm_scores"]
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_extra_info = defaultdict(list)
already_print_data_sources = {}
for i in range(len(data)):
data_item = data[i] # DataProtoItem
prompt_ids = data_item.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
########################################################
值得注意的是。prompt_ids是左填充的
response_ids是右填充的
########################################################
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
response_ids = data_item.batch["responses"]
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
decode
prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
eos_token = self.tokenizer.eos_token
if response_str.endswith(eos_token):
response_str = response_str[: -len(eos_token)]
ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"]
data_source = data_item.non_tensor_batch[self.reward_fn_key]
extra_info = data_item.non_tensor_batch.get("extra_info", {})
rollout_reward_scores = data_item.non_tensor_batch.get("reward_scores", {})
extra_info["rollout_reward_scores"] = rollout_reward_scores
result = self.compute_score(
data_source=data_source,
solution_str=response_str,
ground_truth=ground_truth,
extra_info=extra_info,
)
score: float
if isinstance(result, dict):
score = result["score"]
Store the information including original reward
for key, value in result.items():
reward_extra_info[key].append(value)
else:
score = result
reward_extra_info["acc"].append(score)
reward = score
########################################################
这里是overlong reward的计算
########################################################
if self.overlong_buffer_cfg.enable:
overlong_buffer_len = self.overlong_buffer_cfg.len
expected_len = self.max_resp_len - overlong_buffer_len
exceed_len = valid_response_length - expected_len
overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor
overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)
reward += overlong_reward
if self.overlong_buffer_cfg.log:
reward_extra_info["overlong_reward"].append(overlong_reward)
reward_extra_info["overlong"].append(overlong_reward < 0)
reward_tensor[i, valid_response_length - 1] = reward
if data_source not in already_print_data_sources:
already_print_data_sources[data_source] = 0
if already_print_data_sources[data_source] < self.num_examine:
already_print_data_sources[data_source] += 1
print("[prompt]", prompt_str)
print("[response]", response_str)
print("[ground_truth]", ground_truth)
if isinstance(result, dict):
for key, value in result.items():
print(f"[{key}]", value)
else:
print("[score]", score)
if return_dict:
return {
"reward_tensor": reward_tensor,
"reward_extra_info": reward_extra_info,
}
else:
return reward_tensor