Ultralytics:解读Concat模块

Ultralytics:解读Concat模块

前言

相关介绍

Ultralytics 简介

Ultralytics 基于多年的计算机视觉和人工智能基础研究,创建了最先进的 (SOTA) YOLO 模型。我们的模型不断更新性能和灵活性,快速、准确且易于使用。他们擅长对象检测、跟踪、实例分割、语义分割、图像分类和姿势估计任务。

前提条件

  • 熟悉Python、Pytorch

实验环境

bash 复制代码
Package                  Version
------------------------ ------------
Python                   3.11.8
absl-py                  2.4.0
accelerate               1.13.0
annotated-doc            0.0.4
anyio                    4.13.0
calflops                 0.3.2
certifi                  2026.4.22
charset-normalizer       3.4.7
click                    8.3.3
colorama                 0.4.6
contourpy                1.3.3
cycler                   0.12.1
filelock                 3.29.0
flatbuffers              25.12.19
fonttools                4.62.1
fsspec                   2026.4.0
grpcio                   1.80.0
h11                      0.16.0
hf-xet                   1.5.0
httpcore                 1.0.9
httpx                    0.28.1
huggingface_hub          1.14.0
idna                     3.15
Jinja2                   3.1.6
kiwisolver               1.5.0
Markdown                 3.10.2
markdown-it-py           4.2.0
MarkupSafe               3.0.3
matplotlib               3.10.9
mdurl                    0.1.2
ml_dtypes                0.5.0
mpmath                   1.3.0
networkx                 3.6.1
numpy                    1.26.4
nvidia-cublas-cu12       12.8.3.14
nvidia-cuda-cupti-cu12   12.8.57
nvidia-cuda-nvrtc-cu12   12.8.61
nvidia-cuda-runtime-cu12 12.8.57
nvidia-cudnn-cu12        9.7.1.26
nvidia-cufft-cu12        11.3.3.41
nvidia-cufile-cu12       1.13.0.11
nvidia-curand-cu12       10.3.9.55
nvidia-cusolver-cu12     11.7.2.55
nvidia-cusparse-cu12     12.5.7.53
nvidia-cusparselt-cu12   0.6.3
nvidia-nccl-cu12         2.26.2
nvidia-nvjitlink-cu12    12.8.61
nvidia-nvtx-cu12         12.8.55
onnx                     1.19.0
onnxruntime-gpu          1.26.0
onnxslim                 0.1.94
opencv-python            4.6.0.66
packaging                26.2
pillow                   12.2.0
pip                      24.0
polars                   1.40.1
polars-runtime-32        1.40.1
protobuf                 7.34.1
psutil                   7.2.2
pycocotools              2.0.11
Pygments                 2.20.0
pyparsing                3.3.2
python-dateutil          2.9.0.post0
PyYAML                   6.0.3
regex                    2026.5.9
requests                 2.34.1
rich                     15.0.0
safetensors              0.7.0
scipy                    1.16.0
setuptools               65.5.0
shellingham              1.5.4
six                      1.17.0
sympy                    1.14.0
tabulate                 0.10.0
tensorboard              2.20.0
tensorboard-data-server  0.7.2
tokenizers               0.22.2
torch                    2.7.1+cu128
torchaudio               2.7.1+cu128
torchvision              0.22.1+cu128
tqdm                     4.67.3
transformers             5.8.1
triton                   3.3.1
typer                    0.25.1
typing_extensions        4.15.0
ultralytics              8.4.58
ultralytics-thop         2.0.19
urllib3                  2.7.0
Werkzeug                 3.1.8

Concat(张量拼接模块)

Concattorch.cat 的简单封装,用于在指定维度上拼接多个张量。在目标检测模型(如 YOLO)中,它常用于特征金字塔(FPN)的跨层级特征融合,将不同尺度的特征图在通道维度(或批维度)上拼接,以增强多尺度表示能力。


代码实现

python 复制代码
import cv2
import math
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn

class Concat(nn.Module):
    """Concatenate a list of tensors along specified dimension.

    Attributes:
        d (int): Dimension along which to concatenate tensors.
    """

    def __init__(self, dimension=1):
        """Initialize Concat module.

        Args:
            dimension (int): Dimension along which to concatenate tensors.
        """
        super().__init__()
        self.d = dimension

    def forward(self, x: list[torch.Tensor]):
        """Concatenate input tensors along specified dimension.

        Args:
            x (list[torch.Tensor]): List of input tensors.

        Returns:
            (torch.Tensor): Concatenated tensor.
        """
        return torch.cat(x, self.d)

功能

  • 简单拼接:将输入张量列表按指定维度连接,不改变各张量的数值,仅沿维度方向堆叠。
  • 灵活配置 :通过 dimension 参数指定拼接维度(默认为 1,即通道维度),适应不同需求。
  • 即插即用:常用于网络结构中的特征融合,如将上采样后的深层特征与浅层特征拼接。

初始化参数

参数 类型 说明
dimension int 拼接维度(默认 1),例如 dimension=1 表示沿通道维拼接。

该模块没有可学习参数,仅做张量操作。


前向方法

  • forward(x):输入 x 为张量列表,返回值是 torch.cat(x, self.d)

使用示例

python 复制代码
if __name__ == '__main__':
    # 1. 读取图像
    img_path = "cat_640x640.png"
    img_bgr = cv2.imread(img_path)
    if img_bgr is None:
        raise FileNotFoundError(f"图片 {img_path} 不存在!")
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    img_tensor = torch.from_numpy(img_rgb).float().permute(2, 0, 1).unsqueeze(0)  # [1, 3, 640, 640]

    # 2. 创建水平翻转图像(用于拼接对照)
    flipped = torch.flip(img_tensor, dims=[3])  # [1, 3, 640, 640]

    # 3. 沿通道维拼接(dimension=1)
    concat = Concat(dimension=1)
    out = concat([img_tensor, flipped])  # [1, 6, 640, 640]
    print("拼接后形状:", out.shape)

    # 4. 拆分拼接后的前 3 通道(原图)和后 3 通道(翻转图)
    out_rgb1 = out[0, :3, :, :].cpu().numpy()  # [3, 640, 640]
    out_rgb2 = out[0, 3:, :, :].cpu().numpy()  # [3, 640, 640]
    # 转置为 HWC 格式便于 matplotlib 显示
    out_rgb1 = np.transpose(out_rgb1, (1, 2, 0))
    out_rgb2 = np.transpose(out_rgb2, (1, 2, 0))
    # 归一化到 [0, 255] uint8(便于显示)
    def normalize(img):
        img = (img - img.min()) / (img.max() - img.min() + 1e-8) * 255
        return img.astype(np.uint8)
    out_rgb1 = normalize(out_rgb1)
    out_rgb2 = normalize(out_rgb2)

    # 原图和翻转图(用于对比)
    orig = img_rgb
    flipped_np = flipped[0].cpu().numpy().transpose(1, 2, 0).astype(np.uint8)

    # 5. 可视化:原图、翻转图、拼接前3通道、拼接后3通道
    plt.figure(figsize=(16, 5), constrained_layout=True)
    plt.subplot(1, 4, 1)
    plt.imshow(orig)
    plt.title("Original")
    plt.axis("off")

    plt.subplot(1, 4, 2)
    plt.imshow(flipped_np)
    plt.title("Flipped")
    plt.axis("off")

    plt.subplot(1, 4, 3)
    plt.imshow(out_rgb1)
    plt.title("Concatenated (first 3 ch)")
    plt.axis("off")

    plt.subplot(1, 4, 4)
    plt.imshow(out_rgb2)
    plt.title("Concatenated (last 3 ch)")
    plt.axis("off")

    plt.savefig("concat_visualization.png", dpi=150)
    # plt.show()  # 若在远程服务器,建议注释
    print("可视化已保存为 concat_visualization.png")

输出示例

复制代码
拼接后形状: torch.Size([1, 6, 640, 640])
可视化已保存为 concat_visualization.png

流程示意图


代码解读

__init__ 方法
  • 存储拼接维度 self.d,默认值为 1(通道维),因为在 CNN 特征融合中,通常沿通道拼接不同层的特征。
forward 方法
  • 直接调用 torch.cat(x, self.d),对输入列表进行拼接。
  • 要求列表中各张量在非拼接维度上的尺寸完全一致,否则会报错。

注意事项

  1. 输入必须为列表或元组forward 期望接收一个可迭代对象,且内部元素均为 torch.Tensor
  2. 维度一致性 :所有张量在除拼接维度外的其他维度上必须具有相同的形状,否则 torch.cat 会抛出 RuntimeError
  3. 支持负维度索引dimension 可为负数(如 -1 表示最后一维),但使用时需注意区分。
  4. 无梯度计算影响:拼接本身是可微的,梯度会正确传播到各个输入张量。
  5. 常见场景 :在 YOLOv8 的 Head 部分,Concat 用于将上采样后的深层特征与骨干网络的同尺度特征拼接,以便后续检测头利用多层级信息。

优缺点

优点
  1. 实现极简:仅一行代码,轻量无开销。
  2. 通用性强:可用于任意维度的拼接,不仅限于通道维。
  3. 与网络结构兼容 :在 PyTorch 中,Concat 可无缝嵌入 nn.Sequential 或作为子模块使用。
  4. 易于理解:功能单一,行为清晰,便于调试和可视化。
缺点
  1. 功能过于简单 :几乎等同于直接调用 torch.cat,额外封装的意义有限(仅为了统一模块风格或便于配置文件解析)。
  2. 缺乏输入校验:不检查张量尺寸是否匹配,可能将错误推迟到运行时。
  3. 灵活性受限:无法支持拼接前的预处理(如归一化或缩放),若有需求需额外添加。

在 YOLOv8 等模型中,Concat 通常出现在 Neck 部分的特征融合阶段,配合 UpsampleC2f 一起构建特征金字塔。虽然它只是一个简单的拼接操作,但作为模块化设计的一部分,有利于网络结构的可视化配置(如通过 YAML 文件描述)。在实际开发中,若需更强大的融合机制,可考虑 Add(相加)或加权融合等替代方案。

参考文献

1 https://docs.ultralytics.com/

2 https://github.com/ultralytics/ultralytics.git