forward_step_func=partial(self.inner_forward_step, forward_func)这句话什么意思
好的,我们来深入探讨一下 Python 中回调(Callback)的原理,以及在 roll 框架这个具体场景下,回调函数是如何被"触发"的。
Python 回调的基本原理
在 Python 中,函数是一等公民(First-class Citizens)。这意味着函数可以:
- 被赋值给一个变量。
- 作为参数传递给另一个函数。
- 作为另一个函数的返回值。
"回调"正是利用了第二点特性。
核心思想 :你定义一个函数 A,然后将另一个函数 B 作为参数传递给 A。在 A 执行过程中的某个特定时间点,A 会去调用("回调")你传给它的那个函数 B。
一个简单的例子:
python
# 定义一个回调函数,它规定了如何处理一个数字
def square_callback(number):
result = number * number
print(f"Callback executed: The square of {number} is {result}")
return result
# 定义一个主函数,它接收一个数据列表和一个回调函数
def process_data(data_list, callback_function):
print("Main function started, processing data...")
results = []
for item in data_list:
# 在这里,主函数调用了("回调")传入的函数
processed_item = callback_function(item)
results.append(processed_item)
print("Main function finished.")
return results
# --- 主程序 ---
my_data = [1, 2, 3, 4]
# 调用主函数,并将 square_callback 作为参数传递
final_results = process_data(my_data, square_callback)
print(f"Final results: {final_results}")
输出:
Main function started, processing data...
Callback executed: The square of 1 is 1
Callback executed: The square of 2 is 4
Callback executed: The square of 3 is 9
Callback executed: The square of 4 is 16
Main function finished.
Final results: [1, 4, 9, 16]
在这个例子中:
process_data是主调函数(Caller),类似于strategy.forward_step。square_callback是回调函数(Callback),类似于forward_func_log_probs。process_data只负责遍历数据这个通用流程,它不知道具体要对数据做什么运算。square_callback定义了具体的运算(求平方)。- 当
process_data运行到processed_item = callback_function(item)这一行时,回调就被触发 了。实际上,callback_function此时就是square_callback的一个引用。
在 roll 框架中的具体触发过程
现在,我们把这个原理应用到 roll 的代码中。
参与者:
- 主调函数 (Caller) :
strategy.forward_step - 回调函数 (Callback) :
self.forward_func_log_probs(即ActorWorker的一个方法) - 触发点 :
inner_forward_step内部的return output_tensor, partial(loss_func, data)和forward_backward_func的内部实现。
让我们追踪一下调用的路径,看看回调是如何被触发的。
路径 1: compute_log_probs -> forward_step
python
# ActorWorker.py
def compute_log_probs(self, data: DataProto):
# ...
# 这里,self.forward_func_log_probs 被当作一个值(一个可调用对象)传递给了 forward_step
results = self.strategy.forward_step(
batch=data,
forward_func=self.forward_func_log_probs # <--- 传递回调
)
# ...
在 forward_step 的定义中,它接收了这个回调函数,并将其命名为 forward_func。
python
# MegatronInferStrategy.py (或类似的 Strategy 类)
def forward_step(self, batch: DataProto, forward_func: Callable): # <--- 接收回调
# ...
# 它将 forward_func 进一步传递下去
losses_reduced = self.forward_backward_func(
forward_step_func=partial(self.inner_forward_step, forward_func), # <--- 再次传递
# ...
)
# ...
路径 2: forward_step -> forward_backward_func -> inner_forward_step
forward_backward_func 是 Megatron-LM 框架中的一个函数,它封装了处理流水线并行(Pipeline Parallelism)和微批次(Micro-batching)的复杂逻辑。它的核心作用是循环调用 你提供给它的 forward_step_func。
在我们的例子中,forward_step_func 是 partial(self.inner_forward_step, forward_func)。这意味着 forward_backward_func 在其内部循环中会执行类似这样的代码:
python
# forward_backward_func 的伪代码
def forward_backward_func(forward_step_func, data_iterator, ...):
all_outputs = []
# 循环处理每个微批次
for i in range(num_microbatches):
# *** 触发点 1 ***
# 调用我们通过 partial 构造的函数
# 这会执行 inner_forward_step(data_iterator, model)
output_tensor, process_fn = forward_step_func(data_iterator, model)
# *** 触发点 2 ***
# process_fn 就是 partial(loss_func, data)
# 也就是 partial(forward_func_log_probs, data)
# 这里,回调函数 forward_func_log_probs 被真正地执行了!
loss, metrics = process_fn(output_tensor)
all_outputs.append(metrics)
return all_outputs
现在我们看 inner_forward_step 做了什么:
python
# MegatronInferStrategy.py
def inner_forward_step(self, loss_func, data_iterator, model): # loss_func 就是 forward_func_log_probs
# ... (准备数据)
# 1. 执行模型前向传播
output_tensor = model(input_ids=..., attention_mask=...)
# 2. 准备回调
# 它并不直接调用 loss_func,而是返回一个 partial 对象
# partial(loss_func, data) 的意思是:
# "创建一个新的、可调用的对象,它等价于调用 loss_func(data, ...),只是第一个参数 data 已经预先填好了"
return output_tensor, partial(loss_func, data)
总结一下触发过程:
compute_log_probs将forward_func_log_probs这个方法对象 作为参数传给了forward_step。forward_step将这个方法对象进一步传给了底层的forward_backward_func。forward_backward_func在其微批次循环中,调用了inner_forward_step。inner_forward_step执行了模型的前向传播,得到output_tensor。inner_forward_step并不直接调用forward_func_log_probs。相反,它创建并返回了一个partial对象,这个对象"包裹"了forward_func_log_probs和当前的微批次数据data。forward_backward_func拿到了output_tensor和这个partial对象(我们叫它process_fn)。- 真正的触发点 :
forward_backward_func调用process_fn(output_tensor)。由于process_fn是一个包裹了forward_func_log_probs的partial对象,这一步就等价于执行:
forward_func_log_probs(data, output_tensor)
此时,ActorWorker中定义的回调函数就被成功触发了,它拿到了模型输出output_tensor和对应的输入数据data,开始执行它自己的特定计算(即调用op_compute_log_probs)。
这个稍微有些绕的设计,是为了将 Megatron-LM 自身的复杂循环逻辑与 roll 框架的上层业务逻辑解耦。forward_backward_func 只需要知道它需要一个能返回 (tensor, callable) 的函数即可,而不需要关心那个 callable 到底是什么。
这里为了适配megatron 中的 forward_step_func, 他forward_step_func只需要传入data_iter, 与model.
但是他inner_forward_step(self, loss_func, data_iterator, model):三个函数,所以要
forward_step_func=partial(self.inner_forward_step, forward_func)
cpp
def get_forward_backward_func():
"""Retrieves the appropriate forward_backward function given the
configuration of parallel_state.
Returns a function that will perform all of the forward and
backward passes of the model given the pipeline model parallel
world size and virtual pipeline model parallel world size in the
global parallel_state.
Note that if using sequence parallelism, the sequence length component of
the tensor shape is updated to original_sequence_length /
tensor_model_parallel_world_size.
The function returned takes the following arguments:
forward_step_func (required): A function that takes a data
iterator and a model as its arguments and return the model's
forward output and the loss function. The loss function should
take one torch.Tensor and return a torch.Tensor of loss and a
dictionary of string -> torch.Tensor.
A third argument, checkpoint_activations_microbatch, indicates
that the activations for this microbatch should be
checkpointed. A None value for this argument indicates that
the default from the configuration should be used. This is
used when the
num_microbatches_with_partial_activation_checkpoints is used.
For example:
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
data, loss_mask = next(data_iterator)
output = model(data)
return output, partial(loss_func, loss_mask)
forward_backward_func(forward_step_func=forward_step, ...)
您询问的是 forward_backward_func 中 forward_step_func 的使用方式,以及它返回的 partial(loss_func, loss_mask, model=model) 的用途。
forward_step_func 的使用方式
forward_step_func 是用户定义的前向传播函数,在 forward_backward_func 内部被调用。 1
调用位置
在 forward_step() 函数中,forward_step_func 被调用: 2
具体调用方式:
python
output_tensor, loss_func = forward_step_func(data_iterator, model)
这里 forward_step_func 接收两个参数:
data_iterator: 数据迭代器model: 模型实例
返回两个值:
output_tensor: 模型的输出(通常是损失张量)loss_func: 一个偏函数,用于后续的损失处理
partial(loss_func, ...) 的用途
返回的 partial(loss_func, loss_mask, model=model) 被用于 forward_step_calc_loss() 函数中。 3
使用流程
-
接收 loss_func : 4
-
调用 loss_func : 在
forward_step_calc_loss()中,如果是最后一个 pipeline stage,会调用这个loss_func: 5 -
处理返回值 :
loss_func被调用时传入output_tensor,返回处理后的损失和元数据。
实际示例
在 pretrain_gpt.py 中可以看到完整的使用示例: 6
这里:
forward_step_func调用模型得到output_tensor- 返回
partial(loss_func, loss_mask, model=model),将loss_mask和model预先绑定 - 后续在
forward_step_calc_loss()中,这个偏函数被调用时只需传入output_tensor,就能完成损失计算
为什么使用 partial
使用 partial 的好处是:
- 延迟执行 : 在
forward_step_func中不立即计算最终损失,而是返回一个函数 - 参数绑定 : 将
loss_mask和model等上下文信息提前绑定,后续只需传入output_tensor - 灵活性: 允许在不同的 pipeline stage 有不同的损失处理逻辑
Notes
这种设计模式在所有训练脚本中都是一致的,包括 pretrain_mamba.py 7 , pretrain_t5.py 8 等。这种统一的接口设计使得 Megatron-LM 能够支持不同类型的模型和训练任务。
Wiki pages you might want to explore:
Citations
File: megatron/core/pipeline_parallel/schedules.py (L199-287)
python
def forward_step_calc_loss(
model,
output_tensor,
loss_func,
config,
vp_stage,
collect_non_loss_data,
num_microbatches,
forward_data_store,
cp_group_size=None,
is_last_stage=None,
):
"""Calculate the loss and number of tokens for forward_step()"""
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
model_vp_stage = getattr(model, "vp_stage", None)
if vp_stage is not None and model_vp_stage is not None:
assert (
vp_stage == model_vp_stage
), f"vp_stage ({vp_stage}) doesn't match model_vp_stage ({model_vp_stage})"
if cp_group_size is None and is_last_stage is None:
# fallback to parallel state
cp_group_size = parallel_state.get_context_parallel_world_size()
is_last_stage = parallel_state.is_pipeline_last_stage(
ignore_virtual=False, vp_stage=vp_stage
)
else:
assert (
cp_group_size is not None and is_last_stage is not None
), "cp_group_size and is_last_stage must be provided"
num_tokens = torch.tensor(0, dtype=torch.int)
if is_last_stage:
if not collect_non_loss_data:
outputs = loss_func(output_tensor)
if len(outputs) == 3:
output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss:
# Protect against division by zero when all tokens are masked
# in a microbatch.
output_tensor /= torch.clamp(num_tokens, min=1)
output_tensor /= num_microbatches
else:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert len(outputs) == 2
output_tensor, loss_reduced = outputs
output_tensor *= cp_group_size
output_tensor /= num_microbatches
forward_data_store.append(loss_reduced)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
if config.timers is not None:
config.timers('forward-compute').stop()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale
# explicitly.
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
if config.calculate_per_token_loss:
MoEAuxLossAutoScaler.set_loss_scale(loss_scale)
else:
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
if config.calculate_per_token_loss:
MTPLossAutoScaler.set_loss_scale(loss_scale)
else:
MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
return output_tensor, num_tokens
File: megatron/core/pipeline_parallel/schedules.py (L290-422)
python
def forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
cp_group_size,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
vp_stage=None,
is_last_stage=True,
):
"""Forward step for passed-in model.
If it is the first stage, the input tensor is obtained from the data_iterator.
Otherwise, the passed-in input_tensor is used.
Args:
forward_step_func (callable):
The forward step function for the model that takes the
data iterator as the first argument, and model as the second.
This user's forward step is expected to output a tuple of two elements:
1. The output object from the forward step. This output object needs to be a
tensor or some kind of collection of tensors. The only hard requirement
for this object is that it needs to be acceptible as input into the second
function.
2. A function to reduce (optionally) the output from the forward step. This
could be a reduction over the loss from the model, it could be a function that
grabs the output from the model and reformats, it could be a function that just
passes through the model output. This function must have one of the following
patterns, and depending on the pattern different things happen internally:
a. A tuple of reduced loss and some other data. Note that in this case
the first argument is divided by the number of global microbatches,
assuming it is a loss, so that the loss is stable as a function of
the number of devices the step is split across.
b. A triple of reduced loss, number of tokens, and some other data. This
is similar to case (a), but the loss is further averaged across the
number of tokens in the batch. If the user is not already averaging
across the number of tokens, this pattern is useful to use.
c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
of tensors, etc in the case of inference). To trigger case 3 you need
to specify `collect_non_loss_data=True` and you may also want to
specify `forward_only=True` in the call to the parent forward_backward
function.
data_iterator (iterator):
The data iterator.
model (nn.Module):
The model to perform the forward step on.
num_microbatches (int):
The number of microbatches.
input_tensor (Tensor or list[Tensor]):
The input tensor(s) for the forward step.
forward_data_store (list):
The list to store the forward data. If you go down path 2.a or
2.b for the return of your forward reduction function then this will store only the
final dimension of the output, for example the metadata output by the loss function.
If you go down the path of 2.c then this will store the entire output of the forward
reduction function applied to the model output.
config (object):
The configuration object.
collect_non_loss_data (bool, optional):
Whether to collect non-loss data. Defaults to False.
This is the path to use if you want to collect arbitrary output from the model forward,
such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional):
The microbatch to checkpoint activations.
Defaults to None.
is_first_microbatch (bool, optional):
Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional):
The current microbatch. Defaults to None.
vp_stage (int, optional):
The virtual pipeline stage. Defaults to None.
is_last_stage (bool, optional):
Whether it is the last stage. Defaults to True.
Also considering virtual stages.
In case of PP/VPP, is_last_stage/is_vp_last_stage.
Returns:
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
model.set_is_first_microbatch()
if current_microbatch is not None:
set_current_microbatch(model, current_microbatch)
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
if config.enable_autocast:
context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
else:
context_manager = contextlib.nullcontext()
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = forward_step_func(data_iterator, model)
else:
output_tensor, loss_func = forward_step_func(
data_iterator, model, checkpoint_activations_microbatch
)
output_tensor, num_tokens = forward_step_calc_loss(
model,
output_tensor,
loss_func,
config,
vp_stage,
collect_non_loss_data,
num_microbatches,
forward_data_store,
cp_group_size,
is_last_stage,
)
if unwrap_output_tensor:
return output_tensor, num_tokens
return [output_tensor], num_tokens
File: pretrain_gpt.py (L121-157)
python
def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = False):
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
return_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor
"""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
global stimer
with stimer(bdata=True):
vp_stage = get_attr_wrapped_model(model, "vp_stage")
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, vp_stage)
timers('batch-generator').stop()
with stimer:
if args.use_legacy_models:
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
else:
if return_schedule_plan:
assert args.overlap_moe_expert_parallel_comm, \
"overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"
schedule_plan = model.build_schedule_plan(
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)
return schedule_plan, partial(loss_func, loss_mask, model=model)
else:
output_tensor = model(
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)
# [ModelOpt]: model is needed to access ModelOpt distillation losses
return output_tensor, partial(loss_func, loss_mask, model=model)
File: pretrain_mamba.py (L131-154)
python
def forward_step(data_iterator, model: MambaModel):
"""Forward training step.
Args:
data_iterator : Input data iterator
model (MambaModel): The GPT Model
"""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
global stimer
with stimer(bdata=True):
vp_stage = get_attr_wrapped_model(model, "vp_stage")
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, vp_stage)
timers('batch-generator').stop()
with stimer:
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
# [ModelOpt]: model is needed to access ModelOpt distillation losses
return output_tensor, partial(loss_func, loss_mask, model=model)
File: pretrain_t5.py (L174-198)
python
def forward_step(data_iterator, model: T5Model):
"""Forward training step.
Args:
data_iterator : Input data iterator
model (T5Model): The T5 Model
"""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch generator', log_level=2).start()
use_local = args.transformer_impl == "local"
tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = get_batch(
data_iterator, use_local
)
timers('batch generator').stop()
# Forward model lm_labels
output_tensor = model(
tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, lm_labels=lm_labels
)
return output_tensor, partial(loss_func, loss_mask)