Multi-Head Attention:为什么要分多个头

读到这里,大多数读者应该已经能把 Scaled Dot-Product Attention 的基本流程复述出来: <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 做内积、除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 、过 softmax,再用得到的权重聚合 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V。问题在于,这台机器每次只给你一组注意力分布。对位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 来说,最终只有一份权重决定它往哪里看、看多少。

这在玩具例子里没有问题,在真实语言里就很快碰到上限。同一个 token 往往同时需要处理句法关系、指代关系、局部邻近关系、语义主题关系。如果把这些判断全部压进同一个 softmax,模型只能在多种关系之间做妥协,而不是并行建模。

Multi-Head Attention 做的事情其实非常直接:把 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l d_{model} </math>dmodel 维表示投到 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 个独立子空间里,让每个子空间各自形成一份注意力分布,最后再把这些结果拼回去。它看起来像一次简单的切分,但恰恰是这一步,让 Transformer 从单一的相似度度量,变成了同一步内并行处理多种关系的架构。

读完这一篇,你应该能回答:

  • 为什么单头 attention 迟早会卡住。
  • 为什么多头几乎不增加参数量。
  • 不同头到底学到了什么,以及为什么不能把可视化当成因果解释。
  • 为什么现代大模型训练时保留多头,推理时却大量使用 GQA 或 MQA。
  • 生产代码里为什么一定是大矩阵乘法加 reshape,而不是 for 循环跑 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 次。

原文链接

一、为什么一定要多头

1. 单头 attention 的上限在哪里

先把单头形式写清楚。给定输入序列 <math xmlns="http://www.w3.org/1998/Math/MathML"> X ∈ R n × d X \in \mathbb{R}^{n \times d} </math>X∈Rn×d,标准 attention 做的是:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q = X W Q , K = X W K , V = X W V A = softmax ⁡ ( Q K T d k ) , Z = A V \begin{aligned} Q &= XW^Q,\quad K = XW^K,\quad V = XW^V \\ A &= \operatorname{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right),\quad Z = AV \end{aligned} </math>QA=XWQ,K=XWK,V=XWV=softmax(dk QKT),Z=AV

这里真正决定模型怎么看世界的是 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A。对位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 来说,它只有一行 softmax 概率分布,只能在所有候选位置里分出一套权重。这意味着单头 attention 有两个硬约束。

第一,它一次只能表达一种关系。如果当前位置既要看主语和动词之间的句法链路,又要看代词和先行词之间的指代链路,那么这一切都要挤在同一份分布里完成。结果往往不是两种关系都学好,而是两边都被摊薄。

第二,它只能依赖一套相似度度量。 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^T </math>QKT 之所以能得到权重,是因为模型假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 在同一空间里的点积足以衡量相似度。但句法相似度、位置相似度、主题相似度并不天然属于同一种空间。要求一组 <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q W^Q </math>WQ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> W K W^K </math>WK 同时支撑这些判断,本质上是在逼同一把尺子量多种不同性质的东西。

这就是单头的核心瓶颈:它不缺一次聚合的能力,缺的是同一步里并行处理多种关系的能力。

2. 为什么不靠堆更深的层解决

一个自然反问是:单层只算一种关系也没关系,多堆几层不就行了?

问题在于,深度解决的是逐层组合,宽度解决的是同一步并行。Transformer 每一层的输出都会进入残差流,再交给下一层继续处理。第一层如果已经把一部分信息按某种关系混合了,第二层看到的就是混合后的表示,而不是原始 token 表示。你当然可以让下一层再学另一种关系,但这已经不是同一步并行完成,而是先后改写。

从建模目标上说,多头更像同一层里的多组滤波器,而不是更多层的重复堆叠。CNN 不会指望一个卷积核学完所有局部模式,再靠更深的层把它们拆开;同理,attention 也不应该只有一套相似度度量,然后把所有关系都往后推。

所以多头解决的不是层数不够,而是单层表达过窄的问题。


二、多头到底是怎么工作的

1. 标准定义

Multi-Head Attention 的标准定义是:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> MultiHead ⁡ ( Q , K , V ) = Concat ⁡ ( head ⁡ 1 , ... , head ⁡ h ) W O head ⁡ i = Attention ⁡ ( Q W i Q , K W i K , V W i V ) \begin{aligned} \operatorname{MultiHead}(Q, K, V) &= \operatorname{Concat}(\operatorname{head}_1, \ldots, \operatorname{head}_h)W^O \\ \operatorname{head}_i &= \operatorname{Attention}(QW_i^Q, KW_i^K, VW_i^V) \end{aligned} </math>MultiHead(Q,K,V)headi=Concat(head1,...,headh)WO=Attention(QWiQ,KWiK,VWiV)

关键不在 concat,而在每个头都有自己独立的 <math xmlns="http://www.w3.org/1998/Math/MathML"> W i Q W_i^Q </math>WiQ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W i K W_i^K </math>WiK、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W i V W_i^V </math>WiV。同一份输入会被投到 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 个不同的子空间里,每个子空间各自形成一份 softmax 分布。头和头之间参数不共享,所以模型有机会把不同头训练成不同的关系探测器。

最后的 <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W^O </math>WO 也不是可有可无的装饰。它的作用是把各个头的输出重新混回统一的残差流,让下一层能够在一个共享空间里继续处理,而不是面对 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 个互不沟通的孤岛。

2. 参数量为什么几乎不变

很多教程会说多头不怎么增加参数量,但这句话如果不算账,很容易被误解。

把所有头的投影矩阵沿最后一维拼起来,可以得到:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> W f u l l Q ∈ R d m o d e l × h d k W f u l l K ∈ R d m o d e l × h d k W f u l l V ∈ R d m o d e l × h d v \begin{aligned} W_{\mathrm{full}}^Q &\in \mathbb{R}^{d_{model} \times h d_k} \\ W_{\mathrm{full}}^K &\in \mathbb{R}^{d_{model} \times h d_k} \\ W_{\mathrm{full}}^V &\in \mathbb{R}^{d_{model} \times h d_v} \end{aligned} </math>WfullQWfullKWfullV∈Rdmodel×hdk∈Rdmodel×hdk∈Rdmodel×hdv

在最常见的设置里, <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = d v = d m o d e l / h d_k = d_v = d_{model} / h </math>dk=dv=dmodel/h,于是 <math xmlns="http://www.w3.org/1998/Math/MathML"> h d k = d m o d e l h d_k = d_{model} </math>hdk=dmodel。这意味着 <math xmlns="http://www.w3.org/1998/Math/MathML"> W f u l l Q W_{\mathrm{full}}^Q </math>WfullQ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W f u l l K W_{\mathrm{full}}^K </math>WfullK、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W f u l l V W_{\mathrm{full}}^V </math>WfullV 都退回成 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l × d m o d e l d_{model} \times d_{model} </math>dmodel×dmodel 的方阵,再加上一个同样大小的 <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W^O </math>WO,多头 attention 整体上仍然只是 4 个 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l × d m o d e l d_{model} \times d_{model} </math>dmodel×dmodel 矩阵。

用 Transformer-base 的常见配置举例:

  • 单头大版本: <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l = 512 d_{model} = 512 </math>dmodel=512,那么 <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q W^Q </math>WQ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W K W^K </math>WK、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W V W^V </math>WV 各是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 512 × 512 512 \times 512 </math>512×512。
  • 8 头版本:每个头的矩阵是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 512 × 64 512 \times 64 </math>512×64,一共 8 头,拼起来还是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 512 × 512 512 \times 512 </math>512×512。

也就是说,多头买的不是更多参数,多头买的是更多独立 softmax 的并行度。

这一点非常关键。一个更大的单头只能给你一份更精细的分布,但还是只有一份分布;多头给你的是 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 份独立分布,它们可以同时盯住不同关系。

3. 一个最小数值例子

为了把直觉落地,考虑一个极小的例子: <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l = 4 d_{model} = 4 </math>dmodel=4, <math xmlns="http://www.w3.org/1998/Math/MathML"> h = 2 h = 2 </math>h=2,因此 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = d v = 2 d_k = d_v = 2 </math>dk=dv=2,序列长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> n = 3 n = 3 </math>n=3。输入设为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = ( 1 0 1 0 0 1 0 1 1 1 0 0 ) X = \begin{pmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{pmatrix} </math>X= 101011100010

再取最简单的投影: <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q = W K = W V = I W^Q = W^K = W^V = I </math>WQ=WK=WV=I。

这时第 1 个头只看前两维,第 2 个头只看后两维。对同一个 token 来说,这两个头看到的是不同的几何结构。第 1 个头里,第三个 token 同时和前两个 token 有相似性;第 2 个头里,第三个 token 恰好变成零向量,对谁都不特别相似。

如果把第 1 个头的打分写出来,有:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> scores ⁡ 1 = Q 1 K 1 T 2 = ( 1 / 2 0 1 / 2 0 1 / 2 1 / 2 1 / 2 1 / 2 2 / 2 ) \operatorname{scores}_1 = \frac{Q_1K_1^T}{\sqrt{2}} = \begin{pmatrix} 1/\sqrt{2} & 0 & 1/\sqrt{2} \\ 0 & 1/\sqrt{2} & 1/\sqrt{2} \\ 1/\sqrt{2} & 1/\sqrt{2} & 2/\sqrt{2} \end{pmatrix} </math>scores1=2 Q1K1T= 1/2 01/2 01/2 1/2 1/2 1/2 2/2

而第 2 个头里,第三行会全部变成 0,softmax 之后就是均匀分布。于是同一个 query 在两个头里得到的注意力模式完全不同:一个头把权重集中到和自己最相近的位置,另一个头因为分辨不出差异,只能平均分配。

这就是多头和单头最本质的差别:多头不是把一个大空间切碎而已,而是给每个子空间独立保留一份 softmax 表达能力。


三、不同的头到底学到了什么

1. BERT 里最常见的四类头

BERT 火起来之后,研究者第一次系统地把多头逐个可视化。Clark 等人的分析里,最常见的头大致可以分成四类。

第一类是位置型。它们几乎只看相邻 token,或者只看自己,像是在做局部 n-gram 聚合。

第二类是锚点型。它们把大量权重给 [CLS]、[SEP]、句号,或者序列开头的若干位置。后来这类模式在长上下文推理里演化成了 attention sink 的重要现象。

第三类是句法型。某些头会稳定地把注意力放到主语对应的动词、介词对应的宾语、修饰语对应的中心词上。模型从来没被显式教过依存语法,但它会自发学出这类结构。

第四类是指代型。它们更稀有,通常出现在中后层,用来追踪 pronoun 和先行词之间的关系。

这些结果至少说明一件事:多头并不是训练出很多完全一样的副本。它们确实会分工,而且分工经常与我们关心的语言结构对应。

2. 可视化很有用,但不是因果解释

看到这里很容易走到另一个极端:把某个好看的注意力图直接当成模型解释。

这一步需要非常克制。Jain 与 Wallace 的结论非常明确:注意力分布可以和某种解释相一致,但不能直接等同于模型的因果机制。因为最终输出不仅取决于注意力权重,还取决于被加权的 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 本身,以及更早层已经写进残差流里的信息。

所以更稳妥的理解是:

  • 可视化适合生成假设。它能告诉你某个头看起来像句法头、像 sink 头、像位置头。
  • 消融和干预才更接近验证。把头置零、替换输出、观察性能下降,才更能说明这个头是不是在承担关键功能。

换句话说,注意力图能帮你看见模式,但不能替你完成归因。

3. 跨层分工与头剪枝

如果把视角从单层拉到多层,现象会更有意思。Tenney 等人的 probing 结果显示,BERT 的浅层更接近词法和局部邻近特征,中层更偏句法,深层更偏语义和篇章。这意味着多头不只是横向并行,也在纵向上形成了层级分工。

另一方面,Michel 和 Voita 的剪头实验也说明:并不是每个头都同等重要。很多头可以被单独剪掉而几乎不掉点,但也有少数头一旦剪掉,性能会明显下滑。这说明多头内部既有专责头,也有冗余头。

这对工程的启发非常直接:训练阶段保留较多头,有利于模型探索不同关系;部署阶段则可以把部分冗余结构压缩掉,于是才有了后来的 GQA、MQA 和各种头剪枝方案。


四、从 MHA 到 GQA:工程上的现实约束

1. 头数怎么选

原始 Transformer 的经验其实已经给出了很强的约束:头数不是越多越好,而是要和每头维度一起看。

<math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 典型结论
1 512 表达力不足,单一分布太受限
4 128 明显改善
8 64 经典甜点区间
16 32 开始变窄
32 16 每头维度过小,效果回落

后来大模型的配置大体沿着这个经验走:

模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l d_{model} </math>dmodel <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk
Transformer-base 512 8 64
BERT-base 768 12 64
BERT-large 1024 16 64
GPT-3 175B 12288 96 128
LLaMA-2 7B 4096 32 128
LLaMA-2 70B 8192 64 128

最稳定的经验不是头数本身,而是每头维度通常锁在 64 或 128。头太少,关系不够并行;头太多,单头维度又太瘦,连基本的相似度判断都做不扎实。

2. 为什么推理端开始大量砍头

训练时多头是优势,推理时多头却很快变成负担,问题集中在 KV cache。

标准 MHA 里,每个头都有自己的一份 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V。当上下文很长时,这部分缓存会迅速吃光显存。于是工程上出现了两条典型路线:

  • MQA:所有 query 头共享同一份 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V,KV cache 最省,但表达力损失更明显。
  • GQA:把 query 头分组,每组共享一份 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V,在质量和速度之间取折中。
变体 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q 头数 <math xmlns="http://www.w3.org/1998/Math/MathML"> K / V K/V </math>K/V 头数 KV cache 常见取舍
MHA <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 最大 训练最好,推理最慢
GQA <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h <math xmlns="http://www.w3.org/1998/Math/MathML"> g g </math>g 中等 质量接近 MHA,推理显著更快
MQA <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 1 最小 最省显存,但更容易掉点

这也是为什么现代大模型常常呈现一个看上去矛盾的趋势:训练时保留较多 query 头,推理时尽量共享 <math xmlns="http://www.w3.org/1998/Math/MathML"> K / V K/V </math>K/V。

3. 训练稳定性的几个注意点

多头本身不神秘,但大模型里它会和训练稳定性强耦合,最常见的注意点有三个。

第一,pre-LN 比 post-LN 更稳。深层模型中,attention 输出会不断写回残差流,post-LN 更容易让梯度方差沿层数积累,pre-LN 在 GPT、LLaMA 这类大模型里已经几乎成为默认选择。

第二,训练前期不同头往往都很像。softmax 输入接近零时,各头的分布都接近均匀,分工是在训练中后期逐渐拉开的。不要拿训练早期的注意力图去解释模型行为。

第三, <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W^O </math>WO 的初始化值得认真对待。GPT-2 之后很常见的做法,是按层数缩小 <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W^O </math>WO 的初始方差,减少 attention 输出反复写回残差流时的方差放大。这不是多头独有的数学性质,但它直接影响多头模块在深层网络里的稳定性。

4. 自注意力和交叉注意力有什么不同

多头机制不仅用于 self-attention,也同样用于 cross-attention。形式上两者完全一样,差别只在来源:

  • self-attention 里, <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K、 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 都来自同一份输入。
  • cross-attention 里, <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q 来自 decoder 当前状态, <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 来自 encoder 输出。

从多头的角度看,变化不在公式,而在任务含义上。self-attention 更像序列内部关系建模,cross-attention 更像目标序列对源序列做可寻址检索。很多翻译和多模态模型里的对齐能力,靠的正是 cross-attention 中不同头的分工。


五、工程实现:一次大矩阵乘法加 reshape

1. 为什么生产代码不是 for 循环

概念上,多头好像就是把 attention 跑 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 次,然后把结果拼起来。但生产代码从不会真的写一个 for 循环。

原因很简单:GPU 喜欢一次大矩阵乘法,不喜欢很多次小矩阵乘法。真正高效的实现会先用一到三次大 GEMM 一次性算出全部头的 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q、 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K、 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V,然后 reshape 成 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( B , h , N , d k ) (B, h, N, d_k) </math>(B,h,N,dk) 的形状,再把头维度当成 batched matmul 的一个批次维来统一处理。

也就是说,工程实现的骨架其实是:

python 复制代码
# X: (B, N, D)
qkv = X @ W_qkv
q, k, v = split_and_reshape(qkv)   # (B, h, N, d_k)
scores = (q @ k.transpose(-2, -1)) / sqrt(d_k)
attn = softmax(scores, dim=-1)
out = attn @ v
out = merge_heads(out) @ W_o

核心思想不是把一个头复制很多次,而是把所有头并进同一套张量运算里。

2. 一份完整的 PyTorch 实现

下面是一份简洁但已经接近生产习惯的 PyTorch 写法:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, 'd_model 必须能被 num_heads 整除'
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        qkv = self.W_qkv(x)
        qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        scores = torch.matmul(q, k.transpose(-2, -1)) / sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        out = self.W_o(out)
        return out

这段实现里有几个值得特别注意的点。

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q W^Q </math>WQ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W K W^K </math>WK、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W V W^V </math>WV 被合成了一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> W q k v W_{qkv} </math>Wqkv,这是为了减少 GEMM 次数。
  • viewpermute 的顺序不能错,错了通常不会立刻报错,但模型会学不起来。
  • contiguous() 不是多余的,它是在张量转置之后为后续 reshape 和 matmul 保底。
  • 真正上 GPU 跑大模型时,通常会直接调用 F.scaled_dot_product_attention 或者底层 fused kernel,而不是自己手写 softmax。

3. 最容易踩的几个坑

多头实现里最常见的坑,基本都不是理论错误,而是张量细节错误。

第一, <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l d_{model} </math>dmodel 和头数不整除。这是最简单也最常见的 bug。

第二,reshape 顺序错。把 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( B , N , h , d k ) (B, N, h, d_k) </math>(B,N,h,dk) 写成 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( B , h , N , d k ) (B, h, N, d_k) </math>(B,h,N,dk),代码可能照样能跑,但 token 维和 head 维已经被弄乱。

第三,mask 形状或 dtype 不对。实践里最好显式把 mask 写成带 head 维的 bool tensor,不要依赖隐式 broadcast。

第四,不要轻易删掉 <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W^O </math>WO。concat 之后虽然已经回到 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l d_{model} </math>dmodel 维,但没有 <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W^O </math>WO,各头之间就失去了重新混合和重新写入残差流的机会。


六、把答案收回到核心问题

如果把整篇内容压缩成一句话,那么 Multi-Head Attention 的作用就是:把一次 attention 从单一 softmax 升级成多组并行 softmax,让模型在同一步里同时建模多种关系。

它真正厉害的地方不在于公式有多复杂,而在于设计非常节制:参数量基本不变,计算结构仍然适合大矩阵乘法,表达力却从一组关系扩展成了一组子空间里的并行关系。后续从 BERT、GPT 到 LLaMA,再到 GQA、MQA 和 FlashAttention,本质上都仍然在围绕这个设计继续打磨。


关键概念回顾

  • 多头的本质不是把维度切碎,而是给不同子空间各自保留一份独立的 softmax 分布。
  • 在 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = d m o d e l / h d_k = d_{model} / h </math>dk=dmodel/h 的标准设置下,多头几乎不比单头多参数;它换来的主要是并行建模不同关系的能力。
  • 不同头确实会分工,但注意力图只能作为线索,不能直接当成因果解释。
  • 训练喜欢保留较多独立头,推理则更关心 KV cache,所以现代大模型才会大量使用 GQA 和 MQA。
  • 真正高效的实现一定是大矩阵乘法加 reshape,而不是 for 循环跑 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 次 attention。

常见误解

  • 误解一:头越多越好。错。头数必须和每头维度一起看,单头太瘦会直接掉表达力。
  • 误解二:多头比单头多很多参数。错。标准配置下参数量几乎等价。
  • 误解三:一个漂亮的注意力图就等于模型学会了句法。错。可视化只能给出相关性线索,不是因果证明。
  • 误解四:把同一个 attention 跑 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 次再平均就是多头。错。多头的关键是每个头有自己独立的投影矩阵。
  • 误解五:推理时继续保留完整 MHA 一定最好。错。部署场景里,GQA 和 MQA 往往是更合理的工程折中。

下一步

参考文献

  • Vaswani A., Shazeer N., Parmar N., Uszkoreit J., Jones L., Gomez A. N., Kaiser L., Polosukhin I. Attention Is All You Need. NeurIPS 2017.
  • Clark K., Khandelwal U., Levy O., Manning C. D. What Does BERT Look At? An Analysis of BERT's Attention. EMNLP 2019.
  • Voita E., Talbot D., Moiseev F., Sennrich R., Titov I. Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned. ACL 2019.
  • Michel P., Levy O., Neubig G. Are Sixteen Heads Really Better than One? NeurIPS 2019.
  • Jain S., Wallace B. C. Attention is not Explanation. NAACL 2019.
  • Tenney I., Das D., Pavlick E. BERT Rediscovers the Classical NLP Pipeline. ACL 2019.
  • Shazeer N. Fast Transformer Decoding: One Write-Head is All You Need. 2019.
  • Ainslie J., Lee-Thorp J., de Jong M., Zemlyanskiy Y., Lebrón F., Sanghai S. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023.
  • Xiao G., Tian Y., Chen B., Han S., Lewis M. Efficient Streaming Language Models with Attention Sinks. ICLR 2024.
  • Xiong R., Yang Y., He D., Zheng K., Zheng S., Xing C., Zhang H., Lan Y., Wang L., Liu T. On Layer Normalization in the Transformer Architecture. ICML 2020.

← 上一篇:15. Scaled Dot-Product | 下一篇:17. Causal Mask

相关推荐
ltl2 小时前
Scaled Dot-Product:那个根号 d_k 是怎么来的'
后端
折哥的程序人生 · 物流技术专研4 小时前
《Java 100 天进阶之路》第17篇:Java常用包装类与自动装箱拆箱深入
java·开发语言·后端·面试
IT_陈寒4 小时前
为什么Java的Stream并行处理反而变慢了?
前端·人工智能·后端
孙6903425 小时前
swf 图片转 pdf
java·后端
长安不见5 小时前
从CompletionService的一个错误用法谈起
后端
空山返景6 小时前
Dify RAG知识库-自部署完整指南
后端
苏三的开发日记6 小时前
如何规避死锁
后端
该用户已不存在6 小时前
用 Claude Code Agents 与 CI/CD 搭建自动化研发团队(Part 3)
后端·ai编程·claude
豹哥学前端7 小时前
agent智能体经典范式构建
人工智能·后端