cpp
mtp_logits, _ = output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
# Calc loss for the current Multi-Token Prediction (MTP) layers.
labels, _ = roll_tensor(labels, shifts=-1, dims=-1)
loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1)
mtp_loss = compute_language_model_loss(labels, mtp_logits)
mtp_loss = loss_mask * mtp_loss
if self.training:
MTPLossLoggingHelper.save_loss_to_tracker(
torch.sum(mtp_loss) / num_tokens,
layer_number,
self.config.mtp_num_layers,
avg_group=parallel_state.get_tensor_and_context_parallel_group(),
)
avg_group=parallel_state.get_tensor_and_context_parallel_group(), 有必要用这个嘛?只在CP组内是不是即可。
您提出了一个很好的问题:在 MTP loss logging 中是否真的需要在 TP 维度上平均 loss 。让我分析一下实际的计算流程。
您的理解是正确的
您说得对!在 Megatron-LM 的实现中,输出层的 logits 在计算 loss 之前就已经被聚合了 。
实际的计算流程
查看 GPTModel._postprocess() 中 MTP loss 的计算: 1
关键在于 self.output_layer() 的调用。这个 output_layer 是一个 ColumnParallelLinear,它的定义中有一个重要参数: 2
注意 gather_output=not self.parallel_output 这个参数。
ColumnParallelLinear 的行为
在 ColumnParallelLinear 中,如果 gather_output=True,会在返回前执行 gather_from_tensor_model_parallel_region(),将所有 TP ranks 的输出聚合成完整的 logits。
对于 MTP 的情况,由于 parallel_output=True 被传入 GPTModel,所以 gather_output=False,logits 不会被自动聚合 。
vocab_parallel_cross_entropy 的处理
但是,loss 计算使用的是 vocab_parallel_cross_entropy,它会处理切分的 vocab: 3
这个函数内部会:
- 在每个 TP rank 上计算本地 vocab 范围内的 logits 的 softmax
- 通过 all_reduce 聚合
predicted_logits和sum_exp_logits - 最终每个 TP rank 得到相同的 loss 值
结论
您的理解是对的:在 vocab_parallel_cross_entropy 内部,通过 all_reduce 操作,每个 TP rank 最终计算出的 loss 是相同的 。因此,在 TP 维度上再做一次平均确实是冗余的,不会改变结果。
为什么代码中使用了 TP+CP 组
代码中使用 get_tensor_and_context_parallel_group() 可能有以下原因:
- 保守的设计:确保在所有情况下 loss 都被正确聚合
- 未来扩展性:如果将来有不同的 loss 计算方式,这个设计更安全
- 代码一致性:与其他地方的 loss logging 保持一致的模式
但从当前的实现来看,只使用 CP 组应该就足够了 ,因为 TP 维度上的 loss 已经通过 vocab_parallel_cross_entropy 内部的 all_reduce 保证一致。
Notes
这是一个很细节的观察。在实际使用中,使用 TP+CP 组不会导致错误(因为对相同的值求平均结果不变),但确实存在优化空间。如果您想修改这部分代码,只使用 CP 组应该是安全的,但需要确保 vocab_parallel_cross_entropy 的行为不会在未来版本中改变。
Wiki pages you might want to explore:
Citations
File: megatron/core/models/gpt/gpt_model.py (L234-247)
python
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
tp_group=self.pg_collection.tp,
)
File: megatron/core/models/gpt/gpt_model.py (L547-558)
python
# output
mtp_logits, _ = self.output_layer(
hidden_states_list[mtp_layer_number + 1],
weight=output_weight,
runtime_gather_output=runtime_gather_output,
)
# Calc loss for the current Multi-Token Prediction (MTP) layers.
mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)
loss_mask, num_tokens = roll_tensor(
loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group
)
mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
File: megatron/core/tensor_parallel/cross_entropy.py (L219-233)
python
def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0):
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Args:
vocab_parallel_logits: logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, vocab_size/num_parallel_ranks]
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
label_smoothing: smoothing factor, must be in range [0.0, 1.0)
default is no smoothing (=0.0)
"""
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing)