选择内核注意力 SK | Selective Kernel Networks

论文名称:《Selective Kernel Networks》

论文地址:https://arxiv.org/pdf/1903.06586.pdf

代码地址:https://github.com/implus/SKNet


在标准的卷积神经网络中,每层人工神经元的感受野被设计为具有相同的大小。神经科学界已经广泛认识到,视觉皮层神经元的感受野大小会受到刺激的调节,然而在构建CNN时很少考虑这一点。我们提出了一种动态选择机制,使得CNN中的每个神经元可以根据多个输入信息的尺度自适应地调整其感受野大小。我们设计了一个称为Selective Kernel (SK)单元的构建块,其中使用softmax注意力将具有不同核大小的多个分支进行融合,这种注意力受到这些分支中的信息的指导。对这些分支的不同注意力产生了融合层中神经元的有效感受野的不同大小。多个SK单元堆叠成一个深度网络,称为Selective Kernel Networks (SKNets)。在 ImageNetCIFAR 基准测试中,我们经验证明SKNet 在模型复杂度较低的情况下胜过了现有的最先进架构。详细分析表明,SKNet中的神经元能够捕捉具有不同尺度的目标对象,从而验证了神经元根据输入自适应调整感受野大小的能力。


问题背景

在深度学习领域,卷积神经网络(CNN)通常设计为在每一层具有固定大小的感受野。这种设计忽略了视觉皮层神经元根据刺激变化而调整感受野的能力,无法充分捕获多尺度信息。为了解决这一问题,Selective Kernel Networks (SKNet) 提出了一种动态选择机制,允许神经元根据多尺度输入信息自适应地调整感受野的大小。这一机制的目标是通过软注意力机制,动态选择不同大小的卷积核,从而实现多尺度信息的聚合。


核心概念

SKNet的核心概念是"选择性卷积核"(Selective Kernel)。该机制允许网络在多条路径上使用不同大小的卷积核,并通过软注意力机制选择最合适的路径。通过这种方式,神经元可以根据输入信息的特点动态调整其感受野,从而在保持较低计算成本的同时,提高网络的性能。


模块的操作步骤

Selective Kernel的操作步骤包括三个关键环节:拆分(Split)、融合(Fuse)和选择(Select)。在拆分步骤中,模块生成多个不同大小的卷积核路径,每条路径对应不同的感受野。在融合步骤中,模块将来自多个路径的信息聚合,生成用于选择的全局表示。最后,选择步骤使用软注意力机制,根据之前生成的全局表示来选择最佳的路径,并将选择权重应用于特征图。这样,网络可以自适应地选择不同的感受野,从而增强对目标对象的感知能力。


文章贡献

这篇文章的主要贡献在于提出了选择性卷积核机制,通过动态选择不同大小的卷积核实现多尺度信息的聚合。作者通过在ImageNet和CIFAR等基准测试上进行实验,证明了SKNet的有效性。实验结果显示,SKNet在保持较低模型复杂度的同时,性能优于现有的多种先进架构。此外,SKNet的选择机制为神经元的感受野大小自适应调整提供了新的方法,这可能是提高网络在目标识别任务中性能的关键。


实验结果与应用

实验结果显示,SKNet在ImageNet和CIFAR等基准测试上都取得了优异的表现。在ImageNet上,SKNet-50比ResNeXt-50降低了1.44%的Top-1错误率,尽管两者的模型复杂度相近。此外,SKNet还可以应用于轻量级模型,如ShuffleNetV2,证明了其广泛的适用性和有效性。SKNet在对象检测和语义分割等下游任务中也表现出色,这进一步表明其在多种视觉任务中的潜力。


对未来工作的启示

SKNet的成功启示了卷积神经网络中动态选择机制的潜力。未来的工作可以探索将SKNet应用于其他类型的神经网络,或者将其与其他注意力机制相结合。此外,SKNet的选择机制可能在轻量级模型的设计中发挥重要作用,为移动设备和嵌入式系统提供高效且有效的解决方案。研究人员还可以考虑将SKNet应用于其他领域,如自然语言处理和音频分析,以进一步拓展其应用范围。


代码

python 复制代码
import torch.nn as nn
import torch


class GAM_Attention(nn.Module):
    def __init__(self, in_channels, rate=4):
        super(GAM_Attention, self).__init__()

        self.channel_attention = nn.Sequential(
            nn.Linear(in_channels, int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels / rate), in_channels),
        )

        self.spatial_attention = nn.Sequential(
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
            nn.BatchNorm2d(int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(in_channels),
        )

    def forward(self, x):

        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()
        x = x * x_channel_att
        x_spatial_att = self.spatial_attention(x).sigmoid()
        out = x * x_spatial_att

        return out


if __name__ == "__main__":
    input = torch.randn(1, 64, 20, 20)
    model = GAM_Attention(in_channels=64)
    output = model(input)
    print(output.size())
相关推荐
lanboAI31 分钟前
基于卷积神经网络的蔬菜水果识别系统,resnet50,mobilenet模型【pytorch框架+python源码】
pytorch·python·cnn
RockLiu@8052 小时前
探索PyTorch中的空间与通道双重注意力机制:实现concise的scSE模块
人工智能·pytorch·python
进取星辰2 小时前
PyTorch 深度学习实战(23):多任务强化学习(Multi-Task RL)之扩展
人工智能·pytorch·深度学习
Psycho_MrZhang3 小时前
Pytorch 反向传播
人工智能·pytorch·python
KY_chenzhao12 小时前
ChatGPT与DeepSeek在科研论文撰写中的整体科研流程与案例解析
人工智能·机器学习·chatgpt·论文·科研·deepseek
墨顿17 小时前
Transformer数学推导——Q29 推导语音识别中流式注意力(Streaming Attention)的延迟约束优化
人工智能·深度学习·transformer·注意力机制·跨模态与多模态
MatpyMaster19 小时前
液体神经网络LNN-Attention创新结合——基于液体神经网络的时间序列预测(PyTorch框架)
人工智能·pytorch·神经网络·时间序列预测
白熊18819 小时前
【计算机视觉】TorchVision 深度解析:从核心功能到实战应用 ——PyTorch 官方计算机视觉库的全面指南
人工智能·pytorch·计算机视觉
夜松云21 小时前
从对数变换到深度框架:逻辑回归与交叉熵的数学原理及PyTorch实战
pytorch·算法·逻辑回归·梯度下降·交叉熵·对数变换·sigmoid函数