Ultralytics:解读ChannelAttention模块

前言
- 由于本人水平有限,难免出现错漏,敬请批评改正。
- 更多精彩内容,可点击进入Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏、人工智能混合编程实践专栏或我的个人主页查看
- YOLOs-CPP:一个免费开源的YOLO全系列C++推理库(以YOLO26为例)
- PaddleOCR:Win10上安装使用PPOCRLabel标注工具
- 目标检测:使用自己的数据集微调DEIMv2进行物体检测
- 图像分割:PyTorch从零开始实现SegFormer语义分割
- 图像超分:使用自己的数据集微调Real-ESRGAN-x4plus进行超分重建
- 图像生成:PyTorch从零开始实现一个简单的扩散模型
- Stable Diffusion:使用自己的数据集微调 Stable Diffusion 3.5 LoRA 文生图模型
- 图像超分:使用自己的数据集微调Real-ESRGAN-x2plus进行超分重建
- Anomalib:使用Anomalib 2.1.0训练自己的数据集进行异常检测
- Anomalib:在Linux服务器上安装使用Anomalib 2.1.0
- 人工智能混合编程实践:C++调用封装好的DLL进行异常检测推理
- 人工智能混合编程实践:C++调用封装好的DLL进行FP16图像超分重建(v3.0)
- 隔离系统Python:源码编译3.11.8到自定义目录(含PGO性能优化)
- 在线机的Python环境迁移到离线机上
- Nuitka 将 Python 脚本封装为 .pyd 或 .so 文件
- Ultralytics:使用 YOLO11 进行速度估计
- Ultralytics:使用 YOLO11 进行物体追踪
- Ultralytics:使用 YOLO11 进行物体计数
- Ultralytics:使用 YOLO11 进行目标打码
- 人工智能混合编程实践:C++调用Python ONNX进行YOLOv8推理
- 人工智能混合编程实践:C++调用封装好的DLL进行YOLOv8实例分割
- 人工智能混合编程实践:C++调用Python ONNX进行图像超分重建
- 人工智能混合编程实践:C++调用Python AgentOCR进行文本识别
- 通过计算实例简单地理解PatchCore异常检测
- Python将YOLO格式实例分割数据集转换为COCO格式实例分割数据集
- YOLOv8 Ultralytics:使用Ultralytics框架训练RT-DETR实时目标检测模型
- 基于DETR的人脸伪装检测
- YOLOv7训练自己的数据集(口罩检测)
- YOLOv8训练自己的数据集(足球检测)
- YOLOv5:TensorRT加速YOLOv5模型推理
- YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
- 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
- YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
- YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
- Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
- YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
- 使用Kaggle GPU资源免费体验Stable Diffusion开源项目
- Stable Diffusion:在服务器上部署使用Stable Diffusion WebUI进行AI绘图(v2.0)
- Stable Diffusion:使用自己的数据集微调训练LoRA模型(v2.0)
相关介绍
Ultralytics 简介
Ultralytics 基于多年的计算机视觉和人工智能基础研究,创建了最先进的 (SOTA) YOLO 模型。我们的模型不断更新性能和灵活性,快速、准确且易于使用。他们擅长对象检测、跟踪、实例分割、语义分割、图像分类和姿势估计任务。
前提条件
- 熟悉Python、Pytorch
实验环境
bash
Package Version
------------------------ ------------
Python 3.11.8
absl-py 2.4.0
accelerate 1.13.0
annotated-doc 0.0.4
anyio 4.13.0
calflops 0.3.2
certifi 2026.4.22
charset-normalizer 3.4.7
click 8.3.3
colorama 0.4.6
contourpy 1.3.3
cycler 0.12.1
filelock 3.29.0
flatbuffers 25.12.19
fonttools 4.62.1
fsspec 2026.4.0
grpcio 1.80.0
h11 0.16.0
hf-xet 1.5.0
httpcore 1.0.9
httpx 0.28.1
huggingface_hub 1.14.0
idna 3.15
Jinja2 3.1.6
kiwisolver 1.5.0
Markdown 3.10.2
markdown-it-py 4.2.0
MarkupSafe 3.0.3
matplotlib 3.10.9
mdurl 0.1.2
ml_dtypes 0.5.0
mpmath 1.3.0
networkx 3.6.1
numpy 1.26.4
nvidia-cublas-cu12 12.8.3.14
nvidia-cuda-cupti-cu12 12.8.57
nvidia-cuda-nvrtc-cu12 12.8.61
nvidia-cuda-runtime-cu12 12.8.57
nvidia-cudnn-cu12 9.7.1.26
nvidia-cufft-cu12 11.3.3.41
nvidia-cufile-cu12 1.13.0.11
nvidia-curand-cu12 10.3.9.55
nvidia-cusolver-cu12 11.7.2.55
nvidia-cusparse-cu12 12.5.7.53
nvidia-cusparselt-cu12 0.6.3
nvidia-nccl-cu12 2.26.2
nvidia-nvjitlink-cu12 12.8.61
nvidia-nvtx-cu12 12.8.55
onnx 1.19.0
onnxruntime-gpu 1.26.0
onnxslim 0.1.94
opencv-python 4.6.0.66
packaging 26.2
pillow 12.2.0
pip 24.0
polars 1.40.1
polars-runtime-32 1.40.1
protobuf 7.34.1
psutil 7.2.2
pycocotools 2.0.11
Pygments 2.20.0
pyparsing 3.3.2
python-dateutil 2.9.0.post0
PyYAML 6.0.3
regex 2026.5.9
requests 2.34.1
rich 15.0.0
safetensors 0.7.0
scipy 1.16.0
setuptools 65.5.0
shellingham 1.5.4
six 1.17.0
sympy 1.14.0
tabulate 0.10.0
tensorboard 2.20.0
tensorboard-data-server 0.7.2
tokenizers 0.22.2
torch 2.7.1+cu128
torchaudio 2.7.1+cu128
torchvision 0.22.1+cu128
tqdm 4.67.3
transformers 5.8.1
triton 3.3.1
typer 0.25.1
typing_extensions 4.15.0
ultralytics 8.4.58
ultralytics-thop 2.0.19
urllib3 2.7.0
Werkzeug 3.1.8
ChannelAttention(通道注意力模块)
ChannelAttention 是一种轻量级的注意力机制,它通过 全局平均池化 和 1×1 卷积 为每个通道生成注意力权重,然后与原始特征图逐通道相乘,实现 特征重标定 (feature recalibration)。该模块源自 SENet(Squeeze-and-Excitation Networks),能有效增强重要通道的特征响应,抑制无关通道,在图像分类、目标检测等任务中被广泛使用,例如 MMDetection 的 RTMDet 中即采用了此实现。
代码实现
python
import cv2
import math
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
class ChannelAttention(nn.Module):
"""Channel-attention module for feature recalibration.
Applies attention weights to channels based on global average pooling.
Attributes:
pool (nn.AdaptiveAvgPool2d): Global average pooling.
fc (nn.Conv2d): Fully connected layer implemented as 1x1 convolution.
act (nn.Sigmoid): Sigmoid activation for attention weights.
References:
https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet
"""
def __init__(self, channels: int) -> None:
"""Initialize Channel-attention module.
Args:
channels (int): Number of input channels.
"""
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
self.act = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply channel attention to input tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Channel-attended output tensor.
"""
return x * self.act(self.fc(self.pool(x)))
功能
- 全局信息聚合 :通过
AdaptiveAvgPool2d(1)将每个通道的空间特征压缩为一个标量,聚合全局上下文信息。 - 通道权重生成 :通过 1×1 卷积层(等价于全连接层)学习通道间的依赖关系,并经过 Sigmoid 激活将权重映射到
(0,1)区间。 - 特征重标定:将生成的注意力权重与原始特征图逐通道相乘,突出重要通道,抑制不相关通道。
初始化参数
| 参数 | 类型 | 说明 |
|---|---|---|
channels |
int | 输入特征图的通道数,也是输出的通道数(输入输出通道不变) |
该模块不改变特征图的空间尺寸和通道数,仅对每个通道进行加权。
前向方法
forward(x):输入x(形状[B, C, H, W]),输出x * attention,其中attention形状为[B, C, 1, 1],经过广播逐元素相乘。
使用示例

python
if __name__ == '__main__':
torch.manual_seed(42) # 或任意固定值
# 1. 读取图像(请修改为实际路径)
img_path = "cat_640x640.png"
img_bgr = cv2.imread(img_path)
if img_bgr is None:
raise FileNotFoundError(f"图片 {img_path} 不存在!")
# 2. 转为张量 (1,3,640,640)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
img_tensor = torch.from_numpy(img_rgb).float().permute(2, 0, 1).unsqueeze(0)
# 3. 创建 ChannelAttention 模块(输入通道数为3)
ca = ChannelAttention(channels=3)
# 4. 前向传播并获取注意力权重
with torch.no_grad():
out = ca(img_tensor)
attention_weights = ca.act(ca.fc(ca.pool(img_tensor))) # shape [1, C, 1, 1]
print("输出形状:", out.shape) # torch.Size([1, 3, 640, 640])
print("注意力权重形状:", attention_weights.shape) # [1, 3, 1, 1]
# 转换为 numpy 并打印各通道权重
weights_np = attention_weights.squeeze().cpu().numpy()
print("各通道注意力权重:", weights_np)
# 5. 获取权重最大的通道索引
max_ch = np.argmax(weights_np)
max_weight = weights_np[max_ch]
print(f"最大权重通道索引: {max_ch}, 权重值: {max_weight:.6f}")
# 6. 可视化原图和加权后最大通道的特征图
# 原始图像(RGB)
img_display = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
# 加权后输出,选取最大权重通道
feat_map = out[0, max_ch, :, :].cpu().numpy()
feat_map = (feat_map - feat_map.min()) / (feat_map.max() - feat_map.min() + 1e-8)
plt.figure(figsize=(12, 5))
plt.subplot(1, 3, 1)
plt.imshow(img_display)
plt.title("Original")
plt.axis("off")
plt.subplot(1, 3, 2)
# 显示注意力权重热力图(将单值扩展为全图)
attn_heatmap = np.full((640, 640), weights_np[max_ch], dtype=np.float32)
plt.imshow(attn_heatmap, cmap='hot', vmin=0, vmax=1)
plt.title(f"Attention Weight\n(Ch{max_ch}, {max_weight:.3f})")
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(feat_map, cmap='gray')
plt.title(f"Weighted Feature (Ch{max_ch})")
plt.axis("off")
plt.tight_layout()
plt.savefig("channel_attention_max_channel.png", dpi=150)
# plt.show()
print("可视化已保存为 channel_attention_max_channel.png")

输出示例:
输出形状: torch.Size([1, 3, 640, 640])
注意力权重形状: torch.Size([1, 3, 1, 1])
各通道注意力权重: [1. 1. 1.]
最大权重通道索引: 0, 权重值: 1.000000
可视化已保存为 channel_attention_max_channel.png
流程示意图

代码解读
__init__ 方法
self.pool = nn.AdaptiveAvgPool2d(1):自适应全局平均池化,将每个通道的空间维度压缩为1×1,输出形状[B, C, 1, 1]。self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True):1×1 卷积,模拟全连接层,学习通道间的依赖关系。bias=True保留偏置以增强拟合能力。self.act = nn.Sigmoid():Sigmoid 激活函数,将权重限制在(0,1)之间。
forward 方法
- 先池化、再 1×1 卷积、再 Sigmoid,得到注意力权重。
- 使用广播机制将权重与输入逐元素相乘,输出与输入形状相同。
与 SENet 的区别
- 原始 SENet 使用两个全连接层(降维后再升维)来减少计算量,而本实现直接使用 1×1 卷积(相当于单层全连接),更为轻量。这种设计在 RTMDet 等高效模型中常见。
注意事项
- 输入通道数必须一致 :
channels需与输入特征图的通道数相同,否则无法相乘。 - 无 BN 和激活:该模块仅包含池化、卷积和 Sigmoid,不包含 BatchNorm,可即插即用。
- 计算开销:相比标准卷积,该模块增加的计算量很少(仅全局池化和 1×1 卷积),适合轻量级网络。
- 与空间注意力的区别:该模块仅作用于通道维度,不关注空间位置;可与空间注意力组合使用(如 CBAM)。
- 训练稳定性:Sigmoid 输出非负权重,能保证梯度稳定性,但也可能使权重趋近 0 或 1,需配合适当的学习率。
优缺点
优点
- 轻量高效 :仅增加少量参数(
C²个),计算量可忽略,适合移动端部署。 - 性能提升显著:在多种任务中可带来 1~2% 的精度提升,尤其在通道数较多的深层特征上效果明显。
- 即插即用:可插入任意 CNN 层之后,无需改动网络结构。
- 可解释性强:注意力权重可直观反映各通道的重要性,有助于模型分析。
缺点
- 忽略空间信息:仅通过全局平均池化聚合空间信息,可能丢失局部细节,对空间敏感的任务(如小目标检测)效果有限。
- 单层全连接容量有限:没有降维-升维的瓶颈结构,拟合能力弱于原始 SENet,可能在某些复杂任务上提升不足。
- 固定通道数:输入输出通道数必须相同,无法进行通道变换。
- 对初始值敏感:若 1×1 卷积初始化不当,可能导致早期梯度饱和(Sigmoid 输出接近 0 或 1),影响训练收敛。
在 YOLO 系列中,ChannelAttention 可嵌入 C2f 模块或检测头之前,用于增强关键特征。建议在深层特征图(如 P4、P5)使用,并配合适当的正则化策略。
参考文献
1 https://docs.ultralytics.com/
2 https://github.com/ultralytics/ultralytics.git
- 由于本人水平有限,难免出现错漏,敬请批评改正。
- 更多精彩内容,可点击进入Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏、人工智能混合编程实践专栏或我的个人主页查看
- YOLOs-CPP:一个免费开源的YOLO全系列C++推理库(以YOLO26为例)
- PaddleOCR:Win10上安装使用PPOCRLabel标注工具
- 目标检测:使用自己的数据集微调DEIMv2进行物体检测
- 图像分割:PyTorch从零开始实现SegFormer语义分割
- 图像超分:使用自己的数据集微调Real-ESRGAN-x4plus进行超分重建
- 图像生成:PyTorch从零开始实现一个简单的扩散模型
- Stable Diffusion:使用自己的数据集微调 Stable Diffusion 3.5 LoRA 文生图模型
- 图像超分:使用自己的数据集微调Real-ESRGAN-x2plus进行超分重建
- Anomalib:使用Anomalib 2.1.0训练自己的数据集进行异常检测
- Anomalib:在Linux服务器上安装使用Anomalib 2.1.0
- 人工智能混合编程实践:C++调用封装好的DLL进行异常检测推理
- 人工智能混合编程实践:C++调用封装好的DLL进行FP16图像超分重建(v3.0)
- 隔离系统Python:源码编译3.11.8到自定义目录(含PGO性能优化)
- 在线机的Python环境迁移到离线机上
- Nuitka 将 Python 脚本封装为 .pyd 或 .so 文件
- Ultralytics:使用 YOLO11 进行速度估计
- Ultralytics:使用 YOLO11 进行物体追踪
- Ultralytics:使用 YOLO11 进行物体计数
- Ultralytics:使用 YOLO11 进行目标打码
- 人工智能混合编程实践:C++调用Python ONNX进行YOLOv8推理
- 人工智能混合编程实践:C++调用封装好的DLL进行YOLOv8实例分割
- 人工智能混合编程实践:C++调用Python ONNX进行图像超分重建
- 人工智能混合编程实践:C++调用Python AgentOCR进行文本识别
- 通过计算实例简单地理解PatchCore异常检测
- Python将YOLO格式实例分割数据集转换为COCO格式实例分割数据集
- YOLOv8 Ultralytics:使用Ultralytics框架训练RT-DETR实时目标检测模型
- 基于DETR的人脸伪装检测
- YOLOv7训练自己的数据集(口罩检测)
- YOLOv8训练自己的数据集(足球检测)
- YOLOv5:TensorRT加速YOLOv5模型推理
- YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
- 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
- YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
- YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
- Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
- YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
- 使用Kaggle GPU资源免费体验Stable Diffusion开源项目
- Stable Diffusion:在服务器上部署使用Stable Diffusion WebUI进行AI绘图(v2.0)
- Stable Diffusion:使用自己的数据集微调训练LoRA模型(v2.0)