1. MoE 的负载均衡损失 (aux_loss) 为什么能逼迫专家均匀接客?
背景: 在 MoE 模型中,有一个"路由器(Router)"网络负责把每个 Token 分配给 EEE 个专家中的几个(比如 E=8E=8E=8,每次挑 2 个)。路由器的输出是通过 Softmax 计算出的概率分布。
为什么网络会"偷懒"? 因为神经网络存在**"马太效应(赢家通吃)"。如果在随机初始化的第一步,专家 A 刚好比别人表现好一点点,路由器就会给 A 更高的权重;因为 A 被分配了更多数据,A 的权重得到了更多的更新,它就变得更强;下一轮,路由器就更倾向于把所有 Token 都扔给 A。最终,8 个专家里 1 个累死,7 个闲死,这就叫路由崩塌(Routing Collapse)**。
为了打破这个死循环,我们在正常的预测 Loss 之外,加上了一个负载均衡损失(Load Balancing Loss)。以最经典的 Switch Transformer 的公式为例:
我们定义当前这一个 Batch 的数据(假设有 NNN 个 Token):
- fif_ifi:实际分配比例(Fraction) 。即这 NNN 个 Token 中,有多大比例被实际派给了第 iii 个专家。
- PiP_iPi:平均路由概率(Probability) 。即这 NNN 个 Token 在第 iii 个专家上的 Softmax 路由概率的平均值。
负载均衡损失 LauxL_{aux}Laux 的计算公式如下(其中 α\alphaα 是一个缩放常数,用来控制这个惩罚的力度):
Laux=α⋅E∑i=1Efi⋅PiL_{aux} = \alpha \cdot E \sum_{i=1}^{E} f_i \cdot P_iLaux=α⋅Ei=1∑Efi⋅Pi
原理解析(极其巧妙的数学惩罚):
这个公式的核心是一个**内积(点乘)**操作。因为 ∑i=1Efi=1\sum_{i=1}^{E} f_i = 1∑i=1Efi=1 且 ∑i=1EPi=1\sum_{i=1}^{E} P_i = 1∑i=1EPi=1,我们要最小化 LauxL_{aux}Laux 这个和式。
根据数学中的均值不等式 ,当两个非负向量的和为定值时,只有当这两个向量的分布越均匀,它们的点乘积才越小;分布越极端,点乘积越大。
-
反面教材(所有人都挤向专家 1):
假设有 4 个专家。系统偷懒了,把所有 Token 都给了专家 1。
那么 f=[1.0,0,0,0]f = [1.0, 0, 0, 0]f=[1.0,0,0,0],同时路由概率 P=[1.0,0,0,0]P = [1.0, 0, 0, 0]P=[1.0,0,0,0]。
此时 ∑(fi⋅Pi)=1.0×1.0+0+0+0=1.0\sum (f_i \cdot P_i) = 1.0 \times 1.0 + 0 + 0 + 0 = 1.0∑(fi⋅Pi)=1.0×1.0+0+0+0=1.0。
此时的 LauxL_{aux}Laux 达到了最大值,模型受到了极其严厉的惩罚(Loss 飙升)。
-
理想状态(大家均匀接客):
假设 4 个专家平分了 Token,大家各自干了 25% 的活。
那么 f=[0.25,0.25,0.25,0.25]f = [0.25, 0.25, 0.25, 0.25]f=[0.25,0.25,0.25,0.25],P=[0.25,0.25,0.25,0.25]P = [0.25, 0.25, 0.25, 0.25]P=[0.25,0.25,0.25,0.25]。
此时 ∑(fi⋅Pi)=4×(0.25×0.25)=4×0.0625=0.25\sum (f_i \cdot P_i) = 4 \times (0.25 \times 0.25) = 4 \times 0.0625 = 0.25∑(fi⋅Pi)=4×(0.25×0.25)=4×0.0625=0.25。
乘上前面的 EEE (即 4),结果是 4×0.25=1.04 \times 0.25 = 1.04×0.25=1.0。通过这种方式,LauxL_{aux}Laux 被最小化了。
举例说明:
我们直接用一个最极简的"排排坐、分果果"的具体数字例子 来拆解,你立刻就能明白 fif_ifi 和 PiP_iPi 到底是怎么算出来的。
为了方便计算,我们做一个极其简单的假设:
- 专家数量 (EEE): 只有 2 个专家(专家 1 和专家 2)。
- 当前批次的 Token 数量 (NNN): 只有 4 个 Token(字词),分别是 T1,T2,T3,T4T_1, T_2, T_3, T_4T1,T2,T3,T4。
- 路由规则: 每个 Token 只分配给概率最高的 1 个专家(Top-1 路由)。
第一步:收集路由器的原始数据 (Softmax 概率)
当这 4 个 Token 经过 MoE 路由器时,路由器会给每个 Token 计算出一个分配概率(总和为 100%)。假设输出如下:
- T1T_1T1 的概率: [0.8 , 0.2] 👉 0.8 > 0.2,所以 T1T_1T1 实际被派给专家 1。
- T2T_2T2 的概率: [0.6 , 0.4] 👉 0.6 > 0.4,所以 T2T_2T2 实际被派给专家 1。
- T3T_3T3 的概率: [0.3, 0.7 ] 👉 0.7 > 0.3,所以 T3T_3T3 实际被派给专家 2。
- T4T_4T4 的概率: [0.9 , 0.1] 👉 0.9 > 0.1,所以 T4T_4T4 实际被派给专家 1。
第二步:计算 fif_ifi (实际分配比例 Fraction)
fif_ifi 极其简单,就是数人头。这 4 个 Token 里,到底有百分之几去了这个专家那里?
- 分配给专家 1 的 Token: T1,T2,T4T_1, T_2, T_4T1,T2,T4(共 3 个)。
👉 f1=3/4=0.75f_1 = 3 / 4 = 0.75f1=3/4=0.75(75% 的数据实际给了专家 1) - 分配给专家 2 的 Token: T3T_3T3(共 1 个)。
👉 f2=1/4=0.25f_2 = 1 / 4 = 0.25f2=1/4=0.25(25% 的数据实际给了专家 2)
(验证:f1+f2=0.75+0.25=1.0f_1 + f_2 = 0.75 + 0.25 = 1.0f1+f2=0.75+0.25=1.0,没毛病)
第三步:计算 PiP_iPi (平均路由概率 Probability)
PiP_iPi 不看最终去向,只看路由器在一开始对所有 Token 展现出的"倾向性"平均值。把所有 Token 对应专家的概率加起来求平均:
- 对于专家 1 的平均概率: (0.8+0.6+0.3+0.9)/4=2.6/4(0.8 + 0.6 + 0.3 + 0.9) / 4 = 2.6 / 4(0.8+0.6+0.3+0.9)/4=2.6/4
👉 P1=0.65P_1 = 0.65P1=0.65 - 对于专家 2 的平均概率: (0.2+0.4+0.7+0.1)/4=1.4/4(0.2 + 0.4 + 0.7 + 0.1) / 4 = 1.4 / 4(0.2+0.4+0.7+0.1)/4=1.4/4
👉 P2=0.35P_2 = 0.35P2=0.35
(验证:P1+P2=0.65+0.35=1.0P_1 + P_2 = 0.65 + 0.35 = 1.0P1+P2=0.65+0.35=1.0,没毛病)
第四步:计算最终的负载均衡损失 LauxL_{aux}Laux
现在我们有了:
f1=0.75f_1 = 0.75f1=0.75, P1=0.65P_1 = 0.65P1=0.65
f2=0.25f_2 = 0.25f2=0.25, P2=0.35P_2 = 0.35P2=0.35
专家数量 E=2E = 2E=2。假设缩放系数 α=1\alpha = 1α=1。
套用公式:Laux=E⋅(f1⋅P1+f2⋅P2)L_{aux} = E \cdot (f_1 \cdot P_1 + f_2 \cdot P_2)Laux=E⋅(f1⋅P1+f2⋅P2)
Laux=2×(0.75×0.65+0.25×0.35)L_{aux} = 2 \times (0.75 \times 0.65 + 0.25 \times 0.35)Laux=2×(0.75×0.65+0.25×0.35)
Laux=2×(0.4875+0.0875)L_{aux} = 2 \times (0.4875 + 0.0875)Laux=2×(0.4875+0.0875)
Laux=2×0.575=1.15L_{aux} = 2 \times 0.575 = 1.15Laux=2×0.575=1.15
核心结论: 因为当前分配极其不均衡(3个全给了专家1,专家2快饿死了),导致算出来的 LauxL_{aux}Laux 高达 1.15 。这个高额的惩罚数值会作为 Loss 的一部分,通过反向传播告诉路由器:"你太偏心专家 1 了,下一轮给我把 T1,T2,T4T_1, T_2, T_4T1,T2,T4 的概率往专家 2 身上匀一点!"
如果分配是完美的(每个专家拿 2 个 Token),LauxL_{aux}Laux 的最小值会等于 1.0。
2. 梯度累加除法:为什么 loss = loss / accumulation_steps?
这个问题纯粹是微积分中导数(梯度)的线性性质问题。
目标(我们真正想要的):
假设我们的目标大 Batch Size 是 B=100B=100B=100。根据深度学习标准定义,这 100 条数据的真实平均 Loss 是:
Ltrue=1100∑k=1100lkL_{true} = \frac{1}{100} \sum_{k=1}^{100} l_kLtrue=1001k=1∑100lk
(其中 lkl_klk 是第 kkk 条数据的单条 Loss)。
对其求梯度更新模型,真正的完美梯度应该是:
∇Ltrue=1100∑k=1100∇lk\nabla L_{true} = \frac{1}{100} \sum_{k=1}^{100} \nabla l_k∇Ltrue=1001k=1∑100∇lk
现实(显卡装不下):
显卡一次只能塞下 b=25b=25b=25 条数据(微批次 Micro-batch)。我们需要走 4 步才能攒够 100 条数据。所以 accumulation_steps = 4。
如果我们不除以 accumulation_steps 会发生什么?
第一步(前 25 条数据),PyTorch 自动算出的平均 Loss 是:Lstep1=125∑k=125lkL_{step1} = \frac{1}{25} \sum_{k=1}^{25} l_kLstep1=251∑k=125lk
第二、三、四步同理。
如果你直接调用 .backward(),PyTorch 会把这 4 步的梯度直接相加 (这就是梯度的累加特性):
∇accumulated=∇Lstep1+∇Lstep2+∇Lstep3+∇Lstep4\nabla_{accumulated} = \nabla L_{step1} + \nabla L_{step2} + \nabla L_{step3} + \nabla L_{step4}∇accumulated=∇Lstep1+∇Lstep2+∇Lstep3+∇Lstep4
∇accumulated=125∑k=125∇lk+125∑k=2650∇lk+⋯=125∑k=1100∇lk\nabla_{accumulated} = \frac{1}{25} \sum_{k=1}^{25} \nabla l_k + \frac{1}{25} \sum_{k=26}^{50} \nabla l_k + \dots = \frac{1}{25} \sum_{k=1}^{100} \nabla l_k∇accumulated=251k=1∑25∇lk+251k=26∑50∇lk+⋯=251k=1∑100∇lk
发现问题了吗?累加出来的梯度,分母是 25,而不是 100!这意味着,你算出来的梯度,足足比真实的完美梯度 ∇Ltrue\nabla L_{true}∇Ltrue 放大了 4 倍! 这会导致优化器步子迈得太大,直接越过最优解,甚至导致模型崩溃(Loss 飞到天上去)。
加上这行代码后的修正:
loss = loss / args.accumulation_steps(即除以 4)。
根据导数的线性法则,函数的常数倍,等于导数的常数倍:∇(c⋅f(x))=c⋅∇f(x)\nabla (c \cdot f(x)) = c \cdot \nabla f(x)∇(c⋅f(x))=c⋅∇f(x)。
所以,第一步修正后的梯度变成了:
∇Lstep1_fixed=14(125∑k=125∇lk)=1100∑k=125∇lk\nabla L_{step1\fixed} = \frac{1}{4} \left( \frac{1}{25} \sum{k=1}^{25} \nabla l_k \right) = \frac{1}{100} \sum_{k=1}^{25} \nabla l_k∇Lstep1_fixed=41(251k=1∑25∇lk)=1001k=1∑25∇lk
当你把 4 步修正后的梯度加起来:
∇fixed_accumulated=1100∑k=1100∇lk\nabla_{fixed\accumulated} = \frac{1}{100} \sum{k=1}^{100} \nabla l_k∇fixed_accumulated=1001k=1∑100∇lk
看!这个结果与我们一开始写的 ∇Ltrue\nabla L_{true}∇Ltrue 完美相等。
结论: 先把微批次的 Loss 除以累加步数,是为了保证多次累加后的梯度在数学期望上,与一口气把所有数据塞进显卡算出来的梯度绝对一致,一分不差。