读到这里,大多数读者应该已经能把 Scaled Dot-Product Attention 的基本流程复述出来: <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 做内积、除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 、过 softmax,再用得到的权重聚合 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V。问题在于,这台机器每次只给你一组注意力分布。对位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 来说,最终只有一份权重决定它往哪里看、看多少。
这在玩具例子里没有问题,在真实语言里就很快碰到上限。同一个 token 往往同时需要处理句法关系、指代关系、局部邻近关系、语义主题关系。如果把这些判断全部压进同一个 softmax,模型只能在多种关系之间做妥协,而不是并行建模。
Multi-Head Attention 做的事情其实非常直接:把 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l d_{model} </math>dmodel 维表示投到 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 个独立子空间里,让每个子空间各自形成一份注意力分布,最后再把这些结果拼回去。它看起来像一次简单的切分,但恰恰是这一步,让 Transformer 从单一的相似度度量,变成了同一步内并行处理多种关系的架构。
读完这一篇,你应该能回答:
- 为什么单头 attention 迟早会卡住。
- 为什么多头几乎不增加参数量。
- 不同头到底学到了什么,以及为什么不能把可视化当成因果解释。
- 为什么现代大模型训练时保留多头,推理时却大量使用 GQA 或 MQA。
- 生产代码里为什么一定是大矩阵乘法加 reshape,而不是 for 循环跑 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 次。
一、为什么一定要多头
1. 单头 attention 的上限在哪里
先把单头形式写清楚。给定输入序列 <math xmlns="http://www.w3.org/1998/Math/MathML"> X ∈ R n × d X \in \mathbb{R}^{n \times d} </math>X∈Rn×d,标准 attention 做的是:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q = X W Q , K = X W K , V = X W V A = softmax ( Q K T d k ) , Z = A V \begin{aligned} Q &= XW^Q,\quad K = XW^K,\quad V = XW^V \\ A &= \operatorname{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right),\quad Z = AV \end{aligned} </math>QA=XWQ,K=XWK,V=XWV=softmax(dk QKT),Z=AV
这里真正决定模型怎么看世界的是 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A。对位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 来说,它只有一行 softmax 概率分布,只能在所有候选位置里分出一套权重。这意味着单头 attention 有两个硬约束。
第一,它一次只能表达一种关系。如果当前位置既要看主语和动词之间的句法链路,又要看代词和先行词之间的指代链路,那么这一切都要挤在同一份分布里完成。结果往往不是两种关系都学好,而是两边都被摊薄。
第二,它只能依赖一套相似度度量。 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^T </math>QKT 之所以能得到权重,是因为模型假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 在同一空间里的点积足以衡量相似度。但句法相似度、位置相似度、主题相似度并不天然属于同一种空间。要求一组 <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q W^Q </math>WQ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> W K W^K </math>WK 同时支撑这些判断,本质上是在逼同一把尺子量多种不同性质的东西。
这就是单头的核心瓶颈:它不缺一次聚合的能力,缺的是同一步里并行处理多种关系的能力。
2. 为什么不靠堆更深的层解决
一个自然反问是:单层只算一种关系也没关系,多堆几层不就行了?
问题在于,深度解决的是逐层组合,宽度解决的是同一步并行。Transformer 每一层的输出都会进入残差流,再交给下一层继续处理。第一层如果已经把一部分信息按某种关系混合了,第二层看到的就是混合后的表示,而不是原始 token 表示。你当然可以让下一层再学另一种关系,但这已经不是同一步并行完成,而是先后改写。
从建模目标上说,多头更像同一层里的多组滤波器,而不是更多层的重复堆叠。CNN 不会指望一个卷积核学完所有局部模式,再靠更深的层把它们拆开;同理,attention 也不应该只有一套相似度度量,然后把所有关系都往后推。
所以多头解决的不是层数不够,而是单层表达过窄的问题。
二、多头到底是怎么工作的
1. 标准定义
Multi-Head Attention 的标准定义是:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> MultiHead ( Q , K , V ) = Concat ( head 1 , ... , head h ) W O head i = Attention ( Q W i Q , K W i K , V W i V ) \begin{aligned} \operatorname{MultiHead}(Q, K, V) &= \operatorname{Concat}(\operatorname{head}_1, \ldots, \operatorname{head}_h)W^O \\ \operatorname{head}_i &= \operatorname{Attention}(QW_i^Q, KW_i^K, VW_i^V) \end{aligned} </math>MultiHead(Q,K,V)headi=Concat(head1,...,headh)WO=Attention(QWiQ,KWiK,VWiV)
关键不在 concat,而在每个头都有自己独立的 <math xmlns="http://www.w3.org/1998/Math/MathML"> W i Q W_i^Q </math>WiQ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W i K W_i^K </math>WiK、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W i V W_i^V </math>WiV。同一份输入会被投到 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 个不同的子空间里,每个子空间各自形成一份 softmax 分布。头和头之间参数不共享,所以模型有机会把不同头训练成不同的关系探测器。
最后的 <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W^O </math>WO 也不是可有可无的装饰。它的作用是把各个头的输出重新混回统一的残差流,让下一层能够在一个共享空间里继续处理,而不是面对 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 个互不沟通的孤岛。
2. 参数量为什么几乎不变
很多教程会说多头不怎么增加参数量,但这句话如果不算账,很容易被误解。
把所有头的投影矩阵沿最后一维拼起来,可以得到:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> W f u l l Q ∈ R d m o d e l × h d k W f u l l K ∈ R d m o d e l × h d k W f u l l V ∈ R d m o d e l × h d v \begin{aligned} W_{\mathrm{full}}^Q &\in \mathbb{R}^{d_{model} \times h d_k} \\ W_{\mathrm{full}}^K &\in \mathbb{R}^{d_{model} \times h d_k} \\ W_{\mathrm{full}}^V &\in \mathbb{R}^{d_{model} \times h d_v} \end{aligned} </math>WfullQWfullKWfullV∈Rdmodel×hdk∈Rdmodel×hdk∈Rdmodel×hdv
在最常见的设置里, <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = d v = d m o d e l / h d_k = d_v = d_{model} / h </math>dk=dv=dmodel/h,于是 <math xmlns="http://www.w3.org/1998/Math/MathML"> h d k = d m o d e l h d_k = d_{model} </math>hdk=dmodel。这意味着 <math xmlns="http://www.w3.org/1998/Math/MathML"> W f u l l Q W_{\mathrm{full}}^Q </math>WfullQ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W f u l l K W_{\mathrm{full}}^K </math>WfullK、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W f u l l V W_{\mathrm{full}}^V </math>WfullV 都退回成 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l × d m o d e l d_{model} \times d_{model} </math>dmodel×dmodel 的方阵,再加上一个同样大小的 <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W^O </math>WO,多头 attention 整体上仍然只是 4 个 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l × d m o d e l d_{model} \times d_{model} </math>dmodel×dmodel 矩阵。
用 Transformer-base 的常见配置举例:
- 单头大版本: <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l = 512 d_{model} = 512 </math>dmodel=512,那么 <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q W^Q </math>WQ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W K W^K </math>WK、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W V W^V </math>WV 各是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 512 × 512 512 \times 512 </math>512×512。
- 8 头版本:每个头的矩阵是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 512 × 64 512 \times 64 </math>512×64,一共 8 头,拼起来还是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 512 × 512 512 \times 512 </math>512×512。
也就是说,多头买的不是更多参数,多头买的是更多独立 softmax 的并行度。
这一点非常关键。一个更大的单头只能给你一份更精细的分布,但还是只有一份分布;多头给你的是 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 份独立分布,它们可以同时盯住不同关系。
3. 一个最小数值例子
为了把直觉落地,考虑一个极小的例子: <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l = 4 d_{model} = 4 </math>dmodel=4, <math xmlns="http://www.w3.org/1998/Math/MathML"> h = 2 h = 2 </math>h=2,因此 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = d v = 2 d_k = d_v = 2 </math>dk=dv=2,序列长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> n = 3 n = 3 </math>n=3。输入设为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = ( 1 0 1 0 0 1 0 1 1 1 0 0 ) X = \begin{pmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{pmatrix} </math>X= 101011100010
再取最简单的投影: <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q = W K = W V = I W^Q = W^K = W^V = I </math>WQ=WK=WV=I。
这时第 1 个头只看前两维,第 2 个头只看后两维。对同一个 token 来说,这两个头看到的是不同的几何结构。第 1 个头里,第三个 token 同时和前两个 token 有相似性;第 2 个头里,第三个 token 恰好变成零向量,对谁都不特别相似。
如果把第 1 个头的打分写出来,有:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> scores 1 = Q 1 K 1 T 2 = ( 1 / 2 0 1 / 2 0 1 / 2 1 / 2 1 / 2 1 / 2 2 / 2 ) \operatorname{scores}_1 = \frac{Q_1K_1^T}{\sqrt{2}} = \begin{pmatrix} 1/\sqrt{2} & 0 & 1/\sqrt{2} \\ 0 & 1/\sqrt{2} & 1/\sqrt{2} \\ 1/\sqrt{2} & 1/\sqrt{2} & 2/\sqrt{2} \end{pmatrix} </math>scores1=2 Q1K1T= 1/2 01/2 01/2 1/2 1/2 1/2 2/2
而第 2 个头里,第三行会全部变成 0,softmax 之后就是均匀分布。于是同一个 query 在两个头里得到的注意力模式完全不同:一个头把权重集中到和自己最相近的位置,另一个头因为分辨不出差异,只能平均分配。
这就是多头和单头最本质的差别:多头不是把一个大空间切碎而已,而是给每个子空间独立保留一份 softmax 表达能力。
三、不同的头到底学到了什么
1. BERT 里最常见的四类头
BERT 火起来之后,研究者第一次系统地把多头逐个可视化。Clark 等人的分析里,最常见的头大致可以分成四类。
第一类是位置型。它们几乎只看相邻 token,或者只看自己,像是在做局部 n-gram 聚合。
第二类是锚点型。它们把大量权重给 [CLS]、[SEP]、句号,或者序列开头的若干位置。后来这类模式在长上下文推理里演化成了 attention sink 的重要现象。
第三类是句法型。某些头会稳定地把注意力放到主语对应的动词、介词对应的宾语、修饰语对应的中心词上。模型从来没被显式教过依存语法,但它会自发学出这类结构。
第四类是指代型。它们更稀有,通常出现在中后层,用来追踪 pronoun 和先行词之间的关系。
这些结果至少说明一件事:多头并不是训练出很多完全一样的副本。它们确实会分工,而且分工经常与我们关心的语言结构对应。
2. 可视化很有用,但不是因果解释
看到这里很容易走到另一个极端:把某个好看的注意力图直接当成模型解释。
这一步需要非常克制。Jain 与 Wallace 的结论非常明确:注意力分布可以和某种解释相一致,但不能直接等同于模型的因果机制。因为最终输出不仅取决于注意力权重,还取决于被加权的 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 本身,以及更早层已经写进残差流里的信息。
所以更稳妥的理解是:
- 可视化适合生成假设。它能告诉你某个头看起来像句法头、像 sink 头、像位置头。
- 消融和干预才更接近验证。把头置零、替换输出、观察性能下降,才更能说明这个头是不是在承担关键功能。
换句话说,注意力图能帮你看见模式,但不能替你完成归因。
3. 跨层分工与头剪枝
如果把视角从单层拉到多层,现象会更有意思。Tenney 等人的 probing 结果显示,BERT 的浅层更接近词法和局部邻近特征,中层更偏句法,深层更偏语义和篇章。这意味着多头不只是横向并行,也在纵向上形成了层级分工。
另一方面,Michel 和 Voita 的剪头实验也说明:并不是每个头都同等重要。很多头可以被单独剪掉而几乎不掉点,但也有少数头一旦剪掉,性能会明显下滑。这说明多头内部既有专责头,也有冗余头。
这对工程的启发非常直接:训练阶段保留较多头,有利于模型探索不同关系;部署阶段则可以把部分冗余结构压缩掉,于是才有了后来的 GQA、MQA 和各种头剪枝方案。
四、从 MHA 到 GQA:工程上的现实约束
1. 头数怎么选
原始 Transformer 的经验其实已经给出了很强的约束:头数不是越多越好,而是要和每头维度一起看。
| <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h | <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk | 典型结论 |
|---|---|---|
| 1 | 512 | 表达力不足,单一分布太受限 |
| 4 | 128 | 明显改善 |
| 8 | 64 | 经典甜点区间 |
| 16 | 32 | 开始变窄 |
| 32 | 16 | 每头维度过小,效果回落 |
后来大模型的配置大体沿着这个经验走:
| 模型 | <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l d_{model} </math>dmodel | <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h | <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk |
|---|---|---|---|
| Transformer-base | 512 | 8 | 64 |
| BERT-base | 768 | 12 | 64 |
| BERT-large | 1024 | 16 | 64 |
| GPT-3 175B | 12288 | 96 | 128 |
| LLaMA-2 7B | 4096 | 32 | 128 |
| LLaMA-2 70B | 8192 | 64 | 128 |
最稳定的经验不是头数本身,而是每头维度通常锁在 64 或 128。头太少,关系不够并行;头太多,单头维度又太瘦,连基本的相似度判断都做不扎实。
2. 为什么推理端开始大量砍头
训练时多头是优势,推理时多头却很快变成负担,问题集中在 KV cache。
标准 MHA 里,每个头都有自己的一份 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V。当上下文很长时,这部分缓存会迅速吃光显存。于是工程上出现了两条典型路线:
- MQA:所有 query 头共享同一份 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V,KV cache 最省,但表达力损失更明显。
- GQA:把 query 头分组,每组共享一份 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V,在质量和速度之间取折中。
| 变体 | <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q 头数 | <math xmlns="http://www.w3.org/1998/Math/MathML"> K / V K/V </math>K/V 头数 | KV cache | 常见取舍 |
|---|---|---|---|---|
| MHA | <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h | <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h | 最大 | 训练最好,推理最慢 |
| GQA | <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h | <math xmlns="http://www.w3.org/1998/Math/MathML"> g g </math>g | 中等 | 质量接近 MHA,推理显著更快 |
| MQA | <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h | 1 | 最小 | 最省显存,但更容易掉点 |
这也是为什么现代大模型常常呈现一个看上去矛盾的趋势:训练时保留较多 query 头,推理时尽量共享 <math xmlns="http://www.w3.org/1998/Math/MathML"> K / V K/V </math>K/V。
3. 训练稳定性的几个注意点
多头本身不神秘,但大模型里它会和训练稳定性强耦合,最常见的注意点有三个。
第一,pre-LN 比 post-LN 更稳。深层模型中,attention 输出会不断写回残差流,post-LN 更容易让梯度方差沿层数积累,pre-LN 在 GPT、LLaMA 这类大模型里已经几乎成为默认选择。
第二,训练前期不同头往往都很像。softmax 输入接近零时,各头的分布都接近均匀,分工是在训练中后期逐渐拉开的。不要拿训练早期的注意力图去解释模型行为。
第三, <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W^O </math>WO 的初始化值得认真对待。GPT-2 之后很常见的做法,是按层数缩小 <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W^O </math>WO 的初始方差,减少 attention 输出反复写回残差流时的方差放大。这不是多头独有的数学性质,但它直接影响多头模块在深层网络里的稳定性。
4. 自注意力和交叉注意力有什么不同
多头机制不仅用于 self-attention,也同样用于 cross-attention。形式上两者完全一样,差别只在来源:
- self-attention 里, <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K、 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 都来自同一份输入。
- cross-attention 里, <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q 来自 decoder 当前状态, <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 来自 encoder 输出。
从多头的角度看,变化不在公式,而在任务含义上。self-attention 更像序列内部关系建模,cross-attention 更像目标序列对源序列做可寻址检索。很多翻译和多模态模型里的对齐能力,靠的正是 cross-attention 中不同头的分工。
五、工程实现:一次大矩阵乘法加 reshape
1. 为什么生产代码不是 for 循环
概念上,多头好像就是把 attention 跑 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 次,然后把结果拼起来。但生产代码从不会真的写一个 for 循环。
原因很简单:GPU 喜欢一次大矩阵乘法,不喜欢很多次小矩阵乘法。真正高效的实现会先用一到三次大 GEMM 一次性算出全部头的 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K、 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V,然后 reshape 成 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( B , h , N , d k ) (B, h, N, d_k) </math>(B,h,N,dk) 的形状,再把头维度当成 batched matmul 的一个批次维来统一处理。
也就是说,工程实现的骨架其实是:
python
# X: (B, N, D)
qkv = X @ W_qkv
q, k, v = split_and_reshape(qkv) # (B, h, N, d_k)
scores = (q @ k.transpose(-2, -1)) / sqrt(d_k)
attn = softmax(scores, dim=-1)
out = attn @ v
out = merge_heads(out) @ W_o
核心思想不是把一个头复制很多次,而是把所有头并进同一套张量运算里。
2. 一份完整的 PyTorch 实现
下面是一份简洁但已经接近生产习惯的 PyTorch 写法:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, 'd_model 必须能被 num_heads 整除'
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
qkv = self.W_qkv(x)
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
scores = torch.matmul(q, k.transpose(-2, -1)) / sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
out = self.W_o(out)
return out
这段实现里有几个值得特别注意的点。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q W^Q </math>WQ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W K W^K </math>WK、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W V W^V </math>WV 被合成了一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> W q k v W_{qkv} </math>Wqkv,这是为了减少 GEMM 次数。
view和permute的顺序不能错,错了通常不会立刻报错,但模型会学不起来。contiguous()不是多余的,它是在张量转置之后为后续 reshape 和 matmul 保底。- 真正上 GPU 跑大模型时,通常会直接调用
F.scaled_dot_product_attention或者底层 fused kernel,而不是自己手写 softmax。
3. 最容易踩的几个坑
多头实现里最常见的坑,基本都不是理论错误,而是张量细节错误。
第一, <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l d_{model} </math>dmodel 和头数不整除。这是最简单也最常见的 bug。
第二,reshape 顺序错。把 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( B , N , h , d k ) (B, N, h, d_k) </math>(B,N,h,dk) 写成 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( B , h , N , d k ) (B, h, N, d_k) </math>(B,h,N,dk),代码可能照样能跑,但 token 维和 head 维已经被弄乱。
第三,mask 形状或 dtype 不对。实践里最好显式把 mask 写成带 head 维的 bool tensor,不要依赖隐式 broadcast。
第四,不要轻易删掉 <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W^O </math>WO。concat 之后虽然已经回到 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l d_{model} </math>dmodel 维,但没有 <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W^O </math>WO,各头之间就失去了重新混合和重新写入残差流的机会。
六、把答案收回到核心问题
如果把整篇内容压缩成一句话,那么 Multi-Head Attention 的作用就是:把一次 attention 从单一 softmax 升级成多组并行 softmax,让模型在同一步里同时建模多种关系。
它真正厉害的地方不在于公式有多复杂,而在于设计非常节制:参数量基本不变,计算结构仍然适合大矩阵乘法,表达力却从一组关系扩展成了一组子空间里的并行关系。后续从 BERT、GPT 到 LLaMA,再到 GQA、MQA 和 FlashAttention,本质上都仍然在围绕这个设计继续打磨。
关键概念回顾
- 多头的本质不是把维度切碎,而是给不同子空间各自保留一份独立的 softmax 分布。
- 在 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = d m o d e l / h d_k = d_{model} / h </math>dk=dmodel/h 的标准设置下,多头几乎不比单头多参数;它换来的主要是并行建模不同关系的能力。
- 不同头确实会分工,但注意力图只能作为线索,不能直接当成因果解释。
- 训练喜欢保留较多独立头,推理则更关心 KV cache,所以现代大模型才会大量使用 GQA 和 MQA。
- 真正高效的实现一定是大矩阵乘法加 reshape,而不是 for 循环跑 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 次 attention。
常见误解
- 误解一:头越多越好。错。头数必须和每头维度一起看,单头太瘦会直接掉表达力。
- 误解二:多头比单头多很多参数。错。标准配置下参数量几乎等价。
- 误解三:一个漂亮的注意力图就等于模型学会了句法。错。可视化只能给出相关性线索,不是因果证明。
- 误解四:把同一个 attention 跑 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 次再平均就是多头。错。多头的关键是每个头有自己独立的投影矩阵。
- 误解五:推理时继续保留完整 MHA 一定最好。错。部署场景里,GQA 和 MQA 往往是更合理的工程折中。
下一步
- 想理解 decoder 为什么不能看未来:去看 17. Causal Mask。
- 想接着看 attention 的复杂度和长上下文瓶颈:18. 注意力的复杂度问题。
- 想回到总图看 encoder、decoder 和 FFN 怎么拼起来:20. Transformer 整体架构。
参考文献
- Vaswani A., Shazeer N., Parmar N., Uszkoreit J., Jones L., Gomez A. N., Kaiser L., Polosukhin I. Attention Is All You Need. NeurIPS 2017.
- Clark K., Khandelwal U., Levy O., Manning C. D. What Does BERT Look At? An Analysis of BERT's Attention. EMNLP 2019.
- Voita E., Talbot D., Moiseev F., Sennrich R., Titov I. Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned. ACL 2019.
- Michel P., Levy O., Neubig G. Are Sixteen Heads Really Better than One? NeurIPS 2019.
- Jain S., Wallace B. C. Attention is not Explanation. NAACL 2019.
- Tenney I., Das D., Pavlick E. BERT Rediscovers the Classical NLP Pipeline. ACL 2019.
- Shazeer N. Fast Transformer Decoding: One Write-Head is All You Need. 2019.
- Ainslie J., Lee-Thorp J., de Jong M., Zemlyanskiy Y., Lebrón F., Sanghai S. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023.
- Xiao G., Tian Y., Chen B., Han S., Lewis M. Efficient Streaming Language Models with Attention Sinks. ICLR 2024.
- Xiong R., Yang Y., He D., Zheng K., Zheng S., Xing C., Zhang H., Lan Y., Wang L., Liu T. On Layer Normalization in the Transformer Architecture. ICML 2020.
← 上一篇:15. Scaled Dot-Product | 下一篇:17. Causal Mask →