🍋改进模块🍋:三重注意力机制(TripletAttention )
🍋解决问题🍋:TripletAttention模块解决了传统注意力机制在计算注意力权重时未能充分捕捉跨维度交互的问题。
🍋改进优势🍋:几乎不增加额外的可学习参数,涨点很明显
🍋适用场景🍋:小目标检测、遮挡目标检测等
🍋思路来源🍋:WACV2021《MRotate to Attend: Convolutional Triplet Attention Module》
目录
🍉🍉2.三重注意力机制(TripletAttention)模块优势
👍👍4.将TripletAttention引入YOLOv12的python代码修改
🔔🔔1.设计动机
近乎parameter-free的即插即用attention模块
TripletAttention模块的设计动机是为了解决传统注意力机制在计算注意力权重时未能充分捕捉跨维度交互的问题。以往的注意力机制,如SENet、CBAM等,虽然在通道注意力和空间注意力方面取得了显著的性能提升,但它们通常存在以下局限性:
-
缺乏跨维度交互:传统方法在计算通道注意力时,通常通过全局池化操作将输入张量压缩为单像素通道,导致空间信息丢失,无法捕捉通道维度与空间维度之间的依赖关系。
-
冗余的降维操作:一些方法(如SENet和CBAM)在计算通道注意力时引入了降维操作,这虽然降低了计算复杂度,但也间接导致了通道权重与通道特征之间的对应关系不明确,从而影响了特征表示的丰富性。
针对这些问题,TripletAttention模块旨在通过一种轻量级且高效的方式捕捉输入张量的跨维度交互,从而提供更丰富的特征表示,同时避免不必要的降维操作。
🍉🍉2.三重注意力机制( TripletAttention**)模块优势**
👍轻量级设计
与传统的注意力机制相比,TripletAttention几乎不增加额外的可学习参数。例如,在ResNet-50上,它仅增加了约4.8K参数和4.7e-2 GFLOPs的计算量,但显著提升了模型性能。
👍高效的跨维度交互
通过旋转操作和并行分支结构,TripletAttention能够同时捕捉通道维度与空间维度(H和W)之间的交互,以及空间维度之间的交互。这种跨维度交互能够提供更丰富的特征表示,而无需额外的降维操作。
👍广泛的适用性
TripletAttention可以作为附加模块轻松集成到各种经典的骨干网络中,如ResNet、MobileNet等,适用于图像分类、目标检测等多种计算机视觉任务。
👍显著的性能提升
实验表明,TripletAttention在多个数据集(如ImageNet、COCO、PASCAL VOC)上均取得了显著的性能提升,尤其是在目标检测和实例分割任务中,其改进效果尤为明显。
🏆🏆3. TripletAttention结构

TripletAttention模块由三个并行分支组成,每个分支负责捕捉输入张量的不同维度交互。其具体结构如下:
🐸输入张量
假设输入张量的形状为 C×H×W,其中 C 表示通道数,H 和 W 分别表示特征图的高度和宽度。
🐸三个分支
-
第一分支:通过旋转操作将输入张量沿高度维度 H 逆时针旋转90°,得到形状为 W×H×C 的张量。然后通过Z-Pool操作将通道维度压缩为2,得到形状为 2×H×C 的张量。接着通过一个 k×k 的卷积层和批量归一化层,生成形状为 1×H×C 的注意力权重。最后,将注意力权重通过Sigmoid激活函数生成注意力图,并将其应用到旋转后的输入张量上,再顺时针旋转90°恢复到原始形状。
-
第二分支:与第一分支类似,但旋转操作是沿宽度维度 W 进行的。最终生成形状为 1×C×W 的注意力图,并恢复到原始形状。
-
第三分支:直接对输入张量进行Z-Pool操作,将通道维度压缩为2,得到形状为 2×H×W 的张量。然后通过一个 k×k 的卷积层和批量归一化层,生成形状为 1×H×W 的空间注意力图,并将其应用到输入张量上。
🐸输出
将三个分支生成的注意力图通过简单平均的方式进行聚合,得到最终的输出张量,其形状与输入张量相同。
👍👍4.将 TripletAttention引入YOLOv12的python代码修改
🍂修改处一
在ultralytics/nn/modules/目录下添加TripletAttention.py,TripletAttention.py定义如下:
python
import torch
import torch.nn as nn
class BasicConv(nn.Module):
def __init__(
self,
in_planes, # 输入通道数
out_planes, # 输出通道数
kernel_size, # 卷积核大小
stride=1, # 步长
padding=0, # 填充
dilation=1, # 膨胀系数
groups=1, # 组数,用于分组卷积
relu=True, # 是否使用ReLU激活函数
bn=True, # 是否使用Batch Normalization
bias=False, # 是否使用偏置
):
super(BasicConv, self).__init__()
self.out_channels = out_planes
# 卷积层定义
self.conv = nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
# 定义BN层(可选)
self.bn = (
nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
if bn
else None
)
# 定义ReLU激活函数(可选)
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class ChannelPool(nn.Module):
"""用于通道池化,生成两个特征图:最大池化图和平均池化图。"""
def forward(self, x):
return torch.cat(
(torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
)
class SpatialGate(nn.Module):
"""生成空间注意力的门控机制,基于输入特征的空间分布生成注意力图。"""
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7 # 卷积核大小
# 使用ChannelPool压缩通道维度后,使用BasicConv层生成空间注意力图
self.compress = ChannelPool()
self.spatial = BasicConv(
2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False
)
def forward(self, x):
x_compress = self.compress(x) # 通道池化
x_out = self.spatial(x_compress) # 生成注意力图
scale = torch.sigmoid_(x_out) # 使用sigmoid激活以限制范围在0到1
return x * scale # 将输入乘以注意力权重
class TripletAttention(nn.Module):
"""三重注意力模块,通过通道方向和空间方向对特征图生成注意力。"""
def __init__(
self,
no_spatial=False, # 是否禁用空间注意力
):
super(TripletAttention, self).__init__()
self.ChannelGateH = SpatialGate() # 水平方向注意力
self.ChannelGateW = SpatialGate() # 垂直方向注意力
self.no_spatial = no_spatial # 控制是否使用空间注意力
if not no_spatial:
self.SpatialGate = SpatialGate() # 空间注意力
def forward(self, x):
x_perm1 = x.permute(0, 2, 1, 3).contiguous() # 将通道和宽度维度互换
x_out1 = self.ChannelGateH(x_perm1) # 计算水平方向注意力
x_out11 = x_out1.permute(0, 2, 1, 3).contiguous() # 恢复原始维度顺序
x_perm2 = x.permute(0, 3, 2, 1).contiguous() # 将通道和高度维度互换
x_out2 = self.ChannelGateW(x_perm2) # 计算垂直方向注意力
x_out21 = x_out2.permute(0, 3, 2, 1).contiguous() # 恢复原始维度顺序
if not self.no_spatial:
x_out = self.SpatialGate(x) # 计算空间注意力
x_out = (1 / 3) * (x_out + x_out11 + x_out21) # 三种注意力的平均加权
else:
x_out = (1 / 2) * (x_out11 + x_out21) # 两种注意力的平均加权
return x_out # 输出加权后的特征图
def autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
"""Perform transposed convolution of 2D data."""
return self.act(self.conv(x))
🍂修改处二
在 task.py文件中导入TripletAttention类
导入代码如下:
python
import thop
import torch
import torch.nn as nn
#修改
from ultralytics.nn.modules.TripletAttentionimport TripletAttention
#修改
🍂修改处三
在task.py中将TripletAttention的使用添加修改到1075行左右位置,修改后代码如下:
python
...
elif m is TripletAttention:
args = []
else:
c2 = ch[f]
注意:使用YOLOv12进行训练,torch版本要在2.x及以上,否则会出现scaled_dot_product_attention模块导入不成功的问题
python
from torch.nn.functional import scaled_dot_product_attention as sdpa
🍂修改处四
修改网络结构定义yolov12.yaml文件,在原来的模型结构基础上添加TripletAttention模块在第一个A2C2f之后。其他位置A2C2f之后也可以,自行调整位置,注意要根据自己的数据集修改nc,即num_class。
python
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov12n.yaml' will call yolov12.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # summary: 465 layers, 2,603,056 parameters, 2,603,040 gradients, 6.7 GFLOPs
s: [0.50, 0.50, 1024] # summary: 465 layers, 9,285,632 parameters, 9,285,616 gradients, 21.7 GFLOPs
m: [0.50, 1.00, 512] # summary: 501 layers, 20,201,216 parameters, 20,201,200 gradients, 68.1 GFLOPs
l: [1.00, 1.00, 512] # summary: 831 layers, 26,454,880 parameters, 26,454,864 gradients, 89.7 GFLOPs
x: [1.00, 1.50, 512] # summary: 831 layers, 59,216,928 parameters, 59,216,912 gradients, 200.3 GFLOPs
# YOLO12n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 4, A2C2f, [512, True, 4]]
- [-1, 1, TripletAttention, []]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 4, A2C2f, [1024, True, 1]] # 8
# YOLO12n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 2, A2C2f, [512, False, -1]] # 11
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 2, A2C2f, [256, False, -1]] # 14
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 11], 1, Concat, [1]] # cat head P4
- [-1, 2, A2C2f, [512, False, -1]] # 17
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 8], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large)
- [[14, 17, 20], 1, Detect, [nc]] # Detect(P3, P4, P5)
👍👍5.成功训练后的网络结构截图

整理不易,欢迎一键三连!!!
送你们一条美丽的--分割线--
🌷🌷🍀🍀🌾🌾🍓🍓🍂🍂🙋🙋🐸🐸🙋🙋💖💖🍌🍌🔔🔔🍉🍉🍭🍭🍋🍋🍇🍇🏆🏆📸📸⛵⛵⭐⭐🍎🍎👍👍🌷🌷