【大模型训练】MTPLossLoggingHelper中get_tensor_and_context_parallel_group使用

cpp 复制代码
            mtp_logits, _ = output_layer(
                hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
            )
            # Calc loss for the current Multi-Token Prediction (MTP) layers.
            labels, _ = roll_tensor(labels, shifts=-1, dims=-1)
            loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1)
            mtp_loss = compute_language_model_loss(labels, mtp_logits)
            mtp_loss = loss_mask * mtp_loss
            if self.training:
                MTPLossLoggingHelper.save_loss_to_tracker(
                    torch.sum(mtp_loss) / num_tokens,
                    layer_number,
                    self.config.mtp_num_layers,
                    avg_group=parallel_state.get_tensor_and_context_parallel_group(),
                )

avg_group=parallel_state.get_tensor_and_context_parallel_group(), 有必要用这个嘛?只在CP组内是不是即可。

您提出了一个很好的问题:在 MTP loss logging 中是否真的需要在 TP 维度上平均 loss 。让我分析一下实际的计算流程。

您的理解是正确的

您说得对!在 Megatron-LM 的实现中,输出层的 logits 在计算 loss 之前就已经被聚合了

实际的计算流程

查看 GPTModel._postprocess() 中 MTP loss 的计算: 1

关键在于 self.output_layer() 的调用。这个 output_layer 是一个 ColumnParallelLinear,它的定义中有一个重要参数: 2

注意 gather_output=not self.parallel_output 这个参数。

ColumnParallelLinear 的行为

ColumnParallelLinear 中,如果 gather_output=True,会在返回前执行 gather_from_tensor_model_parallel_region(),将所有 TP ranks 的输出聚合成完整的 logits。

对于 MTP 的情况,由于 parallel_output=True 被传入 GPTModel,所以 gather_output=False,logits 不会被自动聚合

vocab_parallel_cross_entropy 的处理

但是,loss 计算使用的是 vocab_parallel_cross_entropy,它会处理切分的 vocab: 3

这个函数内部会:

  1. 在每个 TP rank 上计算本地 vocab 范围内的 logits 的 softmax
  2. 通过 all_reduce 聚合 predicted_logitssum_exp_logits
  3. 最终每个 TP rank 得到相同的 loss 值

结论

您的理解是对的:vocab_parallel_cross_entropy 内部,通过 all_reduce 操作,每个 TP rank 最终计算出的 loss 是相同的 。因此,在 TP 维度上再做一次平均确实是冗余的,不会改变结果。

为什么代码中使用了 TP+CP 组

代码中使用 get_tensor_and_context_parallel_group() 可能有以下原因:

  1. 保守的设计:确保在所有情况下 loss 都被正确聚合
  2. 未来扩展性:如果将来有不同的 loss 计算方式,这个设计更安全
  3. 代码一致性:与其他地方的 loss logging 保持一致的模式

但从当前的实现来看,只使用 CP 组应该就足够了 ,因为 TP 维度上的 loss 已经通过 vocab_parallel_cross_entropy 内部的 all_reduce 保证一致。

Notes

这是一个很细节的观察。在实际使用中,使用 TP+CP 组不会导致错误(因为对相同的值求平均结果不变),但确实存在优化空间。如果您想修改这部分代码,只使用 CP 组应该是安全的,但需要确保 vocab_parallel_cross_entropy 的行为不会在未来版本中改变。

Wiki pages you might want to explore:

Citations

File: megatron/core/models/gpt/gpt_model.py (L234-247)

python 复制代码
            self.output_layer = tensor_parallel.ColumnParallelLinear(
                config.hidden_size,
                self.vocab_size,
                config=config,
                init_method=config.init_method,
                bias=False,
                skip_bias_add=False,
                gather_output=not self.parallel_output,
                skip_weight_param_allocation=self.pre_process
                and self.share_embeddings_and_output_weights,
                embedding_activation_buffer=self.embedding_activation_buffer,
                grad_output_buffer=self.grad_output_buffer,
                tp_group=self.pg_collection.tp,
            )

File: megatron/core/models/gpt/gpt_model.py (L547-558)

python 复制代码
                # output
                mtp_logits, _ = self.output_layer(
                    hidden_states_list[mtp_layer_number + 1],
                    weight=output_weight,
                    runtime_gather_output=runtime_gather_output,
                )
                # Calc loss for the current Multi-Token Prediction (MTP) layers.
                mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)
                loss_mask, num_tokens = roll_tensor(
                    loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group
                )
                mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)

File: megatron/core/tensor_parallel/cross_entropy.py (L219-233)

python 复制代码
def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0):
    """
    Performs cross entropy loss when logits are split across tensor parallel ranks

    Args:
        vocab_parallel_logits: logits split across tensor parallel ranks
            dimension is [sequence_length, batch_size, vocab_size/num_parallel_ranks]

        target: correct vocab ids of dimseion [sequence_length, micro_batch_size]

        label_smoothing: smoothing factor, must be in range [0.0, 1.0)
                         default is no smoothing (=0.0)
    """
    return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing)
相关推荐
IT WorryFree4 分钟前
OpenClaw-Medical-Skills 仓库介绍
人工智能·skill·openclaw
多年小白5 分钟前
今日AI科技简报 | 2026年3月19日
人工智能·科技·ai编程
逄逄不是胖胖12 分钟前
《动手学深度学习》-69预训练bert数据集实现
人工智能·深度学习·bert
IT_陈寒17 分钟前
Python开发者的效率革命:这5个技巧让你的代码提速50%!
前端·人工智能·后端
用户693717500138419 分钟前
不卷AI速度,我卷自己的从容——北京程序员手记
android·前端·人工智能
love530love23 分钟前
不用聊天软件 OpenClaw 手机浏览器远程访问控制:Tailscale 配置、设备配对与常见问题全解
人工智能·windows·python·智能手机·tailscale·openclaw·远程访问控制
lifallen30 分钟前
从零推导多 Agent 协作网络 (Flow Agent)
人工智能·语言模型
CoovallyAIHub32 分钟前
2.5GB 塞进浏览器:Mistral 开源实时语音识别,延迟不到半秒
深度学习·算法·计算机视觉
guoji778834 分钟前
2026年Gemini 3 Pro vs 豆包2.0深度评测:海外顶流与国产黑马谁更强?
大数据·人工智能·架构
NAGNIP39 分钟前
一文搞懂深度学习中的损失函数设计!
人工智能·算法