Ultralytics:解读DeformableTransformerDecoder模块

Ultralytics:解读DeformableTransformerDecoder模块

前言

相关介绍

Ultralytics 简介

Ultralytics 基于多年的计算机视觉和人工智能基础研究,创建了最先进的 (SOTA) YOLO 模型。我们的模型不断更新性能和灵活性,快速、准确且易于使用。他们擅长对象检测、跟踪、实例分割、语义分割、图像分类和姿势估计任务。

前提条件

  • 熟悉Python、Pytorch

实验环境

bash 复制代码
Package                  Version
------------------------ ------------
Python                   3.11.8
absl-py                  2.4.0
accelerate               1.13.0
annotated-doc            0.0.4
anyio                    4.13.0
calflops                 0.3.2
certifi                  2026.4.22
charset-normalizer       3.4.7
click                    8.3.3
colorama                 0.4.6
contourpy                1.3.3
cycler                   0.12.1
filelock                 3.29.0
flatbuffers              25.12.19
fonttools                4.62.1
fsspec                   2026.4.0
grpcio                   1.80.0
h11                      0.16.0
hf-xet                   1.5.0
httpcore                 1.0.9
httpx                    0.28.1
huggingface_hub          1.14.0
idna                     3.15
Jinja2                   3.1.6
kiwisolver               1.5.0
Markdown                 3.10.2
markdown-it-py           4.2.0
MarkupSafe               3.0.3
matplotlib               3.10.9
mdurl                    0.1.2
ml_dtypes                0.5.0
mpmath                   1.3.0
networkx                 3.6.1
numpy                    1.26.4
nvidia-cublas-cu12       12.8.3.14
nvidia-cuda-cupti-cu12   12.8.57
nvidia-cuda-nvrtc-cu12   12.8.61
nvidia-cuda-runtime-cu12 12.8.57
nvidia-cudnn-cu12        9.7.1.26
nvidia-cufft-cu12        11.3.3.41
nvidia-cufile-cu12       1.13.0.11
nvidia-curand-cu12       10.3.9.55
nvidia-cusolver-cu12     11.7.2.55
nvidia-cusparse-cu12     12.5.7.53
nvidia-cusparselt-cu12   0.6.3
nvidia-nccl-cu12         2.26.2
nvidia-nvjitlink-cu12    12.8.61
nvidia-nvtx-cu12         12.8.55
onnx                     1.19.0
onnxruntime-gpu          1.26.0
onnxslim                 0.1.94
opencv-python            4.6.0.66
packaging                26.2
pillow                   12.2.0
pip                      24.0
polars                   1.40.1
polars-runtime-32        1.40.1
protobuf                 7.34.1
psutil                   7.2.2
pycocotools              2.0.11
Pygments                 2.20.0
pyparsing                3.3.2
python-dateutil          2.9.0.post0
PyYAML                   6.0.3
regex                    2026.5.9
requests                 2.34.1
rich                     15.0.0
safetensors              0.7.0
scipy                    1.16.0
setuptools               65.5.0
shellingham              1.5.4
six                      1.17.0
sympy                    1.14.0
tabulate                 0.10.0
tensorboard              2.20.0
tensorboard-data-server  0.7.2
tokenizers               0.22.2
torch                    2.7.1+cu128
torchaudio               2.7.1+cu128
torchvision              0.22.1+cu128
tqdm                     4.67.3
transformers             5.8.1
triton                   3.3.1
typer                    0.25.1
typing_extensions        4.15.0
ultralytics              8.4.58
ultralytics-thop         2.0.19
urllib3                  2.7.0
Werkzeug                 3.1.8

DeformableTransformerDecoder(可变形Transformer解码器)

DeformableTransformerDecoderDeformable-DETR 解码器的完整实现,它由多个 DeformableTransformerDecoderLayer 堆叠而成,并集成了 迭代边界框细化多层预测 机制。该模块以解码器嵌入、参考框和多尺度特征为输入,通过多层可变形注意力逐步优化目标框和分类分数,是目标检测任务中高效且强大的解码组件。


代码实现

python 复制代码
import cv2
import math
import copy
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.nn.init import constant_, xavier_uniform_

def inverse_sigmoid(x, eps=1e-5):
    """Calculate the inverse sigmoid function for a tensor.

    This function applies the inverse of the sigmoid function to a tensor, which is useful in various neural network
    operations, particularly in attention mechanisms and coordinate transformations.

    Args:
        x (torch.Tensor): Input tensor with values in range [0, 1].
        eps (float, optional): Small epsilon value to prevent numerical instability.

    Returns:
        (torch.Tensor): Tensor after applying the inverse sigmoid function.

    Examples:
        >>> x = torch.tensor([0.2, 0.5, 0.8])
        >>> inverse_sigmoid(x)
        tensor([-1.3863,  0.0000,  1.3863])
    """
    x = x.clamp(min=0, max=1)
    x1 = x.clamp(min=eps)
    x2 = (1 - x).clamp(min=eps)
    return torch.log(x1 / x2)

def multi_scale_deformable_attn_pytorch(
    value: torch.Tensor,
    value_spatial_shapes: list,
    sampling_locations: torch.Tensor,
    attention_weights: torch.Tensor,
) -> torch.Tensor:
    """Implement multi-scale deformable attention in PyTorch.

    Folds the (num_levels, num_points) axes into a single num_total_points axis so every traced tensor stays at rank <=
    5, the maximum rank supported by CoreML's MIL converter. Numerically equivalent to the rank-6 reference
    implementation on CUDA and CPU.

    Args:
        value (torch.Tensor): Value tensor with shape (bs, num_keys, num_heads, embed_dims).
        value_spatial_shapes (list): Per-level spatial shapes as [(H_0, W_0), ..., (H_{L-1}, W_{L-1})].
        sampling_locations (torch.Tensor): Sampling locations with shape (bs, num_queries, num_heads, num_levels *
            num_points, 2).
        attention_weights (torch.Tensor): Attention weights with shape (bs, num_queries, num_heads, num_levels *
            num_points).

    Returns:
        (torch.Tensor): Output tensor with shape (bs, num_queries, num_heads * embed_dims).

    References:
        https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py
    """
    bs, _, num_heads, embed_dims = value.shape
    _, num_queries, _, num_total_points, _ = sampling_locations.shape
    num_points = num_total_points // len(value_spatial_shapes)

    # (bs, num_keys, num_heads, embed_dims) -> tuple of (bs*num_heads, embed_dims, H*W) per level
    value_list = value.permute(0, 2, 3, 1).flatten(0, 1).split([h * w for h, w in value_spatial_shapes], dim=-1)
    # Map to grid_sample coords in [-1, 1] and split per level: tuple of (bs*num_heads, num_queries, num_points, 2)
    sampling_grids = (2 * sampling_locations - 1).permute(0, 2, 1, 3, 4).flatten(0, 1).split(num_points, dim=-2)

    sampling_value_list = []
    for level, (h, w) in enumerate(value_spatial_shapes):
        value_l = value_list[level].reshape(bs * num_heads, embed_dims, h, w)
        sampling_value_list.append(
            F.grid_sample(value_l, sampling_grids[level], mode="bilinear", padding_mode="zeros", align_corners=False)
        )
    attention_weights = attention_weights.permute(0, 2, 1, 3).reshape(bs * num_heads, 1, num_queries, num_total_points)
    output = (
        (torch.cat(sampling_value_list, dim=-1) * attention_weights)
        .sum(-1)
        .view(bs, num_heads * embed_dims, num_queries)
    )
    return output.transpose(1, 2).contiguous()

class MSDeformAttn(nn.Module):
    """Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.

    This module implements multiscale deformable attention that can attend to features at multiple scales with learnable
    sampling locations and attention weights.

    Attributes:
        im2col_step (int): Step size for im2col operations.
        d_model (int): Model dimension.
        n_levels (int): Number of feature levels.
        n_heads (int): Number of attention heads.
        n_points (int): Number of sampling points per attention head per feature level.
        sampling_offsets (nn.Linear): Linear layer for generating sampling offsets.
        attention_weights (nn.Linear): Linear layer for generating attention weights.
        value_proj (nn.Linear): Linear layer for projecting values.
        output_proj (nn.Linear): Linear layer for projecting output.

    References:
        https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
    """

    def __init__(self, d_model: int = 256, n_levels: int = 4, n_heads: int = 8, n_points: int = 4):
        """Initialize MSDeformAttn with the given parameters.

        Args:
            d_model (int): Model dimension.
            n_levels (int): Number of feature levels.
            n_heads (int): Number of attention heads.
            n_points (int): Number of sampling points per attention head per feature level.
        """
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}")
        _d_per_head = d_model // n_heads
        # Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation
        assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`"

        self.im2col_step = 64

        self.d_model = d_model
        self.n_levels = n_levels
        self.n_heads = n_heads
        self.n_points = n_points

        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
        self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
        self.value_proj = nn.Linear(d_model, d_model)
        self.output_proj = nn.Linear(d_model, d_model)

        self._reset_parameters()

    def _reset_parameters(self):
        """Reset module parameters."""
        constant_(self.sampling_offsets.weight.data, 0.0)
        thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
        grid_init = (
            (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
            .view(self.n_heads, 1, 1, 2)
            .repeat(1, self.n_levels, self.n_points, 1)
        )
        for i in range(self.n_points):
            grid_init[:, :, i, :] *= i + 1
        with torch.no_grad():
            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
        constant_(self.attention_weights.weight.data, 0.0)
        constant_(self.attention_weights.bias.data, 0.0)
        xavier_uniform_(self.value_proj.weight.data)
        constant_(self.value_proj.bias.data, 0.0)
        xavier_uniform_(self.output_proj.weight.data)
        constant_(self.output_proj.bias.data, 0.0)

    def forward(
        self,
        query: torch.Tensor,
        refer_bbox: torch.Tensor,
        value: torch.Tensor,
        value_shapes: list,
        value_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Perform forward pass for multiscale deformable attention.

        Args:
            query (torch.Tensor): Query tensor with shape [bs, query_length, C].
            refer_bbox (torch.Tensor): Reference bounding boxes with shape [bs, query_length, 1, 2 or 4], range in [0,
                1], top-left (0,0), bottom-right (1, 1). The size-1 axis broadcasts across n_levels.
            value (torch.Tensor): Value tensor with shape [bs, value_length, C].
            value_shapes (list): List with shape [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})].
            value_mask (torch.Tensor, optional): Mask tensor with shape [bs, value_length], True for padding elements,
                False for non-padding elements.

        Returns:
            (torch.Tensor): Output tensor with shape [bs, Length_{query}, C].

        References:
            https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
        """
        bs, len_q = query.shape[:2]
        len_v = value.shape[1]
        assert sum(s[0] * s[1] for s in value_shapes) == len_v

        value = self.value_proj(value)
        if value_mask is not None:
            value = value.masked_fill(value_mask[..., None], float(0))
        value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads)
        # Fold (n_levels, n_points) into one axis so every traced tensor stays at rank <= 5 (required for CoreML
        # export); refer_bbox arrives as (bs, len_q, 1, 2 or 4) and its size-1 axis broadcasts implicitly.
        n_total_points = self.n_levels * self.n_points
        sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, n_total_points, 2)
        attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, n_total_points)
        attention_weights = F.softmax(attention_weights, -1)
        num_points = refer_bbox.shape[-1]
        if num_points == 2:
            offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)
            offset_normalizer = offset_normalizer[:, None, :].expand(-1, self.n_points, -1).reshape(n_total_points, 2)
            sampling_locations = refer_bbox[:, :, None, :, :] + sampling_offsets / offset_normalizer
        elif num_points == 4:
            sampling_locations = (
                refer_bbox[:, :, None, :, :2] + sampling_offsets / self.n_points * refer_bbox[:, :, None, :, 2:] * 0.5
            )
        else:
            raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.")
        output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
        return self.output_proj(output)

class DeformableTransformerDecoderLayer(nn.Module):
    """Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.

    This class implements a single decoder layer with self-attention, cross-attention using multiscale deformable
    attention, and a feedforward network.

    Attributes:
        self_attn (nn.MultiheadAttention): Self-attention module.
        dropout1 (nn.Dropout): Dropout after self-attention.
        norm1 (nn.LayerNorm): Layer normalization after self-attention.
        cross_attn (MSDeformAttn): Cross-attention module.
        dropout2 (nn.Dropout): Dropout after cross-attention.
        norm2 (nn.LayerNorm): Layer normalization after cross-attention.
        linear1 (nn.Linear): First linear layer in the feedforward network.
        act (nn.Module): Activation function.
        dropout3 (nn.Dropout): Dropout in the feedforward network.
        linear2 (nn.Linear): Second linear layer in the feedforward network.
        dropout4 (nn.Dropout): Dropout after the feedforward network.
        norm3 (nn.LayerNorm): Layer normalization after the feedforward network.

    References:
        https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
        https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py
    """

    def __init__(
        self,
        d_model: int = 256,
        n_heads: int = 8,
        d_ffn: int = 1024,
        dropout: float = 0.0,
        act: nn.Module = nn.ReLU(),
        n_levels: int = 4,
        n_points: int = 4,
    ):
        """Initialize the DeformableTransformerDecoderLayer with the given parameters.

        Args:
            d_model (int): Model dimension.
            n_heads (int): Number of attention heads.
            d_ffn (int): Dimension of the feedforward network.
            dropout (float): Dropout probability.
            act (nn.Module): Activation function.
            n_levels (int): Number of feature levels.
            n_points (int): Number of sampling points.
        """
        super().__init__()

        # Self attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # Cross attention
        self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # FFN
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.act = act
        self.dropout3 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)

    @staticmethod
    def with_pos_embed(tensor: torch.Tensor, pos: torch.Tensor | None) -> torch.Tensor:
        """Add positional embeddings to the input tensor, if provided."""
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, tgt: torch.Tensor) -> torch.Tensor:
        """Perform forward pass through the Feed-Forward Network part of the layer.

        Args:
            tgt (torch.Tensor): Input tensor.

        Returns:
            (torch.Tensor): Output tensor after FFN.
        """
        tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        return self.norm3(tgt)

    def forward(
        self,
        embed: torch.Tensor,
        refer_bbox: torch.Tensor,
        feats: torch.Tensor,
        shapes: list,
        padding_mask: torch.Tensor | None = None,
        attn_mask: torch.Tensor | None = None,
        query_pos: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Perform the forward pass through the entire decoder layer.

        Args:
            embed (torch.Tensor): Input embeddings.
            refer_bbox (torch.Tensor): Reference bounding boxes.
            feats (torch.Tensor): Feature maps.
            shapes (list): Feature shapes.
            padding_mask (torch.Tensor, optional): Padding mask.
            attn_mask (torch.Tensor, optional): Attention mask.
            query_pos (torch.Tensor, optional): Query position embeddings.

        Returns:
            (torch.Tensor): Output tensor after decoder layer.
        """
        # Self attention
        q = k = self.with_pos_embed(embed, query_pos)
        tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[
            0
        ].transpose(0, 1)
        embed = embed + self.dropout1(tgt)
        embed = self.norm1(embed)

        # Cross attention
        tgt = self.cross_attn(
            self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask
        )
        embed = embed + self.dropout2(tgt)
        embed = self.norm2(embed)

        # FFN
        return self.forward_ffn(embed)

def _get_clones(module, n):
    """Create a list of cloned modules from the given module.

    Args:
        module (nn.Module): The module to be cloned.
        n (int): Number of clones to create.

    Returns:
        (nn.ModuleList): A ModuleList containing n clones of the input module.

    Examples:
        >>> import torch.nn as nn
        >>> layer = nn.Linear(10, 10)
        >>> clones = _get_clones(layer, 3)
        >>> len(clones)
        3
    """
    return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])

class DeformableTransformerDecoder(nn.Module):
    """Deformable Transformer Decoder based on PaddleDetection implementation.

    This class implements a complete deformable transformer decoder with multiple decoder layers and prediction heads
    for bounding box regression and classification.

    Attributes:
        layers (nn.ModuleList): List of decoder layers.
        num_layers (int): Number of decoder layers.
        hidden_dim (int): Hidden dimension.
        eval_idx (int): Index of the layer to use during evaluation.

    References:
        https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
    """

    def __init__(self, hidden_dim: int, decoder_layer: nn.Module, num_layers: int, eval_idx: int = -1):
        """Initialize the DeformableTransformerDecoder with the given parameters.

        Args:
            hidden_dim (int): Hidden dimension.
            decoder_layer (nn.Module): Decoder layer module.
            num_layers (int): Number of decoder layers.
            eval_idx (int): Index of the layer to use during evaluation.
        """
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx

    def forward(
        self,
        embed: torch.Tensor,  # decoder embeddings
        refer_bbox: torch.Tensor,  # anchor
        feats: torch.Tensor,  # image features
        shapes: list,  # feature shapes
        bbox_head: nn.Module,
        score_head: nn.Module,
        pos_mlp: nn.Module,
        attn_mask: torch.Tensor | None = None,
        padding_mask: torch.Tensor | None = None,
    ):
        """Perform the forward pass through the entire decoder.

        Args:
            embed (torch.Tensor): Decoder embeddings.
            refer_bbox (torch.Tensor): Reference bounding boxes.
            feats (torch.Tensor): Image features.
            shapes (list): Feature shapes.
            bbox_head (nn.Module): Bounding box prediction head.
            score_head (nn.Module): Score prediction head.
            pos_mlp (nn.Module): Position MLP.
            attn_mask (torch.Tensor, optional): Attention mask.
            padding_mask (torch.Tensor, optional): Padding mask.

        Returns:
            dec_bboxes (torch.Tensor): Decoded bounding boxes.
            dec_cls (torch.Tensor): Decoded classification scores.
        """
        output = embed
        dec_bboxes = []
        dec_cls = []
        last_refined_bbox = None
        refer_bbox = refer_bbox.sigmoid()
        for i, layer in enumerate(self.layers):
            output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))

            bbox = bbox_head[i](output)
            refined_bbox = torch.sigmoid(bbox + inverse_sigmoid(refer_bbox))

            if self.training:
                dec_cls.append(score_head[i](output))
                if i == 0:
                    dec_bboxes.append(refined_bbox)
                else:
                    dec_bboxes.append(torch.sigmoid(bbox + inverse_sigmoid(last_refined_bbox)))
            elif i == self.eval_idx:
                dec_cls.append(score_head[i](output))
                dec_bboxes.append(refined_bbox)
                break

            last_refined_bbox = refined_bbox
            refer_bbox = refined_bbox.detach() if self.training else refined_bbox

        return torch.stack(dec_bboxes), torch.stack(dec_cls)

功能

  • 堆叠解码器层 :包含多个 DeformableTransformerDecoderLayer,每层通过自注意力和多尺度可变形交叉注意力逐步精炼特征。
  • 迭代边界框细化:每一层都输出一个精炼的边界框,下一层以更新后的框作为参考,逐步逼近真实目标。
  • 多层预测:在训练时,每一层都会输出类别分数和边界框,并参与损失计算,促进深层监督和更快收敛。
  • 推理选择 :通过 eval_idx 参数指定在推理时使用哪一层的输出,通常选择最后一层(-1)。

初始化参数

参数 类型 说明
hidden_dim int 模型的特征维度(与编码器和解码器层一致)
decoder_layer nn.Module 一个 DeformableTransformerDecoderLayer 实例,将被克隆 num_layers
num_layers int 解码器的层数
eval_idx int 推理时使用的层索引(默认 -1,即最后一层)

前向方法

  • forward(embed, refer_bbox, feats, shapes, bbox_head, score_head, pos_mlp, attn_mask=None, padding_mask=None)

    参数 类型 说明
    embed torch.Tensor 解码器初始嵌入 [bs, num_queries, hidden_dim]
    refer_bbox torch.Tensor 初始参考框(通常在 [0,1] 范围内)[bs, num_queries, 2]
    feats torch.Tensor 多尺度特征拼接 [bs, total_len, hidden_dim]
    shapes list 各层特征图尺寸 [(H0,W0), ...]
    bbox_head nn.Module 边界框预测头(通常是 nn.ModuleList,每层一个)
    score_head nn.Module 分类预测头(nn.ModuleList,每层一个)
    pos_mlp nn.Module 位置编码 MLP,输入参考框,输出位置嵌入
    attn_mask torch.Tensor, optional 自注意力掩码
    padding_mask torch.Tensor, optional 特征填充掩码

    输出

    • dec_bboxes:所有层的边界框预测堆叠 [num_layers, bs, num_queries, 4](训练)或单层 [1, bs, num_queries, 4](推理)。
    • dec_cls:所有层的分类分数堆叠 [num_layers, bs, num_queries, num_classes](训练)或单层 [1, bs, num_queries, num_classes](推理)。

使用示例

python 复制代码
if __name__ == '__main__':
    # 参数设置
    bs, num_query, hidden_dim = 2, 100, 256
    n_heads, d_ffn, dropout = 8, 1024, 0.1
    n_levels, n_points, num_layers = 4, 4, 6
    num_classes = 80

    # 构建单个解码器层
    decoder_layer = DeformableTransformerDecoderLayer(
        d_model=hidden_dim,
        n_heads=n_heads,
        d_ffn=d_ffn,
        dropout=dropout,
        n_levels=n_levels,
        n_points=n_points,
    )

    # 构建完整解码器
    decoder = DeformableTransformerDecoder(
        hidden_dim=hidden_dim,
        decoder_layer=decoder_layer,
        num_layers=num_layers,
        eval_idx=-1,
    )

    # 模拟输入:refer_bbox 使用 4 维 (x, y, w, h)
    embed = torch.randn(bs, num_query, hidden_dim)
    refer_bbox = torch.rand(bs, num_query, 4)  # 改为 4 维
    # 多尺度特征
    shapes = [(32, 32), (16, 16), (8, 8), (4, 4)]
    total_len = sum(h * w for h, w in shapes)
    feats = torch.randn(bs, total_len, hidden_dim)

    # 预测头(每层一个)
    bbox_head = nn.ModuleList([nn.Linear(hidden_dim, 4) for _ in range(num_layers)])
    score_head = nn.ModuleList([nn.Linear(hidden_dim, num_classes) for _ in range(num_layers)])
    # 位置 MLP:输入维度改为 4
    pos_mlp = nn.Sequential(
        nn.Linear(4, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, hidden_dim)
    )

    # 前向传播(训练模式)
    decoder.train()
    dec_bboxes, dec_cls = decoder(embed, refer_bbox, feats, shapes, bbox_head, score_head, pos_mlp)
    print("训练时边界框形状:", dec_bboxes.shape)  # [6, 2, 100, 4]
    print("训练时分类形状:", dec_cls.shape)      # [6, 2, 100, 80]

    # 推理模式
    decoder.eval()
    with torch.no_grad():
        dec_bboxes_eval, dec_cls_eval = decoder(embed, refer_bbox, feats, shapes, bbox_head, score_head, pos_mlp)
    print("推理时边界框形状:", dec_bboxes_eval.shape)  # [1, 2, 100, 4]
    print("推理时分类形状:", dec_cls_eval.shape)      # [1, 2, 100, 80]

输出示例

复制代码
训练时边界框形状: torch.Size([6, 2, 100, 4])
训练时分类形状: torch.Size([6, 2, 100, 80])
推理时边界框形状: torch.Size([1, 2, 100, 4])
推理时分类形状: torch.Size([1, 2, 100, 80])

流程示意图


代码解读

_get_clones 辅助函数
  • 使用 copy.deepcopy 克隆给定的模块 n 次,返回 nn.ModuleList。确保每层具有独立的参数,即使结构相同。
__init__ 方法
  • 调用 _get_clones 创建 num_layersdecoder_layer 实例。
  • 存储 hidden_dimnum_layerseval_idx。若 eval_idx 为负,转换为正索引(num_layers + eval_idx)。
forward 方法
  1. 初始化output = embedrefer_bbox = refer_bbox.sigmoid() 确保参考框在 [0,1] 内。
  2. 迭代层
    • 对第 i 层,调用 layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox)),其中 pos_mlp(refer_bbox) 生成位置嵌入。
    • 更新 output 为当前层的输出。
    • 预测边界框增量:bbox = bbox_head[i](output)
    • 计算精炼框:refined_bbox = sigmoid(bbox + inverse_sigmoid(refer_bbox))。其中 inverse_sigmoidlogit 函数,将 refer_bbox 映射到实数域,使得回归更稳定(需额外实现,但此处未定义,实际代码需提供)。
    • 分类分数:score = score_head[i](output)
  3. 训练与推理分支
    • 训练 :保存所有层的 scorerefined_bbox(第一层使用 refined_bbox,后续层使用 bbox + inverse_sigmoid(last_refined_bbox) 后再 sigmoid)。
    • 推理 :只保存 eval_idx 指定的层,并跳出循环。
  4. 更新参考框last_refined_bbox = refined_bbox,并根据 training 模式决定是否 detach()(训练时停止梯度,防止细化循环中的梯度干扰)。
  5. 返回 :将 dec_bboxesdec_cls 列表堆叠为张量。

注意 :代码中使用了 inverse_sigmoid 但未导入,实际使用时需定义:def inverse_sigmoid(x): return torch.log(x / (1 - x + 1e-8))


注意事项

  1. 迭代细化 :参考框在各层之间传递并不断精炼,训练时需小心梯度传播(通常只在最后一层回传,或使用 detach 控制)。
  2. inverse_sigmoid 函数:需自行实现,避免数值不稳定(添加小常数防止除以零)。
  3. 预测头数量bbox_headscore_head 必须是 num_layers 个,每层共享或独立均可。
  4. 训练与推理行为:训练时输出所有层,可用于辅助损失;推理时仅输出指定层,节省计算。
  5. 位置编码pos_mlp 将参考框映射为位置嵌入,通常为 2 层 MLP,是可学习的。
  6. 掩码传递attn_maskpadding_mask 传递给每层,注意形状匹配。

优缺点

优点
  1. 多层精炼:通过迭代边界框细化,逐步提升定位精度,提高检测性能。
  2. 深层监督:训练时每层都提供预测损失,加速收敛并缓解梯度消失。
  3. 灵活推理:可选择任意层的输出,便于在精度和速度间权衡。
  4. 即插即用:可替换标准 Transformer 解码器,只需配合可变形注意力层。
缺点
  1. 实现复杂 :需配合 inverse_sigmoid 和位置 MLP,代码依赖较多。
  2. 训练开销:训练时需计算所有层的损失,增加显存和计算量。
  3. 参数量大:每层都有独立的预测头和层参数,模型体积较大。
  4. 调参敏感 :迭代细化的学习率、eval_idx 等需仔细调节。

在 Deformable-DETR 和 RT-DETR 中,DeformableTransformerDecoder 是解码器的标准实现。使用时建议采用预训练权重,并设置合适的 eval_idx(通常为 -1),训练时注意梯度截断和损失权重分配。

参考文献

1 https://docs.ultralytics.com/

2 https://github.com/ultralytics/ultralytics.git