下面我把两段内容完全合并 ,用公式 + 代码 + 直观例子 + 图示逻辑 + 梯度解释 一次性讲透,保证你彻底理解:
为什么 LLM 都用 RMSNorm,而抛弃 LayerNorm。
0. 最终结论(先放这方便记忆)
- LayerNorm:中心化 + 归一化 → 破坏残差恒等映射 → 深层梯度难传
- RMSNorm :只缩放、不移中心 → 保护残差通道 → 训练更稳更快
→ 现代 LLM 全部使用 RMSNorm
1. 公式对比(最本质区别)
LayerNorm
LayerNorm(x)=x−μσ2+ϵγ+β \text{LayerNorm}(x)=\frac{x-\mu}{\sqrt{\sigma^2+\epsilon}}\gamma+\beta LayerNorm(x)=σ2+ϵ x−μγ+β
- 做了 去均值(中心化)
- 做了 方差归一化
- 有缩放 γ + 偏移 β
RMSNorm
RMSNorm(x)=xE[x2]+ϵγ \text{RMSNorm}(x)=\frac{x}{\sqrt{\mathbb{E}[x^2]+\epsilon}}\gamma RMSNorm(x)=E[x2]+ϵ xγ
- 没有去均值!
- 只用均方根缩放
- 只有缩放 γ,没有偏移 β
2. 极简代码对比
LayerNorm
python
class LayerNorm(nn.Module):
def __init__(self, dim):
self.gamma = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
self.eps = 1e-6
def forward(self, x):
mean = x.mean(-1, keepdim=True)
var = x.var(-1, keepdim=True)
x = (x - mean) / torch.sqrt(var + self.eps)
x = x * self.gamma + self.beta
return x
RMSNorm(你模型里的代码)
python
class RMSNorm(nn.Module):
def __init__(self, dim):
self.weight = nn.Parameter(torch.ones(dim))
self.eps = 1e-6
def forward(self, x):
x = x.float()
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
x = x * self.weight
return x
3. 直观数字例子(一看就懂)
设
x=[1,2,3] x = [1,2,3] x=[1,2,3]
① LayerNorm 处理
mean = 2
var = 1
x−mean=[−1,0,1] x - mean = [-1,0,1] x−mean=[−1,0,1]
x−meanσ=[−1,0,1] \frac{x-mean}{\sigma} = [-1,0,1] σx−mean=[−1,0,1]
原始 x 完全被"改写"了,信息不再是原来的信号。
② RMSNorm 处理
rms=12+22+323=143≈2.16 rms = \sqrt{\frac{1^2+2^2+3^2}{3}} = \sqrt{\frac{14}{3}} \approx 2.16 rms=312+22+32 =314 ≈2.16
xrms≈[0.46,0.92,1.38] \frac{x}{rms} \approx [0.46, 0.92, 1.38] rmsx≈[0.46,0.92,1.38]
- 相对大小没变
- 中心没变
- 只是整体缩放
- 原始 x 的结构完全保留
4. 核心关键:残差恒等映射
Transformer 每一层结构:
out=x+SubLayer(x) out = x + \text{SubLayer}(x) out=x+SubLayer(x)
理想情况:
如果子网络输出 0,那么
out=x out = x out=x
这叫 恒等映射 ,梯度可以无损直通。
5. LayerNorm 为什么破坏残差?
LayerNorm 会做:
x \\rightarrow x - mean
于是结构变成:
out = (x - \\mu) + \\text{SubLayer}(...)
- 原来的 x 被修改了
- 恒等映射被破坏
- 梯度不能直接原路返回
- 深层容易梯度消失 / 不稳定
6. RMSNorm 为什么保护残差?
RMSNorm 只做:
x→xrms x \rightarrow \frac{x}{rms} x→rmsx
- 没有减去均值
- 没有改变中心
- 只是缩放幅度
- 恒等路径
out = x + ...完全保留 - 梯度一路顺畅到底
7. 图示版(文字画流程图)
LayerNorm 流程
输入 x
→ 减去均值 → 中心强制归零
→ 除以方差
→ 缩放+偏移
→ 输出完全变形
→ 残差通道断裂
→ 梯度难传
RMSNorm 流程
输入 x
→ 计算均方根
→ 只做缩放(不改中心)
→ 输出形状/趋势不变
→ 残差通道完整
→ 梯度直通
8. 为什么 LLM 必选 RMSNorm(4 条硬核理由)
-
不破坏残差恒等映射
深层网络梯度传播更稳定,不会炸也不会消失。
-
计算更快
LayerNorm 算均值+方差,RMSNorm 只算均方,快 15%~30%。
-
无多余偏移项 β
大模型下更少参数,更稳定。
-
实际效果全面领先
Llama / Qwen / Mistral / GPT 系列均验证:
RMSNorm 收敛更快、loss 更低、生成更稳。
9. 终极一句话总结
LayerNorm 移中心 → 毁残差 → 梯度难传
RMSNorm 只缩放 → 保残差 → 梯度顺畅
所以 LLM 全都用 RMSNorm!
10 图
纯手绘 ASCII 结构图
1)Transformer 残差结构原貌(理想恒等映射)
┌───────────────────────────────────┐
│ │
x ──┴──→ Norm → Attention/MLP ──→ + ────┴──→ out = x + F(x)
(子层变换)
梯度反向: out ←──────+ ←────── SubLayer ←────── x
↑ ↑
└────────直接回传───────────┘
- 理想:
out = x + 0 = x - 梯度可以直接跳回输入 x,不经过任何层
- 这就是「恒等映射」,梯度永不消失
2)LayerNorm 对残差的破坏
┌───────────────────────────────────┐
│ │
x ──┴──→ LayerNorm ──→ Attention/MLP ──→ + ───→ out
⬇️⬇️⬇️ 关键破坏
x → x - μ → 再缩放 → 信号完全变形
梯度回传变成这样:
out ←── + ←── SubLayer ←── (x - μ) ←── x
↑
这里多了一次"平移变换"
梯度被扭曲、衰减
- 原来的
x被强行去均值、中心化 - 恒等映射
x → x被打断 - 梯度多走一道变换,深层容易消失/震荡
3)RMSNorm 如何保护残差
┌───────────────────────────────────┐
│ │
x ──┴──→ RMSNorm ──→ Attention/MLP ──→ + ───→ out
⬇️⬇️⬇️ 只缩放,不移中心
x → x / rms → 形状不变、方向不变
梯度回传:
out ←── + ←── SubLayer ←── (x / rms) ←── x
↑
只是简单缩放,不扭曲信号
梯度几乎"原样通过"
- 没有减去均值
- 没有改变数据中心、相对大小、分布形状
- 只做幅度缩放
- 残差恒等路径基本完整
- 梯度传播顺滑、稳定、不衰减
4)一张图看懂本质区别
LayerNorm: 输入 x → 移中心 → 归一化 → 输出面目全非
残差断裂 → 梯度难传
RMSNorm: 输入 x → 只缩放 → 输出结构不变
残差完整 → 梯度直通
5)最容易记的口诀
- LayerNorm = 改中心 + 毁残差 + 梯度难传
- RMSNorm = 只缩放 + 保残差 + 梯度顺畅
如果你愿意,我可以再给你画一张数值变化对比图 ,把 x=[1,2,3] 在两种 Norm 里每一步怎么变,用表格画出来,彻底刻进脑子里。