手写Multi-Head Attention多头注意力机制,Pytorch实现与原理详解

前言

注意力机制是Transformer架构的核心组件,而**多头注意力(Multi-Head Attention)**更是其中的精髓------通过将特征映射到多个头进行并行的注意力计算,能让模型同时捕捉不同维度、不同尺度的特征关联,相比单头注意力拥有更强的特征表达能力。

本文将从原理出发,结合纯Pytorch实现的多头注意力代码(无高级框架封装),逐段拆解代码+对应核心公式,让公式落地代码、代码印证公式,彻底打通原理与实现的壁垒,代码可直接嵌入Transformer、ViT、LLM等相关模型中使用。

一、多头注意力核心原理与核心公式

在进入代码前,先明确多头注意力的核心步骤与对应公式,为代码解析做精准铺垫,后续每段代码都会对应到具体公式环节。

多头自注意力(Q/K/V由同一输入生成)是最常用的形式,核心执行流程 :线性投影→分头发→尺度缩放→注意力权重计算→加权求和→特征拼接→输出投影,对应的核心公式如下:

整体多头注意力公式

MultiHead(Q,K,V)=Concat(head1,head2,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(head_1,head_2,...,head_h)W^OMultiHead(Q,K,V)=Concat(head1,head2,...,headh)WO

单个注意力头计算公式

headi=Attention(QWiQ,KWiK,VWiV)=Softmax(QWiQ(KWiK)Tdk)VWiVhead_i = \text{Attention}(QW_i^Q,KW_i^K,VW_i^V) = \text{Softmax}\left(\frac{QW_i^Q (KW_i^K)^T}{\sqrt{d_k}}\right)VW_i^Vheadi=Attention(QWiQ,KWiK,VWiV)=Softmax(dk QWiQ(KWiK)T)VWiV

符号说明

  • hhh:注意力头的数量,是多头注意力的核心超参数;
  • dkd_kdk:单个注意力头的特征维度,决定了单头的特征表达能力;
  • QWiQ/KWiK/VWiVQW_i^Q/ KW_i^K/ VW_i^VQWiQ/KWiK/VWiV:输入分别映射到第iii个头的查询/键/值特征;
  • Concat\text{Concat}Concat:将所有注意力头的输出特征进行拼接;
  • WOW^OWO:拼接后特征的输出投影矩阵,用于将拼接维度映射回原输入维度。

关键设计要点 :引入尺度因子1dk\frac{1}{\sqrt{d_k}}dk 1,缓解dkd_kdk较大时QK点积结果数值过大,导致Softmax函数梯度消失的问题。

二、完整Pytorch代码实现

本文实现的多头自注意力基于Pytorch和einops的rearrange(特征维度重排,比手动view更直观,避免维度混乱),先安装依赖:

bash 复制代码
pip install torch einops

完整可复用代码:

python 复制代码
import torch
import torch.nn as nn
from einops import rearrange

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

输入输出维度约定 :输入特征形状为[batch_size,seq_len,dim][batch\_size, seq\_len, dim][batch_size,seq_len,dim](记为[b,n,d][b, n, d][b,n,d]),输出与输入维度完全一致,可无缝嵌入各类模型。

三、代码逐段拆解 + 公式对应详解

3.1 初始化函数__init__:定义超参数与网络层

初始化函数的核心是为后续注意力计算准备超参数可学习层,每段代码对应多头注意力的前置设计,无直接公式但为公式落地做铺垫。

代码段1:核心超参数计算
python 复制代码
inner_dim = dim_head *  heads
project_out = not (heads == 1 and dim_head == dim)
  • inner_dim:所有注意力头的总特征维度,即h×dkh \times d_kh×dk(hhh为heads,dkd_kdk为dim_head),是Q/K/V各自的投影维度;
  • project_out:判定是否需要最后的输出投影层WOW^OWO:仅当单头且单头维度=输入维度 (h=1h=1h=1且dk=dimd_k=dimdk=dim)时,拼接后维度与输入一致,无需WOW^OWO,否则必须通过WOW^OWO映射回原维度。
代码段2:注意力基础超参数保存
python 复制代码
self.heads = heads
self.scale = dim_head ** -0.5
  • self.heads:保存注意力头数hhh,为后续分头发计算做准备;
  • self.scale:尺度缩放因子,对应公式中的1dk\frac{1}{\sqrt{d_k}}dk 1,用dim_head ** -0.5等价实现,避免浮点数除法,计算更高效。
代码段3:注意力权重计算与QKV投影层
python 复制代码
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
  • self.attend:Softmax层,对应公式中的Softmax(⋅)\text{Softmax}(\cdot)Softmax(⋅),指定dim=-1是为了对每个位置的相似度维度做归一化,得到合法的注意力权重;
  • self.to_qkv:QKV联合投影层,是工业界高效实现方案 :通过1个线性层替代3个独立线性层,直接将输入映射为3×innerdim3 \times inner_dim3×innerdim维度(Q/K/V拼接),再通过拆分得到各自特征,减少层的数量与计算开销;
  • bias=False:Transformer经典设计,实践中偏置对注意力计算的效果提升有限,还会增加参数量,因此省略。
代码段4:输出投影层定义
python 复制代码
self.to_out = nn.Sequential(
    nn.Linear(inner_dim, dim),
    nn.Dropout(dropout)
) if project_out else nn.Identity()
  • 当需要投影时,nn.Linear(inner_dim, dim)就是公式中的输出投影矩阵WOW^OWO ,将拼接后的h×dkh \times d_kh×dk维度映射回原输入维度dimdimdim;
  • nn.Dropout(dropout):为投影后的特征添加正则化,防止过拟合,是工程实践的必要补充;
  • 无需投影时,用nn.Identity()(恒等映射)替代,保证代码逻辑统一,输入输出维度一致。

3.2 前向传播forward:核心注意力计算(代码+公式一一对应)

前向传播是多头注意力的核心执行环节 ,每段代码都直接对应多头注意力的核心公式,是原理落地的关键,输入x形状 :[b,n,dim][b, n, dim][b,n,dim]。

代码段1:QKV线性投影与拆分 → 对应QWiQ、KWiK、VWiVQW_i^Q、KW_i^K、VW_i^VQWiQ、KWiK、VWiV
python 复制代码
qkv = self.to_qkv(x).chunk(3, dim = -1)
  • 执行逻辑:[b,n,dim]→Linear[b,n,3×innerdim]→chunk[b,n,innerdim]×3[b, n, dim] \xrightarrow{\text{Linear}} [b, n, 3 \times inner_dim] \xrightarrow{\text{chunk}} [b, n, inner_dim] \times 3[b,n,dim]Linear [b,n,3×innerdim]chunk [b,n,innerdim]×3;
  • self.to_qkv(x):完成输入到Q/K/V的联合线性投影,对应公式中所有头的QWiQ、KWiK、VWiVQW_i^Q、KW_i^K、VW_i^VQWiQ、KWiK、VWiV拼接;
  • chunk(3, dim=-1):在最后一维将拼接特征拆分为3个等长张量,分别得到Q、K、V,每个张量形状为[b,n,h×dk][b, n, h \times d_k][b,n,h×dk]。
代码段2:分头发维度重排 → 为并行计算head1∼headhhead_1 \sim head_hhead1∼headh做准备
python 复制代码
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
  • 执行逻辑:[b,n,h×dk]→rearrange[b,h,n,dk][b, n, h \times d_k] \xrightarrow{\text{rearrange}} [b, h, n, d_k][b,n,h×dk]rearrange [b,h,n,dk](q/k/v均完成此变换);
  • 维度含义变化:将注意力头维度hhh从特征维度中抽离,成为独立维度,让每个头的Q/K/VQ/K/VQ/K/V特征独立,可通过Pytorch广播机制并行计算所有头的注意力,无需循环遍历,大幅提升效率;
  • 替代实现:若不用einops,可通过view+transpose实现:t.view(b, n, h, d_k).transpose(1, 2),效果一致但可读性差;
  • 对应公式:重排后,q[b,h,:,:]q[b, h, :, :]q[b,h,:,:]就是第bbb个样本的第hhh个头的查询特征QWhQQW_h^QQWhQ,k/v同理,为后续单头注意力计算headihead_iheadi做准备。
代码段3:QK点积 + 尺度缩放 → 对应QWiQ(KWiK)Tdk\frac{QW_i^Q (KW_i^K)^T}{\sqrt{d_k}}dk QWiQ(KWiK)T
python 复制代码
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
  • 执行逻辑:[b,h,n,dk]×[b,h,dk,n]→matmul[b,h,n,n]→×scale[b,h,n,n][b, h, n, d_k] \times [b, h, d_k, n] \xrightarrow{\text{matmul}} [b, h, n, n] \xrightarrow{\times scale} [b, h, n, n][b,h,n,dk]×[b,h,dk,n]matmul [b,h,n,n]×scale [b,h,n,n];
  • k.transpose(-1, -2):将K的最后两个维度交换,形状从[b,h,n,dk][b, h, n, d_k][b,h,n,dk]变为[b,h,dk,n][b, h, d_k, n][b,h,dk,n],满足矩阵乘法"前一个最后一维=后一个倒数第二维"的要求;
  • torch.matmul(q, k.transpose(-1, -2)):计算QK点积,对应公式中的QWiQ(KWiK)TQW_i^Q (KW_i^K)^TQWiQ(KWiK)T,结果形状为[b,h,n,n][b, h, n, n][b,h,n,n],其中dots[b,h,i,j]dots[b, h, i, j]dots[b,h,i,j]表示第b个样本、第h个头中,第i个位置对第j个位置的特征相似度
  • * self.scale:乘以尺度缩放因子,对应公式中的1dk\frac{1}{\sqrt{d_k}}dk 1,核心作用是缓解点积后数值过大导致的Softmax梯度消失。
  • 对应公式 :此步骤直接实现QWiQ(KWiK)Tdk\frac{QW_i^Q (KW_i^K)^T}{\sqrt{d_k}}dk QWiQ(KWiK)T,得到每个头、每个位置的未归一化相似度。
代码段4:Softmax计算注意力权重 → 对应Softmax(QWiQ(KWiK)Tdk)\text{Softmax}\left(\frac{QW_i^Q (KW_i^K)^T}{\sqrt{d_k}}\right)Softmax(dk QWiQ(KWiK)T)
python 复制代码
attn = self.attend(dots)
  • 执行逻辑:[b,h,n,n]→Softmax[b,h,n,n][b, h, n, n] \xrightarrow{\text{Softmax}} [b, h, n, n][b,h,n,n]Softmax [b,h,n,n];
  • 执行细节:对dots的最后一维(相似度维度)做Softmax,将未归一化的相似度转换为0~1之间的注意力权重,且每行权重之和为1;
  • 结果含义:attn[b,h,i,j]attn[b, h, i, j]attn[b,h,i,j]表示第b个样本、第h个头中,第i个位置对第j个位置的关注程度
  • 对应公式 :此步骤直接实现Softmax(QWiQ(KWiK)Tdk)\text{Softmax}\left(\frac{QW_i^Q (KW_i^K)^T}{\sqrt{d_k}}\right)Softmax(dk QWiQ(KWiK)T),得到合法的注意力权重矩阵。
代码段5:注意力权重与V加权求和 → 对应headi=Softmax(⋅)VWiVhead_i = \text{Softmax}(\cdot)VW_i^Vheadi=Softmax(⋅)VWiV
python 复制代码
out = torch.matmul(attn, v)
  • 执行逻辑:[b,h,n,n]×[b,h,n,dk]→matmul[b,h,n,dk][b, h, n, n] \times [b, h, n, d_k] \xrightarrow{\text{matmul}} [b, h, n, d_k][b,h,n,n]×[b,h,n,dk]matmul [b,h,n,dk];
  • 计算细节:注意力权重矩阵与值特征V相乘,将每个位置的特征按照注意力权重进行加权求和,得到每个注意力头的独立输出;
  • 结果含义:out[b,h,:,:]out[b, h, :, :]out[b,h,:,:]就是第bbb个样本的第hhh个头的注意力输出headhhead_hheadh;
  • 对应公式 :此步骤直接实现单个注意力头的计算headi=Softmax(QWiQ(KWiK)Tdk)VWiVhead_i = \text{Softmax}\left(\frac{QW_i^Q (KW_i^K)^T}{\sqrt{d_k}}\right)VW_i^Vheadi=Softmax(dk QWiQ(KWiK)T)VWiV,最终得到所有头的输出head1∼headhhead_1 \sim head_hhead1∼headh。
代码段6:多头发特征拼接 → 对应Concat(head1,head2,...,headh)\text{Concat}(head_1,head_2,...,head_h)Concat(head1,head2,...,headh)
python 复制代码
out = rearrange(out, 'b h n d -> b n (h d)')
  • 执行逻辑:[b,h,n,dk]→rearrange[b,n,h×dk][b, h, n, d_k] \xrightarrow{\text{rearrange}} [b, n, h \times d_k][b,h,n,dk]rearrange [b,n,h×dk];
  • 执行细节:将注意力头维度hhh重新拼接到特征维度,把所有头的输出head1∼headhhead_1 \sim head_hhead1∼headh按特征维度拼接,得到融合了多头信息的特征;
  • 对应公式 :此步骤直接实现Concat(head1,head2,...,headh)\text{Concat}(head_1,head_2,...,head_h)Concat(head1,head2,...,headh),拼接后特征维度为h×dkh \times d_kh×dk(即inner_dim)。
代码段7:输出投影与返回 → 对应Concat(⋅)WO\text{Concat}(\cdot)W^OConcat(⋅)WO
python 复制代码
return self.to_out(out)
  • 执行逻辑:[b,n,h×dk]→WO[b,n,dim][b, n, h \times d_k] \xrightarrow{W^O} [b, n, dim][b,n,h×dk]WO [b,n,dim](若需要投影);
  • 执行细节:将拼接后的特征输入到self.to_out,若需要投影则通过WOW^OWO(nn.Linear)映射回原输入维度dimdimdim,并添加Dropout;若无需投影则直接返回;
  • 结果:最终输出形状为[b,n,dim][b, n, dim][b,n,dim],与输入维度完全一致;
  • 对应公式 :此步骤直接实现多头注意力的整体公式MultiHead(Q,K,V)=Concat(head1,head2,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(head_1,head_2,...,head_h)W^OMultiHead(Q,K,V)=Concat(head1,head2,...,headh)WO,完成整个多头注意力计算。

四、维度验证:让公式与代码的维度对应更直观

为了让大家更清晰地看到每段代码的维度变化公式中维度的对应关系,这里给出完整的维度验证示例,输入为随机生成的特征,模拟实际使用场景。

验证代码

python 复制代码
# 初始化注意力层:dim=512(输入维度),heads=8(h),dim_head=64(d_k),dropout=0.1
attn = Attention(dim=512, heads=8, dim_head=64, dropout=0.1)
# 构造输入:batch_size=2(b),seq_len=16(n),dim=512
x = torch.randn(2, 16, 512)
# 前向传播并打印关键步骤维度
qkv = attn.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = attn.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * attn.scale
attn_weight = attn.attend(dots)
out = torch.matmul(attn_weight, v)
out_concat = rearrange(out, 'b h n d -> b n (h d)')
out_final = attn.to_out(out_concat)

# 打印关键步骤维度
print(f"输入x维度:{x.shape}")
print(f"Q/K/V各自维度:{q.shape}")
print(f"QK点积后维度:{dots.shape}")
print(f"注意力权重维度:{attn_weight.shape}")
print(f"单头加权求和后维度:{out.shape}")
print(f"多头发拼接后维度:{out_concat.shape}")
print(f"最终输出维度:{out_final.shape}")

输出结果

复制代码
输入x维度:torch.Size([2, 16, 512])
Q/K/V各自维度:torch.Size([2, 8, 16, 64])
QK点积后维度:torch.Size([2, 8, 16, 16])
注意力权重维度:torch.Size([2, 8, 16, 16])
单头加权求和后维度:torch.Size([2, 8, 16, 64])
多头发拼接后维度:torch.Size([2, 16, 512])
最终输出维度:torch.Size([2, 16, 512])

维度结论

所有步骤的维度完全符合公式设计,且输入输出维度一致,验证了代码的正确性与公式的落地性。

五、代码的工程实践亮点

本文实现的多头注意力代码并非简单的原理复现,而是结合了工业界主流的工程实践 ,让公式落地的同时,保证代码的高效性、可复用性、可读性,核心亮点如下:

  1. QKV联合投影:1个线性层替代3个独立线性层,大幅提升计算效率,是Transformer相关模型的标准设计;
  2. 维度一致性:输入输出维度完全一致,可无缝嵌入Transformer、ViT、LLM等模型,无需额外维度转换;
  3. 可选输出投影 :通过project_out动态判定是否需要WOW^OWO,减少单头场景下的不必要计算;
  4. 无偏置设计:符合Transformer经典实现,减少参数量且不影响效果;
  5. einops维度重排 :让维度变换更直观,避免手动view+transpose的维度混乱,提升代码可读性;
  6. 模块化封装 :基于nn.Module实现,支持Pytorch所有原生操作(求导、保存、加载、分布式训练等);
  7. 超参数可配置:支持自定义头数、单头维度、dropout概率,适配不同模型的需求。

六、拓展:从自注意力到交叉注意力(公式+代码修改)

本文实现的是自注意力(Self-Attention) (Q/K/V由同一输入生成),而Transformer解码器中常用交叉注意力(Cross-Attention)(Q来自解码器,K/V来自编码器),核心公式不变,仅需少量代码修改即可实现。

交叉注意力核心公式

CrossAttention(Q,K,V)=Concat(head1,...,headh)WO\text{CrossAttention}(Q,K,V) = \text{Concat}(head_1,...,head_h)W^OCrossAttention(Q,K,V)=Concat(head1,...,headh)WO
headi=Softmax(QdecWiQ(KencWiK)Tdk)VencWiVhead_i = \text{Softmax}\left(\frac{Q_{dec}W_i^Q (K_{enc}W_i^K)^T}{\sqrt{d_k}}\right)V_{enc}W_i^Vheadi=Softmax(dk QdecWiQ(KencWiK)T)VencWiV

(仅Q/K/V的输入源不同,单头计算逻辑与自注意力完全一致)

代码修改要点

  1. 初始化函数中,将to_qkv拆分为to_qto_kv两个线性层;
  2. 前向传播函数中,分别输入Q的特征(解码器)和K/V的特征(编码器),再分别投影、分头发计算;
  3. 其余计算逻辑(尺度缩放、Softmax、加权求和、拼接、投影)与自注意力完全一致。

七、总结

本文通过代码逐段拆解+核心公式一一对应的方式,实现了多头注意力机制的原理落地,让公式不再是抽象的数学表达,代码不再是孤立的代码块,核心要点总结:

  1. 多头注意力的核心公式分为单头计算整体拼接投影两部分,代码的前向传播完全按照公式的执行步骤实现;
  2. 代码中self.scale对应1dk\frac{1}{\sqrt{d_k}}dk 1,Softmax对应Softmax(⋅)\text{Softmax}(\cdot)Softmax(⋅),to_out中的线性层对应输出投影矩阵WOW^OWO;
  3. 前向传播的7个核心代码段,分别对应多头注意力公式的7个核心环节,从QKV投影到最终输出一一落地;
  4. 工业界的高效实现技巧(QKV联合投影、无偏置、可选投影)让公式落地的同时,保证代码的计算效率与可复用性;
  5. 输入输出维度一致的设计,让代码可无缝嵌入各类基于Transformer的模型。

本文的代码可直接作为Transformer、ViT、Swin-Transformer、LLM等模型的基础组件,大家可以在此基础上根据任务需求拓展窗口注意力、轴向注意力、交叉注意力等变体。


点赞+收藏,让更多人打通多头注意力的原理与实现壁垒!如果有公式理解或代码修改的疑问,欢迎在评论区留言交流~

相关推荐
Gavin在路上2 小时前
SpringAIAlibaba之深度剖析序列化过程中LinkedHashMap类型转换异常(十)
人工智能
wfeqhfxz25887822 小时前
击剑运动员姿态识别与关键部位检测_YOLOv26模型应用与优化
人工智能·yolo·目标跟踪
克里斯蒂亚诺更新2 小时前
vue展示node express调用python解析tdms
服务器·python·express
idwangzhen2 小时前
2026郑州GEO优化哪个平台靠谱
python·信息可视化
OpenCSG2 小时前
OpenCSG(开放传神)开源数据贡献解析:3大标杆数据集,筑牢中文AI基建
人工智能·开源
国产化创客2 小时前
RK3588平台基于RKNN-SDK的NPU加速推理与YOLOv5模型部署全流程
人工智能·边缘计算·智能硬件
CHrisFC2 小时前
江苏硕晟 LIMS 系统:加速环境检测机构合规化进程的利器
大数据·人工智能
SEO_juper2 小时前
Query Fan-Out:AI搜索时代,内容如何突破“隐形壁垒”被引用?
人工智能·ai·seo·数字营销
Wilber的技术分享2 小时前
【Transformer原理详解2】Decoder结构解析、Decoder-Only结构中的Decoder
人工智能·笔记·深度学习·llm·transformer