理解多头注意力机制:像乐队合奏一样处理信息

一、为什么需要多个"注意力头"?

想象你正在参加一场交响乐演出,每个乐手都专注乐谱的不同部分------小提琴组负责主旋律,打击乐把控节奏,铜管组强调高潮段落。这种分工协作的方式,正是多头注意力机制的核心思想。

传统单头注意力就像只有一位听众在欣赏音乐,只能从一个角度理解整个演奏。而多头注意力让多个"虚拟听众"(头)同时工作,每个头都能:

  1. 捕捉不同距离的关联(如主歌与副歌的关系)
  2. 关注不同类型的特征(旋律、节奏、和声)
  3. 组合多种理解方式形成全面认知


图1 多头注意力:多个头连结然后线性变换

二、多头注意力如何运作?

2.1 核心计算步骤

假设我们要处理一句歌词:"雨下整夜,我的爱溢出就像雨水"。每个词的表示向量都要与其它词产生关联,具体分为三步:

步骤1:创建多重视角

为每个头创建独立视角


  • 头1_查询 = 线性变换(原始查询) <math xmlns="http://www.w3.org/1998/Math/MathML"> W q ( 1 ) Q \boxed{W_q^{(1)}Q} </math>Wq(1)Q
  • 头1_键 = 线性变换(原始键) <math xmlns="http://www.w3.org/1998/Math/MathML"> W k ( 1 ) K \boxed{W_k^{(1)}K} </math>Wk(1)K
  • 头1_值 = 线性变换(原始值) <math xmlns="http://www.w3.org/1998/Math/MathML"> W v ( 1 ) V \boxed{W_v^{(1)}V} </math>Wv(1)V

  • 头2_查询 = 线性变换(原始查询) <math xmlns="http://www.w3.org/1998/Math/MathML"> W q ( 2 ) Q \boxed{W_q^{(2)}Q} </math>Wq(2)Q
  • ...(共h个头)

数学表达式(每个头i的计算):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> head i = Attention ( W q ( i ) Q , W k ( i ) K , W v ( i ) V ) \text{head}_i = \text{Attention}(W_q^{(i)}Q, W_k^{(i)}K, W_v^{(i)}V) </math>headi=Attention(Wq(i)Q,Wk(i)K,Wv(i)V)

步骤2:并行注意力计算

每个头独立进行注意力计算(以缩放点积注意力为例):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V </math>Attention(Q,K,V)=softmax(dk QKT)V

步骤3:合并所有结果

将各头的输出拼接后做最终变换:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> MultiHead = W o [ head 1 ; head 2 ; . . . ; head h ] \text{MultiHead} = W_o[\text{head}_1; \text{head}_2; ...; \text{head}_h] </math>MultiHead=Wo[head1;head2;...;headh]

2.2 维度变换图解

假设原始维度d=64,使用8个头:

  1. 每个头的维度变为64/8=8
  2. 各头计算结果拼接后恢复64维
  3. 最终线性变换保持维度一致

三、亲手搭建迷你多头注意力

3.1 简化版实现(使用PyTorch)

python 复制代码
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""

    # 输入X的形状: (batch_size, 查询或者"键-值"对的个数, num_hiddens)
    # 输出X的形状: (batch_size, 查询或者"键-值"对的个数, num_heads,num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # 输出X的形状: (batch_size, num_heads, 查询或者"键-值"对的个数, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    # 最终输出的形状: (batch_size*num_heads,查询或者"键-值"对的个数, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""

    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)


class MultiHeadAttention(nn.Module):
    """多头注意力"""

    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状:
        # (batch_size,查询或者"键-值"对的个数,num_hiddens)

        # valid_lens 的形状:
        # (batch_size,)或(batch_size, 查询的个数)

        # 经过变换后,输出的queries, keys, values 的形状:
        # (batch_size*num_heads, 查询或者"键-值"对的个数, num_hiddens/num_heads)

        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形状: (batch_size*num_heads, 查询的个数, num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size, 查询的个数, num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

下面使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。多头注意力输出的形状是(batch_size, num_queries, num_hiddens)

python 复制代码
import torch

import d2l

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(key_size=num_hiddens, query_size=num_hiddens, value_size=num_hiddens,
                                   num_hiddens=num_hiddens, num_heads=num_heads, dropout=0.5)
attention.eval()

print(attention)
text 复制代码
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
python 复制代码
batch_size, num_queries = 2, 4  # num_queries: 查询的个数
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))  # queries
print(X.shape)
# torch.Size([2, 4, 100]) (batch_size, num_queries, num_hiddens)

Y = torch.ones((batch_size, num_kvpairs, num_hiddens))  # keys, values
print(Y.shape)
# torch.Size([2, 6, 100]) (batch_size, num_kvpairs, num_hiddens)

print(attention(X, Y, Y, valid_lens).shape)
# torch.Size([2, 4, 100]) (batch_size, num_queries, num_hiddens)

3.2 关键技巧解析

  1. 维度拆分:将64维拆分为8个8维头
python 复制代码
# 伪代码
q = q.view(batch_size, seq_len, 8, 8).transpose(1,2)
  1. 并行计算:利用矩阵运算同时处理所有头
  2. 结果融合:拼接后通过线性层整合信息

四、实际应用示例:歌词情感分析

假设分析周杰伦《七里香》歌词的情感:

python 复制代码
歌词 = ["窗外的麻雀", "在电线杆上多嘴", 
       "你说这一句", "很有夏天的感觉"]

# 创建词向量(假设已编码)
词向量 = torch.randn(4, 64)  # 4个词,每个64维

# 使用迷你多头注意力
注意力输出 = MiniMultiHead()(词向量, 词向量, 词向量)

print("每个词的新表示维度:", 注意力输出.shape)
# 输出: torch.Size([4, 64])

此时每个词的表示都融合了:

  • "麻雀"与"电线杆"的位置关系(空间头)
  • "多嘴"与"感觉"的情感关联(语义头)
  • "夏天"与整句的意境联系(语境头)

五、技术要点总结

关键概念 类比解释 数学表达
线性投影 给每个头配不同颜色的眼镜 <math xmlns="http://www.w3.org/1998/Math/MathML"> W q ( i ) Q W_q^{(i)}Q </math>Wq(i)Q
头拼接 乐队各声部录音的合并 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ h e a d 1 ; . . . ; h e a d h ] [head_1;...;head_h] </math>[head1;...;headh]
缩放点积 计算词语间的匹配分数 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T d k \frac{QK^T}{\sqrt{d_k}} </math>dk QKT
最终投影 指挥家统一协调各声部 <math xmlns="http://www.w3.org/1998/Math/MathML"> W o W_o </math>Wo

多头注意力的三大优势

  1. 并行处理:多个头同时计算,效率提升
  2. 多样化关注:捕获词语间的不同类型关系
  3. 强大表征:通过线性变换组合复杂特征

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 多头注意力 = 多个视角 + 并行计算 + 智能融合 \text{多头注意力} = \text{多个视角} + \text{并行计算} + \text{智能融合} </math>多头注意力=多个视角+并行计算+智能融合

理解多头注意力机制,就像学会用多种角度欣赏音乐。当每个"头"专注不同声部,最终合奏出的,便是深度学习最动人的智能交响。

相关推荐
大G哥29 分钟前
19_大模型微调和训练之-基于LLamaFactory+LoRA微调LLama3
人工智能·pytorch·python·深度学习·计算机视觉
njsgcs1 小时前
vison transformer vit 论文阅读
人工智能·深度学习·transformer
果冻人工智能2 小时前
AI能否取代软件架构师?我将4个大语言模型进行了测试
大数据·人工智能·深度学习·语言模型·自然语言处理·ai员工
EDPJ2 小时前
(2025,AR,NAR,GAN,Diffusion,模型对比,数据集,评估指标,性能对比)文本到图像生成和编辑:综述
深度学习·生成对抗网络·计算机视觉
背太阳的牧羊人2 小时前
[CLS] 向量是 BERT 类模型中一个特别重要的输出向量,它代表整个句子或文本的全局语义信息
人工智能·深度学习·bert
ayiya_Oese2 小时前
[数据处理] 6. 数据可视化
人工智能·pytorch·python·深度学习·机器学习·信息可视化
老艾的AI世界3 小时前
AI制作祝福视频,直播礼物收不停,广州塔、动态彩灯、LED表白(附下载链接)
图像处理·人工智能·深度学习·神经网络·目标检测·机器学习·ai·ai视频·ai视频生成·ai视频制作
Blossom.1183 小时前
虚拟现实(VR)与增强现实(AR)在教育领域的应用:开启沉浸式学习新时代
人工智能·深度学习·学习·机器学习·ar·制造·vr
一只安3 小时前
GoWeb开发(基础)
深度学习·学习
慕婉03075 小时前
机器学习实战:6种数据集划分方法详解与代码实现
人工智能·深度学习·机器学习·数据集划分