【论文笔记】独属于CV的注意力机制CBAM-Convolutional Block Attention Module

目录

写在前面

一、基数和宽度

[二、通道注意力模块(Channel Attention Module)](#二、通道注意力模块(Channel Attention Module))

[三、空间注意力模块(Spatial Attention Module)](#三、空间注意力模块(Spatial Attention Module))

[四、CBAM(Convolutional Block Attention Module)](#四、CBAM(Convolutional Block Attention Module))

五、总结


写在前面

CBAM论文地址:https://arxiv.org/abs/1807.06521

CBAM(Convolutional Block Attention Module)是2018年被提出的,不同于ViT的Attention,CBAM是为CNN量身定做的Attention模块,实现简单、效果好,你值得拥有。

为了提高CNN的性能,我们可以从深度(depth)、宽度(width)和基数(cardinality)三个方面入手。深度很好理解,就是模型的层数,层数越多模型越深。下面说一说基数和宽度的区别。

一、基数和宽度

**基数(cardinality):**指的是并行分支的数量。

**宽度(width):**指每一层卷积的卷积核数量(即输出特征图的通道数)。

举个GoogLeNet的例子:

GoogLeNet 的设计中,Inception模块通过组合多个不同大小的卷积核(例如 1x1、3x3、5x5)和池化操作来提取不同尺度的特征。

增加"宽度"的效果: 我们有一个 Inception 模块,包含 1x1、3x3 和 5x5 的卷积层,以及一个 3x3 的最大池化层。如果我们在该 Inception 模块中增加每个卷积操作的通道数,例如将 1x1 卷积层的输出通道数从 32 增加到 64,将 3x3 卷积层的输出通道数从 64 增加到 128,这种操作就增加了网络的"宽度"。增加"宽度"意味着每个 Inception 模块可以提取更多的特征信息,但同时也增加了计算成本。

增加基数: 如果我们将每个卷积和池化操作进一步拆分为多个组,例如在每个卷积操作中使用组卷积(group convolution),那么这些并行组卷积操作的数量就类似于增加了"基数"。每个组卷积操作都是一个独立的路径,这些路径的数量增加就代表了基数的增加。基数不仅节省了参数的总数,而且比深度和宽度这两个因素具有更强的表示能力。

可以看下图,蓝色的线表示模型的基数,红色的数字表示宽度。

CBAM由两个顺序的子模块组成:通道注意力模块(Channel Attention Module)和空间注意力模块(Spatial Attention Module)。CAM解决的问题是重要的信息是什么('what' is meaningful given an input image)、SAM解决重要的信息在哪里('where' is an informative part)。这两个模块都使用了增加基数的方式,提升模型的表达能力。

二、通道注意力模块(Channel Attention Module)

我们知道CNN的每个通道可以提取不同的特征(也就是Feature Map),通道注意力模块(Channel Attention Module)的主要作用是自适应地调整和增强输入Feature Map 中每个通道的重要性。它通过学习每个通道对于当前任务的重要性权重,从而对不同通道进行加权,增强关键信息的表达,同时抑制不相关或冗余的特征。这种机制能够使神经网络更高效地利用信息,提高模型的性能和表达能力。

通俗的说,CAM就是判断每个Feature Map的重要程度。即下图中相同颜色的部分会有一个标记重要程度的权重。

CAM使用并行的平均池化和最大池化,每个池化分别经过两个卷积模块,然后相加经过sigmoid得到注意力概率图,公式如下:

其中σ为Sigmoid型函数,MLP这里使用的是两个卷积层,权重是共享的。

公式不直观,示意图如下,假设输入是一个32通道244x244的Feature Map,输出是32x1x1,表示32个Feature Map的重要程度。

使用pooling可以捕获每个通道特征的强度。平均池化(Average Pooling) 能够捕获每个通道的整体激活分布信息。它反映了一个通道中所有特征点的平均响应,能够平滑地代表特征的整体强度。最大池化(Max Pooling) 捕获的是特征图中的最强激活信号。它能够突出特征图中最显著的特征,强调特征中的极端值。

平均池化提供了特征分布的全局性信息,而最大池化提供了最显著特征的信息。通过结合这两种池化方法,通道注意力模块能够更好地理解哪些特征通道在给定输入图像中最重要。

两种池化之后的卷积层共享参数,而且两个卷积层的中间维度使用in_planes//16,最大限度的减少参数,因为注意力模块只提供一个注意力概率图,提取特征并不是它的首要任务,所以不需要太多参数。

代码示例:

python 复制代码
class ChannelAttention(nn.Module):
    def __init__(self, in_planes):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

三、空间注意力模块(Spatial Attention Module)

SAM首先沿着通道轴分别使用平均池化和最大池化操作,并将它们连接起来,然后经过一个7x7的卷积层将通道数变成1,最后是Sigmoid得到结果,是一个宽高和输入一致通道数为1的概率图。

SAM就是判断Feature Map中每个部分(像素)的重要程度,需要综合每个部分所有通道的特征。即下图中相同颜色的部分会有一个标记重要程度的权重。

公式如下:

其中,σ表示Sigmoid函数,表示大小为7×7的卷积运算。

补充一下,这里的卷积核大小是7x7,而上面CAM的是1x1,这是因为CAM关注的是整个特征图中的全局通道信息1x1 卷积核适合这种只需在通道维度上操作的情况。SAM关注的是特征图中的局部和全局空间信息7x7 卷积核适合捕捉空间维度上的局部和全局特征。

示意图如下,仍然假设输入是一个32通道244x244的Feature Map,输出是1x244x244。

同时使用平均池化和最大池化的原因和CAM类似,平均池化提供了特征分布的全局性信息,而最大池化提供了最显著特征的信息。通过结合这两种池化方法,SAM能够更好地理解Feature Map中哪些区域最重要。

四、CBAM(Convolutional Block Attention Module)

有了CAM和SAM,就剩最后一个问题,这两个模块怎么摆放,是并行放置还是顺行方式呢?作者发现顺序排列比平行排列的效果更好。下面是CBAM完整的结构图,我们随意在CBAM之前放几个卷积层:

CBAM完整的代码:

python 复制代码
import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
    def __init__(self, in_planes):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)


class DemoNet(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(DemoNet, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)

        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        out = self.relu(out)

        return out


if __name__ == '__main__':
    input = torch.randn(1, 3, 224, 224)
    demo_net = DemoNet(inplanes=3, planes=32)
    output = demo_net(input)
    print(output)

这里还有一个CBAM应用到ResNet的完整例子:https://github.com/luuuyi/CBAM.PyTorch

五、总结

1.CBAM是一个轻量级和通用的模块,它可以无缝地集成到任何CNN架构中,而开销可以忽略不计,并且可以与基础CNN一起进行端到端训练;

2.通道注意力模块(Channel Attention Module)关注每个通道的Feature Map的重要程度;

3.空间注意力模块(Spatial Attention Module)关注Feature Map每个部分(像素)的重要程度;

4.CBAM由通道注意力模块和空间注意力模块两个模块组成,同时兼顾了通道与空间特征的表达,相比传统的卷积层参数更少、效果更好。

CBAM就介绍到这里,关注不迷路(*^▽^*)

关注订阅号了解更多精品文章

交流探讨、商务合作请加微信

相关推荐
忘梓.1 小时前
解锁动态规划的奥秘:从零到精通的创新思维解析(3)
算法·动态规划
人机与认知实验室1 小时前
人、机、环境中各有其神经网络系统
人工智能·深度学习·神经网络·机器学习
黑色叉腰丶大魔王1 小时前
基于 MATLAB 的图像增强技术分享
图像处理·人工智能·计算机视觉
tinker在coding3 小时前
Coding Caprice - Linked-List 1
算法·leetcode
迅易科技4 小时前
借助腾讯云质检平台的新范式,做工业制造企业质检的“AI慧眼”
人工智能·视觉检测·制造
古希腊掌管学习的神5 小时前
[机器学习]XGBoost(3)——确定树的结构
人工智能·机器学习
ZHOU_WUYI5 小时前
4.metagpt中的软件公司智能体 (ProjectManager 角色)
人工智能·metagpt
靴子学长6 小时前
基于字节大模型的论文翻译(含免费源码)
人工智能·深度学习·nlp
AI_NEW_COME7 小时前
知识库管理系统可扩展性深度测评
人工智能
海棠AI实验室7 小时前
AI的进阶之路:从机器学习到深度学习的演变(一)
人工智能·深度学习·机器学习