深度学习进阶(二)多头自注意力机制(Multi-Head Attention)

第一篇中,我们已经得到了自注意力的核心公式:

\\\mathrm{Attention}(\\mathbf{Q},\\mathbf{K},\\mathbf{V})=\\mathrm{softmax}\\left(\\frac{ \\mathbf{Q}\\mathbf{K}\^T}{\\sqrt{d_k}}\\right) \\mathbf{V} \\

再概述一下自注意力的本质:通过一次全局加权,将序列中的所有信息重新融合到每一个位置上,最终强化信息表示。

但单头的自注意力还是有些局限:一组 \((\mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V)\) 只能用一种方式去理解序列

这其实和在卷积层中使用多个卷积核是相似的道理,我们不能只用一个卷积核去提取纹理、色彩、形状等所有特征。

同理,我们不能指望一组参数矩阵就能学习到序列在语义、语法、情感等多个方面的关联。

因此,我们在实际算法设计中,使用的往往是多头自注意力

其原理并不复杂,只是在单头自注意力基础上的简单改进。

1. 多头注意力的核心思想

多头注意力的核心思想很直观:

在"一层自注意力层"中,不只做一次注意力,而是做多次注意力,每次关注不同的信息子空间。

具体来说就是:不再只用一组 \((\mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V)\), 而是同时学习 \(h\) 组不同的参数矩阵,这样的每一组参数矩阵就是一个"头",综合所有头的注意力信息,得到最终输出。

比如,第 \(t\) 个头为:

\\\mathbf{Q}_t = \\mathbf{X} \\mathbf{W}_Q\^{(t)} \\

\\\mathbf{K}_t = \\mathbf{X} \\mathbf{W}_K\^{(t)} \\

\\\mathbf{V}_t = \\mathbf{X} \\mathbf{W}_V\^{(t)} \\

然后,每一个头都独立进行注意力计算:

\\\mathbf{Z}_t = \\mathrm{Attention}(\\mathbf{Q}_t, \\mathbf{K}_t, \\mathbf{V}_t) \\

由于每个注意力头拥有独立的初始化参数矩阵 ,所以每一个头都是一个"观察角度",它们分别回答在不同语义空间下相关性问题。

其计算过程和单头自注意力并无区别,但一个新的问题是:

如何融合多头输出?

2. 多头的融合方式

2.1 拼接

通过多头自注意力,现在我们得到了多个输出:

\\\mathbf{Z}_1, \\mathbf{Z}_2, \\dots, \\mathbf{Z}_h \\

而要回答这些信息怎么合在一起,首先要了解 Transformer 对每个头的维度划分设计

\d_k = d_v = \\frac{d}{h} \\

这里的 \(d\) 仍是表示序列中一个位置,或者说一个 token 的特征维度。

举个例子来说明:

假定每个 token进入模型的特征维度:

\d = 8 \\

并且使用:

\h = 2 \\quad (\\text{两个注意力头}) \\

那么每个头的维度为:

\d_k = d_v = \\frac{d}{h} = \\frac{8}{2} = 4 \\

这代表每个头的 Query / Key 向量维度和 Value 输出维度都为 4,分别计算注意力得到输出:

\\\mathbf{Z_1},\\mathbf{Z_2} \\in \\mathbb{R}\^{n \\times 4} \\

明确了维度变化后,我们就能进行多头输出融合的第一步:拼接

\\\mathbf{Z} = \\text{Concat}(\\mathbf{Z}_1, \\mathbf{Z}_2, \\dots, \\mathbf{Z}_h) \\

思路很明确,就是把所有头的输出先直接拼在一起

\\\mathbf{Z} \\in \\mathbb{R}\^{n \\times (h \\cdot d_v=d)} \\

可以发现,拼接后的维度重新恢复到了原始的模型维度 \(d\) 。

这样的切分逻辑可以在固定模型维度 \(d\) 的前提下,使多头不会增加总体计算复杂度的数量级,从而避免因头数增加而产生计算量爆炸问题。

同时,这种让输入输出维度相同的设计也和 Transformer 的后续逻辑相关。

2.2 线性变换

拼接完还没有结束,实际上,在这之后,\(\mathbf{Z}\) 还需要经过一个线性层:

\\\mathbf{Z}_{\\text{final}} = \\mathbf{Z} \\mathbf{W}_O,\\mathbf{W}_O \\in \\mathbb{R}\^{d \\times d} \\

你会发现,这里没有加入偏置

实际上,线性层本身自然是拥有偏置的,但许多 Transformer 实现都会选择关闭偏置,这是因为 Transformer block 的结构中仍存在后续线性变换以及归一化操作,这里的单个线性层中的偏置项对整体表达能力的影响较小,因此在理论公式中经常省略。

现在把整体写成一行如下:

\\\mathrm{MultiHead}(\\mathbf{Q},\\mathbf{K},\\mathbf{V}) = \\text{Concat}(\\mathbf{Z}_1,\\dots,\\mathbf{Z}_h)\\mathbf{W}_O \\

这就是多头注意力的融合公式。

到这里,你可能有这样一个问题:只拼接不行吗?为什么还要再过一个线性层?

我们举个例子来回答这个问题:

假设对于某个位置 \(i\),经过多头注意力后,我们得到了两个头的输出:

\\\mathbf{z}_1\^{(i)} = \[语法,结构,顺序,主谓 \]

\\\mathbf{z}_2\^{(i)} = \[语义,情感,主题,语境 \]

拼接之后得到:

\\\mathbf{z}\^{(i)} = \[语法,结构,顺序,主谓 , 语义,情感,主题,语境 \]

此时,不同头的信息只是被"并排放在一起",但它们之间并没有发生任何交互或关联。

也就是说:有点用,但还不够。

现在再进行融合:

\\\mathbf{z}_{final}\^{(i)} = \\mathbf{z}\^{(i)} \\mathbf{W}_O \\

这步计算实际上是对拼接后的所有特征进行一次"重新加权组合"。

假设经过学习后,\(\mathbf{W}_O\) 做出的组合类似于:

  1. 把"主谓" + "语义"组合,得到更准确的句法语义关系。
  2. 把"结构" + "语境"组合,得到更高层次的上下文理解。
  3. 把"情感"适当放大或抑制。

最终形成新的表示:

\\\mathbf{z}_{final}\^{(i)} = \[综合特征_1, 综合特征_2, \\dots \]

由此,所有头的信息被打散并重新组合,模型可以自由地学习跨多头的特征关系。

这就是多头自注意力机制的详细内容,我们由此实现了从多个角度对输入信息的强化表示,从模型整体角度来说,多头注意力本质上是一个用于建模序列内部关系的计算模块。

因此,在 Transformer 中,多头注意力并不是单独使用的,而是被嵌入到一个更完整的结构单元中,这个单元就是 Transformer Block,我们在下一篇中再对其展开介绍。

相关推荐
垚森16 小时前
我用 GLM-5.2 造了个炸裂主题后台:16 套主题随心切,可在线体验
ai·react
doiito20 小时前
【Agent Harness】Gliding Horse 工具结果压缩体系:如何用“指针”驯服上下文膨胀
ai·rust·架构设计·系统设计·ai agent
Lihua奏2 天前
# 机器学习:机器是怎么从数据里学出规则的
机器学习
饼干哥哥2 天前
用AI全自动剪辑,日更 100条爆款视频——HyperFrames、Remotion、Git使用入门
人工智能·机器学习·ai编程
doiito2 天前
【Agent Harness】Gliding Horse 上下文动态感知与智能压缩:让 Agent 真正“听得进”每一句话
ai·rust·架构设计·系统设计·ai agent
探索云原生3 天前
K8s 1.36 这个 GA 特性,把 initContainer 拉模型的 hack 干掉了
ai·云原生·kubernetes
Zy宇3 天前
从养 OpenClaw 到养社区 AI:一套 Multi-Agent 社区的设计思路
人工智能·ai
魏祖潇3 天前
我在飞书里养了个“分身”——私聊喊它办事,群里 @ 它干活,还能替我传话
人工智能·机器学习