Ultralytics:解读TransformerEncoderLayer模块

Ultralytics:解读TransformerEncoderLayer模块

前言

相关介绍

Ultralytics 简介

Ultralytics 基于多年的计算机视觉和人工智能基础研究,创建了最先进的 (SOTA) YOLO 模型。我们的模型不断更新性能和灵活性,快速、准确且易于使用。他们擅长对象检测、跟踪、实例分割、语义分割、图像分类和姿势估计任务。

前提条件

  • 熟悉Python、Pytorch

实验环境

bash 复制代码
Package                  Version
------------------------ ------------
Python                   3.11.8
absl-py                  2.4.0
accelerate               1.13.0
annotated-doc            0.0.4
anyio                    4.13.0
calflops                 0.3.2
certifi                  2026.4.22
charset-normalizer       3.4.7
click                    8.3.3
colorama                 0.4.6
contourpy                1.3.3
cycler                   0.12.1
filelock                 3.29.0
flatbuffers              25.12.19
fonttools                4.62.1
fsspec                   2026.4.0
grpcio                   1.80.0
h11                      0.16.0
hf-xet                   1.5.0
httpcore                 1.0.9
httpx                    0.28.1
huggingface_hub          1.14.0
idna                     3.15
Jinja2                   3.1.6
kiwisolver               1.5.0
Markdown                 3.10.2
markdown-it-py           4.2.0
MarkupSafe               3.0.3
matplotlib               3.10.9
mdurl                    0.1.2
ml_dtypes                0.5.0
mpmath                   1.3.0
networkx                 3.6.1
numpy                    1.26.4
nvidia-cublas-cu12       12.8.3.14
nvidia-cuda-cupti-cu12   12.8.57
nvidia-cuda-nvrtc-cu12   12.8.61
nvidia-cuda-runtime-cu12 12.8.57
nvidia-cudnn-cu12        9.7.1.26
nvidia-cufft-cu12        11.3.3.41
nvidia-cufile-cu12       1.13.0.11
nvidia-curand-cu12       10.3.9.55
nvidia-cusolver-cu12     11.7.2.55
nvidia-cusparse-cu12     12.5.7.53
nvidia-cusparselt-cu12   0.6.3
nvidia-nccl-cu12         2.26.2
nvidia-nvjitlink-cu12    12.8.61
nvidia-nvtx-cu12         12.8.55
onnx                     1.19.0
onnxruntime-gpu          1.26.0
onnxslim                 0.1.94
opencv-python            4.6.0.66
packaging                26.2
pillow                   12.2.0
pip                      24.0
polars                   1.40.1
polars-runtime-32        1.40.1
protobuf                 7.34.1
psutil                   7.2.2
pycocotools              2.0.11
Pygments                 2.20.0
pyparsing                3.3.2
python-dateutil          2.9.0.post0
PyYAML                   6.0.3
regex                    2026.5.9
requests                 2.34.1
rich                     15.0.0
safetensors              0.7.0
scipy                    1.16.0
setuptools               65.5.0
shellingham              1.5.4
six                      1.17.0
sympy                    1.14.0
tabulate                 0.10.0
tensorboard              2.20.0
tensorboard-data-server  0.7.2
tokenizers               0.22.2
torch                    2.7.1+cu128
torchaudio               2.7.1+cu128
torchvision              0.22.1+cu128
tqdm                     4.67.3
transformers             5.8.1
triton                   3.3.1
typer                    0.25.1
typing_extensions        4.15.0
ultralytics              8.4.58
ultralytics-thop         2.0.19
urllib3                  2.7.0
Werkzeug                 3.1.8

TransformerEncoderLayer(Transformer 编码器层)

TransformerEncoderLayer 是标准 Transformer 编码器的核心组件,它由 多头自注意力 (Multi-Head Self-Attention)和 前馈网络 (Feed-Forward Network)组成,并配合残差连接和层归一化。该模块支持 pre-normalization (先归一化再子层)和 post-normalization(先子层再归一化)两种配置,灵活性高,常用于目标检测(如 RT-DETR)或视觉 Transformer 等模型中。


代码实现

python 复制代码
import cv2
import math
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn

class TransformerEncoderLayer(nn.Module):
    """A single layer of the transformer encoder.

    This class implements a standard transformer encoder layer with multi-head attention and feedforward network,
    supporting both pre-normalization and post-normalization configurations.

    Attributes:
        ma (nn.MultiheadAttention): Multi-head attention module.
        fc1 (nn.Linear): First linear layer in the feedforward network.
        fc2 (nn.Linear): Second linear layer in the feedforward network.
        norm1 (nn.LayerNorm): Layer normalization after attention.
        norm2 (nn.LayerNorm): Layer normalization after feedforward network.
        dropout (nn.Dropout): Dropout layer for the feedforward network.
        dropout1 (nn.Dropout): Dropout layer after attention.
        dropout2 (nn.Dropout): Dropout layer after feedforward network.
        act (nn.Module): Activation function.
        normalize_before (bool): Whether to apply normalization before attention and feedforward.
    """

    def __init__(
        self,
        c1: int,
        cm: int = 2048,
        num_heads: int = 8,
        dropout: float = 0.0,
        act: nn.Module = nn.GELU(),
        normalize_before: bool = False,
    ):
        """Initialize the TransformerEncoderLayer with specified parameters.

        Args:
            c1 (int): Input dimension.
            cm (int): Hidden dimension in the feedforward network.
            num_heads (int): Number of attention heads.
            dropout (float): Dropout probability.
            act (nn.Module): Activation function.
            normalize_before (bool): Whether to apply normalization before attention and feedforward.
        """
        super().__init__()
        # from ...utils.torch_utils import TORCH_1_9

        # if not TORCH_1_9:
        #     raise ModuleNotFoundError(
        #         "TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)."
        #     )
        self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
        # Implementation of Feedforward model
        self.fc1 = nn.Linear(c1, cm)
        self.fc2 = nn.Linear(cm, c1)

        self.norm1 = nn.LayerNorm(c1)
        self.norm2 = nn.LayerNorm(c1)
        self.dropout = nn.Dropout(dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.act = act
        self.normalize_before = normalize_before

    @staticmethod
    def with_pos_embed(tensor: torch.Tensor, pos: torch.Tensor | None = None) -> torch.Tensor:
        """Add position embeddings to the tensor if provided."""
        return tensor if pos is None else tensor + pos

    def forward_post(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor | None = None,
        src_key_padding_mask: torch.Tensor | None = None,
        pos: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Perform forward pass with post-normalization.

        Args:
            src (torch.Tensor): Input tensor.
            src_mask (torch.Tensor, optional): Mask for the src sequence.
            src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
            pos (torch.Tensor, optional): Positional encoding.

        Returns:
            (torch.Tensor): Output tensor after attention and feedforward.
        """
        q = k = self.with_pos_embed(src, pos)
        src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
        src = src + self.dropout2(src2)
        return self.norm2(src)

    def forward_pre(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor | None = None,
        src_key_padding_mask: torch.Tensor | None = None,
        pos: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Perform forward pass with pre-normalization.

        Args:
            src (torch.Tensor): Input tensor.
            src_mask (torch.Tensor, optional): Mask for the src sequence.
            src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
            pos (torch.Tensor, optional): Positional encoding.

        Returns:
            (torch.Tensor): Output tensor after attention and feedforward.
        """
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
        return src + self.dropout2(src2)

    def forward(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor | None = None,
        src_key_padding_mask: torch.Tensor | None = None,
        pos: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Forward propagate the input through the encoder module.

        Args:
            src (torch.Tensor): Input tensor.
            src_mask (torch.Tensor, optional): Mask for the src sequence.
            src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
            pos (torch.Tensor, optional): Positional encoding.

        Returns:
            (torch.Tensor): Output tensor after transformer encoder layer.
        """
        if self.normalize_before:
            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
        return self.forward_post(src, src_mask, src_key_padding_mask, pos)

功能

  • 特征转换:对输入序列(如特征图展平后的 token 序列)进行自注意力建模,捕获长距离依赖关系,再通过前馈网络进行非线性变换。
  • 归一化策略可选 :通过 normalize_before 控制使用 pre-norm (先 LayerNorm 再子层)或 post-norm(先子层再 LayerNorm)。pre-norm 通常更稳定,训练收敛更快。
  • 位置编码集成 :通过 with_pos_embed 静态方法,可将位置编码添加到查询(Q)和键(K)上,使注意力能感知位置信息。

初始化参数

参数 类型 说明
c1 int 输入特征维度(即每个 token 的向量维度)
cm int 前馈网络隐藏层维度(默认 2048)
num_heads int 多头注意力的头数(默认 8)
dropout float Dropout 概率(默认 0.0)
act nn.Module 前馈网络的激活函数(默认 nn.GELU()
normalize_before bool 是否使用 pre-normalization(默认 False,即 post-norm)

注意:该模块要求 PyTorch ≥ 1.9,因为 nn.MultiheadAttention 使用了 batch_first=True


前向方法

forward_post(后归一化)
  1. 将位置编码加到输入 src 上,得到 Q 和 K。
  2. 多头自注意力(Q, K, V=src)得到注意力输出 src2
  3. 残差连接:src = src + dropout(src2)
  4. 第一次 LayerNorm:src = norm1(src)
  5. 前馈网络(FC1 → Act → Dropout → FC2)得到 src2
  6. 残差连接:src = src + dropout(src2)
  7. 第二次 LayerNorm:return norm2(src)
forward_pre(前归一化)
  1. 第一次 LayerNorm:src2 = norm1(src)
  2. 将位置编码加到 src2 上,得到 Q 和 K。
  3. 多头自注意力(Q, K, V=src2)得到 src2
  4. 残差连接:src = src + dropout(src2)
  5. 第二次 LayerNorm:src2 = norm2(src)
  6. 前馈网络(FC1 → Act → Dropout → FC2)得到 src2
  7. 残差连接:return src + dropout(src2)
forward

根据 normalize_before 选择调用 forward_preforward_post


使用示例

python 复制代码
if __name__ == '__main__':
    # 构造输入:batch_size=2, seq_len=10, dim=128
    batch_size, seq_len, dim = 2, 10, 128
    src = torch.randn(batch_size, seq_len, dim)

    # 创建编码器层(post-norm 风格)
    encoder_layer = TransformerEncoderLayer(
        c1=dim,
        cm=2048,
        num_heads=8,
        dropout=0.1,
        act=nn.GELU(),
        normalize_before=False,  # post-norm
    )

    # 前向传播
    with torch.no_grad():
        out = encoder_layer(src)
    print("输入形状:", src.shape)          # [2, 10, 128]
    print("post-norm 输出形状:", out.shape)          # [2, 10, 128]

    # 切换为 pre-norm 风格
    encoder_layer_pre = TransformerEncoderLayer(
        c1=dim,
        cm=2048,
        num_heads=8,
        dropout=0.1,
        act=nn.GELU(),
        normalize_before=True,
    )
    with torch.no_grad():
        out_pre = encoder_layer_pre(src)
    print("pre-norm 输出形状:", out_pre.shape)

    # 演示使用 mask(可选)
    # 生成一个 padding mask(假设某些 token 为填充)
    key_padding_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
    key_padding_mask[0, -2:] = True  # 第一个样本的最后两个 token 被 mask
    with torch.no_grad():
        out_masked = encoder_layer(src, src_key_padding_mask=key_padding_mask)
    print("带 mask 的输出形状:", out_masked.shape)

输出示例

复制代码
输入形状: torch.Size([2, 10, 128])
post-norm 输出形状: torch.Size([2, 10, 128])
pre-norm 输出形状: torch.Size([2, 10, 128])
带 mask 的输出形状: torch.Size([2, 10, 128])

流程示意图

Post-Normalization(默认)
Pre-Normalization

代码解读

  • with_pos_embed:静态方法,若位置编码不为空,则将其加到输入张量上(用于 Q 和 K)。
  • __init__ :初始化 MHA、FFN 的线性层、LayerNorm 和 Dropout。注意 FFN 采用 Linear -> Act -> Dropout -> Linear 结构,且 dropout 参数统一控制所有 Dropout 层。
  • 版本检查 :若 PyTorch < 1.9,则抛出 ModuleNotFoundError,因为 batch_first=True 参数在旧版本中不可用。
  • forward_postforward_pre:分别实现两种归一化顺序,完全遵循 Transformer 原始设计(Vaswani et al.)和后续改进(pre-norm)。

注意事项

  1. 输入格式src 必须是 (B, T, C) 形状,其中 B 为 batch size,T 为序列长度(如特征图展平后的 token 数),C 为特征维度。
  2. 位置编码 :需要外部提供位置编码(pos),可通过三角函数或可学习的位置嵌入生成,并传入 forward
  3. Mask 使用
    • src_mask:序列内部的注意力掩码(如防止看到未来信息),形状通常为 (T, T)
    • src_key_padding_mask:针对 batch 中不同样本的填充 token 掩码,形状为 (B, T),值为 True 表示该位置被忽略。
  4. 训练与推理 :该模块包含 Dropout,训练时应设为 train() 模式,推理时应设为 eval()
  5. 内存占用:由于 MHA 的计算复杂度为 O(T²),当序列长度较大时(如高分辨率特征图),显存和计算量会急剧增加,需谨慎使用。

优缺点

优点
  1. 强大的全局建模能力:自注意力机制能捕获序列中任意两个位置之间的依赖关系,优于卷积的局部感受野。
  2. 灵活的归一化策略:支持 pre-norm 和 post-norm,pre-norm 在深层网络中更稳定,训练更平滑。
  3. 标准化接口 :与 PyTorch 官方 TransformerEncoderLayer 兼容,易于替换和对比。
  4. 掩码支持:可处理变长序列,适合检测、分割等需要 padding 的任务。
缺点
  1. 计算量大:自注意力的复杂度与序列长度平方成正比,对高分辨率特征图不友好。
  2. 依赖外部位置编码:本身不包含位置信息,需额外添加(如正弦编码或可学习嵌入),增加设计复杂度。
  3. 训练不稳定(post-norm):post-norm 在深层网络中可能梯度爆炸,需配合学习率预热等技术。
  4. 对硬件要求高 :需要 PyTorch ≥ 1.9,且 MHA 的 batch_first=True 在某些旧硬件上可能不支持。

在 YOLO 系列的 RT-DETR 中,TransformerEncoderLayer 被用于编码器部分,将 CNN 提取的特征图转换为序列并进行注意力建模,从而提升检测精度。使用时建议在深层特征层采用 pre-norm 配置,并注意序列长度的控制(可通过降采样或窗口注意力缓解复杂度)。

扩展

多头注意力(Multi-Head Attention)

nn.MultiheadAttention 是 PyTorch 中实现多头注意力(Multi-Head Attention) 机制的核心模块。它源自经典的Transformer论文《Attention Is All You Need》,是构建各种Transformer架构(如BERT、GPT、ViT等)的基础组件。

简单来说,它的核心思想是:让模型能够从不同的角度(子空间)同时关注输入信息的不同部分


核心原理:它是如何工作的?

多头注意力机制将整个计算过程分解为几个关键步骤,其核心数学定义如下:

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W^O

其中每个注意力头 head_i 的计算是:

head_i = Attention(Q * W_i^Q, K * W_i^K, V * W_i^V)

为了更直观地理解,可以把这个过程拆解为四步:

  1. 线性投影(Projection) :对于输入的查询(Q)、键(K)和值(V)张量,模块会使用三个独立的线性层(全连接层)将它们分别投影到不同的空间。这里的投影维度由embed_dim决定。
  2. 拆分多头(Split Heads) :投影后的Q、K、V张量会被均匀拆分num_heads 份。每一份就代表一个"头",每个头的维度是 embed_dim // num_heads。这使得每个头都能在一个相对低维的子空间里独立工作。
  3. 缩放点积注意力(Scaled Dot-Product Attention) :每个头独立地执行注意力计算。其核心是缩放点积注意力 公式:
    Attention(Q, K, V) = softmax(Q * K^T / √d_k) * V
    这个过程可以理解为:用Q去"查询"K,计算出 attention 权重(相关性分数),然后用这个权重去加权求和 V,从而得到针对当前查询的"关注"结果。
  4. 合并与输出(Concatenate and Project) :所有头计算完成后,会将它们的结果拼接(Concat)起来,恢复成 embed_dim 的维度。最后,再通过一个最终的线性层(W^O)进行投影,得到模块的最终输出。

流程示意图

以下是 nn.MultiheadAttention 核心流程 "投影-拆分-并行计算-合并" 的示意图,清晰展示了数据流转过程。

  1. 输入 :三个张量 Q、K、V,形状均为 (batch, seq_len, embed_dim)
  2. 线性投影:通过三个独立的线性层(全连接)将 Q、K、V 投影到指定的特征空间。
  3. 拆分多头 :将投影后的向量在最后一维(embed_dim)均匀拆分为 num_headshead_dimhead_dim = embed_dim // num_heads),形成多个头。
  4. 并行计算 :每个头独立执行缩放点积注意力(softmax(Q·K^T/√d_k)·V),得到各自的输出。
  5. 合并头 :将所有头的输出在最后一维拼接(Concat),恢复维度为 embed_dim
  6. 最终线性投影 :通过一个额外的线性层将拼接后的结果映射回 embed_dim,得到最终输出。

该流程完整体现了多头注意力机制的设计思想:通过多个并行的子空间,让模型同时关注不同方面的信息

主要参数详解

初始化 nn.MultiheadAttention 时,最核心的参数如下:

  • embed_dim (int)模型的总维度 。这是整个模块输入和输出的特征维度。注意embed_dim 必须能够被 num_heads 整除。
  • num_heads (int)并行注意力头的数量 。每个头的维度是 embed_dim // num_heads。增加头的数量可以让模型关注更多不同的子空间。
  • dropout (float):在注意力权重上应用的 Dropout 概率,默认为 0.0。用于防止过拟合。
  • bias (bool) :是否给线性投影层添加偏置,默认为 True
  • batch_first (bool)非常重要 的参数,决定了输入输出张量的形状。
    • batch_first=True:张量形状为 (batch_size, seq_len, embed_dim)这是目前更常用、更直观的格式
    • batch_first=False (默认):张量形状为 (seq_len, batch_size, embed_dim)。这是PyTorch早期的默认格式。

输入与输出

调用一个已初始化的 MultiheadAttention 模块时,主要输入是 query, key, value 三个张量。

  • 输入

    • query, key, value:形状取决于 batch_first 的设置。
      • batch_first=True:形状为 (batch_size, seq_len, embed_dim)
      • batch_first=False:形状为 (seq_len, batch_size, embed_dim)
    • 注意keyvalue 的序列长度(seq_len)可以与 query 不同,但它们的 embed_dim 必须相同。
  • 输出

    • attn_output :注意力机制的最终输出,形状与 query 输入一致。
    • attn_output_weights :(可选)计算出的注意力权重,形状为 (batch_size, num_heads, query_seq_len, key_seq_len)

使用示例

python 复制代码
import torch
import torch.nn as nn

# 1. 定义参数
embed_dim = 512  # 模型总维度
num_heads = 8    # 注意力头数量
batch_size = 2
seq_len = 10

# 2. 创建 MultiheadAttention 模块
# 设置 batch_first=True 使输入输出形状更直观
mha = nn.MultiheadAttention(embed_dim=embed_dim, 
                            num_heads=num_heads, 
                            batch_first=True)

# 3. 创建模拟输入 (batch_size, seq_len, embed_dim)
query = torch.randn(batch_size, seq_len, embed_dim)
key = torch.randn(batch_size, seq_len, embed_dim)
value = torch.randn(batch_size, seq_len, embed_dim)

# 4. 前向传播
attn_output, attn_weights = mha(query, key, value)

print(f"attn_output shape: {attn_output.shape}") # torch.Size([2, 10, 512])
print(f"attn_weights shape: {attn_weights.shape}") # torch.Size([2, 8, 10, 10])
TransformerEncoderLayer 中的使用

你提供的 TransformerEncoderLayer 正是对 nn.MultiheadAttention 的典型封装:

python 复制代码
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)

在编码器中,query, key, value 通常传入同一个张量 src ,这被称为自注意力(Self-Attention)

重要注意事项

  1. batch_first 参数 :务必根据你的数据格式正确设置。PyTorch 官方教程和许多新项目都推荐使用 batch_first=True
  2. 自注意力(Self-Attention) :当 query, key, value 是同一个张量时,即进行自注意力计算。这是 Transformer 编码器层的核心操作。
  3. 推理加速 :在满足特定条件时(如自注意力、batch_first=True、禁用梯度等),PyTorch 会自动使用一个 fastpath 来加速推理。
  4. 掩码(Masking)
    • attn_mask :注意力掩码,形状为 (query_seq_len, key_seq_len)(batch_size * num_heads, query_seq_len, key_seq_len),用于屏蔽特定的注意力位置(如解码器中的未来信息)。
    • key_padding_mask :键填充掩码,形状为 (batch_size, key_seq_len),用于指示哪些位置是填充(padding)的,以避免模型关注到无意义的填充符。

总结

nn.MultiheadAttention 是 PyTorch 对多头注意力机制的高效实现。理解其 "投影-拆分-并行计算-合并" 的核心流程,以及 embed_dimnum_headsbatch_first 等关键参数,是使用和定制各种 Transformer 模型的基础。在实际应用中,它常作为 TransformerEncoderLayer 等更高级模块的核心组件。

num_heads 参数详细作用

nn.MultiheadAttention 中,num_heads 是控制注意力头数量 的核心参数。它直接决定了多头注意力(Multi-Head Attention)的行为,深刻地影响着模型的能力、效率和可解释性。


1. 核心作用:将特征空间划分成多个子空间

多头注意力的核心思想是并行的、不同的注意力机制num_heads 定义了并行头的数量。

  • 每个注意力头都独立地进行缩放点积注意力(Scaled Dot-Product Attention)计算。
  • 每个头拥有独立的线性投影矩阵W_i^Q, W_i^K, W_i^V),将输入的 Q、K、V 投影到不同的低维子空间。
  • 每个子空间的维度为 head_dim = embed_dim // num_heads
  • 多个头允许模型在不同的子空间中关注输入的不同部分或关系,从而捕捉更丰富的特征模式。

2. 对模型容量和表达能力的直接影响

num_heads 影响着模型的参数量计算复杂度

  • 参数量 :每个头有自己的投影层,但总参数量主要取决于 embed_dimnum_heads 的关系。本质上,多头注意力与单头注意力的总参数量相近(因为总投影矩阵的大小是 embed_dim × embed_dim),但多头的投影矩阵被拆分成多个低维矩阵,这增加了模型的多样性而非单纯增加参数量。

  • 计算复杂度 :每个头的计算复杂度是 O(seq_len^2 * head_dim),总复杂度为 O(seq_len^2 * embed_dim)(因为总 head_dim 和等于 embed_dim)。因此,num_heads 不会显著改变理论复杂度,但会改变内存访问模式,在硬件上可能影响速度。

  • 表达能力

    • 每个头可以学习关注不同的特征(例如,一个头关注词性,另一个关注语义,第三个关注距离关系等)。
    • 较多的头通常能捕捉更丰富的模式,但过多的头可能导致每个子空间过小(head_dim 过小),限制每个头的表示能力,导致它们变得同质化或失效。
    • 经验研究表明,在 Transformer 模型中,num_heads 通常设置为 8、12、16,并需要配合合适的 embed_dim(保证 head_dim 至少为 32 或 64,以保持足够的表现力)。

3. 对训练和优化的影响
  • 梯度流 :多个头提供了更丰富的梯度信号,有助于模型在训练初期更快收敛。
  • 稳定性:多头设计使得模型对单头的噪声不敏感,因为多个头可以相互补充,提高鲁棒性。
  • 与正则化的关系:多头注意力天然具有某种正则化效果(类似于集成学习),每个头学习不同的表示,最终拼接输出,这有助于缓解过拟合。

4. embed_dim 的严格关系

embed_dim 必须能够被 num_heads 整除,即 embed_dim % num_heads == 0。这是因为 embed_dim 被均匀拆分到每个头,每个头的维度为 head_dim = embed_dim // num_heads

  • 为什么是整除? 因为输入输出总维度不变,但内部需要将特征向量分块到各个头。如果维度不能整除,则无法均匀拆分,PyTorch 会抛出错误。
  • 如何选择? 通常选择 num_heads 使得 head_dim 至少为 32 或 64,以保证每个头有足够的容量。例如,若 embed_dim=512,可选 num_heads=8head_dim=64)或 num_heads=16head_dim=32)。

5. 实际使用中的常见配置
  • 小型模型 (如 BERT-base):embed_dim=768, num_heads=12head_dim=64
  • 大型模型 (如 BERT-large):embed_dim=1024, num_heads=16head_dim=64
  • 视觉模型 (如 ViT-B/16):embed_dim=768, num_heads=12head_dim=64
  • 计算受限场景 (如轻量级检测模型):可能使用 num_heads=48,并配合较小的 embed_dim

6. 在 TransformerEncoderLayer 中的体现

您提供的 TransformerEncoderLayer 中,num_heads 直接传递给 nn.MultiheadAttention

python 复制代码
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)

其中 c1 即为 embed_dim。因此,调整 num_heads 将直接影响该编码器层的注意力行为。


7. 头的可视化与可解释性

多头注意力的一大优势是注意力权重的可解释性 。在推理时,可以提取每个头的注意力权重矩阵 attn_output_weights(形状为 [batch_size, num_heads, query_len, key_len]),并可视化不同的头关注的区域,以理解模型关注的模式。

例如:

  • 在机器翻译中,有些头关注邻近词,有些头关注长距离依赖。
  • 在图像分类中,不同头可能关注图像的不同区域。

8. 总结:如何调优 num_heads
  • 默认值 :在大多数 Transformer 变体中,num_heads=812 是良好起点。
  • 增大 num_heads
    • 如果 embed_dim 足够大(如 ≥512),可以尝试增加头数,以提高模型的表达能力和性能。
    • 需要注意的是,头数增加会导致每个头的维度变小,可能削弱单头能力,需要通过实验验证。
  • 减小 num_heads
    • 当模型过拟合或计算资源紧张时,可以适度减少头数,降低模型复杂度。
  • 确保整除 :无论增减,必须保证 embed_dim % num_heads == 0

总之,num_heads 是控制多头注意力机制多样性和并行性的关键超参数,合理选择有助于提升模型性能。在实际应用中,建议参考相似任务的成功配置,并进行小范围调优实验。

参考文献

1 https://docs.ultralytics.com/

2 https://github.com/ultralytics/ultralytics.git