【模块】 ASFF 模块

ASFF (Adaptively Spatial Feature Fusion) 方法针对单次射击物体检测器的特征金字塔中存在的不同特征尺度之间的不一致性问题,提出了一种新颖的数据驱动策略进行金字塔特征融合。通过学习空间上筛选冲突信息的方法,减少了特征之间的不一致性,提高了特征的尺度不变性,并且几乎不增加推理开销。

机制

ASFF策略首先将不同层级的特征调整到相同的分辨率,然后通过训练找到最优的融合方式。在每个空间位置上,不同层级的特征被适应性融合,即某些特征因为携带矛盾信息而被过滤掉,而某些特征则因含有更多判别性线索而占主导地位。这一过程是可微分的,因此可以通过反向传播轻松学习。

优势

1、提高准确性:

利用ASFF策略和一个坚实的YOLOv3基线,在MS COCO数据集上实现了最佳的速度-精度权衡,达到了38.1%的AP(平均精度)和60 FPS(每秒帧数)的检测速度。

2、模型通用性:

该方法与基础模型无关,适用于具有特征金字塔结构的单次射击检测器,实现简单,额外计算成本较低。

3、解决特征尺度不一致问题:

通过适应性学习特征融合权重,有效解决了特征金字塔中不同尺度特征之间的一致性问题,避免了在训练过程中的梯度不一致现象,提高了训练效率和检测准确性。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


def autopad(k, p=None):  # kernel, padding
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p


class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))


class ASFF(nn.Module):
    def __init__(self, level, multiplier=1, rfb=False, vis=False, act_cfg=True):
        """
        multiplier should be 1, 0.5
        which means, the channel of ASFF can be
        512, 256, 128 -> multiplier=0.5
        1024, 512, 256 -> multiplier=1
        For even smaller, you need change code manually.
        """
        # init asff_module = ASFF(level=1, multiplier=1, rfb=False, vis=False)
        super(ASFF, self).__init__()
        self.level = level
        self.dim = [int(1024 * multiplier), int(512 * multiplier),
                    int(256 * multiplier)]
        # print(self.dim)

        self.inter_dim = self.dim[self.level]
        if level == 0:
            self.stride_level_1 = Conv(int(512 * multiplier), self.inter_dim, 3, 2)

            self.stride_level_2 = Conv(int(256 * multiplier), self.inter_dim, 3, 2)

            self.expand = Conv(self.inter_dim, int(
                1024 * multiplier), 3, 1)
        elif level == 1:
            self.compress_level_0 = Conv(
                int(1024 * multiplier), self.inter_dim, 1, 1)
            self.stride_level_2 = Conv(
                int(256 * multiplier), self.inter_dim, 3, 2)
            self.expand = Conv(self.inter_dim, int(512 * multiplier), 3, 1)
        elif level == 2:
            self.compress_level_0 = Conv(
                int(1024 * multiplier), self.inter_dim, 1, 1)
            self.compress_level_1 = Conv(
                int(512 * multiplier), self.inter_dim, 1, 1)
            self.expand = Conv(self.inter_dim, int(
                256 * multiplier), 3, 1)

        # when adding rfb, we use half number of channels to save memory
        compress_c = 8 if rfb else 16
        self.weight_level_0 = Conv(
            self.inter_dim, compress_c, 1, 1)
        self.weight_level_1 = Conv(
            self.inter_dim, compress_c, 1, 1)
        self.weight_level_2 = Conv(
            self.inter_dim, compress_c, 1, 1)

        self.weight_levels = Conv(
            compress_c * 3, 3, 1, 1)
        self.vis = vis

    def forward(self, x):  # l,m,s
        """
        #
        256, 512, 1024
        from small -> large
        """
        # forward output_feature = asff_module([level_2_feature, level_1_feature, level_0_feature])
        x_level_0 = x[2]  # 最大特征层 level_0_feature = (1, 1024, 20, 20)  # 大尺寸特征图 尺寸小通道多
        x_level_1 = x[1]  # 中间特征层 level_1_feature = (1, 512, 40, 40)   # 中尺寸特征图
        x_level_2 = x[0]  # 最小特征层 level_2_feature = (1, 256, 80, 80)   # 小尺寸特征图

        if self.level == 0:
            level_0_resized = x_level_0
            level_1_resized = self.stride_level_1(x_level_1)
            level_2_downsampled_inter = F.max_pool2d(
                x_level_2, 3, stride=2, padding=1)
            level_2_resized = self.stride_level_2(level_2_downsampled_inter)
        elif self.level == 1:
            level_0_compressed = self.compress_level_0(x_level_0) # (1, 1024, 20, 20) → self.compress_level_0 =  Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False) → (1, 512, 20, 20)
            
            level_0_resized = F.interpolate(
                level_0_compressed, scale_factor=2, mode='nearest') # (1, 512, 20, 20) → F.interpolate→ [1, 512, 40, 40]
            level_1_resized = x_level_1 #  [1, 512, 40, 40] → = → [1, 512, 40, 40]
            level_2_resized = self.stride_level_2(x_level_2) # [1, 256, 80, 80] → self.stride_level_2 = Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) → [1, 512, 40, 40]
        elif self.level == 2:
            level_0_compressed = self.compress_level_0(x_level_0)
            level_0_resized = F.interpolate(
                level_0_compressed, scale_factor=4, mode='nearest')
            x_level_1_compressed = self.compress_level_1(x_level_1)
            level_1_resized = F.interpolate(
                x_level_1_compressed, scale_factor=2, mode='nearest')
            level_2_resized = x_level_2

        level_0_weight_v = self.weight_level_0(level_0_resized) # [1, 512, 40, 40] → self.weight_level_0 = (conv): Conv2d(512, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) → [1, 16, 40, 40]
        level_1_weight_v = self.weight_level_1(level_1_resized) # [1, 512, 40, 40] → self.weight_level_1 = Conv2d(512, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) → [1, 16, 40, 40]
        level_2_weight_v = self.weight_level_2(level_2_resized) # [1, 512, 40, 40] → self.weight_level_2 = Conv2d(512, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) → [1, 16, 40, 40]

        levels_weight_v = torch.cat(
            (level_0_weight_v, level_1_weight_v, level_2_weight_v), 1) # [1, 16, 40, 40],[1, 16, 40, 40],[1, 16, 40, 40] → cat → [1, 48, 40, 40]
        levels_weight = self.weight_levels(levels_weight_v) # [1, 48, 40, 40] → self.weight_levels = (conv): Conv2d(48, 3, kernel_size=(1, 1), stride=(1, 1), bias=False) → [1, 3, 40, 40]
        levels_weight = F.softmax(levels_weight, dim=1) # [1, 3, 40, 40] → F.softmax → [1, 3, 40, 40]

        fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + \
                            level_1_resized * levels_weight[:, 1:2, :, :] + \
                            level_2_resized * levels_weight[:, 2:, :, :]
        # [1, 512, 40, 40] * [1, 1, 40, 40] + [1, 512, 40, 40] * [1, 1, 40, 40] + [1, 512, 40, 40] * [1, 1, 40, 40] →  [1, 512, 40, 40]
        out = self.expand(fused_out_reduced) # [1, 512, 40, 40] → self.expand = (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) → [1, 512, 40, 40]

        if self.vis:# self.vis = False
            return out, levels_weight, fused_out_reduced.sum(dim=1)
        else:
            return out


if __name__ == "__main__":
    # 模拟的输入特征图,模拟三个不同尺度的特征图,例如来自一个多尺度特征提取网络的输出
    level_0_feature = torch.randn(1, 1024, 20, 20)  # 大尺寸特征图
    level_1_feature = torch.randn(1, 512, 40, 40)   # 中尺寸特征图
    level_2_feature = torch.randn(1, 256, 80, 80)   # 小尺寸特征图

    # 初始化ASFF模块,level表示当前ASFF模块处理的是哪个尺度的特征层,这里以处理中尺寸特征层为例
    # multiplier用于调整通道数,rfb和vis分别表示是否使用更丰富的特征表示和是否可视化
    asff_module = ASFF(level=1, multiplier=1, rfb=False, vis=False)

    # 通过ASFF模块传递特征图
    output_feature = asff_module([level_2_feature, level_1_feature, level_0_feature])

    # 打印输出特征图的形状,确保ASFF模块正常工作
    print(f"Output feature shape: {output_feature.shape}")

# TODO 计算流程图、原文框架图、公式表示
相关推荐
Daitu_Adam35 分钟前
Windows11安装GPU版本Pytorch2.6教程
人工智能·pytorch·python·深度学习
阿正的梦工坊38 分钟前
Grouped-Query Attention(GQA)详解: Pytorch实现
人工智能·pytorch·python
Best_Me071 小时前
【CVPR2024-工业异常检测】PromptAD:与只有正常样本的少样本异常检测的学习提示
人工智能·学习·算法·计算机视觉
山海青风1 小时前
从零开始玩转TensorFlow:小明的机器学习故事 4
人工智能·机器学习·tensorflow
YoseZang1 小时前
【机器学习】信息熵 交叉熵和相对熵
人工智能·深度学习·机器学习
Ronin-Lotus1 小时前
图像处理篇---图像处理中常见参数
图像处理·人工智能·信噪比·分贝·峰值信噪比·动态范围
数据智能老司机2 小时前
深度学习架构师手册——理解神经网络变换器(Transformers)
深度学习·架构
机器视觉知识推荐、就业指导2 小时前
【数字图像处理三】图像变换与频域处理
图像处理·人工智能·计算机视觉
next_travel2 小时前
图像分割UNet、生成模型SD及IP-Adapter
pytorch·深度学习·计算机视觉
东木月2 小时前
windows安装pytorch
人工智能·pytorch·windows