【大模型训练】MTPLossLoggingHelper中get_tensor_and_context_parallel_group使用

cpp 复制代码
            mtp_logits, _ = output_layer(
                hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
            )
            # Calc loss for the current Multi-Token Prediction (MTP) layers.
            labels, _ = roll_tensor(labels, shifts=-1, dims=-1)
            loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1)
            mtp_loss = compute_language_model_loss(labels, mtp_logits)
            mtp_loss = loss_mask * mtp_loss
            if self.training:
                MTPLossLoggingHelper.save_loss_to_tracker(
                    torch.sum(mtp_loss) / num_tokens,
                    layer_number,
                    self.config.mtp_num_layers,
                    avg_group=parallel_state.get_tensor_and_context_parallel_group(),
                )

avg_group=parallel_state.get_tensor_and_context_parallel_group(), 有必要用这个嘛?只在CP组内是不是即可。

您提出了一个很好的问题:在 MTP loss logging 中是否真的需要在 TP 维度上平均 loss 。让我分析一下实际的计算流程。

您的理解是正确的

您说得对!在 Megatron-LM 的实现中,输出层的 logits 在计算 loss 之前就已经被聚合了

实际的计算流程

查看 GPTModel._postprocess() 中 MTP loss 的计算: 1

关键在于 self.output_layer() 的调用。这个 output_layer 是一个 ColumnParallelLinear,它的定义中有一个重要参数: 2

注意 gather_output=not self.parallel_output 这个参数。

ColumnParallelLinear 的行为

ColumnParallelLinear 中,如果 gather_output=True,会在返回前执行 gather_from_tensor_model_parallel_region(),将所有 TP ranks 的输出聚合成完整的 logits。

对于 MTP 的情况,由于 parallel_output=True 被传入 GPTModel,所以 gather_output=False,logits 不会被自动聚合

vocab_parallel_cross_entropy 的处理

但是,loss 计算使用的是 vocab_parallel_cross_entropy,它会处理切分的 vocab: 3

这个函数内部会:

  1. 在每个 TP rank 上计算本地 vocab 范围内的 logits 的 softmax
  2. 通过 all_reduce 聚合 predicted_logitssum_exp_logits
  3. 最终每个 TP rank 得到相同的 loss 值

结论

您的理解是对的:vocab_parallel_cross_entropy 内部,通过 all_reduce 操作,每个 TP rank 最终计算出的 loss 是相同的 。因此,在 TP 维度上再做一次平均确实是冗余的,不会改变结果。

为什么代码中使用了 TP+CP 组

代码中使用 get_tensor_and_context_parallel_group() 可能有以下原因:

  1. 保守的设计:确保在所有情况下 loss 都被正确聚合
  2. 未来扩展性:如果将来有不同的 loss 计算方式,这个设计更安全
  3. 代码一致性:与其他地方的 loss logging 保持一致的模式

但从当前的实现来看,只使用 CP 组应该就足够了 ,因为 TP 维度上的 loss 已经通过 vocab_parallel_cross_entropy 内部的 all_reduce 保证一致。

Notes

这是一个很细节的观察。在实际使用中,使用 TP+CP 组不会导致错误(因为对相同的值求平均结果不变),但确实存在优化空间。如果您想修改这部分代码,只使用 CP 组应该是安全的,但需要确保 vocab_parallel_cross_entropy 的行为不会在未来版本中改变。

Wiki pages you might want to explore:

Citations

File: megatron/core/models/gpt/gpt_model.py (L234-247)

python 复制代码
            self.output_layer = tensor_parallel.ColumnParallelLinear(
                config.hidden_size,
                self.vocab_size,
                config=config,
                init_method=config.init_method,
                bias=False,
                skip_bias_add=False,
                gather_output=not self.parallel_output,
                skip_weight_param_allocation=self.pre_process
                and self.share_embeddings_and_output_weights,
                embedding_activation_buffer=self.embedding_activation_buffer,
                grad_output_buffer=self.grad_output_buffer,
                tp_group=self.pg_collection.tp,
            )

File: megatron/core/models/gpt/gpt_model.py (L547-558)

python 复制代码
                # output
                mtp_logits, _ = self.output_layer(
                    hidden_states_list[mtp_layer_number + 1],
                    weight=output_weight,
                    runtime_gather_output=runtime_gather_output,
                )
                # Calc loss for the current Multi-Token Prediction (MTP) layers.
                mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)
                loss_mask, num_tokens = roll_tensor(
                    loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group
                )
                mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)

File: megatron/core/tensor_parallel/cross_entropy.py (L219-233)

python 复制代码
def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0):
    """
    Performs cross entropy loss when logits are split across tensor parallel ranks

    Args:
        vocab_parallel_logits: logits split across tensor parallel ranks
            dimension is [sequence_length, batch_size, vocab_size/num_parallel_ranks]

        target: correct vocab ids of dimseion [sequence_length, micro_batch_size]

        label_smoothing: smoothing factor, must be in range [0.0, 1.0)
                         default is no smoothing (=0.0)
    """
    return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing)
相关推荐
风象南2 小时前
普通人用AI加持赚到的第一个100块
人工智能·后端
牛奶3 小时前
2026年大模型怎么选?前端人实用对比
前端·人工智能·ai编程
牛奶3 小时前
前端人为什么要学AI?
前端·人工智能·ai编程
罗西的思考6 小时前
AI Agent框架探秘:拆解 OpenHands(10)--- Runtime
人工智能·算法·机器学习
冬奇Lab6 小时前
OpenClaw 源码精读(2):Channel & Routing——一条消息如何找到它的 Agent?
人工智能·开源·源码阅读
冬奇Lab6 小时前
一天一个开源项目(第38篇):Claude Code Telegram - 用 Telegram 远程用 Claude Code,随时随地聊项目
人工智能·开源·资讯
格砸8 小时前
从入门到辞职|从ChatGPT到OpenClaw,跟上智能时代的进化
前端·人工智能·后端
可观测性用观测云8 小时前
可观测性 4.0:教系统如何思考
人工智能
sunny8658 小时前
Claude Code 跨会话上下文恢复:从 8 次纠正到 0 次的工程实践
人工智能·开源·github
小笼包包仔8 小时前
OpenClaw 多Agent软件开发最佳实践指南
人工智能