即插即用显著位置注意力spab,涨点起飞

题目:Salient Positions based Attention Network for Image Classification

论文地址:https://arxiv.org/pdf/2106.04996

创新点

  • 提出了基于显著位置的注意力机制:论文提出了一种名为SPAblock的显著位置选择算法(SPS),通过在注意力计算中仅选择显著位置,减少了计算复杂度和内存需求,同时提取了对图像分类有用的上下文信息。这种方法有效减少了对非关键信息的处理,特别是在图像背景复杂的情况下能更好地避免噪声干扰。

  • 采用了通道维度的聚合:与传统的非局部模块相比,SPAblock在通道维度上对信息进行聚合,而不是空间维度。这种聚合方式在减少计算资源的同时,能够更有效地提取特征信息,提高图像分类的准确性。

  • 引入了适用于特征图的显著性度量:针对特征图的高维特性,论文设计了一种基于平方和的显著性度量方法,利用其近似高斯分布的特点,通过平方和来选择显著位置。这种方法适用于神经网络生成的高维特征图,而非传统的视觉图像。

  • 在低层网络中取得更好效果:实验表明,SPAblock在低层网络中表现更佳,尤其在CIFAR和Tiny-ImageNet等数据集上优于传统的非局部模块,并显著减少了内存使用。这种设计更适合于低层网络的特性,能够在低层网络上更好地进行上下文建模。

方法

整体结构

这篇论文提出了SPAblock模型结构,其核心是基于显著位置的注意力机制(SPS算法),在输入特征中选择少量显著位置进行注意力计算,从而减少计算量和内存占用。模型首先生成查询和数值矩阵,通过SPS算法筛选显著位置,计算注意力矩阵后将数值矩阵的上下文信息聚合到输出特征中,并通过1×11 \times 1 卷积更新特征,最终与输入特征相加形成输出。该设计在图像分类任务中提升了精度,特别适合应用在网络的低层次。

  • 输入特征生成查询和数值矩阵:特征图首先经过两个二维卷积层,分别生成查询矩阵QQ 和数值矩阵VV。这样可以将输入特征图转化为适合注意力计算的形式。

  • 显著位置选择(Salient Positions Selection, SPS)算法:SPS算法根据查询矩阵的平方和选择出前kk 个显著位置。具体来说,SPS算法先计算查询矩阵各通道的平方和,并对每个通道进行求和,再根据该值选择显著位置。这一步骤减少了关注位置的数量,从而降低了计算复杂度。

  • 计算注意力矩阵:利用SPS选出的显著位置构建注意力矩阵AA,并将其进行softmax归一化。此过程相当于计算查询和键的相似度,但仅限于显著位置,节省了大量计算资源。

  • 特征聚合与更新:使用数值矩阵VV和注意力矩阵AA 进行矩阵乘法,将结果重新整形为与输入相同的尺寸。然后通过一个1×11 \times 1 卷积进行变换,并将结果与输入特征相加,以形成输出特征。

  • 逐层级应用的灵活性:SPAblock可以插入到ResNet等深度网络的不同层级,尤其是在低层级的效果尤为显著。这是因为低层特征图通常包含更多空间细节,而SPAblock能够有效地提取其中的显著信息。

即插即用模块作用

SPAblock 作为一个即插即用模块,主要适用于:

  • **图像分类:**在分类任务中增强对重要特征的关注,忽略无关背景。

  • 目标检测:提高对目标区域的聚焦,减少背景噪声的影响。

  • **实时应用:**在资源受限的环境中,如移动设备或嵌入式系统中,用于减少计算量和内存需求。

  • **深度网络的低层或中层:**在特征图信息丰富的低层或中层加入SPAblock,可以更有效地提取关键细节。

消融实验结果

该表比较了SPAblock在ResNet不同层级(从第1到第4层)加入时的性能。结果表明,在低层(尤其是第1层和第2层)加入SPAblock能显著提升分类精度,而在第4层的效果较弱。这说明SPAblock更适合用于低层次的网络结构,因为在低层特征图中显著区域的信息更为丰富,有助于提升模型性能。

该表比较了SPAblock在ResNet不同层级(从第1到第4层)加入时的性能。结果表明,在低层(尤其是第1层和第2层)加入SPAblock能显著提升分类精度,而在第4层的效果较弱。这说明SPAblock更适合用于低层次的网络结构,因为在低层特征图中显著区域的信息更为丰富,有助于提升模型性能。

即插即用模块

python 复制代码
import torch
from torch import nn




class SPABlock(nn.Module):
    def __init__(self, in_channels, k=784, adaptive = False, reduction=16, learning=False, mode='pow'):

        super(SPABlock, self).__init__()
        self.in_channels = in_channels
        self.reduction = reduction
        self.k = k
        self.adptive = adaptive
        self.reduction = reduction
        self.learing = learning
        if self.learing is True:
            self.k = nn.Parameter(torch.tensor(self.k))

        self.mode = mode
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    def forward(self, x, return_info=False):
        input_shape = x.shape
        if len(input_shape)==4:
            x = x.view(x.size(0), self.in_channels, -1)
            x = x.permute(0, 2, 1)
        batch_size,N = x.size(0),x.size(1)

        #(B, H*W,C)
        if self.mode == 'pow':
            x_pow = torch.pow(x,2)# (batchsize,H*W,channel)
            x_powsum = torch.sum(x_pow,dim=2)# (batchsize,H*W)

        if self.adptive is True:
            self.k = N//self.reduction
            if self.k == 0:
                self.k = 1

        outvalue, outindices = x_powsum.topk(k=self.k, dim=-1, largest=True, sorted=True)

        outindices = outindices.unsqueeze(2).expand(batch_size, self.k, x.size(2))
        out = x.gather(dim=1, index=outindices).to(self.device)

        if return_info is True:
            return out, outindices, outvalue
        else:
            return out

if __name__ == '__main__':
    block = SPABlock(in_channels=128)
    input = torch.rand(32, 784, 128)
    output = block(input)
    print(input.size())    print(output.size())
相关推荐
HPC_fac130520678161 分钟前
以科学计算为切入点:剖析英伟达服务器过热难题
服务器·人工智能·深度学习·机器学习·计算机视觉·数据挖掘·gpu算力
小陈phd3 小时前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
如若1238 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
加密新世界10 小时前
优化 Solana 程序
人工智能·算法·计算机视觉
WeeJot嵌入式13 小时前
OpenCV:计算机视觉的瑞士军刀
计算机视觉
思通数科多模态大模型13 小时前
10大核心应用场景,解锁AI检测系统的智能安全之道
人工智能·深度学习·安全·目标检测·计算机视觉·自然语言处理·数据挖掘
学不会lostfound13 小时前
三、计算机视觉_05MTCNN人脸检测
pytorch·深度学习·计算机视觉·mtcnn·p-net·r-net·o-net
Mr.谢尔比14 小时前
李宏毅机器学习课程知识点摘要(1-5集)
人工智能·pytorch·深度学习·神经网络·算法·机器学习·计算机视觉
思通数科AI全行业智能NLP系统15 小时前
六大核心应用场景,解锁AI检测系统的智能安全之道
图像处理·人工智能·深度学习·安全·目标检测·计算机视觉·知识图谱
李歘歘18 小时前
Stable Diffusion经典应用场景
人工智能·深度学习·计算机视觉