前言
注意力机制是Transformer架构的核心组件,而**多头注意力(Multi-Head Attention)**更是其中的精髓------通过将特征映射到多个头进行并行的注意力计算,能让模型同时捕捉不同维度、不同尺度的特征关联,相比单头注意力拥有更强的特征表达能力。
本文将从原理出发,结合纯Pytorch实现的多头注意力代码(无高级框架封装),逐段拆解代码+对应核心公式,让公式落地代码、代码印证公式,彻底打通原理与实现的壁垒,代码可直接嵌入Transformer、ViT、LLM等相关模型中使用。
一、多头注意力核心原理与核心公式
在进入代码前,先明确多头注意力的核心步骤与对应公式,为代码解析做精准铺垫,后续每段代码都会对应到具体公式环节。
多头自注意力(Q/K/V由同一输入生成)是最常用的形式,核心执行流程 :线性投影→分头发→尺度缩放→注意力权重计算→加权求和→特征拼接→输出投影,对应的核心公式如下:
整体多头注意力公式
MultiHead(Q,K,V)=Concat(head1,head2,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(head_1,head_2,...,head_h)W^OMultiHead(Q,K,V)=Concat(head1,head2,...,headh)WO
单个注意力头计算公式
headi=Attention(QWiQ,KWiK,VWiV)=Softmax(QWiQ(KWiK)Tdk)VWiVhead_i = \text{Attention}(QW_i^Q,KW_i^K,VW_i^V) = \text{Softmax}\left(\frac{QW_i^Q (KW_i^K)^T}{\sqrt{d_k}}\right)VW_i^Vheadi=Attention(QWiQ,KWiK,VWiV)=Softmax(dk QWiQ(KWiK)T)VWiV
符号说明
- hhh:注意力头的数量,是多头注意力的核心超参数;
- dkd_kdk:单个注意力头的特征维度,决定了单头的特征表达能力;
- QWiQ/KWiK/VWiVQW_i^Q/ KW_i^K/ VW_i^VQWiQ/KWiK/VWiV:输入分别映射到第iii个头的查询/键/值特征;
- Concat\text{Concat}Concat:将所有注意力头的输出特征进行拼接;
- WOW^OWO:拼接后特征的输出投影矩阵,用于将拼接维度映射回原输入维度。
关键设计要点 :引入尺度因子1dk\frac{1}{\sqrt{d_k}}dk 1,缓解dkd_kdk较大时QK点积结果数值过大,导致Softmax函数梯度消失的问题。
二、完整Pytorch代码实现
本文实现的多头自注意力基于Pytorch和einops的rearrange(特征维度重排,比手动view更直观,避免维度混乱),先安装依赖:
bash
pip install torch einops
完整可复用代码:
python
import torch
import torch.nn as nn
from einops import rearrange
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
输入输出维度约定 :输入特征形状为[batch_size,seq_len,dim][batch\_size, seq\_len, dim][batch_size,seq_len,dim](记为[b,n,d][b, n, d][b,n,d]),输出与输入维度完全一致,可无缝嵌入各类模型。
三、代码逐段拆解 + 公式对应详解
3.1 初始化函数__init__:定义超参数与网络层
初始化函数的核心是为后续注意力计算准备超参数 和可学习层,每段代码对应多头注意力的前置设计,无直接公式但为公式落地做铺垫。
代码段1:核心超参数计算
python
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
inner_dim:所有注意力头的总特征维度,即h×dkh \times d_kh×dk(hhh为heads,dkd_kdk为dim_head),是Q/K/V各自的投影维度;project_out:判定是否需要最后的输出投影层WOW^OWO:仅当单头且单头维度=输入维度 (h=1h=1h=1且dk=dimd_k=dimdk=dim)时,拼接后维度与输入一致,无需WOW^OWO,否则必须通过WOW^OWO映射回原维度。
代码段2:注意力基础超参数保存
python
self.heads = heads
self.scale = dim_head ** -0.5
self.heads:保存注意力头数hhh,为后续分头发计算做准备;self.scale:尺度缩放因子,对应公式中的1dk\frac{1}{\sqrt{d_k}}dk 1,用dim_head ** -0.5等价实现,避免浮点数除法,计算更高效。
代码段3:注意力权重计算与QKV投影层
python
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.attend:Softmax层,对应公式中的Softmax(⋅)\text{Softmax}(\cdot)Softmax(⋅),指定dim=-1是为了对每个位置的相似度维度做归一化,得到合法的注意力权重;self.to_qkv:QKV联合投影层,是工业界高效实现方案 :通过1个线性层替代3个独立线性层,直接将输入映射为3×innerdim3 \times inner_dim3×innerdim维度(Q/K/V拼接),再通过拆分得到各自特征,减少层的数量与计算开销;bias=False:Transformer经典设计,实践中偏置对注意力计算的效果提升有限,还会增加参数量,因此省略。
代码段4:输出投影层定义
python
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
- 当需要投影时,
nn.Linear(inner_dim, dim)就是公式中的输出投影矩阵WOW^OWO ,将拼接后的h×dkh \times d_kh×dk维度映射回原输入维度dimdimdim; nn.Dropout(dropout):为投影后的特征添加正则化,防止过拟合,是工程实践的必要补充;- 无需投影时,用
nn.Identity()(恒等映射)替代,保证代码逻辑统一,输入输出维度一致。
3.2 前向传播forward:核心注意力计算(代码+公式一一对应)
前向传播是多头注意力的核心执行环节 ,每段代码都直接对应多头注意力的核心公式,是原理落地的关键,输入x形状 :[b,n,dim][b, n, dim][b,n,dim]。
代码段1:QKV线性投影与拆分 → 对应QWiQ、KWiK、VWiVQW_i^Q、KW_i^K、VW_i^VQWiQ、KWiK、VWiV
python
qkv = self.to_qkv(x).chunk(3, dim = -1)
- 执行逻辑:[b,n,dim]→Linear[b,n,3×innerdim]→chunk[b,n,innerdim]×3[b, n, dim] \xrightarrow{\text{Linear}} [b, n, 3 \times inner_dim] \xrightarrow{\text{chunk}} [b, n, inner_dim] \times 3[b,n,dim]Linear [b,n,3×innerdim]chunk [b,n,innerdim]×3;
self.to_qkv(x):完成输入到Q/K/V的联合线性投影,对应公式中所有头的QWiQ、KWiK、VWiVQW_i^Q、KW_i^K、VW_i^VQWiQ、KWiK、VWiV拼接;chunk(3, dim=-1):在最后一维将拼接特征拆分为3个等长张量,分别得到Q、K、V,每个张量形状为[b,n,h×dk][b, n, h \times d_k][b,n,h×dk]。
代码段2:分头发维度重排 → 为并行计算head1∼headhhead_1 \sim head_hhead1∼headh做准备
python
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
- 执行逻辑:[b,n,h×dk]→rearrange[b,h,n,dk][b, n, h \times d_k] \xrightarrow{\text{rearrange}} [b, h, n, d_k][b,n,h×dk]rearrange [b,h,n,dk](q/k/v均完成此变换);
- 维度含义变化:将注意力头维度hhh从特征维度中抽离,成为独立维度,让每个头的Q/K/VQ/K/VQ/K/V特征独立,可通过Pytorch广播机制并行计算所有头的注意力,无需循环遍历,大幅提升效率;
- 替代实现:若不用einops,可通过
view+transpose实现:t.view(b, n, h, d_k).transpose(1, 2),效果一致但可读性差; - 对应公式:重排后,q[b,h,:,:]q[b, h, :, :]q[b,h,:,:]就是第bbb个样本的第hhh个头的查询特征QWhQQW_h^QQWhQ,k/v同理,为后续单头注意力计算headihead_iheadi做准备。
代码段3:QK点积 + 尺度缩放 → 对应QWiQ(KWiK)Tdk\frac{QW_i^Q (KW_i^K)^T}{\sqrt{d_k}}dk QWiQ(KWiK)T
python
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
- 执行逻辑:[b,h,n,dk]×[b,h,dk,n]→matmul[b,h,n,n]→×scale[b,h,n,n][b, h, n, d_k] \times [b, h, d_k, n] \xrightarrow{\text{matmul}} [b, h, n, n] \xrightarrow{\times scale} [b, h, n, n][b,h,n,dk]×[b,h,dk,n]matmul [b,h,n,n]×scale [b,h,n,n];
k.transpose(-1, -2):将K的最后两个维度交换,形状从[b,h,n,dk][b, h, n, d_k][b,h,n,dk]变为[b,h,dk,n][b, h, d_k, n][b,h,dk,n],满足矩阵乘法"前一个最后一维=后一个倒数第二维"的要求;torch.matmul(q, k.transpose(-1, -2)):计算QK点积,对应公式中的QWiQ(KWiK)TQW_i^Q (KW_i^K)^TQWiQ(KWiK)T,结果形状为[b,h,n,n][b, h, n, n][b,h,n,n],其中dots[b,h,i,j]dots[b, h, i, j]dots[b,h,i,j]表示第b个样本、第h个头中,第i个位置对第j个位置的特征相似度;* self.scale:乘以尺度缩放因子,对应公式中的1dk\frac{1}{\sqrt{d_k}}dk 1,核心作用是缓解点积后数值过大导致的Softmax梯度消失。- 对应公式 :此步骤直接实现QWiQ(KWiK)Tdk\frac{QW_i^Q (KW_i^K)^T}{\sqrt{d_k}}dk QWiQ(KWiK)T,得到每个头、每个位置的未归一化相似度。
代码段4:Softmax计算注意力权重 → 对应Softmax(QWiQ(KWiK)Tdk)\text{Softmax}\left(\frac{QW_i^Q (KW_i^K)^T}{\sqrt{d_k}}\right)Softmax(dk QWiQ(KWiK)T)
python
attn = self.attend(dots)
- 执行逻辑:[b,h,n,n]→Softmax[b,h,n,n][b, h, n, n] \xrightarrow{\text{Softmax}} [b, h, n, n][b,h,n,n]Softmax [b,h,n,n];
- 执行细节:对
dots的最后一维(相似度维度)做Softmax,将未归一化的相似度转换为0~1之间的注意力权重,且每行权重之和为1; - 结果含义:attn[b,h,i,j]attn[b, h, i, j]attn[b,h,i,j]表示第b个样本、第h个头中,第i个位置对第j个位置的关注程度;
- 对应公式 :此步骤直接实现Softmax(QWiQ(KWiK)Tdk)\text{Softmax}\left(\frac{QW_i^Q (KW_i^K)^T}{\sqrt{d_k}}\right)Softmax(dk QWiQ(KWiK)T),得到合法的注意力权重矩阵。
代码段5:注意力权重与V加权求和 → 对应headi=Softmax(⋅)VWiVhead_i = \text{Softmax}(\cdot)VW_i^Vheadi=Softmax(⋅)VWiV
python
out = torch.matmul(attn, v)
- 执行逻辑:[b,h,n,n]×[b,h,n,dk]→matmul[b,h,n,dk][b, h, n, n] \times [b, h, n, d_k] \xrightarrow{\text{matmul}} [b, h, n, d_k][b,h,n,n]×[b,h,n,dk]matmul [b,h,n,dk];
- 计算细节:注意力权重矩阵与值特征V相乘,将每个位置的特征按照注意力权重进行加权求和,得到每个注意力头的独立输出;
- 结果含义:out[b,h,:,:]out[b, h, :, :]out[b,h,:,:]就是第bbb个样本的第hhh个头的注意力输出headhhead_hheadh;
- 对应公式 :此步骤直接实现单个注意力头的计算headi=Softmax(QWiQ(KWiK)Tdk)VWiVhead_i = \text{Softmax}\left(\frac{QW_i^Q (KW_i^K)^T}{\sqrt{d_k}}\right)VW_i^Vheadi=Softmax(dk QWiQ(KWiK)T)VWiV,最终得到所有头的输出head1∼headhhead_1 \sim head_hhead1∼headh。
代码段6:多头发特征拼接 → 对应Concat(head1,head2,...,headh)\text{Concat}(head_1,head_2,...,head_h)Concat(head1,head2,...,headh)
python
out = rearrange(out, 'b h n d -> b n (h d)')
- 执行逻辑:[b,h,n,dk]→rearrange[b,n,h×dk][b, h, n, d_k] \xrightarrow{\text{rearrange}} [b, n, h \times d_k][b,h,n,dk]rearrange [b,n,h×dk];
- 执行细节:将注意力头维度hhh重新拼接到特征维度,把所有头的输出head1∼headhhead_1 \sim head_hhead1∼headh按特征维度拼接,得到融合了多头信息的特征;
- 对应公式 :此步骤直接实现Concat(head1,head2,...,headh)\text{Concat}(head_1,head_2,...,head_h)Concat(head1,head2,...,headh),拼接后特征维度为h×dkh \times d_kh×dk(即inner_dim)。
代码段7:输出投影与返回 → 对应Concat(⋅)WO\text{Concat}(\cdot)W^OConcat(⋅)WO
python
return self.to_out(out)
- 执行逻辑:[b,n,h×dk]→WO[b,n,dim][b, n, h \times d_k] \xrightarrow{W^O} [b, n, dim][b,n,h×dk]WO [b,n,dim](若需要投影);
- 执行细节:将拼接后的特征输入到
self.to_out,若需要投影则通过WOW^OWO(nn.Linear)映射回原输入维度dimdimdim,并添加Dropout;若无需投影则直接返回; - 结果:最终输出形状为[b,n,dim][b, n, dim][b,n,dim],与输入维度完全一致;
- 对应公式 :此步骤直接实现多头注意力的整体公式MultiHead(Q,K,V)=Concat(head1,head2,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(head_1,head_2,...,head_h)W^OMultiHead(Q,K,V)=Concat(head1,head2,...,headh)WO,完成整个多头注意力计算。
四、维度验证:让公式与代码的维度对应更直观
为了让大家更清晰地看到每段代码的维度变化 与公式中维度的对应关系,这里给出完整的维度验证示例,输入为随机生成的特征,模拟实际使用场景。
验证代码
python
# 初始化注意力层:dim=512(输入维度),heads=8(h),dim_head=64(d_k),dropout=0.1
attn = Attention(dim=512, heads=8, dim_head=64, dropout=0.1)
# 构造输入:batch_size=2(b),seq_len=16(n),dim=512
x = torch.randn(2, 16, 512)
# 前向传播并打印关键步骤维度
qkv = attn.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = attn.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * attn.scale
attn_weight = attn.attend(dots)
out = torch.matmul(attn_weight, v)
out_concat = rearrange(out, 'b h n d -> b n (h d)')
out_final = attn.to_out(out_concat)
# 打印关键步骤维度
print(f"输入x维度:{x.shape}")
print(f"Q/K/V各自维度:{q.shape}")
print(f"QK点积后维度:{dots.shape}")
print(f"注意力权重维度:{attn_weight.shape}")
print(f"单头加权求和后维度:{out.shape}")
print(f"多头发拼接后维度:{out_concat.shape}")
print(f"最终输出维度:{out_final.shape}")
输出结果
输入x维度:torch.Size([2, 16, 512])
Q/K/V各自维度:torch.Size([2, 8, 16, 64])
QK点积后维度:torch.Size([2, 8, 16, 16])
注意力权重维度:torch.Size([2, 8, 16, 16])
单头加权求和后维度:torch.Size([2, 8, 16, 64])
多头发拼接后维度:torch.Size([2, 16, 512])
最终输出维度:torch.Size([2, 16, 512])
维度结论
所有步骤的维度完全符合公式设计,且输入输出维度一致,验证了代码的正确性与公式的落地性。
五、代码的工程实践亮点
本文实现的多头注意力代码并非简单的原理复现,而是结合了工业界主流的工程实践 ,让公式落地的同时,保证代码的高效性、可复用性、可读性,核心亮点如下:
- QKV联合投影:1个线性层替代3个独立线性层,大幅提升计算效率,是Transformer相关模型的标准设计;
- 维度一致性:输入输出维度完全一致,可无缝嵌入Transformer、ViT、LLM等模型,无需额外维度转换;
- 可选输出投影 :通过
project_out动态判定是否需要WOW^OWO,减少单头场景下的不必要计算; - 无偏置设计:符合Transformer经典实现,减少参数量且不影响效果;
- einops维度重排 :让维度变换更直观,避免手动
view+transpose的维度混乱,提升代码可读性; - 模块化封装 :基于
nn.Module实现,支持Pytorch所有原生操作(求导、保存、加载、分布式训练等); - 超参数可配置:支持自定义头数、单头维度、dropout概率,适配不同模型的需求。
六、拓展:从自注意力到交叉注意力(公式+代码修改)
本文实现的是自注意力(Self-Attention) (Q/K/V由同一输入生成),而Transformer解码器中常用交叉注意力(Cross-Attention)(Q来自解码器,K/V来自编码器),核心公式不变,仅需少量代码修改即可实现。
交叉注意力核心公式
CrossAttention(Q,K,V)=Concat(head1,...,headh)WO\text{CrossAttention}(Q,K,V) = \text{Concat}(head_1,...,head_h)W^OCrossAttention(Q,K,V)=Concat(head1,...,headh)WO
headi=Softmax(QdecWiQ(KencWiK)Tdk)VencWiVhead_i = \text{Softmax}\left(\frac{Q_{dec}W_i^Q (K_{enc}W_i^K)^T}{\sqrt{d_k}}\right)V_{enc}W_i^Vheadi=Softmax(dk QdecWiQ(KencWiK)T)VencWiV
(仅Q/K/V的输入源不同,单头计算逻辑与自注意力完全一致)
代码修改要点
- 初始化函数中,将
to_qkv拆分为to_q和to_kv两个线性层; - 前向传播函数中,分别输入Q的特征(解码器)和K/V的特征(编码器),再分别投影、分头发计算;
- 其余计算逻辑(尺度缩放、Softmax、加权求和、拼接、投影)与自注意力完全一致。
七、总结
本文通过代码逐段拆解+核心公式一一对应的方式,实现了多头注意力机制的原理落地,让公式不再是抽象的数学表达,代码不再是孤立的代码块,核心要点总结:
- 多头注意力的核心公式分为单头计算 和整体拼接投影两部分,代码的前向传播完全按照公式的执行步骤实现;
- 代码中
self.scale对应1dk\frac{1}{\sqrt{d_k}}dk 1,Softmax对应Softmax(⋅)\text{Softmax}(\cdot)Softmax(⋅),to_out中的线性层对应输出投影矩阵WOW^OWO; - 前向传播的7个核心代码段,分别对应多头注意力公式的7个核心环节,从QKV投影到最终输出一一落地;
- 工业界的高效实现技巧(QKV联合投影、无偏置、可选投影)让公式落地的同时,保证代码的计算效率与可复用性;
- 输入输出维度一致的设计,让代码可无缝嵌入各类基于Transformer的模型。
本文的代码可直接作为Transformer、ViT、Swin-Transformer、LLM等模型的基础组件,大家可以在此基础上根据任务需求拓展窗口注意力、轴向注意力、交叉注意力等变体。
点赞+收藏,让更多人打通多头注意力的原理与实现壁垒!如果有公式理解或代码修改的疑问,欢迎在评论区留言交流~