LLM-leetcode TASK05

07. MoE Load Balancing Loss | MoE 进阶:负载均衡损失函数 (Load Balancing Loss)

难度: Hard | 标签: MoE, Loss Function, Mixtral | 目标人群: 核心 Infra 与算子开发

在上一节 06_MoE_Router 中,我们实现了 Top-K 路由。但在真实的 MoE 模型(如 Mixtral 8x7B, DeepSeek)训练中,会遇到一个非常严重的问题:路由崩塌 (Router Collapse)

即门控网络"偷懒",把所有的 Token 都发给了第 0 号和第 1 号专家,导致其他专家被饿死(闲置),不仅失去了 MoE 的意义,还会导致算力非常不均衡(OOM)。

因此,面试官非常爱考:如何用代码实现 MoE 的辅助损失函数 (Auxiliary Loss) 来强制负载均衡?

Step 1: 核心数学公式

为了让 NNN 个 Token 均匀地分配给 EEE 个专家,我们需要设计一个惩罚项,加到总的 CrossEntropy Loss 里。

Mixtral / Switch Transformer 使用的经典公式:

Laux=α⋅E∑i=1Efi⋅Pi L_{aux} = \alpha \cdot E \sum_{i=1}^E f_i \cdot P_i Laux=α⋅Ei=1∑Efi⋅Pi

  • EEE: 专家总数。
  • fif_ifi: 专家 iii 被路由到的 Token 比例 (即选了专家 iii 的 token 数 / 总 token 数)。
  • PiP_iPi: 专家 iii 在所有 Token 上的 平均路由概率得分(Softmax 之后的概率的均值)。
  • α\alphaα: 辅助损失的权重系数(通常很小,如 0.01)。

为什么这个公式有效?

根据均值不等式,给定总和为 1 的 fff 和 PPP,当且仅当所有的 fi=1/Ef_i = 1/Efi=1/E 且 Pi=1/EP_i = 1/EPi=1/E 时(即绝对均匀分配),它们的内积(点积)之和最小。优化器为了降低这个 Loss,会拼命把 Token 往不同的专家那里赶!

Step 2: 代码实现框架

你需要统计在当前批次中每个专家实际被选中的次数(形成频率分布 fif_ifi),同时求出门控概率的均值分布(PiP_iPi)。将这两个分布点乘并乘以专家总数 EEE 和超参数 α\alphaα,即可得到最终的 Load Balancing Loss。

关键点 :本实现支持 Top-K 路由(不仅限于 Top-1),通过 top_k 参数控制每个 Token 选择的专家数量。

Step 3: 动手实战

要求 :请补全下方 compute_load_balancing_loss 的逻辑。

注意:本实现支持 Top-K 路由,即每个 Token 可以选择 K 个专家(通常 K=2)。

复制代码
def compute_load_balancing_loss(
    routing_weights: torch.Tensor, 
    selected_experts: torch.Tensor, 
    num_experts: int, 
    top_k: int,
    alpha: float = 0.01
):
    """
    计算 MoE 的负载均衡辅助损失(支持 Top-K 路由)
    
    Args:
        routing_weights: [batch_size * seq_len, top_k],每个 token 选中的 K 个专家的权重(已归一化)
        selected_experts: [batch_size * seq_len, top_k],每个 token 选中的 K 个专家的索引
        num_experts: 专家总数 E
        top_k: 每个 token 选择的专家数量 K
        alpha: 损失权重系数
    
    Returns:
        aux_loss: 标量,负载均衡损失
    """
    batch_size_x_seq_len, _ = selected_experts.shape
    total_tokens = batch_size_x_seq_len
    
    # ==========================================
    # TODO 1: 计算 P_i(每个专家的平均路由概率得分)
    # ==========================================
    # P_i = ???
    P_i = torch.zeros(num_experts, dtype=routing_weights.dtype, device=routing_weights.device)
    P_i.scatter_add_(0, selected_experts.flatten(), routing_weights.flatten())
    P_i = P_i / (total_tokens * top_k)
    
    # ==========================================
    # TODO 2: 计算 f_i(每个专家实际分到的 Token 比例)
    # ==========================================
    # expert_mask = ???
    # tokens_per_expert = ???
    # f_i = ???
    expert_mask = F.one_hot(selected_experts, num_classes=num_experts)
    tokens_per_expert = expert_mask.sum(dim=(0, 1)).float()
    f_i = tokens_per_expert / (total_tokens * top_k)
    
    # ==========================================
    # TODO 3: 计算最终的 auxiliary loss
    # ==========================================
    # aux_loss = ???

    aux_loss = alpha * num_experts * (f_i * P_i).sum()
                                                                                                                                                                                  
    return aux_loss

解析

1. TODO 1: 计算 P_i(平均路由概率)

  • 实现方式

    python 复制代码
    P_i = torch.zeros(num_experts, dtype=routing_weights.dtype, device=routing_weights.device)
    P_i.scatter_add_(0, selected_experts.flatten(), routing_weights.flatten())
    P_i = P_i / (total_tokens * top_k)
  • 核心逻辑 :使用 scatter_add_ 将每个 token 对选中专家的权重累加到对应专家的位置。

  • 归一化 :除以总的选择次数 (total_tokens * top_k) 得到平均权重。

  • 物理含义 :PiP_iPi 表示专家 iii 在所有 token 上的平均被选中概率。

2. TODO 2: 计算 f_i(Token 分配比例)

  • 实现方式

    python 复制代码
    expert_mask = F.one_hot(selected_experts, num_classes=num_experts)
    tokens_per_expert = expert_mask.sum(dim=(0, 1)).float()
    f_i = tokens_per_expert / (total_tokens * top_k)
  • 核心逻辑F.one_hot 将专家索引转换为 one-hot 编码,形状为 [batch_size_x_seq_len, top_k, num_experts]

  • 统计方法:沿前两个维度求和,统计每个专家被选中的总次数。

  • 归一化:除以总的选择次数得到比例。

  • 物理含义 :fif_ifi 表示专家 iii 实际分到的 token 比例。

3. TODO 3: 计算辅助损失

  • 实现方式aux_loss = alpha * num_experts * (f_i * P_i).sum()
  • 数学公式 :Laux=α⋅E∑i=1Efi⋅PiL_{aux} = \alpha \cdot E \sum_{i=1}^E f_i \cdot P_iLaux=α⋅E∑i=1Efi⋅Pi
  • 最小值分析 :根据均值不等式,当 fi=Pi=1/Ef_i = P_i = 1/Efi=Pi=1/E 时(完全均匀),损失最小。对于 Top-K 路由,理论最小值为 α/K\alpha / Kα/K。
  • 优化目标:优化器为了降低这个 Loss,会强制将 Token 均匀分配给所有专家,防止路由崩塌。

工程要点

  • Top-K 兼容性 :代码支持任意 K 值,通过 (total_tokens * top_k) 归一化确保比例计算正确。
  • 数值稳定性 :使用 scatter_add_ 而非循环累加,提升计算效率和数值稳定性。
  • 超参数调优 :α\alphaα 通常设为 0.01,过大会影响主任务性能,过小则无法有效平衡负载。
  • 与主损失结合 :在实际训练中,将 aux_loss 加到 CrossEntropy Loss 上:total_loss = ce_loss + aux_loss

08. Architecture Tricks | 经典架构变体:Qwen 与 Gemma 的核心机制 (Architecture Tricks)

难度: Easy | 标签: 模型架构, Qwen, Gemma | 目标人群: 模型微调与工程部署

06_LLaMA3_Block_Tutorial 中我们搭建了 LLaMA 的骨架。但如果你去面试阿里云(通义千问团队)或者谷歌,他们必然会问自家模型与 LLaMA 的区别。

本节我们将以"打补丁"的方式,在 PyTorch 中快速实现 Qwen 的 Tie Word Embeddings 以及 Gemma 的带偏置 RMSNorm

Step 1: 核心差异与机制

Trick 1: Tie Word Embeddings (权重绑定) - Qwen 系列 / GPT-2

  • 做法 :在绝大多数模型(如 LLaMA)中,最开始的 Token Embedding 矩阵(把 ID 变向量)和最后的 LM Head 矩阵(把向量变概率)是两个独立的权重矩阵。但在 Qwen 中,这两个矩阵共享同一份物理内存的参数!

  • 意义:极大减少了参数量(词表动辄 15 万,非常占参数),并且在训练时能让 Embedding 获得更直接的梯度更新。
    Trick 2: RMSNorm 的 "+1 缩放" - Gemma 系列

  • 做法 :标准的 RMSNorm 公式是 y=xRMS⋅wy = \frac{x}{RMS} \cdot wy=RMSx⋅w。而 Google 的 Gemma 把它改成了 y=xRMS⋅(1+w)y = \frac{x}{RMS} \cdot (1 + w)y=RMSx⋅(1+w)。

  • 意义 :在 PyTorch 中,权重的默认初始化通常是 0(或者很小的值)。Gemma 加上 1,使得在训练的极早期(wpprox0w pprox 0wpprox0 时),RMSNorm 直接等价于一个不做任何缩放的纯归一化层,这带来了非常平滑的梯度和非常稳定的早期训练!

Step 2: Weight Tying 与偏置项的权衡

Weight Tying(权重绑定)强制 Embedding 层和最终的 LM Head 线性层共享同一个权重矩阵。这种方法在早期的模型中很流行,因为它大幅减少了参数量。但在现代极大规模 LLM 中,解绑通常能获得更好的容量表达。此外,取消大部分 Linear 和 Norm 层中的 Bias 项,可以略微提高计算效率并防止显存浪费。

Step 3: 代码实现框架

要实现权重绑定,只需在网络初始化时将 LM Head 的 weight 引用直接指向 Embedding 层的 weight。注意,这意味着隐藏层维度必须与词表维度兼容(或者存在中间投影层)。

Step 4: 动手实战

要求

  1. 补全 GemmaRMSNorm 的公式。
  2. 补全 QwenTieEmbeddings 中的参数共享逻辑。
python 复制代码
# --- Trick 1: Gemma 风格的 RMSNorm ---
class GemmaRMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        # weight 初始化为全 0
        self.weight = nn.Parameter(torch.zeros(hidden_size))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 计算均方根
        x_f32 = x.float()
        variance = x_f32.pow(2).mean(-1, keepdim=True)
        x_norm = x_f32 * torch.rsqrt(variance + self.eps)
        
        # ==========================================
        # TODO 1: 实现 Gemma 的 +1 缩放
        # 注意类型转换回 x.dtype
        # ==========================================
        # output = ???

        # 占位初始化(返回错误值,确保数值测试失败)                                                                                                                                 
        output = x_norm * (1 + self.weight)                                                                                                                                             
  
        return output     
        



# --- Trick 2: Qwen 风格的权重绑定 ---
class QwenTieEmbeddings(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int):
        super().__init__()
        # 1. 定义标准的 Embedding 层
        self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
        
        # 2. 定义最后的 LM Head 预测层,注意不要 bias
        self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
        
        # ==========================================
        # TODO 2: 将 lm_head 的权重在内存级别绑定到 embed_tokens 上
        # 提示: 在 PyTorch 中,可以直接赋值 nn.Parameter 或是底层 tensor
        # self.lm_head.weight = ???
        # ==========================================
        # ???
        self.lm_head.weight = self.embed_tokens.weight
        
    def forward_embed(self, input_ids):
        return self.embed_tokens(input_ids)
        
    def forward_lm_head(self, hidden_states):
        return self.lm_head(hidden_states)

解析

1. TODO 1: Gemma 的 +1 缩放机制

  • 实现方式output = x_norm * (1 + self.weight)
  • 核心思想 :在标准 RMSNorm 的基础上,将缩放因子从 w 改为 (1 + w)
  • 初始化优势 :权重初始化为 0 时,(1 + 0) = 1,此时 RMSNorm 等价于纯归一化层(无缩放),梯度非常平滑。
  • 训练稳定性:在训练早期(权重接近 0),避免了因权重过小导致的梯度消失问题。随着训练进行,权重逐渐学习到合适的缩放值。
  • 工程细节 :必须先转换为 FP32 计算(x.float()),最后再转回原始精度(type_as(x)),防止 FP16/BF16 下的数值不稳定。

2. TODO 2: Qwen 的权重绑定(Weight Tying)

  • 实现方式self.lm_head.weight = self.embed_tokens.weight
  • 物理指针级共享 :这不是复制权重,而是让两个模块的 weight 参数指向同一块内存。修改其中一个,另一个自动同步。
  • 参数量优势:词表通常很大(15万+),绑定后可以节省一半的参数量。例如,词表 150k、隐藏层 4096 的模型,可以节省 150k × 4096 × 4 bytes ≈ 2.4GB 显存。
  • 梯度更新:训练时,Embedding 层和 LM Head 的梯度会累加到同一个权重上,使得 Embedding 获得更直接的监督信号。
  • 适用场景:Qwen、GPT-2 等模型使用此技巧。但在超大规模模型(如 LLaMA 70B)中,解绑通常能获得更好的表达能力。

工程要点

  • 内存验证 :可以通过 data_ptr() 检查两个权重是否指向同一内存地址。
  • 训练同步:由于是物理指针共享,更新 Embedding 权重时,LM Head 权重会自动同步,无需手动处理。
  • 架构权衡:权重绑定减少参数但可能限制表达能力;+1 缩放提升训练稳定性但增加计算量(需要额外的加法)。