MUON (MomentUm Orthogonalized by Newton-Schulz)是月之暗面(Moonshot AI)在 Kimi K2 (1.04T 参数)和 Moonlight (16B MoE)训练中大规模验证并开源的优化器。它并非 Kimi 原创,但 Kimi 团队解决了其分布式扩展性 和训练稳定性两大瓶颈,使其首次在万亿级模型上可用。
一、背景:为什么不用 AdamW?
AdamW 是 LLM 训练的事实标准,但存在结构性缺陷:
- 元素级更新 :AdamW 对每个参数独立计算一阶/二阶矩,完全忽略了权重矩阵的几何结构
- 矩阵冗余:Transformer 中 90%+ 参数是 2D 矩阵(Attention Q/K/V、MLP 权重),其梯度更新存在大量相关性,AdamW 无法利用
- 效率天花板:在计算最优条件下,AdamW 的 token 效率已接近极限
Keller Jordan 提出的 MUON 核心洞察:对 2D 权重矩阵的动量进行正交化,使更新方向保持谱范数约束,从而加速收敛并提升稳定性。月之暗面将其扩展至分布式万亿参数规模,并开源了实现。
二、核心原理:Newton-Schulz 正交化
MUON 的算法流程极简,但数学内涵深刻:
1. 标准动量累积
M_t = β · M_{t-1} + G_t # 与传统 SGD+Momentum 相同
G_t 为梯度,β 为动量系数(通常 0.95,Nesterov 风格)。
2. Newton-Schulz 迭代(核心差异)
不直接对 M_t 做 SVD(太贵),而是用 Newton-Schulz 迭代 近似其极分解的正交因子:
O_t = NewtonSchulz(M_t, T=5)
迭代公式(每次迭代 refine 正交性):
X_{k+1} = 1.5 · X_k - 0.5 · X_k · X_k^T · X_k
为什么 5 步? 实验表明 5 步 Newton-Schulz 即可达到足够正交精度,超过 5 步收益递减且增加每步计算开销。相比完整 SVD,计算成本降低 1-2 个数量级。
3. 参数更新
W_t = W_{t-1} - η · O_t # O_t 是正交化后的动量
正交化的作用 :O_t 是近似正交矩阵(O^T·O ≈ I),其谱范数被约束为 1。这防止了某些方向更新过大导致的训练震荡,天然起到谱正则化效果。
三、Kimi 的关键改进
原始 MUON 在小模型(<1B)上表现优异,但直接扩展到千亿级会崩溃。月之暗面做了三项关键工程:
1. MuonClip:解决注意力 Logits 爆炸
问题:MUON 的激进更新会导致 Attention 的 Q/K 投影权重谱范数增长,注意力分数(logits)无限增大,引发 loss spike。
解法 :每次 MUON 更新后,检查当前 batch 中每个注意力头的最大 logit 值 max_score。若超过阈值 τ(Kimi K2 使用 τ=100),则对 W_q 和 W_k 进行动态缩放:
η = τ / max_score # η < 1 时触发裁剪
W_q ← W_q · η^α
W_k ← W_k · η^(1-α) # α ≈ 0.5,平衡 Q 和 K 的调整
效果 :Kimi K2 在 15.5 万亿 token 预训练中零 loss spike,最大注意力 logits 自动衰减到正常范围。
2. 分布式 MUON:解决与 FSDP/Zero 的冲突
问题 :MUON 的 Newton-Schulz 迭代需要完整的 2D 参数矩阵,但现代分布式训练(FSDP、DeepSpeed Zero-1/2/3)会将单个张量分片到多个 GPU 上。每次优化器 step 需要 all-gather 完整矩阵,通信开销爆炸。
Kimi 的解法:
- "暴力聚合"(Brute-force gather):仅在优化器 step 时临时聚合完整矩阵,做完 Newton-Schulz 后再分片回传
- 利用 MoE 稀疏性:Kimi K2 是 MoE 架构(32B 激活 / 1.04T 总参),非激活专家的参数无需聚合,大幅缓解通信压力
- 内存优化:开源的分布式实现优化了内存使用和通信效率
3. 混合优化器策略(MUON + AdamW)
MUON 仅对 2D 矩阵参数 有意义(正交化需要矩阵结构)。Kimi 采用双优化器架构:
| 参数类型 | 优化器 | 原因 |
|---|---|---|
| 2D 权重(Attention Q/K/V/O、MLP up/down/gate) | MUON | 矩阵结构,正交化有效 |
| 1D/0D 参数(Embedding、LayerNorm、Bias) | AdamW | 向量/标量无正交化概念 |
| 输出头(LM Head) | AdamW | 稀疏 one-hot 梯度,需自适应学习率 |
代码示例:
python
muon_params = [p for p in model.parameters() if p.ndim >= 2]
adam_params = [p for p in model.parameters() if p.ndim < 2]
muon = torch.optim.Muon(muon_params, lr=0.02, momentum=0.95)
adam = torch.optim.AdamW(adam_params, lr=8e-4, weight_decay=0.01)
四、纵横比缩放(Aspect-Ratio Scaling)
这是 MUON 极易被忽略但至关重要的细节:
不同层的权重矩阵形状差异巨大:
- MLP up-projection:
3072 × 768(高矩阵,d_out > d_in) - MLP down-projection:
768 × 3072(宽矩阵) - Attention Q/K/V:
768 × 768(方阵)
如果不做调整,高矩阵和宽矩阵的有效学习率不一致,导致某些层欠训练或过训练。
Kimi/MUON 的缩放公式:
scale = sqrt(max(1, d_out / d_in))
update = lr · scale · O_t
- 高矩阵 (
d_out > d_in):scale > 1,放大更新以匹配容量 - 宽矩阵/方阵 :
scale = 1,保持不变
警告:关闭此缩放会导致深层模型(>26 层)训练不稳定。
五、超参数配置(Kimi 官方推荐)
基于 Kimi Moonlight 和 Keller Jordan 的实验:
| 超参数 | 推荐值 | 说明 |
|---|---|---|
| MUON 学习率 | 0.02 | 约为 AdamW(8e-4)的 25 倍 |
| MUON 动量 | 0.95 | Nesterov 风格 |
| MUON 权重衰减 | 0.0 | 由 AdamW 侧处理 |
| Newton-Schulz 步数 | 5 | 精度与速度的平衡点 |
| AdamW 学习率 | 8e-4 | 标准值 |
| 动量 Warmup | 0.85 → 0.95 | 前 300 步从低动量开始,防止早期震荡 |
六、性能数据
| 指标 | AdamW | MUON / MuonClip | 提升 |
|---|---|---|---|
| 计算效率(计算最优条件) | 基准 | 约 2× | 相同 FLOPs 下收敛更快 |
| Moonlight-16B-A3B(5.7T tokens) | Pareto 前沿 | 突破 Pareto 前沿 | 更少 FLOPs 达到更优 loss |
| Kimi K2(15.5T tokens) | 不稳定 | 零 loss spike | MuonClip 保证稳定性 |
| GB300 吞吐(NVIDIA 实测) | 1,051 TFLOPs/s/GPU | 1,080 TFLOPs/s/GPU | MFU 更高(含 Newton-Schulz FLOPs) |
七、开源实现与接入
| 项目 | 地址/来源 | 说明 |
|---|---|---|
| Keller Jordan 原版 | kellerjordan.github.io/posts/muon/ |
概念验证实现 |
| PyTorch 原生 | torch.optim.Muon(PyTorch 2.9+) |
单节点版本,仅支持 ndim==2 |
| PyTorch MuonWithFallback | pytorch/pytorch#173559 |
提案中,自动路由 2D/非 2D 参数到 MUON/AdamW |
| Kimi 分布式版 | 随 Moonlight/K2 技术报告开源 | 适配 FSDP,优化内存与通信 |
| NVIDIA Megatron | Megatron-Bridge 26.02 | NVIDIA 官方集成,GB300 验证 |
| nanochat 教育版 | github.com/daegonYu/simple-test-muon-optimization |
HuggingFace Trainer 集成示例 |
八、局限性与注意事项
- 仅适用于 2D 参数:Embedding(1D/2D 但语义不同)、LayerNorm、Bias 必须用 AdamW
- 分布式通信开销:每次 step 需 all-gather 完整矩阵,非 MoE 架构成本较高
- bfloat16 稳定性:Newton-Schulz 迭代在 bf16 下需数值稳定实现(Kimi 和 PyTorch 官方已实现)
- 学习率敏感:MUON LR 约为 AdamW 的 25 倍,误设会导致训练崩溃
总结
MUON 的本质是"矩阵感知的动量正交化":先用 SGD+Momentum 累积梯度方向,再用 Newton-Schulz 迭代将其"掰正"为正交矩阵,从而约束谱范数、加速收敛。Kimi 的贡献在于通过 MuonClip 和分布式聚合策略,把这个原本只适合小模型的技巧,首次推到了万亿参数规模。