在LLM中为什么用RMSNorm,可不可以分母为L2 Norm
在现代大语言模型(LLM)如 LLaMA、Gopher、Chinchilla 中,RMSNorm (Root Mean Square Layer Normalization) 确实已经取代了传统的 LayerNorm 成为主流选择。
针对你的两个问题,我将从原理、计算效率以及数学性质上分别进行解答。
一、 为什么 LLM 中要用 RMSNorm?
RMSNorm 是 LayerNorm 的一个变体,它由 Zhang & Sennrich 在 2019 年提出。相比 LayerNorm,它主要有以下优势:
1. 计算效率更高(更省显存带宽和计算量)
传统的 LayerNorm 包含两个步骤:
- Re-centering(去中心化/平移) :计算均值 μ\muμ,让输入减去均值。
- Re-scaling(缩放) :计算方差 σ2\sigma^2σ2,除以标准差进行归一化。
公式为:
xˉ=x−μσ2+ϵ∗γ+β \bar{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} * \gamma + \beta xˉ=σ2+ϵ x−μ∗γ+β
而 RMSNorm 认为 "去中心化(减均值)"并不重要,归一化的核心在于缩放数据的幅度。因此它去掉了减均值的步骤,只保留了均方根(RMS)缩放。
公式为:
xˉ=xRMS(x)∗γ,where RMS(x)=1d∑i=1dxi2 \bar{x} = \frac{x}{\text{RMS}(x)} * \gamma, \quad \text{where } \text{RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^d x_i^2} xˉ=RMS(x)x∗γ,where RMS(x)=d1i=1∑dxi2
优势:
- 少算一个均值(Mean),少做一次减法操作。
- 虽然理论计算量减少不算极其巨大(约为 10%-20% 的 LayerNorm 开销),但在几百亿参数、几万亿 Token 的训练中,累积节省的训练时间非常可观。
- 代码实现更简单,GPU Kernel 优化更容易。
2. 效果并未下降(甚至更好)
实验表明,LayerNorm 中的"平移不变性"(由减均值带来)对 Transformer 的收敛并不是必须的,而"缩放不变性"(由除以方差带来)才是关键。RMSNorm 在保持模型收敛速度和最终效果上,与 LayerNorm 持平甚至略优。
3. 数值稳定性
在训练极深的网络时,RMSNorm 往往表现出比 LayerNorm 更好的数值稳定性,尤其是在使用 fp16 或 bf16 混合精度训练时。
二、 可不可以分母为 L2 Norm?
这个问题的本质是:为什么分母是 RMS(均方根)而不是 L2 范数?
我们先看定义:
- RMS (均方根): 1d∑xi2\sqrt{\frac{1}{d} \sum x_i^2}d1∑xi2
- L2 Norm (欧几里得范数): ∥x∥2=∑xi2\|x\|_2 = \sqrt{\sum x_i^2}∥x∥2=∑xi2
两者关系为:
∥x∥2=d×RMS(x) \|x\|_2 = \sqrt{d} \times \text{RMS}(x) ∥x∥2=d ×RMS(x)
答案是:理论上可以,但直接用会有"尺度(Scale)问题",需要修正。
如果你直接用 L2 Norm 作为分母,即:
xˉ=x∥x∥2 \bar{x} = \frac{x}{\|x\|_2} xˉ=∥x∥2x
这实际上是将输入向量投影到了单位超球面上(Unit Hypersphere),即输出向量的长度恒为 1。
为什么直接用 L2 Norm 不好?
-
维度的诅咒(Dimensionality Issue):
LLM 的隐藏层维度 ddd 通常很大(例如 LLaMA-7B 的 d=4096d=4096d=4096,GPT-3 的 d=12288d=12288d=12288)。
如果输出向量的总长度为 1,那么向量中每个元素的平均大小约为 1d\frac{1}{\sqrt{d}}d 1。
- 当 d=4096d=4096d=4096 时,元素值约为 1/64≈0.0151/64 \approx 0.0151/64≈0.015。
- 这会导致激活值非常小,使得后续的权重矩阵更新困难,容易引发梯度消失或导致数值精度问题(Underflow)。
-
期望的激活值分布:
神经网络通常希望神经元的激活值保持在"单位方差"附近(即数值在 -1 到 1 之间波动,而不是 0.01 级别),这样更有利于非线性激活函数(如 Swish/SiLU)的工作。RMSNorm 通过除以 1d∑x2\sqrt{\frac{1}{d}\sum x^2}d1∑x2 ,正好抵消了维度的影响,使得归一化后的数值幅度与维度 ddd 无关,保持在合理的范围内。
如何修正才能用 L2 Norm?
如果你非要用 L2 Norm,你必须人为乘上一个系数 d\sqrt{d}d 来把数值拉回正常范围:
xˉ=x∥x∥2×d \bar{x} = \frac{x}{\|x\|_2} \times \sqrt{d} xˉ=∥x∥2x×d
你会发现:
x∑x2×d=x1d∑x2=xRMS(x) \frac{x}{\sqrt{\sum x^2}} \times \sqrt{d} = \frac{x}{\sqrt{\frac{1}{d}\sum x^2}} = \frac{x}{\text{RMS}(x)} ∑x2 x×d =d1∑x2 x=RMS(x)x
结论:
乘以 d\sqrt{d}d 修正后的 L2 Norm Normalization,在数学上等价于 RMSNorm。
总结
- 用 RMSNorm 是因为它是 LayerNorm 的简化版,去掉了不必要的"减均值"操作,速度更快且效果相当。
- 不能直接用 L2 Norm,因为在大维度下会导致数值过小,模型难以训练。
- 如果用 L2 Norm 并乘以 d\sqrt{d}d,那么它本质上就是 RMSNorm。
为什么RMSNorm数值在 -1 到 1 之间波动,而不是 0.01 级别
这是一个非常敏锐的问题!说明你已经开始思考归一化背后的数学本质了。
简单直接的回答是:这是由 RMSNorm 的数学定义强制规定的。
RMSNorm 的目的就是强行把数据的**"平均能量(幅度)"**拉回到 1。
为了让你直观理解,我们通过数学推导 和分蛋糕的比喻两个角度来解释。
1. 数学推导:为什么结果一定是 1?
我们要证明:不管输入 xxx 是什么,经过 RMSNorm 处理后的输出 yyy,其 RMS 值恒等于 1。
设定:
- 输入向量:xxx
- RMSNorm 的分母(缩放因子):S=RMS(x)=1d∑xi2S = \text{RMS}(x) = \sqrt{\frac{1}{d}\sum x_i^2}S=RMS(x)=d1∑xi2
- 输出向量:y=xSy = \frac{x}{S}y=Sx
现在,我们来算一下输出向量 yyy 的 RMS 值是多少:
RMS(y)=1d∑yi2 \text{RMS}(y) = \sqrt{\frac{1}{d} \sum y_i^2} RMS(y)=d1∑yi2
把 yi=xiSy_i = \frac{x_i}{S}yi=Sxi 代入进去:
RMS(y)=1d∑(xiS)2 \text{RMS}(y) = \sqrt{\frac{1}{d} \sum (\frac{x_i}{S})^2} RMS(y)=d1∑(Sxi)2
因为 SSS 是一个算出来的常数(标量),可以提出来:
RMS(y)=1S2⋅(1d∑xi2) \text{RMS}(y) = \sqrt{\frac{1}{S^2} \cdot (\frac{1}{d} \sum x_i^2)} RMS(y)=S21⋅(d1∑xi2)
请注意括号里的 1d∑xi2\frac{1}{d} \sum x_i^2d1∑xi2 是什么?它正是 S2S^2S2(即输入 xxx 的均方值)。
所以:
RMS(y)=1S2⋅S2=1=1 \text{RMS}(y) = \sqrt{\frac{1}{S^2} \cdot S^2} = \sqrt{1} = 1 RMS(y)=S21⋅S2 =1 =1
结论:
输出向量 yyy 的 RMS 恒等于 1。
这意味着:yyy 里面的元素 yiy_iyi,其平方的平均值是 1。
那么 yiy_iyi 的数值自然就会分布在 -1 到 1 左右(大部分数值在 [-1, 1] 之间,少部分在 [-2, 2] 等,符合正态分布的特征),而绝对不可能大家都缩在 0.01 这种级别。
2. 通俗比喻:切蛋糕 vs 定餐标
为了理解为什么 RMSNorm 是"1"级别,而 L2 Norm 是"0.01"级别,我们可以打个比方。
假设我们有 d=4096d=4096d=4096 个人(神经元维度),需要给他们分配能量(数值大小)。
L2 Norm(单位球归一化):相当于"切一块固定的蛋糕"
- 规则 :不管有多少人,蛋糕总重量必须是 1kg(∑x2=1\sum x^2 = 1∑x2=1)。
- 结果:因为有 4096 个人分这 1kg 蛋糕,每个人分到的蛋糕就非常小。
- 算一下 :平均每个人分到 1/4096=1/64≈0.0151 / \sqrt{4096} = 1/64 \approx 0.0151/4096 =1/64≈0.015。
- 这就是为什么 L2 Norm 会导致数值变成 0.01 级别。
RMSNorm:相当于"定每个人平均餐标"
- 规则 :不管有多少人,我们要保证平均每个人 吃的蛋糕重量是 1kg(1d∑x2=1\frac{1}{d}\sum x^2 = 1d1∑x2=1)。
- 结果:因为我们定的是"平均标准",所以无论人多还是人少,每个人盘子里的蛋糕都在 1kg 左右。
- 代价 :总的蛋糕重量会随着人数增加而变大(总能量 ∑x2=d\sum x^2 = d∑x2=d),但这没关系,我们要的是每个神经元的感受(数值大小)是正常的。
- 这就是为什么 RMSNorm 的数值维持在 1 级别。
3. 为什么要维持在 1 级别?(不做 0.01 的坏处)
你可能会问:"0.01 就 0.01 呗,只要归一化了不就行了吗?"
在深度学习中,1 是一个非常特殊的"黄金尺度",主要是为了配合激活函数。
看看大模型常用的激活函数(如 Swish, GeLU, SiLU)的图像:
- 如果输入是 0.01(极小值):
这些激活函数在 0 附近几乎就是一条直线(近似线性)。如果你把所有数值都压缩到 0.01,神经网络就退化成了线性变换,非线性能力丢失,模型变傻,层数再深也没用。 - 如果输入是 100(极大值):
可能会进入饱和区或者梯度爆炸区。 - 当输入在 -1 到 1 之间(单位方差):
这正是激活函数非线性变化最丰富、梯度传导最健康的区域。
总结:
RMSNorm 通过定义分母带有 1d\frac{1}{d}d1,巧妙地抵消了维度 ddd 的影响,强制让输出数据的平均幅度保持在 1 左右,从而让神经网络处于最佳的工作状态。
RMSNorm 通过定义分母带有1/d,巧妙地抵消了维度 d的影响,数学上怎么表示
好的,我们用数学推导 来直观地看一下,这个 1d\frac{1}{d}d1 到底是怎么把 ddd "消灭"掉的。
为了方便理解,我们需要引入一个统计学的假设:假设神经元的值是随机分布的。
设定场景
假设输入向量 xxx 有 ddd 个元素:x=[x1,x2,...,xd]x = [x_1, x_2, ..., x_d]x=[x1,x2,...,xd]。
我们假设这些元素 xix_ixi 是独立同分布(i.i.d)的随机变量,且:
- 均值为 0 (E[xi]=0E[x_i] = 0E[xi]=0)
- 方差为 σ2\sigma^2σ2 (E[xi2]=σ2E[x_i^2] = \sigma^2E[xi2]=σ2)
这意味着每个数字的大小大概就在 σ\sigmaσ 左右。
第一步:看看分子的"能量"里藏着什么?
不管是 RMSNorm 还是 L2 Norm,核心项都是 ∑xi2\sum x_i^2∑xi2(所有元素的平方和)。
我们来看一下这个"平方和"的期望值是多少:
E[∑i=1dxi2]=∑i=1dE[xi2] E[\sum_{i=1}^{d} x_i^2] = \sum_{i=1}^{d} E[x_i^2] E[i=1∑dxi2]=i=1∑dE[xi2]
因为每个 xi2x_i^2xi2 平均下来都是 σ2\sigma^2σ2,一共有 ddd 个这样的数相加:
E[∑i=1dxi2]=d⋅σ2 E[\sum_{i=1}^{d} x_i^2] = d \cdot \sigma^2 E[i=1∑dxi2]=d⋅σ2
发现问题了吗?
这个总能量(Sum of Squares)是与维度 ddd 成正比的 。
如果 d=4096d=4096d=4096,这个和就是 4096 倍的 σ2\sigma^2σ2。维度越大,这个数越大,这就是"维度的诅咒"。
第二步:RMSNorm 如何利用 1/d 进行"精准爆破"?
RMSNorm 的分母公式是:
RMS(x)=1d∑i=1dxi2 \text{RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2} RMS(x)=d1i=1∑dxi2
现在我们将第一步算出来的结果代入这个根号里面:
RMS(x)≈1d⋅(d⋅σ2) \text{RMS}(x) \approx \sqrt{\frac{1}{d} \cdot (d \cdot \sigma^2)} RMS(x)≈d1⋅(d⋅σ2)
关键时刻来了!
请看根号内部的运算:
1d×d=1 \frac{1}{d} \times d = 1 d1×d=1
这里的 1d\frac{1}{d}d1 是人为定义的,而 ddd 是数据累加自带的。它们相互抵消了!
于是:
RMS(x)≈1⋅σ2=σ \text{RMS}(x) \approx \sqrt{1 \cdot \sigma^2} = \sigma RMS(x)≈1⋅σ2 =σ
第三步:看最终结果
RMSNorm 的计算公式是:
xˉi=xiRMS(x) \bar{x}_i = \frac{x_i}{\text{RMS}(x)} xˉi=RMS(x)xi
代入刚才的结果:
xˉi≈xiσ \bar{x}_i \approx \frac{x_i}{\sigma} xˉi≈σxi
结论:
你看,最终结果里完全没有 ddd 的影子 。
不管你的模型维度 ddd 是 100 还是 10000:
- 分子 xix_ixi 的量级是 σ\sigmaσ。
- 分母 RMS(x)\text{RMS}(x)RMS(x) 的量级也是 σ\sigmaσ。
- 两者一除,结果就在 1 附近。
对比:如果没有 1/d 会怎样?(L2 Norm 的失败)
如果分母里没有那个 1d\frac{1}{d}d1,也就是直接用 L2 Norm (∑x2\sqrt{\sum x^2}∑x2 ),情况会变成这样:
分母≈d⋅σ2=σ⋅d \text{分母} \approx \sqrt{d \cdot \sigma^2} = \sigma \cdot \sqrt{d} 分母≈d⋅σ2 =σ⋅d
那么输出就是:
xˉi=xiσ⋅d≈1d \bar{x}_i = \frac{x_i}{\sigma \cdot \sqrt{d}} \approx \frac{1}{\sqrt{d}} xˉi=σ⋅d xi≈d 1
这就导致结果里残留了 ddd 的影响 。维度 ddd 越大,分母越大,结果就越小(趋近于0),这就是我们在前一个问题里讨论的"0.01级别"的问题。
总结
数学表达式上的"抵消"过程就是:
1d⏟RMS定义引入×∑i=1dxi2⏟随维度d线性增长≈1d×(d⋅σ2)=σ2(与d无关) \underbrace{\frac{1}{d}}{\text{RMS定义引入}} \times \underbrace{\sum{i=1}^d x_i^2}_{\text{随维度d线性增长}} \approx \frac{1}{d} \times (d \cdot \sigma^2) = \sigma^2 \quad (\text{与d无关}) RMS定义引入 d1×随维度d线性增长 i=1∑dxi2≈d1×(d⋅σ2)=σ2(与d无关)
这就是 RMSNorm 能够让大模型在不同宽度的层之间保持数值稳定性、不随维度变化而"飘"的数学原理。