【大模型训练】MTPLossLoggingHelper中get_tensor_and_context_parallel_group使用

cpp 复制代码
            mtp_logits, _ = output_layer(
                hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
            )
            # Calc loss for the current Multi-Token Prediction (MTP) layers.
            labels, _ = roll_tensor(labels, shifts=-1, dims=-1)
            loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1)
            mtp_loss = compute_language_model_loss(labels, mtp_logits)
            mtp_loss = loss_mask * mtp_loss
            if self.training:
                MTPLossLoggingHelper.save_loss_to_tracker(
                    torch.sum(mtp_loss) / num_tokens,
                    layer_number,
                    self.config.mtp_num_layers,
                    avg_group=parallel_state.get_tensor_and_context_parallel_group(),
                )

avg_group=parallel_state.get_tensor_and_context_parallel_group(), 有必要用这个嘛?只在CP组内是不是即可。

您提出了一个很好的问题:在 MTP loss logging 中是否真的需要在 TP 维度上平均 loss 。让我分析一下实际的计算流程。

您的理解是正确的

您说得对!在 Megatron-LM 的实现中,输出层的 logits 在计算 loss 之前就已经被聚合了

实际的计算流程

查看 GPTModel._postprocess() 中 MTP loss 的计算: 1

关键在于 self.output_layer() 的调用。这个 output_layer 是一个 ColumnParallelLinear,它的定义中有一个重要参数: 2

注意 gather_output=not self.parallel_output 这个参数。

ColumnParallelLinear 的行为

ColumnParallelLinear 中,如果 gather_output=True,会在返回前执行 gather_from_tensor_model_parallel_region(),将所有 TP ranks 的输出聚合成完整的 logits。

对于 MTP 的情况,由于 parallel_output=True 被传入 GPTModel,所以 gather_output=False,logits 不会被自动聚合

vocab_parallel_cross_entropy 的处理

但是,loss 计算使用的是 vocab_parallel_cross_entropy,它会处理切分的 vocab: 3

这个函数内部会:

  1. 在每个 TP rank 上计算本地 vocab 范围内的 logits 的 softmax
  2. 通过 all_reduce 聚合 predicted_logitssum_exp_logits
  3. 最终每个 TP rank 得到相同的 loss 值

结论

您的理解是对的:vocab_parallel_cross_entropy 内部,通过 all_reduce 操作,每个 TP rank 最终计算出的 loss 是相同的 。因此,在 TP 维度上再做一次平均确实是冗余的,不会改变结果。

为什么代码中使用了 TP+CP 组

代码中使用 get_tensor_and_context_parallel_group() 可能有以下原因:

  1. 保守的设计:确保在所有情况下 loss 都被正确聚合
  2. 未来扩展性:如果将来有不同的 loss 计算方式,这个设计更安全
  3. 代码一致性:与其他地方的 loss logging 保持一致的模式

但从当前的实现来看,只使用 CP 组应该就足够了 ,因为 TP 维度上的 loss 已经通过 vocab_parallel_cross_entropy 内部的 all_reduce 保证一致。

Notes

这是一个很细节的观察。在实际使用中,使用 TP+CP 组不会导致错误(因为对相同的值求平均结果不变),但确实存在优化空间。如果您想修改这部分代码,只使用 CP 组应该是安全的,但需要确保 vocab_parallel_cross_entropy 的行为不会在未来版本中改变。

Wiki pages you might want to explore:

Citations

File: megatron/core/models/gpt/gpt_model.py (L234-247)

python 复制代码
            self.output_layer = tensor_parallel.ColumnParallelLinear(
                config.hidden_size,
                self.vocab_size,
                config=config,
                init_method=config.init_method,
                bias=False,
                skip_bias_add=False,
                gather_output=not self.parallel_output,
                skip_weight_param_allocation=self.pre_process
                and self.share_embeddings_and_output_weights,
                embedding_activation_buffer=self.embedding_activation_buffer,
                grad_output_buffer=self.grad_output_buffer,
                tp_group=self.pg_collection.tp,
            )

File: megatron/core/models/gpt/gpt_model.py (L547-558)

python 复制代码
                # output
                mtp_logits, _ = self.output_layer(
                    hidden_states_list[mtp_layer_number + 1],
                    weight=output_weight,
                    runtime_gather_output=runtime_gather_output,
                )
                # Calc loss for the current Multi-Token Prediction (MTP) layers.
                mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)
                loss_mask, num_tokens = roll_tensor(
                    loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group
                )
                mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)

File: megatron/core/tensor_parallel/cross_entropy.py (L219-233)

python 复制代码
def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0):
    """
    Performs cross entropy loss when logits are split across tensor parallel ranks

    Args:
        vocab_parallel_logits: logits split across tensor parallel ranks
            dimension is [sequence_length, batch_size, vocab_size/num_parallel_ranks]

        target: correct vocab ids of dimseion [sequence_length, micro_batch_size]

        label_smoothing: smoothing factor, must be in range [0.0, 1.0)
                         default is no smoothing (=0.0)
    """
    return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing)
相关推荐
好奇龙猫20 分钟前
【人工智能学习-AI-MIT公开课11. 学习:识别树、无序】
人工智能·学习
Coder_Boy_20 分钟前
基于SpringAI企业级智能教学考试平台智能作业模块全业务闭环方案
java·人工智能·spring·spring cloud
玄同76520 分钟前
我是如何学习编程的?——从 “扳手使用” 到编程学习:踩坑式实践的底层方法论
开发语言·人工智能·经验分享·笔记·python·学习·自然语言处理
IT_陈寒21 分钟前
SpringBoot性能翻倍秘籍:5个被低估的配置项让我QPS提升200%
前端·人工智能·后端
Hcoco_me24 分钟前
大模型面试题25:Softmax函数把“得分”变成“概率”的归一化工具
人工智能·rnn·深度学习·lstm·word2vec
勇气要爆发29 分钟前
Prompt Engineering (提示词工程):如何通过“咒语”驯服 AI?
人工智能·prompt
币之互联万物30 分钟前
中象(深圳)投资集团有限公司推动“中象国际联盟”扬帆起航,面向世界
人工智能
川西胖墩墩41 分钟前
智能体在科研辅助中的自动化实验设计
人工智能·算法
努力的小雨43 分钟前
从“Agent 元年”到 AI IDE 元年——我的2025
ide·人工智能
whltaoin1 小时前
【AI Agent Skills】重塑 AI Agent 竞争力:Skills 体系的核心价值、构建方法与未来方向
大数据·人工智能·agent·agent skills