【深度学习注意力机制系列】—— SCSE注意力机制(附pytorch实现)

SCSE注意力模块 (来自论文[1803.02579] Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks (arxiv.org))。其对SE注意力模块进行了改进,提出了cSE、sSE、scSE 三个模块变体,这些模块可以增强有意义的特征,抑制无用特征。今天我们就分别讲解一下这三个注意力模块。

1、cSE模块(通道维度的SE注意力机制)

cSE模块 引入了通道注意力机制,可有效的对通道维度的特征信息进行整合增强,这一点与SE等传统通道注意力机制近似,其最大不同的是其对得到的注意力权重进行了降维再升维的操作,类似与resnet中的瓶颈结构以及Fast RCNN目标检测网络最后的全连接加速层,这种操作方式有些奇异值分解的意思,在深度学习模型中十分常见,可有效的整合通道信息,并且简化模块复杂度,减小模型计算量,提升计算速度

实现机制

  • 将特征图通过全局平均池化层将维度从[C, H, W]变为[C, 1, 1]。
  • 然后使用两个1×1卷积进行信息的处理(即降维与升维操作),最终得到C维的向量。
  • 然后使用sigmoid函数进行归一化,得到对应的权重向量文件。
  • 最后通过channel-wise与原始特征图相乘,得到经过通道信息真个校准过的特征图。

代码实现

python 复制代码
class CSE(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(CSE, self).__init__()
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1),
            nn.Sigmoid(),
        )
   
    def forward(self, x):
        return x * self.cSE(x)

2、sSE模块(空间维度的SE注意力机制)

sSE模块 在特征图的空间维度 展开信息增强整合,同通道维度一样,其也是通过先提取权重信息,再将权重信息同原始特征图相乘得到注意力增强效果,不过在提取权重信息时是在空间维度展开,不再是使用全局平均池化层,而是使用输出通道为1,卷积核大小为1×1 的卷积层,进行信息整合

这里我们顺便简介一下1×1卷积层的作用

改变通道数 (即升维降维)

信息整合(可实现跨通道的信息交互)

增加非线性(基于奇异值分解,结合非线性激活函数,加深模型)

实现机制

  • 将特征图通过一个输出通道为1,卷积核大小为1×1 的卷积层,得到一个维度为(1, H, W)的权重矩阵。
  • 将权重矩阵进行sigmod归一化处理,得到最终的权重矩阵。
  • 将权重矩阵同原始特征图在空间维度相乘,得到最终空间信息增强特征图结果。

代码实现

python 复制代码
class SSE(nn.Module):
    def __init__(self, in_channels):
        super(SSE, self).__init__()
        self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())

    def forward(self, x):
        return x * self.sSE(x)

3、scSE模块(混合维度的SE注意力机制)

scSE模块是sSE模块和cSE模块的综合体,即同时对空间维度和通道维度进行信息整合增强,将两者的特征结果沿着通道维度进行相加(结果和原始特征图维度相同)。

实现机制

  • 将特征图通过cSE模块,得到特征图结果1。
  • 将特征图通过sSE模块,得到特征图结果2.
  • 将特征图结果1和2沿着通道维度相加,得到最终信息校正结果(前后特征图维度不变)。

代码实现

python 复制代码
class SCSE(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SCSE, self).__init__()
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1),
            nn.Sigmoid(),
        )
        self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())

    def forward(self, x):
        return x * self.cSE(x) + x * self.sSE(x)
相关推荐
Codebee6 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º6 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys6 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_56786 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子6 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
智驱力人工智能7 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
qq_160144877 小时前
亲测!2026年零基础学AI的入门干货,新手照做就能上手
人工智能
Howie Zphile7 小时前
全面预算管理难以落地的核心真相:“完美模型幻觉”的认知误区
人工智能·全面预算
人工不智能5777 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
盟接之桥7 小时前
盟接之桥说制造:引流品 × 利润品,全球电商平台高效产品组合策略(供讨论)
大数据·linux·服务器·网络·人工智能·制造