选择内核注意力 SK | Selective Kernel Networks

论文名称:《Selective Kernel Networks》

论文地址:https://arxiv.org/pdf/1903.06586.pdf

代码地址:https://github.com/implus/SKNet


在标准的卷积神经网络中,每层人工神经元的感受野被设计为具有相同的大小。神经科学界已经广泛认识到,视觉皮层神经元的感受野大小会受到刺激的调节,然而在构建CNN时很少考虑这一点。我们提出了一种动态选择机制,使得CNN中的每个神经元可以根据多个输入信息的尺度自适应地调整其感受野大小。我们设计了一个称为Selective Kernel (SK)单元的构建块,其中使用softmax注意力将具有不同核大小的多个分支进行融合,这种注意力受到这些分支中的信息的指导。对这些分支的不同注意力产生了融合层中神经元的有效感受野的不同大小。多个SK单元堆叠成一个深度网络,称为Selective Kernel Networks (SKNets)。在 ImageNetCIFAR 基准测试中,我们经验证明SKNet 在模型复杂度较低的情况下胜过了现有的最先进架构。详细分析表明,SKNet中的神经元能够捕捉具有不同尺度的目标对象,从而验证了神经元根据输入自适应调整感受野大小的能力。


问题背景

在深度学习领域,卷积神经网络(CNN)通常设计为在每一层具有固定大小的感受野。这种设计忽略了视觉皮层神经元根据刺激变化而调整感受野的能力,无法充分捕获多尺度信息。为了解决这一问题,Selective Kernel Networks (SKNet) 提出了一种动态选择机制,允许神经元根据多尺度输入信息自适应地调整感受野的大小。这一机制的目标是通过软注意力机制,动态选择不同大小的卷积核,从而实现多尺度信息的聚合。


核心概念

SKNet的核心概念是"选择性卷积核"(Selective Kernel)。该机制允许网络在多条路径上使用不同大小的卷积核,并通过软注意力机制选择最合适的路径。通过这种方式,神经元可以根据输入信息的特点动态调整其感受野,从而在保持较低计算成本的同时,提高网络的性能。


模块的操作步骤

Selective Kernel的操作步骤包括三个关键环节:拆分(Split)、融合(Fuse)和选择(Select)。在拆分步骤中,模块生成多个不同大小的卷积核路径,每条路径对应不同的感受野。在融合步骤中,模块将来自多个路径的信息聚合,生成用于选择的全局表示。最后,选择步骤使用软注意力机制,根据之前生成的全局表示来选择最佳的路径,并将选择权重应用于特征图。这样,网络可以自适应地选择不同的感受野,从而增强对目标对象的感知能力。


文章贡献

这篇文章的主要贡献在于提出了选择性卷积核机制,通过动态选择不同大小的卷积核实现多尺度信息的聚合。作者通过在ImageNet和CIFAR等基准测试上进行实验,证明了SKNet的有效性。实验结果显示,SKNet在保持较低模型复杂度的同时,性能优于现有的多种先进架构。此外,SKNet的选择机制为神经元的感受野大小自适应调整提供了新的方法,这可能是提高网络在目标识别任务中性能的关键。


实验结果与应用

实验结果显示,SKNet在ImageNet和CIFAR等基准测试上都取得了优异的表现。在ImageNet上,SKNet-50比ResNeXt-50降低了1.44%的Top-1错误率,尽管两者的模型复杂度相近。此外,SKNet还可以应用于轻量级模型,如ShuffleNetV2,证明了其广泛的适用性和有效性。SKNet在对象检测和语义分割等下游任务中也表现出色,这进一步表明其在多种视觉任务中的潜力。


对未来工作的启示

SKNet的成功启示了卷积神经网络中动态选择机制的潜力。未来的工作可以探索将SKNet应用于其他类型的神经网络,或者将其与其他注意力机制相结合。此外,SKNet的选择机制可能在轻量级模型的设计中发挥重要作用,为移动设备和嵌入式系统提供高效且有效的解决方案。研究人员还可以考虑将SKNet应用于其他领域,如自然语言处理和音频分析,以进一步拓展其应用范围。


代码

python 复制代码
import torch.nn as nn
import torch


class GAM_Attention(nn.Module):
    def __init__(self, in_channels, rate=4):
        super(GAM_Attention, self).__init__()

        self.channel_attention = nn.Sequential(
            nn.Linear(in_channels, int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels / rate), in_channels),
        )

        self.spatial_attention = nn.Sequential(
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
            nn.BatchNorm2d(int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(in_channels),
        )

    def forward(self, x):

        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()
        x = x * x_channel_att
        x_spatial_att = self.spatial_attention(x).sigmoid()
        out = x * x_spatial_att

        return out


if __name__ == "__main__":
    input = torch.randn(1, 64, 20, 20)
    model = GAM_Attention(in_channels=64)
    output = model(input)
    print(output.size())
相关推荐
丕羽7 小时前
【Pytorch】基本语法
人工智能·pytorch·python
Shy96041818 小时前
Pytorch实现transformer语言模型
人工智能·pytorch
Mephostopheles19 小时前
0基础读顶会论文—流程即服务(PraaS):通过无服务器流程统一弹性云和有状态云
论文
周末不下雨1 天前
跟着小土堆学习pytorch(六)——神经网络的基本骨架(nn.model)
pytorch·神经网络·学习
spssau1 天前
13类高频数据分析方法分类汇总
大数据·数据分析·论文·spss·spssau
蜡笔小新星1 天前
针对初学者的PyTorch项目推荐
开发语言·人工智能·pytorch·经验分享·python·深度学习·学习
矩阵猫咪1 天前
【深度学习】时间序列预测、分类、异常检测、概率预测项目实战案例
人工智能·pytorch·深度学习·神经网络·机器学习·transformer·时间序列预测
zs1996_1 天前
深度学习注意力机制类型总结&pytorch实现代码
人工智能·pytorch·深度学习
阿亨仔1 天前
Pytorch猴痘病识别
人工智能·pytorch·python·深度学习·算法·机器学习
AI视觉网奇2 天前
nvlink 训练笔记
pytorch·笔记·深度学习