即插即用显著位置注意力spab,涨点起飞

题目:Salient Positions based Attention Network for Image Classification

论文地址:https://arxiv.org/pdf/2106.04996

创新点

  • 提出了基于显著位置的注意力机制:论文提出了一种名为SPAblock的显著位置选择算法(SPS),通过在注意力计算中仅选择显著位置,减少了计算复杂度和内存需求,同时提取了对图像分类有用的上下文信息。这种方法有效减少了对非关键信息的处理,特别是在图像背景复杂的情况下能更好地避免噪声干扰。

  • 采用了通道维度的聚合:与传统的非局部模块相比,SPAblock在通道维度上对信息进行聚合,而不是空间维度。这种聚合方式在减少计算资源的同时,能够更有效地提取特征信息,提高图像分类的准确性。

  • 引入了适用于特征图的显著性度量:针对特征图的高维特性,论文设计了一种基于平方和的显著性度量方法,利用其近似高斯分布的特点,通过平方和来选择显著位置。这种方法适用于神经网络生成的高维特征图,而非传统的视觉图像。

  • 在低层网络中取得更好效果:实验表明,SPAblock在低层网络中表现更佳,尤其在CIFAR和Tiny-ImageNet等数据集上优于传统的非局部模块,并显著减少了内存使用。这种设计更适合于低层网络的特性,能够在低层网络上更好地进行上下文建模。

方法

整体结构

这篇论文提出了SPAblock模型结构,其核心是基于显著位置的注意力机制(SPS算法),在输入特征中选择少量显著位置进行注意力计算,从而减少计算量和内存占用。模型首先生成查询和数值矩阵,通过SPS算法筛选显著位置,计算注意力矩阵后将数值矩阵的上下文信息聚合到输出特征中,并通过1×11 \times 1 卷积更新特征,最终与输入特征相加形成输出。该设计在图像分类任务中提升了精度,特别适合应用在网络的低层次。

  • 输入特征生成查询和数值矩阵:特征图首先经过两个二维卷积层,分别生成查询矩阵QQ 和数值矩阵VV。这样可以将输入特征图转化为适合注意力计算的形式。

  • 显著位置选择(Salient Positions Selection, SPS)算法:SPS算法根据查询矩阵的平方和选择出前kk 个显著位置。具体来说,SPS算法先计算查询矩阵各通道的平方和,并对每个通道进行求和,再根据该值选择显著位置。这一步骤减少了关注位置的数量,从而降低了计算复杂度。

  • 计算注意力矩阵:利用SPS选出的显著位置构建注意力矩阵AA,并将其进行softmax归一化。此过程相当于计算查询和键的相似度,但仅限于显著位置,节省了大量计算资源。

  • 特征聚合与更新:使用数值矩阵VV和注意力矩阵AA 进行矩阵乘法,将结果重新整形为与输入相同的尺寸。然后通过一个1×11 \times 1 卷积进行变换,并将结果与输入特征相加,以形成输出特征。

  • 逐层级应用的灵活性:SPAblock可以插入到ResNet等深度网络的不同层级,尤其是在低层级的效果尤为显著。这是因为低层特征图通常包含更多空间细节,而SPAblock能够有效地提取其中的显著信息。

即插即用模块作用

SPAblock 作为一个即插即用模块,主要适用于:

  • **图像分类:**在分类任务中增强对重要特征的关注,忽略无关背景。

  • 目标检测:提高对目标区域的聚焦,减少背景噪声的影响。

  • **实时应用:**在资源受限的环境中,如移动设备或嵌入式系统中,用于减少计算量和内存需求。

  • **深度网络的低层或中层:**在特征图信息丰富的低层或中层加入SPAblock,可以更有效地提取关键细节。

消融实验结果

该表比较了SPAblock在ResNet不同层级(从第1到第4层)加入时的性能。结果表明,在低层(尤其是第1层和第2层)加入SPAblock能显著提升分类精度,而在第4层的效果较弱。这说明SPAblock更适合用于低层次的网络结构,因为在低层特征图中显著区域的信息更为丰富,有助于提升模型性能。

该表比较了SPAblock在ResNet不同层级(从第1到第4层)加入时的性能。结果表明,在低层(尤其是第1层和第2层)加入SPAblock能显著提升分类精度,而在第4层的效果较弱。这说明SPAblock更适合用于低层次的网络结构,因为在低层特征图中显著区域的信息更为丰富,有助于提升模型性能。

即插即用模块

python 复制代码
import torch
from torch import nn




class SPABlock(nn.Module):
    def __init__(self, in_channels, k=784, adaptive = False, reduction=16, learning=False, mode='pow'):

        super(SPABlock, self).__init__()
        self.in_channels = in_channels
        self.reduction = reduction
        self.k = k
        self.adptive = adaptive
        self.reduction = reduction
        self.learing = learning
        if self.learing is True:
            self.k = nn.Parameter(torch.tensor(self.k))

        self.mode = mode
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    def forward(self, x, return_info=False):
        input_shape = x.shape
        if len(input_shape)==4:
            x = x.view(x.size(0), self.in_channels, -1)
            x = x.permute(0, 2, 1)
        batch_size,N = x.size(0),x.size(1)

        #(B, H*W,C)
        if self.mode == 'pow':
            x_pow = torch.pow(x,2)# (batchsize,H*W,channel)
            x_powsum = torch.sum(x_pow,dim=2)# (batchsize,H*W)

        if self.adptive is True:
            self.k = N//self.reduction
            if self.k == 0:
                self.k = 1

        outvalue, outindices = x_powsum.topk(k=self.k, dim=-1, largest=True, sorted=True)

        outindices = outindices.unsqueeze(2).expand(batch_size, self.k, x.size(2))
        out = x.gather(dim=1, index=outindices).to(self.device)

        if return_info is True:
            return out, outindices, outvalue
        else:
            return out

if __name__ == '__main__':
    block = SPABlock(in_channels=128)
    input = torch.rand(32, 784, 128)
    output = block(input)
    print(input.size())    print(output.size())
相关推荐
硅谷秋水2 小时前
REALM:用于机器人操作泛化能力的真实-仿真验证基准测试
人工智能·机器学习·计算机视觉·语言模型·机器人
Pyeako2 小时前
opencv计算机视觉--LBPH&EigenFace&FisherFace人脸识别
人工智能·python·opencv·计算机视觉·lbph·eigenface·fisherface
工程师老罗2 小时前
举例说明YOLOv1 输出坐标到原图像素的映射关系
人工智能·yolo·计算机视觉
格林威3 小时前
Baumer相机水果表皮瘀伤识别:实现无损品质分级的 7 个核心方法,附 OpenCV+Halcon 实战代码!
人工智能·opencv·计算机视觉·视觉检测·工业相机·sdk开发·堡盟相机
【赫兹威客】浩哥3 小时前
农作物病虫害检测数据集分享及多版本YOLO模型训练验证
人工智能·计算机视觉·目标跟踪
爱打代码的小林3 小时前
基于 OpenCV 与 Dlib 的人脸替换
人工智能·opencv·计算机视觉
光泽雨15 小时前
检测阈值 匹配阈值分析 金字塔
图像处理·人工智能·计算机视觉·机器视觉·smart3
sali-tec15 小时前
C# 基于OpenCv的视觉工作流-章22-Harris角点
图像处理·人工智能·opencv·算法·计算机视觉
学电子她就能回来吗15 小时前
深度学习速成:损失函数与反向传播
人工智能·深度学习·学习·计算机视觉·github
光羽隹衡17 小时前
计算机视觉——Opencv(图像拼接)
人工智能·opencv·计算机视觉