Part 7. Megatron Model MLP与MoE与EP
相比纷繁复杂的Attention,MLP部分设置(也就是transformer原论文中的FFN)就显得眉清目秀了。之前Linear层已经介绍了很大一部分复杂的底层,这次来研究一下偏上面一层是怎么实现的。
还是在熟悉的get_gpt_layer_specs.py中可以找到MLP的相关内容。MLP是作为transformerlayer的一个子模块传入的,因此构建模型的时候会单独构建之后塞到完整的ModuleSpec之中。
get_mlp_module_spec函数主要的内容是,先判断是不是MoE(num experts为None),再判断是否使用transformer engine。对于Dense模型,子模块就是两层,一层Column接一层Row,原因在Part 5中已经说过。简单来说就是,如果用两个Column会导致中间需要一次All gather和一次reduce scatter。接一层Row可以去掉中间的这两次通信,改为在最后一次all reduce。
MLPSubmodules没什么特别的,主要看MoE的实现。get_moe_module_spec在单独的moe_module_specs文件中。每个小专家本质上也是两层的MLP。
MoELayer保存在transformer/moe/moe_layer.py中。一个MoELayer需要包含:
- 一个TopKRouter,路由网络。保存在moe/router下。这个模块相当重要且非常复杂
- experts和shared experts,shared experts在所有ep rank上复制,local experts有一个起始和结束索引来管理
- token dispatcher,负责all to all通信,把token发送到各个rank以及回收。all to all反向也是all to all
Init时,先设置一个router,再根据配置挑选一个dispatcher,最后设置好专家模块。接下来主要以DeepSeek R1的配置来介绍。
Router:
路由模块可以说是MoE的算法灵魂。其重要性可以直接从分配给他的dtype看出来,甚至可以开启fp64双精度。如果有expert bias,甚至会单独保持bias为fp32精度。
该模块在forward的时候,首先对输入[seq,bs,hidden]做一些扰动(原始Switch Transformers使用,DS不使用),然后input过gating获取logits [seq,bs,e],最后根据logits正式过routing获取分数[seq x bs, e]和routing map [seq x bs, e]。
-
Gating很简单,就是一层线性层。计算之前把权重和输入都换成高精度。DS使用的是FP32


-
然后进Routing。logits变成[seq x bs, expert]的形状。因为此时不同batch和seq的token已经可以独立处理了所以能变成超大矩阵处理。
-
Apply Z loss。在ST MoE论文中,通过在loss里加上一个max(logits)的损失,防止logits差异过大,最后可能只有一两个专家分数很高的情况。这里用的是mean(square(logsumexp。logsumexp是一个求最大值的近似操作,如果用max会导致只有最大值处梯度为1,剩下为0。这个公式提供了平滑的梯度(梯度其实就是softmax),同时他是max的上界。有一个突出值时logsumexp=max+log( 1+ sum (exp (x_i - max) ) ),平均时 = max+log (num_experts),有max(x) ≤ logsumexp(x) ≤ max(x) + log(n)。之后在logsumexp基础上平方、求平均、乘定义的系数。平方能保证正定,平滑。之后把loss注册到MoEAuxLossAutoScaler中保障梯度正确传入。
MoEAuxLossAutoScaler使用在反向传播的时候,前向的时候仅保存计算的z loss,apply操作不对logits处理。由于aux loss不在主梯度流之中,想要让router受到它的影响就需要在main loss 反向计算的时候,梯度到达此处,触发scaler。scaler最后会返回scaled aux loss grad,torch观察到这个tensor需要梯度时,会沿着计算图继续计算aux loss,无需显式main loss + aux loss

- 如果启用SP,则把TPrank的对应logits都取回来(保障相同TP组内有相同的logits)
- Routing有四种方式。
- 朴素的topk用在推理的时候,不关心负载均衡,直接用topk。在deepseek中,会考虑到不同节点上有不同的experts,可能会让token尽量路由到按node分组的专家中,使用group limited topk,比如256个专家分成8组,放在八个节点上。先选top 4个组(得分最高的两个专家之和为组分数),然后在这四个组中选择top k比如8个专家。这样能保证选中的专家都在确定的范围内,尽量减少跨机通信;计算专家分数使用softmax或者sigmoid,ds选择sigmoid,此时后者不会归一化分数,从而让高贡献专家的分数差一步那么大,之后再归一化。获取topk专家的得分以及对应索引后,scale一下(ds=2.5),做具体的选择。每个专家可能会有一个负载量即平均专家处理多少token,超出的部分会直接不要了(ds不丢token)。注意,DS使用的专家负载均衡很简单,设置一个bias,如果该专家token拿的比较多,就把bias往下减一个factor x offset,否则往上加,这样选的多的专家得分就小一点,反之大一点。DS的bias在global batch结束后更新。
- aux loss load balancing在算完topk softmax后,对logits归一化,计算aux_loss = sum((probs_per_expert/num_tokens) * (tokens_per_expert/(num_tokens*topk))) * num_experts * moe_aux_loss_coeff;sinkhorn版本强制均衡
- Seq aux loss进一步要求在seq维度内也添加均衡损失
以下三行比较绕。第一行:生成一个mask,[num_tokens, experts],选中的专家那个元素设置为对应的专家概率分数,没选中的为0;第二行:选中的专家为True,没选中的位置为False;第三行:每一列的值求和,每个专家有多少个token。
Dispatcher:
分发器的作用是,把token根据路由的结果分配给指定的rank。目前有四种。
- AllGather:每个 GPU 从所有其他GPU上收集所有 tokens,但只计算分配给本地专家的部分。实现和通信简单,但显存开销大(每个 GPU 都存储所有 tokens),适合专家数量少、EP 规模小的场景
- AlltoAll:大规模EP下,每个GPU点对点通信,每个 GPU 只发送/接收需要的 tokens。显存高效,适合大规模 EP,是生产环境首选。deepseek用的就是 --moe-token-dispatcher-type alltoall
- AlltoAllSeq:结合了SP的版本。由于有SP,每个GPU只有序列的部分内容,所以分发更加细致
- Flex:使用了DeepSeek提出的DeepEP的高效调度内核的版本。他是一个高度优化的AlltoAll通信内核,能重叠通信与计算。
分发之前,先对token重排列(permutation)。原本的token排序是按序列来的,但是不同的专家是要计算不同的token的,所以重排列可以让相同专家的token放在连续的位置。
在AllGather版本中,gather所有token的map,probs和hidden,转置map [e, tokens],每一行找到true在哪些位置获取索引,最后按照索引按照专家顺序开始排列token,以及准备好后面怎么把token反着放回去以及做reduce scatter输出到其他tp-ep rank。
AlltoAll:先预处理preprocess,根据map统计expert的token数,生成input/output splits,然后permutation排序,使用all to all跨EP发送tokens/probs,如果有TP则聚合切片,然后再排序本地expert的tokens。unpermute时reduce scatter回TP,all to all反向回EP,unpermute恢复原本顺序。
AlltoAll Seq已经只用作兼容了,建议统一用AlltoAll。
flex使用了fused_dispatch,一个kernel完成permute+alltoall,最后combine也是fused_combine把反向permute+alltoall合一。



AlltoAll发送的时候,需要提供input_splits和output_splits,即要发送的内容中哪一些是发给哪个EP rank的,接受的时候要从哪些rank接收。DS的dropless是:输入[num_tokens, num_experts],tokens维度求和[num_experts]即当前rank发给每个expert的数量,num_out_tokens即需要发送的总token数量(topK x num tokens )。然后input_splits初始化为[ep_size, num_local_experts],本rank发送给每个ep rank的每个专家的token数,每行求和为发送个每个rank的token总数。接着算output_splits,先all gather一下tp-ep组,在0维度上拼接,换成[tp,ep,experts]的矩阵。含义是所有EP-TP rank在整个通信过程中会向给定专家expert发送多少token。此时再取本TP-EP rank上experts的切片,再把本TP-EP上所有expert接收到的token都求和 [tp_size, ep_size, num_local_experts] -> [tp_size, ep_size]。此时根据TP,取自己的那一行即可;想知道每一个expert会收到多少token,沿着TP求和后再沿着EP求和。
一个简单例子(dropless,ep_size=2,tp_size=1,num_local_experts=2,总 expert=4):
- Rank0 本地计数
[e0:5, e1:3, e2:4, e3:2];Rank1 本地计数[e0:1, e1:6, e2:2, e3:7]。 - Rank0
input_splits= [5+3, 4+2] = [8, 6](发给 ep_rank0 和 ep_rank1)。 - Rank1
input_splits= [1+6, 2+7] = [7, 9]。 - 全局计数(gather 后求和):
- e0: 5+1=6, e1:3+6=9, e2:4+2=6, e3:2+7=9。
- Rank0 负责本地 expert e0,e1 ⇒
output_splits= [本 rank 从 ep0 收到 e0+e1 的 5+3=8, 从 ep1 收到 e0+e1 的 1+6=7] = [8,7]。 - Rank1 负责 e2,e3 ⇒
output_splits= [4+2=6, 2+7=9] = [6,9]。 - Permute1 时,每个 rank 把 tokens 按目标 ep_rank 拼块,长度就是
sum(input_splits);alltoall 用 input/output_splits 对齐收发。之后再按本地 expert 重新排序。
input_split长度 ep_size:本 rank 发给每个 EP rank 的总 token 数。 output_splits长度 ep_size:本 rank 将从每个 EP rank 收到的 token 数。output_splits_tp 长度 tp_size:TP 方向每个 tp_rank 收到的 token 数。num_global_tokens_per_local_expert:每个本地 expert 的总 token 数(用来后续排序)。
DtoH Data Copying 就是把数据从 GPU 设备内存拷到主机内存(Device-to-Host,简称 D2H/DtoH)。在这个代码里(_maybe_dtoh_and_synchronize),主要目的是把一些元数据(比如input_splits/output_splits/num_out_tokens等)转移到 CPU,以便给通信库/后续逻辑使用,同时在合适的点对 side stream 做同步,确保这些 CPU 侧数据已经可用。
专家:
此时我们获得了本地专家需要处理的所有token以及对应的矩阵。此时的logits是[token
共享专家就比较简单,他的输出最后是被加到普通专家之后的。这里的gate是一个可学习的参数,作用是对共享专家的输出做一个加权。但是基本没什么模型在用。当使用SP的时候,计算前要all gather,前向后要reduce scatter。为了减少计算开销,可以把共享专家的计算和token dispatch并行起来。
路由专家有两种,一种是普通的sequential实现,还有一种是更高效的group GEMM实现。
- 分组实现把专家权重分成FC1和FC2两大块,所有专辑的FC1和FC2分别拼接成巨大的块,weight1和weight2.他们的形状是[hidden, moe_hidden * num_experts / TP]以及反过来降维回去。所有专家的权重在内存中连续排列,不单独为每个专家分配Module实例,在前向时可以通过.view()重塑专家维度。
- 假如收到了num_tokens, hidden的输入,以及tokens_per_expert [expert]和对应的概率permuted_probs [numtokens]。
- 如果采用SwiGLU,FC1层的输出宽度要乘2,这是因为GLU需要一半参数来算门控,另一半来算具体的输出。
- 把权重矩阵1[hidden, experts x FFNhidden x 2] 重塑为[experts, hidden, FFNhidden x 2],相当于experts个矩阵。矩阵2就是[experts, FFNhidden, hidden]
- 调用gg,grouped gemm utils的批量矩阵乘法获得FC1输出。FC1输出后和SwiGLU的gate分数相乘,然后和输入的probs相乘。这种做法保障了普通MoE的等价性,即P乘在输出之后。如果在FC1之前乘(也有可选项),数学上不是等价的(权重会影响到门控的分数),目前只有topK=1允许这样做。DS就是中间乘,并且MLP层没有bias
- 序列实现使用的是sequentialMLP,实现非常简单。把输入按照tokens per expert拆分,然后for循环每个专家乘。速度慢,但是整体实现复杂度比较低。
GLU家族,Gated Linear Unit,门控线性单元,其激活函数是
GLU(x) = (xW_gate ⊗ σ) ⊙ (xW_up) = σ(xW_gate) ⊙ (xW_up),
W_gate是门控投影权重,W_up是上投影,σ是激活函数,⊗是Hadamard乘法逐元素乘法。
SwiGLU(x) = SiLU(xW_gate) ⊙ (xW_up)= (xW_gate × sigmoid(xW_gate)) ⊙ (xW_up)。SiLU(x) = x × sigmoid(x) = x × (1 / (1 + e^(-x)))也叫 Swish,是一种平滑的非线性激活函数,当 x → +∞ 时,SiLU(x) → x,当 x → -∞ 时,SiLU(x) → 0,在 x=0 附近平滑过渡。

