论文阅读(二十四):SA-Net: Shuffle Attention for Deep Convolutional Neural Networks

文章目录


论文:SA-Net:Shuffle Attention for Deep Convolutional Neural Networks(SA-Net:置换注意力机制)

论文链接:SA-Net:Shuffle Attention for Deep Convolutional Neural Networks

代码链接:Github

Abstract

计算机视觉的注意力机制主要有空间注意力机制和通道注意力机制两种,分别旨在捕获像素(空间域)依赖性和通道依赖性,尽管将它们融合在一起可能会获得更好的性能,但也会增加计算开销。本文提出一种高效的置换注意力机制 S h u f f l e    A t t e n t i o n ( S A ) Shuffle\;Attention(SA) ShuffleAttention(SA),其将通道维度分组到多个子特征中,然后再并行处理。对于每个子特征,SA 利用 S h u f f l e    U n i t Shuffle\;Unit ShuffleUnit来描述空间和通道维度的特征依赖关系。之后将所有子特征聚合,并采用 c h a n n e l s h u f f l e channel shuffle channelshuffle运算来实现不同子特征之间的信息通信。

1.Introduction

常见的注意力机制,如GCNet(Gcnet: Non-local networks meet squeeze-excitation networks and beyond)、CBAM(CBAM: convolutional block attention module),将空间注意力和通道注意力整合到一个模块中,但也带来较大的计算负担。受ShuffleNet v2(Shufflenet V2: practical guidelines for efficient CNN architecture design)的启发,本文针对深度卷积神经网络提出了置换注意力机制SA(Shuffle Attention)。它将通道维度分为多个子特征,然后利用Shuffle Unit为每个子特征集成互补的通道和空间注意力模块。

2.Shuffle Attention

S h u f f l e    A t t e n t i o n Shuffle\;Attention ShuffleAttention机制包含两种运算:

【1.特征分组】

设有特征图 X ∈ R C × H × W X∈R^{C×H×W} X∈RC×H×W,SA沿通道维度将 X X X分为 G G G组,即, X = X 1 , X 2 , . . . X g , X k ∈ R C g × H × W X={X_1,X_2,...X_g},X_k∈R^{\frac{C}{g}×H×W} X=X1,X2,...Xg,Xk∈RgC×H×W。通过attention模块为每个子特征生成相应的重要性系数。具体来说,在每个注意力单元的开始,将输入 X k X_k Xk沿通道维度拆分为两个分支 X k 1 、 X k 2 ∈ R C 2 g × H × W X_{k1}、X_{k2}∈R^{\frac{C}{2g}×H×W} Xk1、Xk2∈R2gC×H×W。一个分支利用通道的相互关系生成通道注意力图,另一个分支利用特征的空间关系生成空间注意力图。

【2.通道注意力图】

完全捕获通道之间的依赖关系的常见模块,如SE(Squeeze-and-Excitation Networks)模块,其会带来太多的参数。本文提出了一种替代方案,与SE模块的思想一样,先通过全局平均池化(GAP)操作来收集空间域的所有信息,将 X k 1 X_{k1} Xk1转换为向量 1 × 1 × C 2 g 1×1×\frac{C}{2g} 1×1×2gC。计算公式:

之后通过简单的门控机制( F c F_c Fc)与 s i g m o i d sigmoid sigmoid函数( σ σ σ)生成通道注意力图,将其与 X k 1 X_{k1} Xk1相乘,即可完全捕获通道之间的依赖关系。计算公式:

【3.空间注意力图】

空间注意力图用于捕获位置信息(语义信息),其一般是通道注意力的补充。具体来说,对 X k 2 X_{k2} Xk2使用组归一化来捕获空间域的统计信息,与生成通道注意力图的方式相同,使用简单的门控机制( F c F_c Fc)与 s i g m o i d sigmoid sigmoid函数( σ σ σ)生成空间注意力图,将其与 X k 2 X_{k2} Xk2相乘,即可完全捕获空间域信息。计算公式:

【4.特征融合】

先通过 C o n c a t Concat Concat操作将特征图融合得到 X k ′ = [ X k 1 ′ , X k 2 ′ ] ∈ R C 2 G × H × W X'k=[X'{k1},X'_{k2}]∈R^{\frac{C}{2G}×H×W} Xk′=[Xk1′,Xk2′]∈R2GC×H×W。最后采用与ShuffleNetV2相同的思想,采用通道置换操作(channel shuffle)。进行组间通信。SA的最终输出具有与输入相同的尺寸。

3.Code

py 复制代码
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
class sa_layer(nn.Module):
    """Constructs a Channel Spatial Group module.
    Args:
        k_size: Adaptive selection of kernel size
    """
    def __init__(self, channel, groups=64):
        super(sa_layer, self).__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.cweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
        self.sigmoid = nn.Sigmoid()
        self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups))
    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)
        # flatten
        x = x.reshape(b, -1, h, w)
        return x
    def forward(self, x):
    	#1.特征分组
        b, c, h, w = x.shape
        x = x.reshape(b * self.groups, -1, h, w)
        x_0, x_1 = x.chunk(2, dim=1)
        #2.通道注意力图
        xn = self.avg_pool(x_0)
        xn = self.cweight * xn + self.cbias
        xn = x_0 * self.sigmoid(xn)
        #3.空间注意力图
        xs = self.gn(x_1)
        xs = self.sweight * xs + self.sbias
        xs = x_1 * self.sigmoid(xs)
        #特征融合
        out = torch.cat([xn, xs], dim=1)
        out = out.reshape(b, -1, h, w)
        out = self.channel_shuffle(out, 2)
        return out
相关推荐
DuHz1 天前
无线通信与雷达感知融合的波形设计与信号处理——论文阅读(上)
论文阅读·信号处理
DuHz1 天前
无线通信与雷达感知融合的波形设计与信号处理——论文阅读(下)
论文阅读·汽车·信息与通信·信号处理
张较瘦_2 天前
[论文阅读] AI + 软件工程 | LLM救场Serverless开发!SlsReuse框架让函数复用率飙升至91%,还快了44%
论文阅读·人工智能·软件工程
m0_650108243 天前
InstructBLIP:面向通用视觉语言模型的指令微调技术解析
论文阅读·人工智能·q-former·指令微调的视觉语言大模型·零样本跨任务泛化·通用视觉语言模型
做cv的小昊3 天前
VLM经典论文阅读:【综述】An Introduction to Vision-Language Modeling
论文阅读·人工智能·计算机视觉·语言模型·自然语言处理·bert·transformer
m0_650108244 天前
PaLM-E:具身智能的多模态语言模型新范式
论文阅读·人工智能·机器人·具身智能·多模态大语言模型·palm-e·大模型驱动
m0_650108244 天前
PaLM:Pathways 驱动的大规模语言模型 scaling 实践
论文阅读·人工智能·palm·谷歌大模型·大规模语言模型·全面评估与行为分析·scaling效应
小殊小殊4 天前
【论文笔记】视频RAG-Vgent:基于图结构的视频检索推理框架
论文阅读·人工智能·深度学习
有点不太正常4 天前
《ShadowCoT: Cognitive Hijacking for Stealthy Reasoning Backdoors in LLMs》——论文阅读
论文阅读·大模型·agent安全
小殊小殊4 天前
【论文笔记】大型语言模型的知识蒸馏与数据集蒸馏
论文阅读·人工智能·深度学习