Ultralytics:解读TransformerEncoderLayer模块
- 前言
- 相关介绍
-
- [Ultralytics 简介](#Ultralytics 简介)
- 前提条件
- 实验环境
- [TransformerEncoderLayer(Transformer 编码器层)](#TransformerEncoderLayer(Transformer 编码器层))
- 扩展
-
- [多头注意力(Multi-Head Attention)](#多头注意力(Multi-Head Attention))
- [`num_heads` 参数详细作用](#
num_heads参数详细作用) -
-
- [1. **核心作用:将特征空间划分成多个子空间**](#1. 核心作用:将特征空间划分成多个子空间)
- [2. **对模型容量和表达能力的直接影响**](#2. 对模型容量和表达能力的直接影响)
- [3. **对训练和优化的影响**](#3. 对训练和优化的影响)
- [4. **与 `embed_dim` 的严格关系**](#4. 与
embed_dim的严格关系) - [5. **实际使用中的常见配置**](#5. 实际使用中的常见配置)
- [6. **在 TransformerEncoderLayer 中的体现**](#6. 在 TransformerEncoderLayer 中的体现)
- [7. **头的可视化与可解释性**](#7. 头的可视化与可解释性)
- [8. **总结:如何调优 `num_heads`**](#8. 总结:如何调优
num_heads)
-
- 参考文献

前言
- 由于本人水平有限,难免出现错漏,敬请批评改正。
- 更多精彩内容,可点击进入Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏、人工智能混合编程实践专栏或我的个人主页查看
- YOLOs-CPP:一个免费开源的YOLO全系列C++推理库(以YOLO26为例)
- PaddleOCR:Win10上安装使用PPOCRLabel标注工具
- 目标检测:使用自己的数据集微调DEIMv2进行物体检测
- 图像分割:PyTorch从零开始实现SegFormer语义分割
- 图像超分:使用自己的数据集微调Real-ESRGAN-x4plus进行超分重建
- 图像生成:PyTorch从零开始实现一个简单的扩散模型
- Stable Diffusion:使用自己的数据集微调 Stable Diffusion 3.5 LoRA 文生图模型
- 图像超分:使用自己的数据集微调Real-ESRGAN-x2plus进行超分重建
- Anomalib:使用Anomalib 2.1.0训练自己的数据集进行异常检测
- Anomalib:在Linux服务器上安装使用Anomalib 2.1.0
- 人工智能混合编程实践:C++调用封装好的DLL进行异常检测推理
- 人工智能混合编程实践:C++调用封装好的DLL进行FP16图像超分重建(v3.0)
- 隔离系统Python:源码编译3.11.8到自定义目录(含PGO性能优化)
- 在线机的Python环境迁移到离线机上
- Nuitka 将 Python 脚本封装为 .pyd 或 .so 文件
- Ultralytics:使用 YOLO11 进行速度估计
- Ultralytics:使用 YOLO11 进行物体追踪
- Ultralytics:使用 YOLO11 进行物体计数
- Ultralytics:使用 YOLO11 进行目标打码
- 人工智能混合编程实践:C++调用Python ONNX进行YOLOv8推理
- 人工智能混合编程实践:C++调用封装好的DLL进行YOLOv8实例分割
- 人工智能混合编程实践:C++调用Python ONNX进行图像超分重建
- 人工智能混合编程实践:C++调用Python AgentOCR进行文本识别
- 通过计算实例简单地理解PatchCore异常检测
- Python将YOLO格式实例分割数据集转换为COCO格式实例分割数据集
- YOLOv8 Ultralytics:使用Ultralytics框架训练RT-DETR实时目标检测模型
- 基于DETR的人脸伪装检测
- YOLOv7训练自己的数据集(口罩检测)
- YOLOv8训练自己的数据集(足球检测)
- YOLOv5:TensorRT加速YOLOv5模型推理
- YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
- 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
- YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
- YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
- Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
- YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
- 使用Kaggle GPU资源免费体验Stable Diffusion开源项目
- Stable Diffusion:在服务器上部署使用Stable Diffusion WebUI进行AI绘图(v2.0)
- Stable Diffusion:使用自己的数据集微调训练LoRA模型(v2.0)
相关介绍
Ultralytics 简介
Ultralytics 基于多年的计算机视觉和人工智能基础研究,创建了最先进的 (SOTA) YOLO 模型。我们的模型不断更新性能和灵活性,快速、准确且易于使用。他们擅长对象检测、跟踪、实例分割、语义分割、图像分类和姿势估计任务。
前提条件
- 熟悉Python、Pytorch
实验环境
bash
Package Version
------------------------ ------------
Python 3.11.8
absl-py 2.4.0
accelerate 1.13.0
annotated-doc 0.0.4
anyio 4.13.0
calflops 0.3.2
certifi 2026.4.22
charset-normalizer 3.4.7
click 8.3.3
colorama 0.4.6
contourpy 1.3.3
cycler 0.12.1
filelock 3.29.0
flatbuffers 25.12.19
fonttools 4.62.1
fsspec 2026.4.0
grpcio 1.80.0
h11 0.16.0
hf-xet 1.5.0
httpcore 1.0.9
httpx 0.28.1
huggingface_hub 1.14.0
idna 3.15
Jinja2 3.1.6
kiwisolver 1.5.0
Markdown 3.10.2
markdown-it-py 4.2.0
MarkupSafe 3.0.3
matplotlib 3.10.9
mdurl 0.1.2
ml_dtypes 0.5.0
mpmath 1.3.0
networkx 3.6.1
numpy 1.26.4
nvidia-cublas-cu12 12.8.3.14
nvidia-cuda-cupti-cu12 12.8.57
nvidia-cuda-nvrtc-cu12 12.8.61
nvidia-cuda-runtime-cu12 12.8.57
nvidia-cudnn-cu12 9.7.1.26
nvidia-cufft-cu12 11.3.3.41
nvidia-cufile-cu12 1.13.0.11
nvidia-curand-cu12 10.3.9.55
nvidia-cusolver-cu12 11.7.2.55
nvidia-cusparse-cu12 12.5.7.53
nvidia-cusparselt-cu12 0.6.3
nvidia-nccl-cu12 2.26.2
nvidia-nvjitlink-cu12 12.8.61
nvidia-nvtx-cu12 12.8.55
onnx 1.19.0
onnxruntime-gpu 1.26.0
onnxslim 0.1.94
opencv-python 4.6.0.66
packaging 26.2
pillow 12.2.0
pip 24.0
polars 1.40.1
polars-runtime-32 1.40.1
protobuf 7.34.1
psutil 7.2.2
pycocotools 2.0.11
Pygments 2.20.0
pyparsing 3.3.2
python-dateutil 2.9.0.post0
PyYAML 6.0.3
regex 2026.5.9
requests 2.34.1
rich 15.0.0
safetensors 0.7.0
scipy 1.16.0
setuptools 65.5.0
shellingham 1.5.4
six 1.17.0
sympy 1.14.0
tabulate 0.10.0
tensorboard 2.20.0
tensorboard-data-server 0.7.2
tokenizers 0.22.2
torch 2.7.1+cu128
torchaudio 2.7.1+cu128
torchvision 0.22.1+cu128
tqdm 4.67.3
transformers 5.8.1
triton 3.3.1
typer 0.25.1
typing_extensions 4.15.0
ultralytics 8.4.58
ultralytics-thop 2.0.19
urllib3 2.7.0
Werkzeug 3.1.8
TransformerEncoderLayer(Transformer 编码器层)
TransformerEncoderLayer 是标准 Transformer 编码器的核心组件,它由 多头自注意力 (Multi-Head Self-Attention)和 前馈网络 (Feed-Forward Network)组成,并配合残差连接和层归一化。该模块支持 pre-normalization (先归一化再子层)和 post-normalization(先子层再归一化)两种配置,灵活性高,常用于目标检测(如 RT-DETR)或视觉 Transformer 等模型中。
代码实现
python
import cv2
import math
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
class TransformerEncoderLayer(nn.Module):
"""A single layer of the transformer encoder.
This class implements a standard transformer encoder layer with multi-head attention and feedforward network,
supporting both pre-normalization and post-normalization configurations.
Attributes:
ma (nn.MultiheadAttention): Multi-head attention module.
fc1 (nn.Linear): First linear layer in the feedforward network.
fc2 (nn.Linear): Second linear layer in the feedforward network.
norm1 (nn.LayerNorm): Layer normalization after attention.
norm2 (nn.LayerNorm): Layer normalization after feedforward network.
dropout (nn.Dropout): Dropout layer for the feedforward network.
dropout1 (nn.Dropout): Dropout layer after attention.
dropout2 (nn.Dropout): Dropout layer after feedforward network.
act (nn.Module): Activation function.
normalize_before (bool): Whether to apply normalization before attention and feedforward.
"""
def __init__(
self,
c1: int,
cm: int = 2048,
num_heads: int = 8,
dropout: float = 0.0,
act: nn.Module = nn.GELU(),
normalize_before: bool = False,
):
"""Initialize the TransformerEncoderLayer with specified parameters.
Args:
c1 (int): Input dimension.
cm (int): Hidden dimension in the feedforward network.
num_heads (int): Number of attention heads.
dropout (float): Dropout probability.
act (nn.Module): Activation function.
normalize_before (bool): Whether to apply normalization before attention and feedforward.
"""
super().__init__()
# from ...utils.torch_utils import TORCH_1_9
# if not TORCH_1_9:
# raise ModuleNotFoundError(
# "TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)."
# )
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
# Implementation of Feedforward model
self.fc1 = nn.Linear(c1, cm)
self.fc2 = nn.Linear(cm, c1)
self.norm1 = nn.LayerNorm(c1)
self.norm2 = nn.LayerNorm(c1)
self.dropout = nn.Dropout(dropout)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.act = act
self.normalize_before = normalize_before
@staticmethod
def with_pos_embed(tensor: torch.Tensor, pos: torch.Tensor | None = None) -> torch.Tensor:
"""Add position embeddings to the tensor if provided."""
return tensor if pos is None else tensor + pos
def forward_post(
self,
src: torch.Tensor,
src_mask: torch.Tensor | None = None,
src_key_padding_mask: torch.Tensor | None = None,
pos: torch.Tensor | None = None,
) -> torch.Tensor:
"""Perform forward pass with post-normalization.
Args:
src (torch.Tensor): Input tensor.
src_mask (torch.Tensor, optional): Mask for the src sequence.
src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
pos (torch.Tensor, optional): Positional encoding.
Returns:
(torch.Tensor): Output tensor after attention and feedforward.
"""
q = k = self.with_pos_embed(src, pos)
src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
src = src + self.dropout2(src2)
return self.norm2(src)
def forward_pre(
self,
src: torch.Tensor,
src_mask: torch.Tensor | None = None,
src_key_padding_mask: torch.Tensor | None = None,
pos: torch.Tensor | None = None,
) -> torch.Tensor:
"""Perform forward pass with pre-normalization.
Args:
src (torch.Tensor): Input tensor.
src_mask (torch.Tensor, optional): Mask for the src sequence.
src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
pos (torch.Tensor, optional): Positional encoding.
Returns:
(torch.Tensor): Output tensor after attention and feedforward.
"""
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
return src + self.dropout2(src2)
def forward(
self,
src: torch.Tensor,
src_mask: torch.Tensor | None = None,
src_key_padding_mask: torch.Tensor | None = None,
pos: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward propagate the input through the encoder module.
Args:
src (torch.Tensor): Input tensor.
src_mask (torch.Tensor, optional): Mask for the src sequence.
src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
pos (torch.Tensor, optional): Positional encoding.
Returns:
(torch.Tensor): Output tensor after transformer encoder layer.
"""
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
功能
- 特征转换:对输入序列(如特征图展平后的 token 序列)进行自注意力建模,捕获长距离依赖关系,再通过前馈网络进行非线性变换。
- 归一化策略可选 :通过
normalize_before控制使用 pre-norm (先 LayerNorm 再子层)或 post-norm(先子层再 LayerNorm)。pre-norm 通常更稳定,训练收敛更快。 - 位置编码集成 :通过
with_pos_embed静态方法,可将位置编码添加到查询(Q)和键(K)上,使注意力能感知位置信息。
初始化参数
| 参数 | 类型 | 说明 |
|---|---|---|
c1 |
int | 输入特征维度(即每个 token 的向量维度) |
cm |
int | 前馈网络隐藏层维度(默认 2048) |
num_heads |
int | 多头注意力的头数(默认 8) |
dropout |
float | Dropout 概率(默认 0.0) |
act |
nn.Module | 前馈网络的激活函数(默认 nn.GELU()) |
normalize_before |
bool | 是否使用 pre-normalization(默认 False,即 post-norm) |
注意:该模块要求 PyTorch ≥ 1.9,因为
nn.MultiheadAttention使用了batch_first=True。
前向方法
forward_post(后归一化)
- 将位置编码加到输入
src上,得到 Q 和 K。 - 多头自注意力(Q, K, V=src)得到注意力输出
src2。 - 残差连接:
src = src + dropout(src2)。 - 第一次 LayerNorm:
src = norm1(src)。 - 前馈网络(FC1 → Act → Dropout → FC2)得到
src2。 - 残差连接:
src = src + dropout(src2)。 - 第二次 LayerNorm:
return norm2(src)。
forward_pre(前归一化)
- 第一次 LayerNorm:
src2 = norm1(src)。 - 将位置编码加到
src2上,得到 Q 和 K。 - 多头自注意力(Q, K, V=src2)得到
src2。 - 残差连接:
src = src + dropout(src2)。 - 第二次 LayerNorm:
src2 = norm2(src)。 - 前馈网络(FC1 → Act → Dropout → FC2)得到
src2。 - 残差连接:
return src + dropout(src2)。
forward
根据 normalize_before 选择调用 forward_pre 或 forward_post。
使用示例
python
if __name__ == '__main__':
# 构造输入:batch_size=2, seq_len=10, dim=128
batch_size, seq_len, dim = 2, 10, 128
src = torch.randn(batch_size, seq_len, dim)
# 创建编码器层(post-norm 风格)
encoder_layer = TransformerEncoderLayer(
c1=dim,
cm=2048,
num_heads=8,
dropout=0.1,
act=nn.GELU(),
normalize_before=False, # post-norm
)
# 前向传播
with torch.no_grad():
out = encoder_layer(src)
print("输入形状:", src.shape) # [2, 10, 128]
print("post-norm 输出形状:", out.shape) # [2, 10, 128]
# 切换为 pre-norm 风格
encoder_layer_pre = TransformerEncoderLayer(
c1=dim,
cm=2048,
num_heads=8,
dropout=0.1,
act=nn.GELU(),
normalize_before=True,
)
with torch.no_grad():
out_pre = encoder_layer_pre(src)
print("pre-norm 输出形状:", out_pre.shape)
# 演示使用 mask(可选)
# 生成一个 padding mask(假设某些 token 为填充)
key_padding_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
key_padding_mask[0, -2:] = True # 第一个样本的最后两个 token 被 mask
with torch.no_grad():
out_masked = encoder_layer(src, src_key_padding_mask=key_padding_mask)
print("带 mask 的输出形状:", out_masked.shape)
输出示例:
输入形状: torch.Size([2, 10, 128])
post-norm 输出形状: torch.Size([2, 10, 128])
pre-norm 输出形状: torch.Size([2, 10, 128])
带 mask 的输出形状: torch.Size([2, 10, 128])
流程示意图
Post-Normalization(默认)

Pre-Normalization

代码解读
with_pos_embed:静态方法,若位置编码不为空,则将其加到输入张量上(用于 Q 和 K)。__init__:初始化 MHA、FFN 的线性层、LayerNorm 和 Dropout。注意 FFN 采用Linear -> Act -> Dropout -> Linear结构,且dropout参数统一控制所有 Dropout 层。- 版本检查 :若 PyTorch < 1.9,则抛出
ModuleNotFoundError,因为batch_first=True参数在旧版本中不可用。 forward_post和forward_pre:分别实现两种归一化顺序,完全遵循 Transformer 原始设计(Vaswani et al.)和后续改进(pre-norm)。
注意事项
- 输入格式 :
src必须是(B, T, C)形状,其中B为 batch size,T为序列长度(如特征图展平后的 token 数),C为特征维度。 - 位置编码 :需要外部提供位置编码(
pos),可通过三角函数或可学习的位置嵌入生成,并传入forward。 - Mask 使用 :
src_mask:序列内部的注意力掩码(如防止看到未来信息),形状通常为(T, T)。src_key_padding_mask:针对 batch 中不同样本的填充 token 掩码,形状为(B, T),值为True表示该位置被忽略。
- 训练与推理 :该模块包含 Dropout,训练时应设为
train()模式,推理时应设为eval()。 - 内存占用:由于 MHA 的计算复杂度为 O(T²),当序列长度较大时(如高分辨率特征图),显存和计算量会急剧增加,需谨慎使用。
优缺点
优点
- 强大的全局建模能力:自注意力机制能捕获序列中任意两个位置之间的依赖关系,优于卷积的局部感受野。
- 灵活的归一化策略:支持 pre-norm 和 post-norm,pre-norm 在深层网络中更稳定,训练更平滑。
- 标准化接口 :与 PyTorch 官方
TransformerEncoderLayer兼容,易于替换和对比。 - 掩码支持:可处理变长序列,适合检测、分割等需要 padding 的任务。
缺点
- 计算量大:自注意力的复杂度与序列长度平方成正比,对高分辨率特征图不友好。
- 依赖外部位置编码:本身不包含位置信息,需额外添加(如正弦编码或可学习嵌入),增加设计复杂度。
- 训练不稳定(post-norm):post-norm 在深层网络中可能梯度爆炸,需配合学习率预热等技术。
- 对硬件要求高 :需要 PyTorch ≥ 1.9,且 MHA 的
batch_first=True在某些旧硬件上可能不支持。
在 YOLO 系列的 RT-DETR 中,TransformerEncoderLayer 被用于编码器部分,将 CNN 提取的特征图转换为序列并进行注意力建模,从而提升检测精度。使用时建议在深层特征层采用 pre-norm 配置,并注意序列长度的控制(可通过降采样或窗口注意力缓解复杂度)。
扩展
多头注意力(Multi-Head Attention)
nn.MultiheadAttention 是 PyTorch 中实现多头注意力(Multi-Head Attention) 机制的核心模块。它源自经典的Transformer论文《Attention Is All You Need》,是构建各种Transformer架构(如BERT、GPT、ViT等)的基础组件。
简单来说,它的核心思想是:让模型能够从不同的角度(子空间)同时关注输入信息的不同部分。
核心原理:它是如何工作的?
多头注意力机制将整个计算过程分解为几个关键步骤,其核心数学定义如下:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W^O
其中每个注意力头 head_i 的计算是:
head_i = Attention(Q * W_i^Q, K * W_i^K, V * W_i^V)
为了更直观地理解,可以把这个过程拆解为四步:
- 线性投影(Projection) :对于输入的查询(Q)、键(K)和值(V)张量,模块会使用三个独立的线性层(全连接层)将它们分别投影到不同的空间。这里的投影维度由
embed_dim决定。 - 拆分多头(Split Heads) :投影后的Q、K、V张量会被均匀拆分 成
num_heads份。每一份就代表一个"头",每个头的维度是embed_dim // num_heads。这使得每个头都能在一个相对低维的子空间里独立工作。 - 缩放点积注意力(Scaled Dot-Product Attention) :每个头独立地执行注意力计算。其核心是缩放点积注意力 公式:
Attention(Q, K, V) = softmax(Q * K^T / √d_k) * V
这个过程可以理解为:用Q去"查询"K,计算出 attention 权重(相关性分数),然后用这个权重去加权求和 V,从而得到针对当前查询的"关注"结果。 - 合并与输出(Concatenate and Project) :所有头计算完成后,会将它们的结果拼接(Concat)起来,恢复成
embed_dim的维度。最后,再通过一个最终的线性层(W^O)进行投影,得到模块的最终输出。
流程示意图
以下是 nn.MultiheadAttention 核心流程 "投影-拆分-并行计算-合并" 的示意图,清晰展示了数据流转过程。

- 输入 :三个张量 Q、K、V,形状均为
(batch, seq_len, embed_dim)。 - 线性投影:通过三个独立的线性层(全连接)将 Q、K、V 投影到指定的特征空间。
- 拆分多头 :将投影后的向量在最后一维(
embed_dim)均匀拆分为num_heads个head_dim(head_dim = embed_dim // num_heads),形成多个头。 - 并行计算 :每个头独立执行缩放点积注意力(
softmax(Q·K^T/√d_k)·V),得到各自的输出。 - 合并头 :将所有头的输出在最后一维拼接(Concat),恢复维度为
embed_dim。 - 最终线性投影 :通过一个额外的线性层将拼接后的结果映射回
embed_dim,得到最终输出。
该流程完整体现了多头注意力机制的设计思想:通过多个并行的子空间,让模型同时关注不同方面的信息。
主要参数详解
初始化 nn.MultiheadAttention 时,最核心的参数如下:
embed_dim(int) :模型的总维度 。这是整个模块输入和输出的特征维度。注意 :embed_dim必须能够被num_heads整除。num_heads(int) :并行注意力头的数量 。每个头的维度是embed_dim // num_heads。增加头的数量可以让模型关注更多不同的子空间。dropout(float):在注意力权重上应用的 Dropout 概率,默认为 0.0。用于防止过拟合。bias(bool) :是否给线性投影层添加偏置,默认为True。batch_first(bool) :非常重要 的参数,决定了输入输出张量的形状。batch_first=True:张量形状为(batch_size, seq_len, embed_dim)。这是目前更常用、更直观的格式。batch_first=False(默认):张量形状为(seq_len, batch_size, embed_dim)。这是PyTorch早期的默认格式。
输入与输出
调用一个已初始化的 MultiheadAttention 模块时,主要输入是 query, key, value 三个张量。
-
输入:
query,key,value:形状取决于batch_first的设置。- 若
batch_first=True:形状为(batch_size, seq_len, embed_dim)。 - 若
batch_first=False:形状为(seq_len, batch_size, embed_dim)。
- 若
- 注意 :
key和value的序列长度(seq_len)可以与query不同,但它们的embed_dim必须相同。
-
输出:
attn_output:注意力机制的最终输出,形状与query输入一致。attn_output_weights:(可选)计算出的注意力权重,形状为(batch_size, num_heads, query_seq_len, key_seq_len)。
使用示例
python
import torch
import torch.nn as nn
# 1. 定义参数
embed_dim = 512 # 模型总维度
num_heads = 8 # 注意力头数量
batch_size = 2
seq_len = 10
# 2. 创建 MultiheadAttention 模块
# 设置 batch_first=True 使输入输出形状更直观
mha = nn.MultiheadAttention(embed_dim=embed_dim,
num_heads=num_heads,
batch_first=True)
# 3. 创建模拟输入 (batch_size, seq_len, embed_dim)
query = torch.randn(batch_size, seq_len, embed_dim)
key = torch.randn(batch_size, seq_len, embed_dim)
value = torch.randn(batch_size, seq_len, embed_dim)
# 4. 前向传播
attn_output, attn_weights = mha(query, key, value)
print(f"attn_output shape: {attn_output.shape}") # torch.Size([2, 10, 512])
print(f"attn_weights shape: {attn_weights.shape}") # torch.Size([2, 8, 10, 10])
在 TransformerEncoderLayer 中的使用
你提供的 TransformerEncoderLayer 正是对 nn.MultiheadAttention 的典型封装:
python
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
在编码器中,query, key, value 通常传入同一个张量 src ,这被称为自注意力(Self-Attention)。
重要注意事项
batch_first参数 :务必根据你的数据格式正确设置。PyTorch 官方教程和许多新项目都推荐使用batch_first=True。- 自注意力(Self-Attention) :当
query,key,value是同一个张量时,即进行自注意力计算。这是 Transformer 编码器层的核心操作。 - 推理加速 :在满足特定条件时(如自注意力、
batch_first=True、禁用梯度等),PyTorch 会自动使用一个fastpath来加速推理。 - 掩码(Masking) :
attn_mask:注意力掩码,形状为(query_seq_len, key_seq_len)或(batch_size * num_heads, query_seq_len, key_seq_len),用于屏蔽特定的注意力位置(如解码器中的未来信息)。key_padding_mask:键填充掩码,形状为(batch_size, key_seq_len),用于指示哪些位置是填充(padding)的,以避免模型关注到无意义的填充符。
总结
nn.MultiheadAttention 是 PyTorch 对多头注意力机制的高效实现。理解其 "投影-拆分-并行计算-合并" 的核心流程,以及 embed_dim、num_heads、batch_first 等关键参数,是使用和定制各种 Transformer 模型的基础。在实际应用中,它常作为 TransformerEncoderLayer 等更高级模块的核心组件。
num_heads 参数详细作用
在 nn.MultiheadAttention 中,num_heads 是控制注意力头数量 的核心参数。它直接决定了多头注意力(Multi-Head Attention)的行为,深刻地影响着模型的能力、效率和可解释性。
1. 核心作用:将特征空间划分成多个子空间
多头注意力的核心思想是并行的、不同的注意力机制 。num_heads 定义了并行头的数量。
- 每个注意力头都独立地进行缩放点积注意力(Scaled Dot-Product Attention)计算。
- 每个头拥有独立的线性投影矩阵 (
W_i^Q,W_i^K,W_i^V),将输入的 Q、K、V 投影到不同的低维子空间。 - 每个子空间的维度为
head_dim = embed_dim // num_heads。 - 多个头允许模型在不同的子空间中关注输入的不同部分或关系,从而捕捉更丰富的特征模式。
2. 对模型容量和表达能力的直接影响
num_heads 影响着模型的参数量 和计算复杂度:
-
参数量 :每个头有自己的投影层,但总参数量主要取决于
embed_dim和num_heads的关系。本质上,多头注意力与单头注意力的总参数量相近(因为总投影矩阵的大小是embed_dim × embed_dim),但多头的投影矩阵被拆分成多个低维矩阵,这增加了模型的多样性而非单纯增加参数量。 -
计算复杂度 :每个头的计算复杂度是
O(seq_len^2 * head_dim),总复杂度为O(seq_len^2 * embed_dim)(因为总 head_dim 和等于 embed_dim)。因此,num_heads不会显著改变理论复杂度,但会改变内存访问模式,在硬件上可能影响速度。 -
表达能力:
- 每个头可以学习关注不同的特征(例如,一个头关注词性,另一个关注语义,第三个关注距离关系等)。
- 较多的头通常能捕捉更丰富的模式,但过多的头可能导致每个子空间过小(
head_dim过小),限制每个头的表示能力,导致它们变得同质化或失效。 - 经验研究表明,在 Transformer 模型中,
num_heads通常设置为 8、12、16,并需要配合合适的embed_dim(保证head_dim至少为 32 或 64,以保持足够的表现力)。
3. 对训练和优化的影响
- 梯度流 :多个头提供了更丰富的梯度信号,有助于模型在训练初期更快收敛。
- 稳定性:多头设计使得模型对单头的噪声不敏感,因为多个头可以相互补充,提高鲁棒性。
- 与正则化的关系:多头注意力天然具有某种正则化效果(类似于集成学习),每个头学习不同的表示,最终拼接输出,这有助于缓解过拟合。
4. 与 embed_dim 的严格关系
embed_dim 必须能够被 num_heads 整除,即 embed_dim % num_heads == 0。这是因为 embed_dim 被均匀拆分到每个头,每个头的维度为 head_dim = embed_dim // num_heads。
- 为什么是整除? 因为输入输出总维度不变,但内部需要将特征向量分块到各个头。如果维度不能整除,则无法均匀拆分,PyTorch 会抛出错误。
- 如何选择? 通常选择
num_heads使得head_dim至少为 32 或 64,以保证每个头有足够的容量。例如,若embed_dim=512,可选num_heads=8(head_dim=64)或num_heads=16(head_dim=32)。
5. 实际使用中的常见配置
- 小型模型 (如 BERT-base):
embed_dim=768, num_heads=12→head_dim=64 - 大型模型 (如 BERT-large):
embed_dim=1024, num_heads=16→head_dim=64 - 视觉模型 (如 ViT-B/16):
embed_dim=768, num_heads=12→head_dim=64 - 计算受限场景 (如轻量级检测模型):可能使用
num_heads=4或8,并配合较小的embed_dim。
6. 在 TransformerEncoderLayer 中的体现
您提供的 TransformerEncoderLayer 中,num_heads 直接传递给 nn.MultiheadAttention:
python
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
其中 c1 即为 embed_dim。因此,调整 num_heads 将直接影响该编码器层的注意力行为。
7. 头的可视化与可解释性
多头注意力的一大优势是注意力权重的可解释性 。在推理时,可以提取每个头的注意力权重矩阵 attn_output_weights(形状为 [batch_size, num_heads, query_len, key_len]),并可视化不同的头关注的区域,以理解模型关注的模式。
例如:
- 在机器翻译中,有些头关注邻近词,有些头关注长距离依赖。
- 在图像分类中,不同头可能关注图像的不同区域。
8. 总结:如何调优 num_heads
- 默认值 :在大多数 Transformer 变体中,
num_heads=8或12是良好起点。 - 增大
num_heads:- 如果
embed_dim足够大(如 ≥512),可以尝试增加头数,以提高模型的表达能力和性能。 - 需要注意的是,头数增加会导致每个头的维度变小,可能削弱单头能力,需要通过实验验证。
- 如果
- 减小
num_heads:- 当模型过拟合或计算资源紧张时,可以适度减少头数,降低模型复杂度。
- 确保整除 :无论增减,必须保证
embed_dim % num_heads == 0。
总之,num_heads 是控制多头注意力机制多样性和并行性的关键超参数,合理选择有助于提升模型性能。在实际应用中,建议参考相似任务的成功配置,并进行小范围调优实验。
参考文献
1 https://docs.ultralytics.com/
2 https://github.com/ultralytics/ultralytics.git
- 由于本人水平有限,难免出现错漏,敬请批评改正。
- 更多精彩内容,可点击进入Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏、人工智能混合编程实践专栏或我的个人主页查看
- YOLOs-CPP:一个免费开源的YOLO全系列C++推理库(以YOLO26为例)
- PaddleOCR:Win10上安装使用PPOCRLabel标注工具
- 目标检测:使用自己的数据集微调DEIMv2进行物体检测
- 图像分割:PyTorch从零开始实现SegFormer语义分割
- 图像超分:使用自己的数据集微调Real-ESRGAN-x4plus进行超分重建
- 图像生成:PyTorch从零开始实现一个简单的扩散模型
- Stable Diffusion:使用自己的数据集微调 Stable Diffusion 3.5 LoRA 文生图模型
- 图像超分:使用自己的数据集微调Real-ESRGAN-x2plus进行超分重建
- Anomalib:使用Anomalib 2.1.0训练自己的数据集进行异常检测
- Anomalib:在Linux服务器上安装使用Anomalib 2.1.0
- 人工智能混合编程实践:C++调用封装好的DLL进行异常检测推理
- 人工智能混合编程实践:C++调用封装好的DLL进行FP16图像超分重建(v3.0)
- 隔离系统Python:源码编译3.11.8到自定义目录(含PGO性能优化)
- 在线机的Python环境迁移到离线机上
- Nuitka 将 Python 脚本封装为 .pyd 或 .so 文件
- Ultralytics:使用 YOLO11 进行速度估计
- Ultralytics:使用 YOLO11 进行物体追踪
- Ultralytics:使用 YOLO11 进行物体计数
- Ultralytics:使用 YOLO11 进行目标打码
- 人工智能混合编程实践:C++调用Python ONNX进行YOLOv8推理
- 人工智能混合编程实践:C++调用封装好的DLL进行YOLOv8实例分割
- 人工智能混合编程实践:C++调用Python ONNX进行图像超分重建
- 人工智能混合编程实践:C++调用Python AgentOCR进行文本识别
- 通过计算实例简单地理解PatchCore异常检测
- Python将YOLO格式实例分割数据集转换为COCO格式实例分割数据集
- YOLOv8 Ultralytics:使用Ultralytics框架训练RT-DETR实时目标检测模型
- 基于DETR的人脸伪装检测
- YOLOv7训练自己的数据集(口罩检测)
- YOLOv8训练自己的数据集(足球检测)
- YOLOv5:TensorRT加速YOLOv5模型推理
- YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
- 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
- YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
- YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
- Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
- YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
- 使用Kaggle GPU资源免费体验Stable Diffusion开源项目
- Stable Diffusion:在服务器上部署使用Stable Diffusion WebUI进行AI绘图(v2.0)
- Stable Diffusion:使用自己的数据集微调训练LoRA模型(v2.0)