序言:一个除号背后藏着的整门数学课
很多人第一次读 Vaswani 2017 的公式时,都会卡在那一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 上。
公式本身写得简洁:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Attention ( Q , K , V ) = softmax ( Q K T d k ) V \operatorname{Attention}(Q, K, V) = \operatorname{softmax}\left(\frac{QK^{\mathsf{T}}}{\sqrt{d_k}}\right) V </math>Attention(Q,K,V)=softmax(dk QKT)V
但那个分母上的 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 看起来像是凭空冒出来的常数。
「为什么是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 不是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk?」
「为什么不是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l \sqrt{d_{\mathrm{model}}} </math>dmodel ?」
「为什么不是其他什么数?」
如果你只读论文的那一句话------"to counteract the effect of large dot products"------你会觉得这是一个 经验技巧。
事实远不止此。
那个 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 是一个 从概率论第一性原理推出来的、几乎别无选择的数字。
它涉及:随机变量的方差加法、softmax 的饱和性、链式法则下的梯度衰减、训练动力学的稳定性------你能想到的关于「为什么神经网络能优化」的核心问题,全都串在这一个除号上。
本文要做的事情,就是把这个除号拆开------一步一步、不跳逻辑------告诉你:为什么是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ?它到底拯救了什么?以及,到 2026 年,对这个 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的现代理解(包括 NTK 视角、FlashAttention 数值稳定性、Muon 优化器对 attention 的影响)有哪些新维度。
读完之后,你应该能在被人问到「为什么除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 」的时候,给出一个 5 分钟版本、一个 30 分钟版本、和一个「我可以为你推一遍」版本。
一、问题缘起:先看不除会发生什么
1.1 复盘公式
第 13 篇我们看到:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> attention ( q , k , v ) = softmax ( q K T ) V \operatorname{attention}(q, k, v) = \operatorname{softmax}(qK^{\mathsf{T}}) V </math>attention(q,k,v)=softmax(qKT)V
第 14 篇我们看到:self-attention 让每个 token 同时扮演 q、k、v。
但实际工程里写的、Vaswani 论文里写的、所有 PyTorch 实现里写的,都是:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> attention ( Q , K , V ) = softmax ( Q K T d k ) V \operatorname{attention}(Q, K, V) = \operatorname{softmax}\left(\frac{QK^{\mathsf{T}}}{\sqrt{d_k}}\right) V </math>attention(Q,K,V)=softmax(dk QKT)V
那个 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 是什么?为什么必须有?
我们做一个思想实验:把 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 拿掉,看会发生什么。
1.2 一个具体的数值实验
设 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 64 d_k = 64 </math>dk=64。
<math xmlns="http://www.w3.org/1998/Math/MathML"> q q </math>q 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k 是两个 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 维向量,每一维独立、均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 0 </math>0、方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1(就当它们是从标准正态采样)。
我们想知道: <math xmlns="http://www.w3.org/1998/Math/MathML"> q ⋅ k q \cdot k </math>q⋅k 的分布是什么?
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q ⋅ k = ∑ i q i k i q \cdot k = \sum_i q_i k_i </math>q⋅k=i∑qiki
每一项 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i k i q_i k_i </math>qiki 是两个独立标准正态的乘积------均值是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 0 </math>0,方差是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 × 1 = 1 1 \times 1 = 1 </math>1×1=1。因为
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Var ( X Y ) = E [ X 2 ] E [ Y 2 ] − E [ X ] 2 E [ Y ] 2 = 1 ⋅ 1 − 0 = 1 , \operatorname{Var}(XY) = \mathbb{E}[X^2]\mathbb{E}[Y^2] - \mathbb{E}[X]^2 \mathbb{E}[Y]^2 = 1 \cdot 1 - 0 = 1, </math>Var(XY)=E[X2]E[Y2]−E[X]2E[Y]2=1⋅1−0=1,
对独立零均值变量成立。
这 <math xmlns="http://www.w3.org/1998/Math/MathML"> 64 64 </math>64 个独立项相加:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E [ q ⋅ k ] = 0 , Var ( q ⋅ k ) = 64 , σ = 8. \mathbb{E}[q \cdot k] = 0, \qquad \operatorname{Var}(q \cdot k) = 64, \qquad \sigma = 8. </math>E[q⋅k]=0,Var(q⋅k)=64,σ=8.
所以 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ⋅ k q \cdot k </math>q⋅k 的取值范围大概在 <math xmlns="http://www.w3.org/1998/Math/MathML"> ± 24 \pm 24 </math>±24( <math xmlns="http://www.w3.org/1998/Math/MathML"> 3 σ 3\sigma </math>3σ)以内浮动。
1.3 把 q·k = 24 喂进 softmax
假设我们有 <math xmlns="http://www.w3.org/1998/Math/MathML"> 8 8 </math>8 个 key,对应 <math xmlns="http://www.w3.org/1998/Math/MathML"> 8 8 </math>8 个点积,碰巧最大那个是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 24 24 </math>24,其它都接近 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 0 </math>0。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> softmax ( [ 24 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ] ) = ? \operatorname{softmax}([24, 0, 0, 0, 0, 0, 0, 0]) = ? </math>softmax([24,0,0,0,0,0,0,0])=?
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> e 24 ≈ 2.6 × 1 0 10 , e 0 = 1. e^{24} \approx 2.6 \times 10^{10}, \qquad e^0 = 1. </math>e24≈2.6×1010,e0=1.
归一化后: <math xmlns="http://www.w3.org/1998/Math/MathML"> [ ≈ 1 , ≈ 0 , ≈ 0 , ... , ≈ 0 ] [\approx 1, \approx 0, \approx 0, \ldots, \approx 0] </math>[≈1,≈0,≈0,...,≈0]------几乎纯 one-hot。
这看起来不是好事吗?模型「下定决心」选了一个 token?
恰恰相反------这是一场灾难。
1.4 灾难的来源:梯度消失
来看 softmax 的 Jacobian。
设 <math xmlns="http://www.w3.org/1998/Math/MathML"> p = softmax ( s ) p = \operatorname{softmax}(s) </math>p=softmax(s),那么:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ p i ∂ s j = p i ( δ i j − p j ) \frac{\partial p_i}{\partial s_j} = p_i (\delta_{ij} - p_j) </math>∂sj∂pi=pi(δij−pj)
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> δ i j \delta_{ij} </math>δij 是 Kronecker delta( <math xmlns="http://www.w3.org/1998/Math/MathML"> i = j i = j </math>i=j 时为 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1,否则为 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 0 </math>0)。
如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> p p </math>p 接近 one-hot,比如 <math xmlns="http://www.w3.org/1998/Math/MathML"> p 1 ≈ 1 p_1 \approx 1 </math>p1≈1,其它 <math xmlns="http://www.w3.org/1998/Math/MathML"> ≈ 0 \approx 0 </math>≈0,那么:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ p 1 ∂ s 1 = 1 ⋅ ( 1 − 1 ) = 0 , ∂ p 1 ∂ s j = 1 ⋅ ( 0 − 0 ) = 0 ( j ≠ 1 ) , ∂ p i ∂ s j ≈ 0 ( i ≠ 1 ) . \frac{\partial p_1}{\partial s_1} = 1 \cdot (1 - 1) = 0, \qquad \frac{\partial p_1}{\partial s_j} = 1 \cdot (0 - 0) = 0 \; (j \neq 1), \qquad \frac{\partial p_i}{\partial s_j} \approx 0 \; (i \neq 1). </math>∂s1∂p1=1⋅(1−1)=0,∂sj∂p1=1⋅(0−0)=0(j=1),∂sj∂pi≈0(i=1).
整个 Jacobian 几乎为零矩阵。
这意味着:通过 softmax 反向传播的梯度被掐死了。
下游的 loss 想告诉 attention 层「你应该多看看 token 5」,但这个信号被 softmax 饱和性吃掉了------logits <math xmlns="http://www.w3.org/1998/Math/MathML"> s s </math>s 几乎不会被更新。
1.5 这就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 要解决的问题
如果 logits 的方差是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk,那把它们除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ,方差就变成 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1。
logits 不再因为维度放大而漂到饱和区,softmax 输出保持在「有梯度的工作点」。
训练就能进行。
这是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的全部直觉------剩下的都是细节。
二、点积方差的严格推导
2.1 假设
我们在以下假设下推 <math xmlns="http://www.w3.org/1998/Math/MathML"> Var ( q ⋅ k ) = d k \operatorname{Var}(q \cdot k) = d_k </math>Var(q⋅k)=dk:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> q q </math>q 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k 是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 维向量
- <math xmlns="http://www.w3.org/1998/Math/MathML"> q q </math>q 的每一维 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i q_i </math>qi 之间独立
- <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k 的每一维 <math xmlns="http://www.w3.org/1998/Math/MathML"> k j k_j </math>kj 之间独立
- <math xmlns="http://www.w3.org/1998/Math/MathML"> q q </math>q 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k 之间也独立
- 所有 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i q_i </math>qi 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> k j k_j </math>kj 都是均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 0 </math>0、方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1
后面我们会讨论这些假设在真实模型中成立到什么程度。
2.2 推导
设
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = q ⋅ k = ∑ i q i k i . X = q \cdot k = \sum_i q_i k_i. </math>X=q⋅k=i∑qiki.
第一步:均值。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E [ X ] = ∑ i E [ q i k i ] = ∑ i E [ q i ] E [ k i ] = 0. \mathbb{E}[X] = \sum_i \mathbb{E}[q_i k_i] = \sum_i \mathbb{E}[q_i] \mathbb{E}[k_i] = 0. </math>E[X]=i∑E[qiki]=i∑E[qi]E[ki]=0.
第二步:方差。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Var ( X ) = E [ X 2 ] − E [ X ] 2 = E [ X 2 ] . \operatorname{Var}(X) = \mathbb{E}[X^2] - \mathbb{E}[X]^2 = \mathbb{E}[X^2]. </math>Var(X)=E[X2]−E[X]2=E[X2].
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E [ X 2 ] = E [ ( ∑ i q i k i ) 2 ] = ∑ i ∑ j E [ q i k i q j k j ] . \mathbb{E}[X^2] = \mathbb{E}\left[\left(\sum_i q_i k_i\right)^2\right] = \sum_i \sum_j \mathbb{E}[q_i k_i q_j k_j]. </math>E[X2]=E (i∑qiki)2 =i∑j∑E[qikiqjkj].
对于 <math xmlns="http://www.w3.org/1998/Math/MathML"> i ≠ j i \neq j </math>i=j:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E [ q i k i q j k j ] = E [ q i ] E [ k i ] E [ q j ] E [ k j ] = 0. \mathbb{E}[q_i k_i q_j k_j] = \mathbb{E}[q_i] \mathbb{E}[k_i] \mathbb{E}[q_j] \mathbb{E}[k_j] = 0. </math>E[qikiqjkj]=E[qi]E[ki]E[qj]E[kj]=0.
对于 <math xmlns="http://www.w3.org/1998/Math/MathML"> i = j i = j </math>i=j:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E [ q i 2 k i 2 ] = E [ q i 2 ] E [ k i 2 ] = 1 × 1 = 1. \mathbb{E}[q_i^2 k_i^2] = \mathbb{E}[q_i^2] \mathbb{E}[k_i^2] = 1 \times 1 = 1. </math>E[qi2ki2]=E[qi2]E[ki2]=1×1=1.
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E [ X 2 ] = ∑ i 1 = d k . \mathbb{E}[X^2] = \sum_i 1 = d_k. </math>E[X2]=i∑1=dk.
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Var ( X ) = d k . \operatorname{Var}(X) = d_k. </math>Var(X)=dk.
2.3 标准差
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> σ = d k . \sigma = \sqrt{d_k}. </math>σ=dk .
这就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的来源------不是任何「经验试出来的常数」,而是 <math xmlns="http://www.w3.org/1998/Math/MathML"> Var \operatorname{Var} </math>Var 加法的直接结果。
2.4 缩放后的分布
定义
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X ′ = X d k . X' = \frac{X}{\sqrt{d_k}}. </math>X′=dk X.
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Var ( X ′ ) = Var ( X ) d k = 1. \operatorname{Var}(X') = \frac{\operatorname{Var}(X)}{d_k} = 1. </math>Var(X′)=dkVar(X)=1.
不管 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 64 64 </math>64、 <math xmlns="http://www.w3.org/1998/Math/MathML"> 512 512 </math>512、还是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 4096 4096 </math>4096, <math xmlns="http://www.w3.org/1998/Math/MathML"> X ′ X' </math>X′ 的方差永远是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1。
logits 的尺度被「归一化」到了一个不依赖于维度的水平。
2.5 为什么是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 不是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk
有人问:「除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 不是更彻底吗?」
不行。
如果除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk,那么
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Var ( X d k ) = d k d k 2 = 1 d k → 0 \operatorname{Var}\left(\frac{X}{d_k}\right) = \frac{d_k}{d_k^2} = \frac{1}{d_k} \to 0 </math>Var(dkX)=dk2dk=dk1→0
当 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 很大时就会发生这个退化。
logits 全部接近 0,softmax 输出接近均匀分布------attention 失去了「选择性」。
我们要的是「方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> = 1 = 1 </math>=1」(既不太尖锐也不太平),所以分母必须是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 。
这是一个临界点,不是一个「随便挑的数」。
2.6 有人问:为什么不除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 d k \sqrt{2 d_k} </math>2dk 之类
也可以,只要常数因子合理(比如让方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> = 0.5 = 0.5 </math>=0.5 而不是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1)。
但这只会让 softmax 略偏平缓------本质和 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 没区别。
Vaswani 选 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 是因为它最自然------把方差归一化到 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1,保留了「方差为 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1」这个统计学最常用的标准化。
后续工作(比如 RoFormer、LLaMA)也都沿用这个选择。
2.7 一个数值表
| <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk | <math xmlns="http://www.w3.org/1998/Math/MathML"> σ = d k \sigma = \sqrt{d_k} </math>σ=dk | logits 范围( <math xmlns="http://www.w3.org/1998/Math/MathML"> 3 σ 3\sigma </math>3σ) |
|---|---|---|
| 8 | 2.83 | ±8.5 |
| 32 | 5.66 | ±17 |
| 64 | 8 | ±24 |
| 128 | 11.3 | ±34 |
| 256 | 16 | ±48 |
| 512 | 22.6 | ±68 |
如果不缩放, <math xmlns="http://www.w3.org/1998/Math/MathML"> 512 512 </math>512 维的点积可能跑到 <math xmlns="http://www.w3.org/1998/Math/MathML"> ± 68 \pm 68 </math>±68------softmax 看到这种 logits,对应的 <math xmlns="http://www.w3.org/1998/Math/MathML"> e 68 ≈ 1 0 29 e^{68} \approx 10^{29} </math>e68≈1029------任何对手 logit 都被压成 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 0 </math>0。
缩放后 logits 永远在 <math xmlns="http://www.w3.org/1998/Math/MathML"> ± 3 \pm 3 </math>±3 左右------softmax 仍能区分大小,但梯度不会断流。
三、softmax 饱和性的可视化
3.1 直观图景
右侧子图(unscaled):一个 logit 比其它大很多------softmax 集中在那一个点上。
左侧子图(scaled):logits 接近,softmax 平缓,多个 token 都有可见权重。
不是说「平缓就一定好」------实际训练中,模型最终会学到「需要 sharp 时 sharp」的能力。
但训练初期必须从「平缓 + 有梯度」的状态出发,否则模型一开始就被卡在饱和区出不来。
3.2 一个比喻
把 softmax 想象成一根弹簧。
弹簧未饱和时,你拉它它会动,反馈给你力------你能学到「拉的方向」。
弹簧饱和时(拉到极限),你怎么拉它都不动------你什么也学不到。
logits 越大,softmax 越饱和;scaled dot-product 就是把弹簧从「饱和区」拉回到「线性区」,让训练能进行。
3.3 数学复盘
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> softmax ( s ) i = e s i ∑ j e s j . \operatorname{softmax}(s)_i = \frac{e^{s_i}}{\sum_j e^{s_j}}. </math>softmax(s)i=∑jesjesi.
求导:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ softmax ( s ) i ∂ s j = p i ( δ i j − p j ) \frac{\partial \operatorname{softmax}(s)i}{\partial s_j} = p_i (\delta{ij} - p_j) </math>∂sj∂softmax(s)i=pi(δij−pj)
最大值的对角项: <math xmlns="http://www.w3.org/1998/Math/MathML"> p i ( 1 − p i ) p_i (1 - p_i) </math>pi(1−pi),当 <math xmlns="http://www.w3.org/1998/Math/MathML"> p i → 1 p_i \to 1 </math>pi→1 时为 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 0 </math>0;当 <math xmlns="http://www.w3.org/1998/Math/MathML"> p i → 0 p_i \to 0 </math>pi→0 时也为 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 0 </math>0;最大在 <math xmlns="http://www.w3.org/1998/Math/MathML"> p i = 0.5 p_i = 0.5 </math>pi=0.5。
非对角项: <math xmlns="http://www.w3.org/1998/Math/MathML"> − p i p j -p_i p_j </math>−pipj,仅当两个都不是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 0 </math>0 也不是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1 时才有效。
所以「梯度最大」的工作点是 <math xmlns="http://www.w3.org/1998/Math/MathML"> p i ∈ [ 0.1 , 0.9 ] p_i \in [0.1, 0.9] </math>pi∈[0.1,0.9]------这正是 logits 适中( <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ≈ 1 \sigma \approx 1 </math>σ≈1)时的状态。
3.4 与温度参数的关系
很多人熟悉「softmax 温度」 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> softmax T ( s ) i = e s i / T ∑ j e s j / T \operatorname{softmax}_T(s)_i = \frac{e^{s_i / T}}{\sum_j e^{s_j / T}} </math>softmaxT(s)i=∑jesj/Tesi/T
<math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 大,输出平缓; <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 小,输出 sharp。
scaled dot-product 中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 就扮演温度的角色------具体来说 <math xmlns="http://www.w3.org/1998/Math/MathML"> T = d k T = \sqrt{d_k} </math>T=dk 。
但与 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 不同, <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 不是调参 ,而是定参------它的值由维度决定,不是用户选择。
3.5 一个常见的混淆
「我可以学一个 temperature 替代 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 吗?」
理论上可以,但实践中很少这么做。
因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 已经把 logits 归一化到 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ≈ 1 \sigma \approx 1 </math>σ≈1,再学一个 temperature 等于多此一举------除非你想做「shaped attention」之类的研究。
LLaMA、GPT、PaLM 等都没有学习 temperature,全用 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 。
但有一些工作(如 NormFormer、QK-norm)提出在 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 上做 LayerNorm,再不除 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ------效果近似但实现略有不同。
到 2026 年,QK-norm 方案在大模型训练中越来越常见。
四、为什么这件事到 d_k = 64 才显著
4.1 一个有趣的现象
在最早的 attention 工作(Bahdanau 2014)中,用的是加性注意力 ------ <math xmlns="http://www.w3.org/1998/Math/MathML"> s c o r e = v T tanh ( W q q + W k k ) \mathrm{score} = v^{\mathsf{T}} \tanh(W_q q + W_k k) </math>score=vTtanh(Wqq+Wkk)------根本没有 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的除法。
为什么 Bahdanau 不需要?
因为 Bahdanau 用的是 RNN 的 hidden state(典型 <math xmlns="http://www.w3.org/1998/Math/MathML"> d = 256 d = 256 </math>d=256 但走 <math xmlns="http://www.w3.org/1998/Math/MathML"> tanh \tanh </math>tanh,输出落在 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ − 1 , 1 ] [-1, 1] </math>[−1,1])+ 学习的 <math xmlns="http://www.w3.org/1998/Math/MathML"> v v </math>v------score 永远在一个有限的 bounded 区间,不会因为维度爆炸。
dot-product attention(Luong 2015)开始有这个问题------因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ⋅ k q \cdot k </math>q⋅k 没有 <math xmlns="http://www.w3.org/1998/Math/MathML"> tanh \tanh </math>tanh 包住,方差直接随 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 增长。
但 Luong 的实验里 d 不大,问题不严重。
到 Vaswani 2017 multi-head 时代, <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 64 d_k = 64 </math>dk=64(每 head 的维度), <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 的来源是线性投影后的向量------方差大约是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1(因为初始化 + LayerNorm)------这时候 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ⋅ k q \cdot k </math>q⋅k 的方差就接近 <math xmlns="http://www.w3.org/1998/Math/MathML"> 64 64 </math>64,问题就显现出来了。
4.2 d_k 越大,问题越严重
到 GPT-3: <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 128 d_k = 128 </math>dk=128(每 head),问题更严重。
到 PaLM: <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 256 d_k = 256 </math>dk=256(每 head),不缩放训练直接发散。
Vaswani 的论文里有一段话:「我们怀疑对于大的 d_k 值,dot products 在量级上变大,从而把 softmax 推到具有极小梯度的区域。」
这是一句经验观察------他们看到了「不缩放训练崩」,做了缩放,发现「训练好了」。
后来的理论分析(Xiong 2020 "On Layer Normalization in the Transformer Architecture")才把这件事讲透。
4.3 为什么加性注意力没这个问题
<math xmlns="http://www.w3.org/1998/Math/MathML"> v T tanh ( W q q + W k k ) v^{\mathsf{T}} \tanh(W_q q + W_k k) </math>vTtanh(Wqq+Wkk) 中, <math xmlns="http://www.w3.org/1998/Math/MathML"> tanh \tanh </math>tanh 的输出落在 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ − 1 , 1 ] [-1, 1] </math>[−1,1]。
随后 <math xmlns="http://www.w3.org/1998/Math/MathML"> v T ( ⋅ ) v^{\mathsf{T}}(\cdot) </math>vT(⋅) 是一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 维点积,这一步也会有方差放大。
但因为 tanh 已经把每一维 bound 住了,方差不会无界放大。
所以加性注意力天然「自带稳定性」,但代价是计算更慢(多一次矩阵乘 + 非线性)。
dot-product attention 要更快------因为它就是一个 matmul------但代价是必须手工加 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 来保稳定。
4.4 一个关于 norm 的细节
Vaswani 假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> q q </math>q 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k 每一维方差是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1。
实际模型里,这通过 LayerNorm 大致成立------LayerNorm 把每一层输出的 mean 归零、std 归一。
但有些层(比如 attention 输出)是 LayerNorm 之前还是之后?这就涉及 Pre-LN vs Post-LN 的选择。
Pre-LN(LayerNorm 在 sublayer 之前)让 <math xmlns="http://www.w3.org/1998/Math/MathML"> q q </math>q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k 在进入 attention 时严格 normalized------ <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的假设最契合。
Post-LN(LayerNorm 在 sublayer 之后)让 <math xmlns="http://www.w3.org/1998/Math/MathML"> q q </math>q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k 在进入 attention 时未必 normalized------可能需要 warmup 来稳住训练。
到 2026 年,Pre-LN 是主流(GPT、LLaMA 都用 Pre-LN)。
4.5 那 Q、K 不是单位方差怎么办
如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> W q W_q </math>Wq、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W k W_k </math>Wk 初始化合理(比如 Xavier 或 Kaiming),且输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 经过 LayerNorm,那么 <math xmlns="http://www.w3.org/1998/Math/MathML"> q = W q x q = W_q x </math>q=Wqx 的方差大致就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1。
但如果你不做 LayerNorm、用奇怪初始化、或者训练到某一步参数漂移------方差就不是 1 了。
QK-norm(在 q、k 上做 LayerNorm)就是把这个假设显式强制------不再依靠「希望 LayerNorm 保住」。
五、点积方差的可视化
5.1 三个直方图
<math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 8 d_k = 8 </math>dk=8 时分布窄, <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ≈ 2.83 \sigma \approx 2.83 </math>σ≈2.83。
<math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 64 d_k = 64 </math>dk=64 时分布宽, <math xmlns="http://www.w3.org/1998/Math/MathML"> σ = 8 \sigma = 8 </math>σ=8。
<math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 512 d_k = 512 </math>dk=512 时分布很宽, <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ≈ 22.6 \sigma \approx 22.6 </math>σ≈22.6。
直方图的横轴是点积值------纵轴是出现频率。
5.2 为什么这个图重要
看到这张图,你应该马上意识到:
不缩放时,点积尺度完全由维度决定------你换一个模型规模,点积尺度就变了------你的训练超参(学习率、初始化等)就要重调。
缩放后,点积尺度永远是 1------超参可以跨规模迁移。
这是 scaling laws 能成立的一个隐性前提:架构内的统计尺度必须不依赖于规模。
5.3 Chinchilla scaling 的隐含条件
Hoffmann 2022 给出 Chinchilla 定律:参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> N p N_p </math>Np 与 token 数 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D 的最优比例 <math xmlns="http://www.w3.org/1998/Math/MathML"> N p ≈ D / 20 N_p \approx D/20 </math>Np≈D/20。
这条定律的成立依赖于「同样的架构、同样的训练超参在不同规模下都能稳定训练」。
如果你不缩放点积,训练在 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 64 d_k = 64 </math>dk=64 时还稳定,到 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 512 d_k = 512 </math>dk=512 时就发散------scaling laws 整个就不成立。
<math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 是 scaling laws 的「隐性基础设施」之一。
六、训练曲线对比
6.1 定性差异
红线(unscaled):早期 loss 下降慢,很快卡在某个高位------softmax 饱和导致的优化困难。
绿线(scaled):稳定下降。
初值相同(损失大约是 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ( N ) \log(N) </math>log(N) 那个均匀分布的 cross-entropy)。
6.2 一个真实的定量例子
Vaswani 2017 §3.2.1 没给具体的对比训练曲线(论文较老),但后续工作(Xiong 2020)做过实验。
在 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 64 d_k = 64 </math>dk=64 的 Transformer-base 上,去掉 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk :
- 不调任何其它超参:loss 卡在 6.x,几乎不动。
- 把学习率降低 <math xmlns="http://www.w3.org/1998/Math/MathML"> 10 × 10\times </math>10×:训练能进行,但 BLEU 显著低于带 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的版本。
也就是说,没有 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 不是「完全不能训练」,而是「需要付出极大的超参代价、且最终质量更差」。
加上 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 等价于一个免费的、零计算开销的稳定性优化------为什么不用呢?
6.3 一个反直觉发现
有人发现:如果模型足够小( <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 16 d_k = 16 </math>dk=16 之类),不缩放也能训。
这与第二节的方差分析一致------ <math xmlns="http://www.w3.org/1998/Math/MathML"> σ = 4 \sigma = 4 </math>σ=4,logits 不会饱和。
但你不可能因为「小模型不需要」就在大模型里也省掉它------大模型里这个除号是必需品。
七、缩放与梯度下降稳定性
7.1 学习率与梯度的关系
如果不缩放,attention 的 logits 在 <math xmlns="http://www.w3.org/1998/Math/MathML"> ± d k \pm \sqrt{d_k} </math>±dk 量级------softmax 输出近似 one-hot------梯度近似 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 0 </math>0------参数几乎不更新。
但偶尔某个 batch 有「比较平的 logits」,梯度突然爆发------参数大跳------loss 飞涨。
这是「饱和 + 偶尔不饱和」的混合模式------非常不稳定。
7.2 缩放后的 Lipschitz 性质
缩放后,softmax 的输入永远在 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ − 3 σ , 3 σ ] = [ − 3 , 3 ] [-3\sigma, 3\sigma] = [-3, 3] </math>[−3σ,3σ]=[−3,3] 左右------softmax 在这个区间内是 Lipschitz 连续的,导数有 bounded 上限。
这意味着「同样大小的输入扰动 <math xmlns="http://www.w3.org/1998/Math/MathML"> → \to </math>→ 同样大小的输出扰动」------训练动力学是稳定的。
八、参考资料
- Vaswani 2017: Ashish Vaswani et al., "Attention Is All You Need" (首次提出 Transformer 与 Scaled Dot-Product).
- Bahdanau 2014 : Dzmitry Bahdanau et al., "Neural Machine Translation by Jointly Learning to Align and Translate" (提出 Additive Attention, <math xmlns="http://www.w3.org/1998/Math/MathML"> tanh \tanh </math>tanh 限制方差).
- Luong 2015: Minh-Thang Luong et al., "Effective Approaches to Attention-based Neural Machine Translation" (提出 Dot-Product Attention).
- Xiong 2020: Ruibin Xiong et al., "On Layer Normalization in the Transformer Architecture" (分析缩放机制与 LayerNorm 对训练稳定性的深度影响).
- Hoffmann 2022: Jordan Hoffmann et al., "Training Compute-Optimal Large Language Models" (Chinchilla 定律,隐含架构缩放的统计不变性假设).
Lipschitz 常数大致是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1(用 <math xmlns="http://www.w3.org/1998/Math/MathML"> ℓ ∞ \ell_{\infty} </math>ℓ∞ 范数估计)。
7.3 与梯度裁剪的关系
很多 Transformer 训练里都有「gradient clipping」(梯度裁剪),把过大的梯度截断到 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∥ g ∥ ≤ c \lVert g \rVert \le c </math>∥g∥≤c。
为什么需要梯度裁剪?因为偶尔会有「outlier batch」让某些参数的梯度爆掉------比如某个 batch 里所有 token 都是同一个。
scaled dot-product 让这种 outlier 的破坏力降低------但不能完全消除------所以梯度裁剪仍是必需。
7.4 与 warmup 的关系
Transformer 训练几乎都用 learning rate warmup(前若干步学习率从 0 线性涨到峰值)。
为什么?因为训练初期参数随机,logits 分布可能极不平衡------warmup 给模型时间「找到稳定区域」再放学习率。
scaled dot-product 让初期 logits 不那么大------warmup 期可以更短------但不能省略。
八、与 NTK / 无限宽神经网络理论的联系
8.1 NTK 是什么
NTK(Neural Tangent Kernel,Jacot 2018)是一种刻画「无限宽网络在小学习率下的训练动力学」的理论。
核心结论是:在某些假设下,无限宽网络的训练等价于一个线性化模型 + 核回归,其中的「核」就叫 NTK。
NTK 给我们一个工具:预测网络在不同初始化、不同尺度下的行为。
8.2 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 与 NTK
NTK 理论强调一个 principle:网络中每一层的输入与输出的统计尺度必须一致------否则梯度传播会失衡。
scaled dot-product 正是这个 principle 在 attention 层的体现------把点积归一化到 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ = 1 \sigma = 1 </math>σ=1,让 attention 层的「输入尺度」与「输出尺度」一致。
如果不缩放,attention 层把方差从 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1 放大到 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ------下一层 LayerNorm 又把它拉回 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1------但中间这一段不稳定。
8.3 muP(Maximal Update Parametrization)
Yang & Hu 2021 的 muP 是 NTK 思想在工程上的实现:通过精心设计每层的 init scale 和 LR scale,让模型在改变宽度时超参不变。
muP 框架下,attention 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 是一个特殊处理------它不是 muP 自动推出的,而是早就独立存在的设计------但它与 muP 的精神高度契合。
到 2026 年,muP(特别是 mup-transfer 思路)成为大模型训练前调超参的重要工具------基础假设之一就是 attention 已经被 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 正则化过了。
8.4 NTK 视角的 attention
在 NTK 视角下:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> attention ( Q , K , V ) = softmax ( Q K T d k ) V \operatorname{attention}(Q, K, V) = \operatorname{softmax}\left(\frac{QK^{\mathsf{T}}}{\sqrt{d_k}}\right) V </math>attention(Q,K,V)=softmax(dk QKT)V 是一个 bilinear 算子。
- bilinear 部分 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^{\mathsf{T}} </math>QKT 把维度从 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 映射到 <math xmlns="http://www.w3.org/1998/Math/MathML"> N × N N \times N </math>N×N 的相似度矩阵。
- softmax 是一个非线性归一化。
- 与 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 的乘法把 <math xmlns="http://www.w3.org/1998/Math/MathML"> N × N N \times N </math>N×N 映射回 <math xmlns="http://www.w3.org/1998/Math/MathML"> N × d v N \times d_v </math>N×dv。
每一步的统计尺度都需要被控制------ <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 是控制 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^{\mathsf{T}} </math>QKT 这一步的尺度的工具。
V 那边没有显式缩放,因为 softmax 输出已经是概率(行和为 1)------V 的均值和方差只取决于 V 自己的统计------这一步通常不需要额外正则化。
九、Vaswani 论文里的原话
9.1 § 3.2.1 的关键段落
原文(NeurIPS 2017):
"We suspect that for large values of d_k, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this effect, we scale the dot products by 1/√d_k."
翻译:「我们怀疑对于大的 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk,点积量级变大,从而把 softmax 推入梯度极小的区域。为抵消这个效应,我们用 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 / d k 1 / \sqrt{d_k} </math>1/dk 缩放点积。」
9.2 这段话其实没给完整证明
注意「我们怀疑」(we suspect)------Vaswani 没有给出 <math xmlns="http://www.w3.org/1998/Math/MathML"> Var ( q ⋅ k ) = d k \operatorname{Var}(q \cdot k) = d_k </math>Var(q⋅k)=dk 的形式化推导,也没有大规模消融实验来证明。
后来的工作(Xiong 2020, On Layer Normalization in the Transformer Architecture)才把这件事详细分析。
但工程上,Vaswani 的「直觉 + 简单理论」已经够用------大家用了 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ,模型能训,事情就成立了。
这是科学研究里很常见的模式:实践先于理论------直觉推动实验,实验验证后再被理论补全。
9.3 注释:dot product vs scaled dot product 的对比实验
Vaswani 论文里的 Table 3 提到:
"While for small values of d_k the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of d_k. We suspect that..."
也就是说,Vaswani 团队做过对照实验------在 d_k 大时不缩放的 dot product 比加性 attention 差------所以加了缩放。
这是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 设计的直接动机。
十、对应的 PyTorch 实现
10.1 最朴素版
python
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5) # ← 这里就是 √d_k
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
return attn @ V, attn
注意 d_k ** 0.5 就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 。
Q.size(-1) 自动取最后一维------所以这段代码不需要传 d_k 参数。
10.2 数值稳定版(log-sum-exp trick)
朴素 softmax 在大 logits 时可能溢出(exp(700) = inf)。
实践中 PyTorch 的 F.softmax 已经内置了 log-sum-exp trick------把所有 logits 减去 max 再 exp,结果不变但数值稳定。
这一点对 scaled dot-product 也有意义------因为缩放后 logits 仍可能在 <math xmlns="http://www.w3.org/1998/Math/MathML"> ± 10 \pm 10 </math>±10 量级(在某些 head 学到 sharp pattern 时),log-sum-exp 仍是必要的。
10.3 PyTorch 2.0+ 的内置实现
python
out = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)
这一行调用底层 CUDA / Metal 实现------可能是 FlashAttention,也可能是 Memory-Efficient Attention,由 backend 自动选择。
但「除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 」这件事仍然在背后发生------只是你不用手写。
10.4 FlashAttention 中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk
FlashAttention 的核心是「tile-by-tile 计算 softmax」------在 SRAM 内做 streaming softmax。
<math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 缩放发生在每个 tile 计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^{\mathsf{T}} </math>QKT 的瞬间------和朴素实现没有本质区别。
工程上的难点是「数值稳定的 streaming softmax」(要保持 running max 和 running sum_exp)------但 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 这一步是简单乘法,不影响 FlashAttention 的核心算法。
10.5 一个常见 bug:缩放放在哪里
有些实现会写:
python
Q = Q / (d_k ** 0.5) # 提前缩放 Q
scores = Q @ K.transpose(-2, -1)
这等价于在 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^{\mathsf{T}} </math>QKT 上除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ------结果一样,但计算更高效(少一次矩阵元素级除法)。
但要注意:如果 K 有特殊处理(比如 RoPE),缩放放在哪里可能影响 RoPE 的正确性------一般推荐放在 score 上,最稳。
十一、几个常见的变体与争议
11.1 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk vs <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l \sqrt{d_{\mathrm{model}}} </math>dmodel
有人混淆:
<math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l d_{\mathrm{model}} </math>dmodel 是 token 嵌入维度(比如 <math xmlns="http://www.w3.org/1998/Math/MathML"> 512 512 </math>512)。
<math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 是每个 head 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q / K Q/K </math>Q/K 维度(比如 <math xmlns="http://www.w3.org/1998/Math/MathML"> 64 64 </math>64,如果有 <math xmlns="http://www.w3.org/1998/Math/MathML"> 8 8 </math>8 个 head)。
scaled dot-product 用的是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ,不是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l \sqrt{d_{\mathrm{model}}} </math>dmodel。
为什么?因为方差推导中,加和的项数是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk------每个 head 的点积只涉及 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 维。
如果你写错成 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l \sqrt{d_{\mathrm{model}}} </math>dmodel (除得太多),attention 会过于平缓------softmax 输出近似均匀------模型失去选择性。
11.2 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 / d k 1/d_k </math>1/dk vs <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 / d k 1/\sqrt{d_k} </math>1/dk
如果你看到某些代码或论文写「 <math xmlns="http://www.w3.org/1998/Math/MathML"> ÷ d k \div d_k </math>÷dk」(不是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ),那是错的------除非他们定义 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 原 d k \sqrt{\text{原 } d_k} </math>原 dk 。
把方差推导记牢: <math xmlns="http://www.w3.org/1998/Math/MathML"> σ = d k \sigma = \sqrt{d_k} </math>σ=dk ,所以分母是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 而不是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk。
11.3 学习的温度参数 vs 固定的 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk
有些工作(Shaped Attention、Stable Attention)把 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 替换成可学习的 <math xmlns="http://www.w3.org/1998/Math/MathML"> τ \tau </math>τ,让模型自适应温度。
这通常需要额外的稳定化(比如把 <math xmlns="http://www.w3.org/1998/Math/MathML"> τ \tau </math>τ clamp 到 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ d k / 2 , 2 d k ] [\sqrt{d_k}/2, 2\sqrt{d_k}] </math>[dk /2,2dk ]),否则 <math xmlns="http://www.w3.org/1998/Math/MathML"> τ \tau </math>τ 容易学到 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 0 </math>0 或 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∞ \infty </math>∞。
主流大模型仍然用固定 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ------因为它已经够好了,省掉一个超参。
11.4 logit-cap:另一个稳定性技巧
Gemini 和某些 Anthropic 模型用「logit cap」技巧:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s c o r e s = c ⋅ tanh ( Q K T d k c ) \mathrm{scores} = c \cdot \tanh\left(\frac{QK^{\mathsf{T}}}{\sqrt{d_k} \, c}\right) </math>scores=c⋅tanh(dk cQKT)
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c 是某个 cap 值(比如 <math xmlns="http://www.w3.org/1998/Math/MathML"> 50 50 </math>50)。
这把 logits 强行 clip 到 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ − c , c ] [-c, c] </math>[−c,c] 区间,防止极端 outlier。
这是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 之后的进一步加强------不替代它,而是补充它。
11.5 query 缩放还是 score 缩放
一些工程实现把缩放放在 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q 上:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q : = Q d k 0.25 , K : = K d k 0.25 Q := \frac{Q}{d_k^{0.25}}, \qquad K := \frac{K}{d_k^{0.25}} </math>Q:=dk0.25Q,K:=dk0.25K
两个 <math xmlns="http://www.w3.org/1998/Math/MathML"> 0.25 0.25 </math>0.25 次方相乘恰好得 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 。
这种「分散到 Q 和 K」的写法在某些硬件上更高效,但数学上完全等价。
PyTorch 默认实现把缩放放在 score 上。
十二、与训练分布假设的关系
12.1 「q, k 是单位方差」这个假设有多严格
我们推导 <math xmlns="http://www.w3.org/1998/Math/MathML"> Var ( q ⋅ k ) = d k \operatorname{Var}(q \cdot k) = d_k </math>Var(q⋅k)=dk 的关键假设是:每一维独立、零均值、单位方差。
实际中:
- 独立:不严格,但近似成立(参数学到的 Q、K 投影把不同维度去相关到一定程度)。
- 零均值:通过 LayerNorm 严格成立。
- 单位方差:通过 LayerNorm 严格成立。
整体上,「点积方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> ≈ d k \approx d_k </math>≈dk」是一个近似------但近似得相当好。
12.2 训练后期的偏移
训练后期, <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 的分布可能偏离单位方差(特别是某些 head 学到 sharp pattern 时)------logits 的实际方差可能比 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 小很多(因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 学到了对齐方向,使 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ⋅ k q \cdot k </math>q⋅k 偏正)。
这时 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 给出的「过度缩放」让 softmax 仍然平缓------模型需要学习一个更大的 <math xmlns="http://www.w3.org/1998/Math/MathML"> W q W_q </math>Wq 或 <math xmlns="http://www.w3.org/1998/Math/MathML"> W k W_k </math>Wk 来「弥补」缩放。
QK-norm 的提案就是为了解决这个:让 <math xmlns="http://www.w3.org/1998/Math/MathML"> q q </math>q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k 在训练全程都保持单位方差, <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 缩放始终精确。
12.3 一个反直觉发现
Su 2024 等人发现:训练初期 logits 近似高斯,但训练到收敛时,logits 分布严重偏离高斯------出现一些「极端 outlier」(某些位置 logit 突变到 <math xmlns="http://www.w3.org/1998/Math/MathML"> ± 50 \pm 50 </math>±50 以上)。
这种 outlier 对训练稳定性是灾难------logit-cap 就是为了 cap 这些 outlier。
<math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的「单位方差」假设在训练稳定期成立,但在收敛附近可能开始失效------这是一个开放研究方向。
12.4 高斯假设到底有多重要
有人会问:如果 Q、K 的分布不是高斯(而是 t 分布、混合高斯、甚至离散),方差推导还成立吗?
成立。
<math xmlns="http://www.w3.org/1998/Math/MathML"> Var ( X + Y ) = Var ( X ) + Var ( Y ) \operatorname{Var}(X+Y) = \operatorname{Var}(X) + \operatorname{Var}(Y) </math>Var(X+Y)=Var(X)+Var(Y) 对任何独立随机变量都成立------和分布无关。
我们推 <math xmlns="http://www.w3.org/1998/Math/MathML"> Var ( q ⋅ k ) = d k \operatorname{Var}(q \cdot k) = d_k </math>Var(q⋅k)=dk 时也只用了「独立 + 零均值 + 单位方差」------没有用到高斯假设。
但软 预测(softmax 输出形状)要用 Central Limit Theorem------ <math xmlns="http://www.w3.org/1998/Math/MathML"> q ⋅ k q \cdot k </math>q⋅k 是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 个项之和, <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 大时近似高斯。
到 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 64 d_k = 64 </math>dk=64 已经足够 CLT 生效------分布看起来很高斯。
12.5 重尾分布(heavy tail)的影响
如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 不是高斯而是 heavy tail(比如 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 分布、Cauchy),那方差推导可能不成立------或者方差为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∞ \infty </math>∞。
实际上深度网络的中间表示确实会有 heavy tail(参考 Martin & Mahoney 2018 关于深度网络中间表示的 heavy tail spectrum)。
但 LayerNorm + 标准初始化让 Q、K 的尾部不至于失控------这是工程上的「经验救场」。
到 2026 年,理解 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 的真实分布、以及 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 在 heavy tail 下的有效性,仍是开放问题。
十三、当 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 不够用时
13.1 long-context 的 logits 暴涨
当上下文长度 N 很大时(比如 100k),同一个 query 要 attend 100k 个 key------每个 key 都贡献 logits 候选。
即使每个 logits 都是 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ = 1 \sigma = 1 </math>σ=1(缩放后),最大值 <math xmlns="http://www.w3.org/1998/Math/MathML"> max i s i \max_i s_i </math>maxisi 在 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 大时按 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 ln N \sqrt{2 \ln N} </math>2lnN 增长(极值统计)------logits 仍然漂移到大值。
但 <math xmlns="http://www.w3.org/1998/Math/MathML"> softmax ( s ) \operatorname{softmax}(s) </math>softmax(s) 中真正起作用的是 <math xmlns="http://www.w3.org/1998/Math/MathML"> s max − s s_{\max} - s </math>smax−s------这个差仍然是 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ ≈ 1 \sigma \approx 1 </math>σ≈1 的量级------所以 attention 仍然分布合理。
13.2 attention sink
Xiao 2023 (StreamingLLM) 发现:在 long-context 中,第一个 token(BOS)会「吸走」大量 attention 权重------这是 softmax 归一化的副产物。
具体机制:当所有 logits 都接近 0 时(没有特别匹配的 key),softmax 趋向均匀------但每个 token 都倾向把「无信息」的 attention 转移到「最早出现的 token」(BOS)。
<math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 与 attention sink 没有直接关系------但在 long-context 中, <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 提供了基础稳定性,attention sink 现象在此基础上才能被研究。
13.3 ALiBi 与 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的相互作用
ALiBi(Press 2021)在 logits 上加一个负的距离偏置:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s i j = q i ⋅ k j d k − m ⋅ ∣ i − j ∣ s_{ij} = \frac{q_i \cdot k_j}{\sqrt{d_k}} - m \cdot |i - j| </math>sij=dk qi⋅kj−m⋅∣i−j∣
<math xmlns="http://www.w3.org/1998/Math/MathML"> m m </math>m 是一个固定的负向斜率。
ALiBi 与 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 是叠加 关系------ <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 控制方差,ALiBi 控制位置偏置------两者各司其职。
13.4 RoPE 与 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的相互作用
RoPE(Su 2021)在 Q、K 上做旋转编码:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q ′ = R ( θ ) q , k ′ = R ( θ ) k q' = R(\theta) q, \qquad k' = R(\theta) k </math>q′=R(θ)q,k′=R(θ)k
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> R ( θ ) R(\theta) </math>R(θ) 是旋转矩阵------保持向量长度不变。
因此 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ′ ⋅ k ′ q' \cdot k' </math>q′⋅k′ 的方差与 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ⋅ k q \cdot k </math>q⋅k 一致------ <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 缩放仍然有效。
RoPE 是「不影响 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 假设」的位置编码------这是它能广泛应用的一个隐性原因。
十四、Muon 优化器与 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的现代视角
14.1 Muon 是什么
Muon(2024)是一个新型优化器,专为 Transformer 设计------它对 attention 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> W q W_q </math>Wq、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W k W_k </math>Wk 矩阵做特殊正交化。
核心思想: <math xmlns="http://www.w3.org/1998/Math/MathML"> W q W_q </math>Wq、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W k W_k </math>Wk 在训练中容易变得「不正交」------这让 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ⋅ k q \cdot k </math>q⋅k 的统计性质偏离原始假设------Muon 强制周期性正交化。
14.2 Muon 与 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的关系
Muon 维持 <math xmlns="http://www.w3.org/1998/Math/MathML"> W q W_q </math>Wq、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W k W_k </math>Wk 的正交性 <math xmlns="http://www.w3.org/1998/Math/MathML"> → \to </math>→ 维持 <math xmlns="http://www.w3.org/1998/Math/MathML"> q q </math>q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k 的单位方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> → \to </math>→ 维持 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的精确性。
也就是说,Muon 让「 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 假设」在训练全程都接近精确------这反过来让 attention 训练更稳定。
到 2026 年,Muon 在某些大模型预训练中开始被采用(比如 Kimi 的 K2 模型)------这印证了「保护 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 假设」的工程价值。
14.3 一种联合视角
把 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 、QK-norm、Muon、logit-cap 放到一起,你会发现一条主线:
保护 attention logits 的统计性质,让 softmax 始终在「有效梯度区」工作。
每一项技术都是这条主线的一个工具------ <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 是最基础的、最便宜的、必须有的------其它都是渐进改进。
十五、一个完整的数值小例子
15.1 设置
设 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 4 d_k = 4 </math>dk=4。 <math xmlns="http://www.w3.org/1998/Math/MathML"> q q </math>q 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k 都是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 4 4 </math>4 维。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q = [ 1 , 0.5 , − 0.5 , 1 ] q = [1, 0.5, -0.5, 1] </math>q=[1,0.5,−0.5,1]
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> k 1 = [ 0.5 , 1 , 0 , − 1 ] k_1 = [0.5, 1, 0, -1] </math>k1=[0.5,1,0,−1]
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> k 2 = [ 1 , 0.5 , 0.5 , 0 ] k_2 = [1, 0.5, 0.5, 0] </math>k2=[1,0.5,0.5,0]
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> k 3 = [ − 1 , 0 , 1 , 0.5 ] k_3 = [-1, 0, 1, 0.5] </math>k3=[−1,0,1,0.5]
15.2 不缩放的 logits
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s 1 = q ⋅ k 1 = 0.5 + 0.5 + 0 − 1 = 0 s_1 = q \cdot k_1 = 0.5 + 0.5 + 0 - 1 = 0 </math>s1=q⋅k1=0.5+0.5+0−1=0
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s 2 = q ⋅ k 2 = 1 + 0.25 − 0.25 + 0 = 1 s_2 = q \cdot k_2 = 1 + 0.25 - 0.25 + 0 = 1 </math>s2=q⋅k2=1+0.25−0.25+0=1
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s 3 = q ⋅ k 3 = − 1 + 0 − 0.5 + 0.5 = − 1 s_3 = q \cdot k_3 = -1 + 0 - 0.5 + 0.5 = -1 </math>s3=q⋅k3=−1+0−0.5+0.5=−1
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> l o g i t s = [ 0 , 1 , − 1 ] \mathrm{logits} = [0, 1, -1] </math>logits=[0,1,−1]
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> softmax ( [ 0 , 1 , − 1 ] ) ≈ [ 0.244 , 0.665 , 0.090 ] \operatorname{softmax}([0, 1, -1]) \approx [0.244, 0.665, 0.090] </math>softmax([0,1,−1])≈[0.244,0.665,0.090]
15.3 缩放的 logits
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> d k = 2 \sqrt{d_k} = 2 </math>dk =2
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s 1 ′ = 0 / 2 = 0 s'_1 = 0 / 2 = 0 </math>s1′=0/2=0
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s 2 ′ = 1 / 2 = 0.5 s'_2 = 1 / 2 = 0.5 </math>s2′=1/2=0.5
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s 3 ′ = − 1 / 2 = − 0.5 s'_3 = -1 / 2 = -0.5 </math>s3′=−1/2=−0.5
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> l o g i t s ′ = [ 0 , 0.5 , − 0.5 ] \mathrm{logits}' = [0, 0.5, -0.5] </math>logits′=[0,0.5,−0.5]
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> softmax ( [ 0 , 0.5 , − 0.5 ] ) ≈ [ 0.295 , 0.487 , 0.218 ] \operatorname{softmax}([0, 0.5, -0.5]) \approx [0.295, 0.487, 0.218] </math>softmax([0,0.5,−0.5])≈[0.295,0.487,0.218]
15.4 比较
不缩放: <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.244 , 0.665 , 0.090 ] [0.244, 0.665, 0.090] </math>[0.244,0.665,0.090]------中间项更突出。
缩放: <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.295 , 0.487 , 0.218 ] [0.295, 0.487, 0.218] </math>[0.295,0.487,0.218]------分布更平。
在 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 4 d_k = 4 </math>dk=4 这种小尺度下,差别不大------缩放只让 softmax 略缓和。
但当 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 64 d_k = 64 </math>dk=64 时,原始 logits 范围会扩大 <math xmlns="http://www.w3.org/1998/Math/MathML"> 4 4 </math>4 倍( <math xmlns="http://www.w3.org/1998/Math/MathML"> 64 / 4 = 4 \sqrt{64}/\sqrt{4} = 4 </math>64 /4 =4),不缩放 logits 是 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0 , 4 , − 4 ] [0, 4, -4] </math>[0,4,−4], <math xmlns="http://www.w3.org/1998/Math/MathML"> softmax ( [ 0 , 4 , − 4 ] ) ≈ [ 0.018 , 0.964 , 0.000 ] \operatorname{softmax}([0, 4, -4]) \approx [0.018, 0.964, 0.000] </math>softmax([0,4,−4])≈[0.018,0.964,0.000]------几乎 one-hot!
而缩放后仍然是 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0 , 0.5 , − 0.5 ] [0, 0.5, -0.5] </math>[0,0.5,−0.5]------分布合理。
这就是「维度越大,越需要 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 」的直观体现。
15.5 backward 梯度差异
设 loss 对 attention 输出有梯度信号 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ o u t \partial L / \partial \mathrm{out} </math>∂L/∂out。
对未缩放 softmax:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ s = J s o f t m a x ⋅ ∂ L ∂ o u t \frac{\partial L}{\partial s} = J_{\mathrm{softmax}} \cdot \frac{\partial L}{\partial \mathrm{out}} </math>∂s∂L=Jsoftmax⋅∂out∂L
<math xmlns="http://www.w3.org/1998/Math/MathML"> J s o f t m a x J_{\mathrm{softmax}} </math>Jsoftmax 在「one-hot」状态下接近零------ <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L / ∂ s ≈ 0 \partial L / \partial s \approx 0 </math>∂L/∂s≈0------ <math xmlns="http://www.w3.org/1998/Math/MathML"> s = Q K T s = QK^{\mathsf{T}} </math>s=QKT 的梯度也接近零------ <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 的更新极缓慢。
对缩放 softmax:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ s ′ = J s o f t m a x ⋅ ∂ L ∂ o u t \frac{\partial L}{\partial s'} = J_{\mathrm{softmax}} \cdot \frac{\partial L}{\partial \mathrm{out}} </math>∂s′∂L=Jsoftmax⋅∂out∂L
这里 <math xmlns="http://www.w3.org/1998/Math/MathML"> J s o f t m a x J_{\mathrm{softmax}} </math>Jsoftmax 不饱和,传播正常。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ s = ∂ L ∂ s ′ ⋅ 1 d k \frac{\partial L}{\partial s} = \frac{\partial L}{\partial s'} \cdot \frac{1}{\sqrt{d_k}} </math>∂s∂L=∂s′∂L⋅dk 1
注意还多了一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 / d k 1/\sqrt{d_k} </math>1/dk ------但这只让梯度等比缩小,不让它消失。
整体梯度量级仍然合理,训练能进行。
十六、关键概念回顾
-
点积方差 : <math xmlns="http://www.w3.org/1998/Math/MathML"> q ⋅ k q \cdot k </math>q⋅k 的方差等于 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk(在标准假设下)。
-
softmax 饱和:当 logits 量级远大于 1 时,softmax 输出近似 one-hot------梯度近似为零。
-
<math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 缩放 :把 logits 方差归一化到 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1,避免 softmax 饱和。
-
临界点 :分母必须是 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ------除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 太多,除以更小则不够。
-
scaling laws 隐性基础 : <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 让 attention 在不同维度下有相同的「统计工作点」------这是 Chinchilla scaling 能成立的前提之一。
-
NTK 视角 : <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 是「保持每一层输入输出尺度一致」的具体实现。
-
现代变体 :QK-norm、logit-cap、Muon 都是「保护 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 假设」的延伸。
-
训练稳定性 : <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 不能完全替代 LayerNorm、warmup、gradient clipping------但它是这些手段的基础。
-
Vaswani 原文:只是一句「we suspect」+ 简单实验------后续工作才补全理论。
-
PyTorch 实现 :
F.scaled_dot_product_attention已内置 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ------但理解原理仍然重要。
十七、常见误解
17.1 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 是经验技巧
错。这是一个有严格概率论推导的设计------不是「随便选一个数」。
17.4 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 是 attention 唯一的稳定化机制
错。LayerNorm、warmup、gradient clipping、初始化都是稳定性的一部分------ <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 只是其中一项。
17.5 缩放 = 退化
错。缩放后 attention 仍然能学到 sharp pattern------只是初始时不卡饱和------训练完成后该 sharp 还是 sharp。
17.12 缩放只为加速训练
不仅如此。
缩放更核心的目的是「让训练能进行」------而不是「让训练加快」。
不缩放时,训练在大模型上根本无法成功------加上缩放后训练才能稳定走完。
这是「质」的区别,不是「量」的区别。
17.14 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 缩放破坏了 attention 的「概率含义」
不严格。
attention 输出仍然是「key 上的概率分布」(softmax 输出每行和为 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1)------ <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 只影响这个分布有多 sharp,不影响它是不是概率。
事实上, <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 让初始的概率分布「更接近均匀」------这反而是更好的概率初始化。
十八、下一步
到这里,我们已经把 attention 机制原理这一块的核心讲完了。
下一篇(第 16 篇)开始进入【Part 3:Transformer 架构】------讨论完整的 Transformer 块如何把 attention、FFN、residual、LayerNorm 串起来。
我们会从「2017 原始 Transformer」讲起,逐步看到「现代 LLaMA-style Transformer」演化的每一个改动是为什么------Pre-LN vs Post-LN、SwiGLU vs ReLU、RMSNorm vs LayerNorm 等。
如果你已经掌握了:
- 第 11 篇的「attention 是什么」直觉
- 第 12 篇的 Bahdanau 加性注意力
- 第 13 篇的 Q/K/V 三件套
- 第 14 篇的 self-attention 概念
- 第 15 篇的 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 缩放原理
那你已经有了进入 Transformer 架构层的所有理论基础------下一篇就把这些拼起来。
十九、参考文献
下面按相关度排序列出本篇直接引用与延伸阅读,每条附一句话提示其在本篇中的角色。
阅读建议:1、2、3、12 是核心,其余是延伸。
- Vaswani, A. et al. "Attention Is All You Need." NeurIPS 2017. §3.2.1 给出 √d_k 的最早动机。
- Xiong, R. et al. "On Layer Normalization in the Transformer Architecture." ICML 2020. 形式化分析 √d_k 与 Pre-LN 的关系。
- Luong, M.-T. et al. "Effective Approaches to Attention-based Neural Machine Translation." EMNLP 2015. dot-product attention 的经典工作,没有 √d_k------展示了不缩放的问题。
- Bahdanau, D. et al. "Neural Machine Translation by Jointly Learning to Align and Translate." ICLR 2015. 加性 attention 没有 √d_k 问题,因为 tanh bound。
- Jacot, A., Gabriel, F., Hongler, C. "Neural Tangent Kernel: Convergence and Generalization in Neural Networks." NeurIPS 2018. NTK 理论的源头。
- Yang, G., Hu, E. J. "Tensor Programs IV: Feature Learning in Infinite-Width Neural Networks." ICML 2021. muP 的理论基础。
- Yang, G. et al. "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer." NeurIPS 2021. muP 的实操版,与 √d_k 互补。
- Hoffmann, J. et al. "Training Compute-Optimal Large Language Models." NeurIPS 2022. Chinchilla scaling laws,隐性依赖架构稳定性。
- Su, J. et al. "RoFormer: Enhanced Transformer with Rotary Position Embedding." Neurocomputing 2024. RoPE 不破坏 √d_k 假设。
- Press, O., Smith, N. A., Lewis, M. "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation." ICLR 2022. ALiBi,与 √d_k 叠加使用。
- Xiao, G. et al. "Efficient Streaming Language Models with Attention Sinks." ICLR 2024. attention sink 现象。
- Dao, T. et al. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022. 工程实现里 √d_k 的位置。
- Henry, A. et al. "Query-Key Normalization for Transformers." EMNLP Findings 2020. QK-norm 提案。
- Shazeer, N. "Fast Transformer Decoding: One Write-Head is All You Need." arXiv 2019. MQA。
- Ainslie, J. et al. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023. GQA。
- Martin, C. H., Mahoney, M. W. "Implicit Self-Regularization in Deep Neural Networks." JMLR 2021. 深度网络中间表示的 heavy tail 现象。
- Jordan, K. et al. "Muon: An Optimizer for the Hidden Layers of Neural Networks." 2024 blog/preprint. Muon 的提出。
← 上一篇:14|Self-Attention | 下一篇:16|Multi-Head Attention →