Ultralytics:解读MSDeformAttn模块

Ultralytics:解读MSDeformAttn模块

前言

相关介绍

Ultralytics 简介

Ultralytics 基于多年的计算机视觉和人工智能基础研究,创建了最先进的 (SOTA) YOLO 模型。我们的模型不断更新性能和灵活性,快速、准确且易于使用。他们擅长对象检测、跟踪、实例分割、语义分割、图像分类和姿势估计任务。

前提条件

  • 熟悉Python、Pytorch

实验环境

bash 复制代码
Package                  Version
------------------------ ------------
Python                   3.11.8
absl-py                  2.4.0
accelerate               1.13.0
annotated-doc            0.0.4
anyio                    4.13.0
calflops                 0.3.2
certifi                  2026.4.22
charset-normalizer       3.4.7
click                    8.3.3
colorama                 0.4.6
contourpy                1.3.3
cycler                   0.12.1
filelock                 3.29.0
flatbuffers              25.12.19
fonttools                4.62.1
fsspec                   2026.4.0
grpcio                   1.80.0
h11                      0.16.0
hf-xet                   1.5.0
httpcore                 1.0.9
httpx                    0.28.1
huggingface_hub          1.14.0
idna                     3.15
Jinja2                   3.1.6
kiwisolver               1.5.0
Markdown                 3.10.2
markdown-it-py           4.2.0
MarkupSafe               3.0.3
matplotlib               3.10.9
mdurl                    0.1.2
ml_dtypes                0.5.0
mpmath                   1.3.0
networkx                 3.6.1
numpy                    1.26.4
nvidia-cublas-cu12       12.8.3.14
nvidia-cuda-cupti-cu12   12.8.57
nvidia-cuda-nvrtc-cu12   12.8.61
nvidia-cuda-runtime-cu12 12.8.57
nvidia-cudnn-cu12        9.7.1.26
nvidia-cufft-cu12        11.3.3.41
nvidia-cufile-cu12       1.13.0.11
nvidia-curand-cu12       10.3.9.55
nvidia-cusolver-cu12     11.7.2.55
nvidia-cusparse-cu12     12.5.7.53
nvidia-cusparselt-cu12   0.6.3
nvidia-nccl-cu12         2.26.2
nvidia-nvjitlink-cu12    12.8.61
nvidia-nvtx-cu12         12.8.55
onnx                     1.19.0
onnxruntime-gpu          1.26.0
onnxslim                 0.1.94
opencv-python            4.6.0.66
packaging                26.2
pillow                   12.2.0
pip                      24.0
polars                   1.40.1
polars-runtime-32        1.40.1
protobuf                 7.34.1
psutil                   7.2.2
pycocotools              2.0.11
Pygments                 2.20.0
pyparsing                3.3.2
python-dateutil          2.9.0.post0
PyYAML                   6.0.3
regex                    2026.5.9
requests                 2.34.1
rich                     15.0.0
safetensors              0.7.0
scipy                    1.16.0
setuptools               65.5.0
shellingham              1.5.4
six                      1.17.0
sympy                    1.14.0
tabulate                 0.10.0
tensorboard              2.20.0
tensorboard-data-server  0.7.2
tokenizers               0.22.2
torch                    2.7.1+cu128
torchaudio               2.7.1+cu128
torchvision              0.22.1+cu128
tqdm                     4.67.3
transformers             5.8.1
triton                   3.3.1
typer                    0.25.1
typing_extensions        4.15.0
ultralytics              8.4.58
ultralytics-thop         2.0.19
urllib3                  2.7.0
Werkzeug                 3.1.8

MSDeformAttn(多尺度可变形注意力模块)

MSDeformAttnDeformable DETR 中提出的多尺度可变形注意力(Multi-Scale Deformable Attention)机制的 PyTorch 实现。它通过 可学习的采样偏移注意力权重,在多个尺度的特征图上自适应地采样关键点,从而聚合空间信息。相比标准 Transformer 的全局注意力,它显著降低了计算复杂度,并提升了目标检测等任务中的性能。本实现参考了 Deformable-DETR 和 PaddleDetection 的代码。


代码实现

python 复制代码
import cv2
import math
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.nn.init import constant_, xavier_uniform_

def multi_scale_deformable_attn_pytorch(
    value: torch.Tensor,
    value_spatial_shapes: list,
    sampling_locations: torch.Tensor,
    attention_weights: torch.Tensor,
) -> torch.Tensor:
    """Implement multi-scale deformable attention in PyTorch.

    Folds the (num_levels, num_points) axes into a single num_total_points axis so every traced tensor stays at rank <=
    5, the maximum rank supported by CoreML's MIL converter. Numerically equivalent to the rank-6 reference
    implementation on CUDA and CPU.

    Args:
        value (torch.Tensor): Value tensor with shape (bs, num_keys, num_heads, embed_dims).
        value_spatial_shapes (list): Per-level spatial shapes as [(H_0, W_0), ..., (H_{L-1}, W_{L-1})].
        sampling_locations (torch.Tensor): Sampling locations with shape (bs, num_queries, num_heads, num_levels *
            num_points, 2).
        attention_weights (torch.Tensor): Attention weights with shape (bs, num_queries, num_heads, num_levels *
            num_points).

    Returns:
        (torch.Tensor): Output tensor with shape (bs, num_queries, num_heads * embed_dims).

    References:
        https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py
    """
    bs, _, num_heads, embed_dims = value.shape
    _, num_queries, _, num_total_points, _ = sampling_locations.shape
    num_points = num_total_points // len(value_spatial_shapes)

    # (bs, num_keys, num_heads, embed_dims) -> tuple of (bs*num_heads, embed_dims, H*W) per level
    value_list = value.permute(0, 2, 3, 1).flatten(0, 1).split([h * w for h, w in value_spatial_shapes], dim=-1)
    # Map to grid_sample coords in [-1, 1] and split per level: tuple of (bs*num_heads, num_queries, num_points, 2)
    sampling_grids = (2 * sampling_locations - 1).permute(0, 2, 1, 3, 4).flatten(0, 1).split(num_points, dim=-2)

    sampling_value_list = []
    for level, (h, w) in enumerate(value_spatial_shapes):
        value_l = value_list[level].reshape(bs * num_heads, embed_dims, h, w)
        sampling_value_list.append(
            F.grid_sample(value_l, sampling_grids[level], mode="bilinear", padding_mode="zeros", align_corners=False)
        )
    attention_weights = attention_weights.permute(0, 2, 1, 3).reshape(bs * num_heads, 1, num_queries, num_total_points)
    output = (
        (torch.cat(sampling_value_list, dim=-1) * attention_weights)
        .sum(-1)
        .view(bs, num_heads * embed_dims, num_queries)
    )
    return output.transpose(1, 2).contiguous()

class MSDeformAttn(nn.Module):
    """Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.

    This module implements multiscale deformable attention that can attend to features at multiple scales with learnable
    sampling locations and attention weights.

    Attributes:
        im2col_step (int): Step size for im2col operations.
        d_model (int): Model dimension.
        n_levels (int): Number of feature levels.
        n_heads (int): Number of attention heads.
        n_points (int): Number of sampling points per attention head per feature level.
        sampling_offsets (nn.Linear): Linear layer for generating sampling offsets.
        attention_weights (nn.Linear): Linear layer for generating attention weights.
        value_proj (nn.Linear): Linear layer for projecting values.
        output_proj (nn.Linear): Linear layer for projecting output.

    References:
        https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
    """

    def __init__(self, d_model: int = 256, n_levels: int = 4, n_heads: int = 8, n_points: int = 4):
        """Initialize MSDeformAttn with the given parameters.

        Args:
            d_model (int): Model dimension.
            n_levels (int): Number of feature levels.
            n_heads (int): Number of attention heads.
            n_points (int): Number of sampling points per attention head per feature level.
        """
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}")
        _d_per_head = d_model // n_heads
        # Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation
        assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`"

        self.im2col_step = 64

        self.d_model = d_model
        self.n_levels = n_levels
        self.n_heads = n_heads
        self.n_points = n_points

        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
        self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
        self.value_proj = nn.Linear(d_model, d_model)
        self.output_proj = nn.Linear(d_model, d_model)

        self._reset_parameters()

    def _reset_parameters(self):
        """Reset module parameters."""
        constant_(self.sampling_offsets.weight.data, 0.0)
        thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
        grid_init = (
            (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
            .view(self.n_heads, 1, 1, 2)
            .repeat(1, self.n_levels, self.n_points, 1)
        )
        for i in range(self.n_points):
            grid_init[:, :, i, :] *= i + 1
        with torch.no_grad():
            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
        constant_(self.attention_weights.weight.data, 0.0)
        constant_(self.attention_weights.bias.data, 0.0)
        xavier_uniform_(self.value_proj.weight.data)
        constant_(self.value_proj.bias.data, 0.0)
        xavier_uniform_(self.output_proj.weight.data)
        constant_(self.output_proj.bias.data, 0.0)

    def forward(
        self,
        query: torch.Tensor,
        refer_bbox: torch.Tensor,
        value: torch.Tensor,
        value_shapes: list,
        value_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Perform forward pass for multiscale deformable attention.

        Args:
            query (torch.Tensor): Query tensor with shape [bs, query_length, C].
            refer_bbox (torch.Tensor): Reference bounding boxes with shape [bs, query_length, 1, 2 or 4], range in [0,
                1], top-left (0,0), bottom-right (1, 1). The size-1 axis broadcasts across n_levels.
            value (torch.Tensor): Value tensor with shape [bs, value_length, C].
            value_shapes (list): List with shape [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})].
            value_mask (torch.Tensor, optional): Mask tensor with shape [bs, value_length], True for padding elements,
                False for non-padding elements.

        Returns:
            (torch.Tensor): Output tensor with shape [bs, Length_{query}, C].

        References:
            https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
        """
        bs, len_q = query.shape[:2]
        len_v = value.shape[1]
        assert sum(s[0] * s[1] for s in value_shapes) == len_v

        value = self.value_proj(value)
        if value_mask is not None:
            value = value.masked_fill(value_mask[..., None], float(0))
        value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads)
        # Fold (n_levels, n_points) into one axis so every traced tensor stays at rank <= 5 (required for CoreML
        # export); refer_bbox arrives as (bs, len_q, 1, 2 or 4) and its size-1 axis broadcasts implicitly.
        n_total_points = self.n_levels * self.n_points
        sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, n_total_points, 2)
        attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, n_total_points)
        attention_weights = F.softmax(attention_weights, -1)
        num_points = refer_bbox.shape[-1]
        if num_points == 2:
            offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)
            offset_normalizer = offset_normalizer[:, None, :].expand(-1, self.n_points, -1).reshape(n_total_points, 2)
            sampling_locations = refer_bbox[:, :, None, :, :] + sampling_offsets / offset_normalizer
        elif num_points == 4:
            sampling_locations = (
                refer_bbox[:, :, None, :, :2] + sampling_offsets / self.n_points * refer_bbox[:, :, None, :, 2:] * 0.5
            )
        else:
            raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.")
        output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
        return self.output_proj(output)

功能

  • 多尺度特征采样:在多个尺度的特征图上,为每个查询(query)位置生成可学习的采样点偏移,实现自适应采样。
  • 注意力权重学习:为每个采样点计算注意力权重,加权聚合特征。
  • 高效计算 :通过限制采样点数量(n_points),将计算复杂度从 O(N²) 降低到 O(N·K)(K 为采样点总数)。
  • 即插即用:可替换 Transformer 中的标准交叉注意力,适用于目标检测、实例分割等任务。

初始化参数

参数 类型 说明
d_model int 模型的特征维度(必须能被 n_heads 整除)
n_levels int 特征金字塔的层级数(如 4)
n_heads int 注意力头数(d_model 必须被其整除)
n_points int 每个注意力头在每个特征层级上的采样点数量(默认 4)

前向方法

  • forward(query, refer_bbox, value, value_shapes, value_mask=None)
    • query:查询张量 [bs, len_q, C]
    • refer_bbox:参考边界框 [bs, len_q, 1, 2][bs, len_q, 1, 4],范围在 [0,1] 内。
    • value:多尺度特征拼接后的张量 [bs, total_len_v, C]
    • value_shapes:每个尺度的空间尺寸列表,如 [(H0,W0), (H1,W1), ...]
    • value_mask(可选):填充掩码 [bs, total_len_v],True 表示忽略。

输出 :形状为 [bs, len_q, C] 的张量。


使用示例

python 复制代码
if __name__ == '__main__':
    # 设置参数
    bs, num_query, d_model = 2, 100, 256
    n_levels, n_heads, n_points = 4, 8, 4
    # 创建多尺度特征(假设每个层级空间尺寸不同)
    value_shapes = [(32, 32), (16, 16), (8, 8), (4, 4)]  # H, W
    total_len_v = sum(h * w for h, w in value_shapes)
    
    # 模拟输入
    query = torch.randn(bs, num_query, d_model)
    refer_bbox = torch.rand(bs, num_query, 1, 2)  # 归一化坐标
    value = torch.randn(bs, total_len_v, d_model)
    
    # 创建 MSDeformAttn
    attn = MSDeformAttn(d_model=d_model, n_levels=n_levels, n_heads=n_heads, n_points=n_points)
    
    # 前向传播
    with torch.no_grad():
        out = attn(query, refer_bbox, value, value_shapes)
    print("输入查询形状:", query.shape)   # [2, 100, 256]
    print("输出形状:", out.shape)         # [2, 100, 256]

输出示例

复制代码
输入查询形状: torch.Size([2, 100, 256])
输出形状: torch.Size([2, 100, 256])

流程示意图


代码解读

__init__ 方法
  • 检查 d_model % n_heads == 0
  • 定义线性层:
    • sampling_offsets:生成每个查询、每个头、每层的偏移量(2D)。
    • attention_weights:生成每个采样点的注意力权重。
    • value_proj:将 value 投影到 d_model
    • output_proj:将注意力输出投影回 d_model
  • 调用 _reset_parameters 初始化权重。
_reset_parameters 方法
  • sampling_offsets 的权重初始化为 0,偏置初始化为一个预定义的网格(grid_init),使初始采样点均匀分布在参考点周围。
  • attention_weights 的权重和偏置初始化为 0(softmax 后权重均匀)。
  • value_projoutput_proj 使用 Xavier 均匀初始化。
forward 方法
  1. 投影值和掩码处理value = self.value_proj(value),若有 value_mask,将掩码位置置零。
  2. 重塑 value(bs, len_v, n_heads, d_head),并 permute 为 (bs, n_heads, d_head, len_v),再按层级分割成多个特征图。
  3. 生成采样偏移和注意力权重
    • sampling_offsets(bs, len_q, n_heads, n_levels*n_points, 2)
    • attention_weights(bs, len_q, n_heads, n_levels*n_points),经 softmax。
  4. 计算采样位置 :根据 refer_bbox 和偏移量,结合各层特征图的归一化因子,得到每个查询在每个层的采样坐标(归一化到 [-1,1])。
    • refer_bbox 维度为 2(中心点),则 sampling_locations = refer_bbox + offsets / normalizer
    • 若维度为 4(框的 x,y,w,h),则用框尺寸缩放偏移。
  5. 多尺度采样 :对每个层级,使用 F.grid_sample 在对应的特征图上采样,得到采样值。
  6. 加权聚合:将采样值与注意力权重相乘,在采样点维度求和,得到每个查询的输出。
  7. 输出投影 :通过 output_proj 映射回 d_model
辅助函数 multi_scale_deformable_attn_pytorch
  • 实现多尺度可变形注意力的核心循环,将 value 按层级拆分,对每个层级进行 grid_sample,然后合并加权求和。
  • 为满足 CoreML 导出,将维度折叠到 5 阶以下。

注意事项

  1. 输入形状要求
    • queryvalue 的特征维度必须为 d_model
    • value_shapes 中各层级的面积之和必须等于 value 的第二维大小。
    • refer_bbox 的最后一维必须为 2(中心点)或 4(矩形框)。
  2. 归一化坐标refer_bbox 的范围应在 [0,1],内部会自动缩放。
  3. 采样点数量 :总采样点数为 n_heads * n_levels * n_points,当特征图较大时,可减小 n_points 以加速。
  4. 掩码处理 :若提供 value_mask,被掩码的位置在 value 中会被置零。
  5. 内存开销 :由于需存储多个层级的采样结果,显存占用较高,建议使用较小 batch 或调整 im2col_step(但本实现未使用)。
  6. 与标准注意力的区别:标准注意力对所有位置计算注意力,而可变形注意力只对少量采样点计算,大幅降低计算量。

优缺点

优点
  1. 计算高效:总采样点数远小于特征图尺寸,复杂度从 O(Lq·Lv) 降至 O(Lq·N_sampled),适合高分辨率特征。
  2. 自适应采样:可学习偏移使模型能关注到目标附近的有效特征,提升检测精度。
  3. 多尺度融合:同时融合多个尺度的特征,增强尺度不变性。
  4. 可导性好:所有操作可微,易于端到端训练。
缺点
  1. 实现复杂:涉及网格采样、多尺度拆分等操作,代码量较大,不易理解和修改。
  2. 对初始化敏感 :采样偏移的初始化(grid_init)影响训练收敛,需合理设计。
  3. 参数开销 :额外的线性层(sampling_offsets, attention_weights)增加了参数量。
  4. 内存占用grid_sample 需要大量中间张量,尤其在多尺度多查询下,显存消耗较大。

在 DETR 或 YOLO 等检测模型中,MSDeformAttn 常作为编码器中的交叉注意力层,用于提升特征聚合质量。使用时建议根据显存和任务需求调整 n_pointsn_levels,并配合预训练权重进行微调。

参考文献

1 https://docs.ultralytics.com/

2 https://github.com/ultralytics/ultralytics.git