【大模型训练】roll 调用megatron 计算损失函数有,会用到partial

forward_step_func=partial(self.inner_forward_step, forward_func)这句话什么意思

好的,我们来深入探讨一下 Python 中回调(Callback)的原理,以及在 roll 框架这个具体场景下,回调函数是如何被"触发"的。

Python 回调的基本原理

在 Python 中,函数是一等公民(First-class Citizens)。这意味着函数可以:

  1. 被赋值给一个变量。
  2. 作为参数传递给另一个函数。
  3. 作为另一个函数的返回值。

"回调"正是利用了第二点特性。

核心思想 :你定义一个函数 A,然后将另一个函数 B 作为参数传递给 A。在 A 执行过程中的某个特定时间点,A 会去调用("回调")你传给它的那个函数 B

一个简单的例子:

python 复制代码
# 定义一个回调函数,它规定了如何处理一个数字
def square_callback(number):
    result = number * number
    print(f"Callback executed: The square of {number} is {result}")
    return result

# 定义一个主函数,它接收一个数据列表和一个回调函数
def process_data(data_list, callback_function):
    print("Main function started, processing data...")
    results = []
    for item in data_list:
        # 在这里,主函数调用了("回调")传入的函数
        processed_item = callback_function(item)
        results.append(processed_item)
    print("Main function finished.")
    return results

# --- 主程序 ---
my_data = [1, 2, 3, 4]

# 调用主函数,并将 square_callback 作为参数传递
final_results = process_data(my_data, square_callback)

print(f"Final results: {final_results}")

输出:

复制代码
Main function started, processing data...
Callback executed: The square of 1 is 1
Callback executed: The square of 2 is 4
Callback executed: The square of 3 is 9
Callback executed: The square of 4 is 16
Main function finished.
Final results: [1, 4, 9, 16]

在这个例子中:

  • process_data 是主调函数(Caller),类似于 strategy.forward_step
  • square_callback 是回调函数(Callback),类似于 forward_func_log_probs
  • process_data 只负责遍历数据这个通用流程,它不知道具体要对数据做什么运算。
  • square_callback 定义了具体的运算(求平方)。
  • process_data 运行到 processed_item = callback_function(item) 这一行时,回调就被触发 了。实际上,callback_function 此时就是 square_callback 的一个引用。

roll 框架中的具体触发过程

现在,我们把这个原理应用到 roll 的代码中。

参与者:

  1. 主调函数 (Caller) : strategy.forward_step
  2. 回调函数 (Callback) : self.forward_func_log_probs (即 ActorWorker 的一个方法)
  3. 触发点 : inner_forward_step 内部的 return output_tensor, partial(loss_func, data)forward_backward_func 的内部实现。

让我们追踪一下调用的路径,看看回调是如何被触发的。

路径 1: compute_log_probs -> forward_step

python 复制代码
# ActorWorker.py
def compute_log_probs(self, data: DataProto):
    # ...
    # 这里,self.forward_func_log_probs 被当作一个值(一个可调用对象)传递给了 forward_step
    results = self.strategy.forward_step(
        batch=data, 
        forward_func=self.forward_func_log_probs # <--- 传递回调
    )
    # ...

forward_step 的定义中,它接收了这个回调函数,并将其命名为 forward_func

python 复制代码
# MegatronInferStrategy.py (或类似的 Strategy 类)
def forward_step(self, batch: DataProto, forward_func: Callable): # <--- 接收回调
    # ...
    # 它将 forward_func 进一步传递下去
    losses_reduced = self.forward_backward_func(
        forward_step_func=partial(self.inner_forward_step, forward_func), # <--- 再次传递
        # ...
    )
    # ...

路径 2: forward_step -> forward_backward_func -> inner_forward_step

forward_backward_func 是 Megatron-LM 框架中的一个函数,它封装了处理流水线并行(Pipeline Parallelism)和微批次(Micro-batching)的复杂逻辑。它的核心作用是循环调用 你提供给它的 forward_step_func

在我们的例子中,forward_step_funcpartial(self.inner_forward_step, forward_func)。这意味着 forward_backward_func 在其内部循环中会执行类似这样的代码:

python 复制代码
# forward_backward_func 的伪代码
def forward_backward_func(forward_step_func, data_iterator, ...):
    all_outputs = []
    # 循环处理每个微批次
    for i in range(num_microbatches):
        # *** 触发点 1 ***
        # 调用我们通过 partial 构造的函数
        # 这会执行 inner_forward_step(data_iterator, model)
        output_tensor, process_fn = forward_step_func(data_iterator, model)
        
        # *** 触发点 2 ***
        # process_fn 就是 partial(loss_func, data)
        # 也就是 partial(forward_func_log_probs, data)
        # 这里,回调函数 forward_func_log_probs 被真正地执行了!
        loss, metrics = process_fn(output_tensor)
        
        all_outputs.append(metrics)
    return all_outputs

现在我们看 inner_forward_step 做了什么:

python 复制代码
# MegatronInferStrategy.py
def inner_forward_step(self, loss_func, data_iterator, model): # loss_func 就是 forward_func_log_probs
    # ... (准备数据)
    
    # 1. 执行模型前向传播
    output_tensor = model(input_ids=..., attention_mask=...)

    # 2. 准备回调
    # 它并不直接调用 loss_func,而是返回一个 partial 对象
    # partial(loss_func, data) 的意思是:
    # "创建一个新的、可调用的对象,它等价于调用 loss_func(data, ...),只是第一个参数 data 已经预先填好了"
    return output_tensor, partial(loss_func, data)

总结一下触发过程:

  1. compute_log_probsforward_func_log_probs 这个方法对象 作为参数传给了 forward_step
  2. forward_step 将这个方法对象进一步传给了底层的 forward_backward_func
  3. forward_backward_func 在其微批次循环中,调用了 inner_forward_step
  4. inner_forward_step 执行了模型的前向传播,得到 output_tensor
  5. inner_forward_step 并不直接调用 forward_func_log_probs。相反,它创建并返回了一个 partial 对象,这个对象"包裹"了 forward_func_log_probs 和当前的微批次数据 data
  6. forward_backward_func 拿到了 output_tensor 和这个 partial 对象(我们叫它 process_fn)。
  7. 真正的触发点forward_backward_func 调用 process_fn(output_tensor)。由于 process_fn 是一个包裹了 forward_func_log_probspartial 对象,这一步就等价于执行:
    forward_func_log_probs(data, output_tensor)
    此时,ActorWorker 中定义的回调函数就被成功触发了,它拿到了模型输出 output_tensor 和对应的输入数据 data,开始执行它自己的特定计算(即调用 op_compute_log_probs)。

这个稍微有些绕的设计,是为了将 Megatron-LM 自身的复杂循环逻辑与 roll 框架的上层业务逻辑解耦。forward_backward_func 只需要知道它需要一个能返回 (tensor, callable) 的函数即可,而不需要关心那个 callable 到底是什么。

这里为了适配megatron 中的 forward_step_func, 他forward_step_func只需要传入data_iter, 与model.

但是他inner_forward_step(self, loss_func, data_iterator, model):三个函数,所以要

forward_step_func=partial(self.inner_forward_step, forward_func)

cpp 复制代码
def get_forward_backward_func():
    """Retrieves the appropriate forward_backward function given the
    configuration of parallel_state.

    Returns a function that will perform all of the forward and
    backward passes of the model given the pipeline model parallel
    world size and virtual pipeline model parallel world size in the
    global parallel_state.

    Note that if using sequence parallelism, the sequence length component of
    the tensor shape is updated to original_sequence_length /
    tensor_model_parallel_world_size.

    The function returned takes the following arguments:

    forward_step_func (required): A function that takes a data
        iterator and a model as its arguments and return the model's
        forward output and the loss function. The loss function should
        take one torch.Tensor and return a torch.Tensor of loss and a
        dictionary of string -> torch.Tensor.

        A third argument, checkpoint_activations_microbatch, indicates
        that the activations for this microbatch should be
        checkpointed. A None value for this argument indicates that
        the default from the configuration should be used. This is
        used when the
        num_microbatches_with_partial_activation_checkpoints is used.

        For example:

        def loss_func(loss_mask, output_tensor):
            losses = output_tensor.float()
            loss_mask = loss_mask.view(-1).float()
            loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

            # Reduce loss for logging.
            averaged_loss = average_losses_across_data_parallel_group([loss])

            return loss, {'lm loss': averaged_loss[0]}

        def forward_step(data_iterator, model):
            data, loss_mask = next(data_iterator)
            output = model(data)
            return output, partial(loss_func, loss_mask)


        forward_backward_func(forward_step_func=forward_step, ...)

您询问的是 forward_backward_funcforward_step_func 的使用方式,以及它返回的 partial(loss_func, loss_mask, model=model) 的用途。

forward_step_func 的使用方式

forward_step_func 是用户定义的前向传播函数,在 forward_backward_func 内部被调用。 1

调用位置

forward_step() 函数中,forward_step_func 被调用: 2

具体调用方式:

python 复制代码
output_tensor, loss_func = forward_step_func(data_iterator, model)

这里 forward_step_func 接收两个参数:

  1. data_iterator: 数据迭代器
  2. model: 模型实例

返回两个值:

  1. output_tensor: 模型的输出(通常是损失张量)
  2. loss_func: 一个偏函数,用于后续的损失处理

partial(loss_func, ...) 的用途

返回的 partial(loss_func, loss_mask, model=model) 被用于 forward_step_calc_loss() 函数中。 3

使用流程

  1. 接收 loss_func : 4

  2. 调用 loss_func : 在 forward_step_calc_loss() 中,如果是最后一个 pipeline stage,会调用这个 loss_func: 5

  3. 处理返回值 : loss_func 被调用时传入 output_tensor,返回处理后的损失和元数据。

实际示例

pretrain_gpt.py 中可以看到完整的使用示例: 6

这里:

  1. forward_step_func 调用模型得到 output_tensor
  2. 返回 partial(loss_func, loss_mask, model=model),将 loss_maskmodel 预先绑定
  3. 后续在 forward_step_calc_loss() 中,这个偏函数被调用时只需传入 output_tensor,就能完成损失计算

为什么使用 partial

使用 partial 的好处是:

  • 延迟执行 : 在 forward_step_func 中不立即计算最终损失,而是返回一个函数
  • 参数绑定 : 将 loss_maskmodel 等上下文信息提前绑定,后续只需传入 output_tensor
  • 灵活性: 允许在不同的 pipeline stage 有不同的损失处理逻辑

Notes

这种设计模式在所有训练脚本中都是一致的,包括 pretrain_mamba.py 7 , pretrain_t5.py 8 等。这种统一的接口设计使得 Megatron-LM 能够支持不同类型的模型和训练任务。

Wiki pages you might want to explore:

Citations

File: megatron/core/pipeline_parallel/schedules.py (L199-287)

python 复制代码
def forward_step_calc_loss(
    model,
    output_tensor,
    loss_func,
    config,
    vp_stage,
    collect_non_loss_data,
    num_microbatches,
    forward_data_store,
    cp_group_size=None,
    is_last_stage=None,
):
    """Calculate the loss and number of tokens for forward_step()"""

    from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler

    model_vp_stage = getattr(model, "vp_stage", None)
    if vp_stage is not None and model_vp_stage is not None:
        assert (
            vp_stage == model_vp_stage
        ), f"vp_stage ({vp_stage}) doesn't match model_vp_stage ({model_vp_stage})"

    if cp_group_size is None and is_last_stage is None:
        # fallback to parallel state
        cp_group_size = parallel_state.get_context_parallel_world_size()
        is_last_stage = parallel_state.is_pipeline_last_stage(
            ignore_virtual=False, vp_stage=vp_stage
        )
    else:
        assert (
            cp_group_size is not None and is_last_stage is not None
        ), "cp_group_size and is_last_stage must be provided"

    num_tokens = torch.tensor(0, dtype=torch.int)
    if is_last_stage:
        if not collect_non_loss_data:
            outputs = loss_func(output_tensor)
            if len(outputs) == 3:
                output_tensor, num_tokens, loss_reduced = outputs
                if not config.calculate_per_token_loss:
                    # Protect against division by zero when all tokens are masked
                    #   in a microbatch.
                    output_tensor /= torch.clamp(num_tokens, min=1)
                    output_tensor /= num_microbatches
            else:
                # preserve legacy loss averaging behavior (ie, over the number of microbatches)
                assert len(outputs) == 2
                output_tensor, loss_reduced = outputs
                output_tensor *= cp_group_size
                output_tensor /= num_microbatches
            forward_data_store.append(loss_reduced)
        else:
            data = loss_func(output_tensor, non_loss_data=True)
            forward_data_store.append(data)

    if config.timers is not None:
        config.timers('forward-compute').stop()

    # Set the loss scale for the auxiliary loss of the MoE layer.
    # Since we use a trick to do backward on the auxiliary loss, we need to set the scale
    # explicitly.
    if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
        # Calculate the loss scale based on the grad_scale_func if available, else default to 1.
        loss_scale = (
            config.grad_scale_func(torch.ones(1, device=output_tensor.device))
            if config.grad_scale_func is not None
            else torch.ones(1, device=output_tensor.device)
        )
        # Set the loss scale
        if config.calculate_per_token_loss:
            MoEAuxLossAutoScaler.set_loss_scale(loss_scale)
        else:
            MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)

    # Set the loss scale for Multi-Token Prediction (MTP) loss.
    if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
        # Calculate the loss scale based on the grad_scale_func if available, else default to 1.
        loss_scale = (
            config.grad_scale_func(torch.ones(1, device=output_tensor.device))
            if config.grad_scale_func is not None
            else torch.ones(1, device=output_tensor.device)
        )
        # Set the loss scale
        if config.calculate_per_token_loss:
            MTPLossAutoScaler.set_loss_scale(loss_scale)
        else:
            MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)

    return output_tensor, num_tokens

File: megatron/core/pipeline_parallel/schedules.py (L290-422)

python 复制代码
def forward_step(
    forward_step_func,
    data_iterator,
    model,
    num_microbatches,
    input_tensor,
    forward_data_store,
    config,
    cp_group_size,
    collect_non_loss_data=False,
    checkpoint_activations_microbatch=None,
    is_first_microbatch=False,
    current_microbatch=None,
    vp_stage=None,
    is_last_stage=True,
):
    """Forward step for passed-in model.

    If it is the first stage, the input tensor is obtained from the data_iterator.
    Otherwise, the passed-in input_tensor is used.

    Args:
        forward_step_func (callable):
            The forward step function for the model that takes the
            data iterator as the first argument, and model as the second.
            This user's forward step is expected to output a tuple of two elements:

                1. The output object from the forward step. This output object needs to be a
                    tensor or some kind of collection of tensors. The only hard requirement
                    for this object is that it needs to be acceptible as input into the second
                    function.
                2. A function to reduce (optionally) the output from the forward step. This
                    could be a reduction over the loss from the model, it could be a function that
                    grabs the output from the model and reformats, it could be a function that just
                    passes through the model output. This function must have one of the following
                    patterns, and depending on the pattern different things happen internally:

                        a. A tuple of reduced loss and some other data. Note that in this case
                            the first argument is divided by the number of global microbatches,
                            assuming it is a loss, so that the loss is stable as a function of
                            the number of devices the step is split across.
                        b. A triple of reduced loss, number of tokens, and some other data. This
                            is similar to case (a), but the loss is further averaged across the
                            number of tokens in the batch. If the user is not already averaging
                            across the number of tokens, this pattern is useful to use.
                        c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
                            of tensors, etc in the case of inference). To trigger case 3 you need
                            to specify `collect_non_loss_data=True` and you may also want to
                            specify `forward_only=True` in the call to the parent forward_backward
                            function.
        data_iterator (iterator):
            The data iterator.
        model (nn.Module):
            The model to perform the forward step on.
        num_microbatches (int):
            The number of microbatches.
        input_tensor (Tensor or list[Tensor]):
            The input tensor(s) for the forward step.
        forward_data_store (list):
            The list to store the forward data. If you go down path 2.a or
            2.b for the return of your forward reduction function then this will store only the
            final dimension of the output, for example the metadata output by the loss function.
            If you go down the path of 2.c then this will store the entire output of the forward
            reduction function applied to the model output.
        config (object):
            The configuration object.
        collect_non_loss_data (bool, optional):
            Whether to collect non-loss data. Defaults to False.
            This is the path to use if you want to collect arbitrary output from the model forward,
            such as with inference use cases. Defaults to False.
        checkpoint_activations_microbatch (int, optional):
            The microbatch to checkpoint activations.
            Defaults to None.
        is_first_microbatch (bool, optional):
            Whether it is the first microbatch. Defaults to False.
        current_microbatch (int, optional):
            The current microbatch. Defaults to None.
        vp_stage (int, optional):
            The virtual pipeline stage. Defaults to None.
        is_last_stage (bool, optional):
            Whether it is the last stage. Defaults to True.
            Also considering virtual stages.
            In case of PP/VPP, is_last_stage/is_vp_last_stage.

    Returns:
        Tensor or list[Tensor]: The output object(s) from the forward step.
        Tensor: The number of tokens.
    """
    from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler

    if config.timers is not None:
        config.timers('forward-compute', log_level=2).start()

    if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
        model.set_is_first_microbatch()
    if current_microbatch is not None:
        set_current_microbatch(model, current_microbatch)

    unwrap_output_tensor = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_output_tensor = True

    set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
    set_input_tensor(input_tensor)

    if config.enable_autocast:
        context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
    else:
        context_manager = contextlib.nullcontext()
    with context_manager:
        if checkpoint_activations_microbatch is None:
            output_tensor, loss_func = forward_step_func(data_iterator, model)
        else:
            output_tensor, loss_func = forward_step_func(
                data_iterator, model, checkpoint_activations_microbatch
            )
    output_tensor, num_tokens = forward_step_calc_loss(
        model,
        output_tensor,
        loss_func,
        config,
        vp_stage,
        collect_non_loss_data,
        num_microbatches,
        forward_data_store,
        cp_group_size,
        is_last_stage,
    )

    if unwrap_output_tensor:
        return output_tensor, num_tokens
    return [output_tensor], num_tokens

File: pretrain_gpt.py (L121-157)

python 复制代码
def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = False):
    """Forward training step.

    Args:
        data_iterator : Input data iterator
        model (GPTModel): The GPT Model
        return_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor
    """
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator', log_level=2).start()
    global stimer
    with stimer(bdata=True):
        vp_stage = get_attr_wrapped_model(model, "vp_stage")
        tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, vp_stage)
    timers('batch-generator').stop()

    with stimer:
        if args.use_legacy_models:
            output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
        else:
            if return_schedule_plan:
                assert args.overlap_moe_expert_parallel_comm, \
                    "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"
                schedule_plan = model.build_schedule_plan(
                    tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
                )
                return schedule_plan, partial(loss_func, loss_mask, model=model)
            else:
                output_tensor = model(
                    tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
                )

    # [ModelOpt]: model is needed to access ModelOpt distillation losses
    return output_tensor, partial(loss_func, loss_mask, model=model)

File: pretrain_mamba.py (L131-154)

python 复制代码
def forward_step(data_iterator, model: MambaModel):
    """Forward training step.

    Args:
        data_iterator : Input data iterator
        model (MambaModel): The GPT Model
    """
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator', log_level=2).start()
    global stimer
    with stimer(bdata=True):
        vp_stage = get_attr_wrapped_model(model, "vp_stage")
        tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, vp_stage)
    timers('batch-generator').stop()

    with stimer:
        output_tensor = model(tokens, position_ids, attention_mask,
                              labels=labels)

    # [ModelOpt]: model is needed to access ModelOpt distillation losses
    return output_tensor, partial(loss_func, loss_mask, model=model)

File: pretrain_t5.py (L174-198)

python 复制代码
def forward_step(data_iterator, model: T5Model):
    """Forward training step.

    Args:
        data_iterator : Input data iterator
        model (T5Model): The T5 Model
    """

    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch generator', log_level=2).start()
    use_local = args.transformer_impl == "local"
    tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = get_batch(
        data_iterator, use_local
    )
    timers('batch generator').stop()

    # Forward model lm_labels
    output_tensor = model(
        tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, lm_labels=lm_labels
    )

    return output_tensor, partial(loss_func, loss_mask)
相关推荐
小蜜蜂爱编程2 小时前
deep learning简介
人工智能·深度学习
IT_陈寒2 小时前
SpringBoot实战避坑指南:我在微服务项目中总结的12条高效开发经验
前端·人工智能·后端
AI优秘企业大脑2 小时前
需求洞察助力战略规划实现潜在市场机会
大数据·人工智能
Learn Beyond Limits2 小时前
Clustering vs Classification|聚类vs分类
人工智能·算法·机器学习·ai·分类·数据挖掘·聚类
诸葛务农2 小时前
光电对抗分类及外场静爆试验操作规程
人工智能·嵌入式硬件·分类·数据挖掘
TG:@yunlaoda360 云老大2 小时前
谷歌发布 Veo 3.1 视频生成模型:有声电影、长视频叙事与人物定制的实测与展望
人工智能·音视频·googlecloud
大连好光景2 小时前
LSTM模型做分类任务2(PyTorch实现)
人工智能·pytorch·lstm
阿里巴巴淘系技术团队官网博客2 小时前
让AI打出丝滑连招:编码-部署-自测-改bug
人工智能·bug
LeonDL1683 小时前
基于YOLO11深度学习的电梯内车辆识别系统【Python源码+Pyqt5界面+数据集+安装使用教程+训练代码】【附下载链接】
人工智能·python·深度学习·pyqt5·yolo数据集·yolo11深度学习·电梯内车辆识别系统