【大模型训练】MTPLossLoggingHelper中get_tensor_and_context_parallel_group使用

cpp 复制代码
            mtp_logits, _ = output_layer(
                hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
            )
            # Calc loss for the current Multi-Token Prediction (MTP) layers.
            labels, _ = roll_tensor(labels, shifts=-1, dims=-1)
            loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1)
            mtp_loss = compute_language_model_loss(labels, mtp_logits)
            mtp_loss = loss_mask * mtp_loss
            if self.training:
                MTPLossLoggingHelper.save_loss_to_tracker(
                    torch.sum(mtp_loss) / num_tokens,
                    layer_number,
                    self.config.mtp_num_layers,
                    avg_group=parallel_state.get_tensor_and_context_parallel_group(),
                )

avg_group=parallel_state.get_tensor_and_context_parallel_group(), 有必要用这个嘛?只在CP组内是不是即可。

您提出了一个很好的问题:在 MTP loss logging 中是否真的需要在 TP 维度上平均 loss 。让我分析一下实际的计算流程。

您的理解是正确的

您说得对!在 Megatron-LM 的实现中,输出层的 logits 在计算 loss 之前就已经被聚合了

实际的计算流程

查看 GPTModel._postprocess() 中 MTP loss 的计算: 1

关键在于 self.output_layer() 的调用。这个 output_layer 是一个 ColumnParallelLinear,它的定义中有一个重要参数: 2

注意 gather_output=not self.parallel_output 这个参数。

ColumnParallelLinear 的行为

ColumnParallelLinear 中,如果 gather_output=True,会在返回前执行 gather_from_tensor_model_parallel_region(),将所有 TP ranks 的输出聚合成完整的 logits。

对于 MTP 的情况,由于 parallel_output=True 被传入 GPTModel,所以 gather_output=False,logits 不会被自动聚合

vocab_parallel_cross_entropy 的处理

但是,loss 计算使用的是 vocab_parallel_cross_entropy,它会处理切分的 vocab: 3

这个函数内部会:

  1. 在每个 TP rank 上计算本地 vocab 范围内的 logits 的 softmax
  2. 通过 all_reduce 聚合 predicted_logitssum_exp_logits
  3. 最终每个 TP rank 得到相同的 loss 值

结论

您的理解是对的:vocab_parallel_cross_entropy 内部,通过 all_reduce 操作,每个 TP rank 最终计算出的 loss 是相同的 。因此,在 TP 维度上再做一次平均确实是冗余的,不会改变结果。

为什么代码中使用了 TP+CP 组

代码中使用 get_tensor_and_context_parallel_group() 可能有以下原因:

  1. 保守的设计:确保在所有情况下 loss 都被正确聚合
  2. 未来扩展性:如果将来有不同的 loss 计算方式,这个设计更安全
  3. 代码一致性:与其他地方的 loss logging 保持一致的模式

但从当前的实现来看,只使用 CP 组应该就足够了 ,因为 TP 维度上的 loss 已经通过 vocab_parallel_cross_entropy 内部的 all_reduce 保证一致。

Notes

这是一个很细节的观察。在实际使用中,使用 TP+CP 组不会导致错误(因为对相同的值求平均结果不变),但确实存在优化空间。如果您想修改这部分代码,只使用 CP 组应该是安全的,但需要确保 vocab_parallel_cross_entropy 的行为不会在未来版本中改变。

Wiki pages you might want to explore:

Citations

File: megatron/core/models/gpt/gpt_model.py (L234-247)

python 复制代码
            self.output_layer = tensor_parallel.ColumnParallelLinear(
                config.hidden_size,
                self.vocab_size,
                config=config,
                init_method=config.init_method,
                bias=False,
                skip_bias_add=False,
                gather_output=not self.parallel_output,
                skip_weight_param_allocation=self.pre_process
                and self.share_embeddings_and_output_weights,
                embedding_activation_buffer=self.embedding_activation_buffer,
                grad_output_buffer=self.grad_output_buffer,
                tp_group=self.pg_collection.tp,
            )

File: megatron/core/models/gpt/gpt_model.py (L547-558)

python 复制代码
                # output
                mtp_logits, _ = self.output_layer(
                    hidden_states_list[mtp_layer_number + 1],
                    weight=output_weight,
                    runtime_gather_output=runtime_gather_output,
                )
                # Calc loss for the current Multi-Token Prediction (MTP) layers.
                mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)
                loss_mask, num_tokens = roll_tensor(
                    loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group
                )
                mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)

File: megatron/core/tensor_parallel/cross_entropy.py (L219-233)

python 复制代码
def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0):
    """
    Performs cross entropy loss when logits are split across tensor parallel ranks

    Args:
        vocab_parallel_logits: logits split across tensor parallel ranks
            dimension is [sequence_length, batch_size, vocab_size/num_parallel_ranks]

        target: correct vocab ids of dimseion [sequence_length, micro_batch_size]

        label_smoothing: smoothing factor, must be in range [0.0, 1.0)
                         default is no smoothing (=0.0)
    """
    return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing)
相关推荐
这张生成的图像能检测吗1 小时前
(论文速读)SpiralMLP:一个轻量级的视觉MLP架构
图像处理·人工智能·深度学习·计算机视觉·mlp框架·分类、检测、分割
doubao361 小时前
如何在海量文献中高效筛选有价值信息
人工智能·学习·自然语言处理·aigc·ai工具·ai检索
美团技术团队1 小时前
美团 LongCat Interaction 团队发布大模型交互系统技术报告 WOWService
人工智能
烤麻辣烫1 小时前
AI(新手)
人工智能·学习·机器学习·ai编程
CoovallyAIHub1 小时前
分割万事万物的AI,再进化!Meta SAM 3 来了,支持中文提示词!
深度学习·算法·计算机视觉
虹科网络安全1 小时前
从AI模型到云生态:构建系统化的企业AI安全管理体系【系列文章(1)】
人工智能·安全
G***E3161 小时前
机器学习特征选择方法
人工智能·机器学习
赋创小助手2 小时前
英特尔确认取消 8 通道 Diamond Rapids:服务器 CPU 战局再度升级
服务器·图像处理·人工智能·深度学习·计算机视觉·自然语言处理·自动驾驶
用户5191495848452 小时前
Rust 1.91.0 发布:新增平台支持与安全增强
人工智能·aigc