CBAM-2018学习笔记

名称:

Convolutional Block Attention Module (CBAM)

来源:

CBAM: Convolutional Block Attention Module

相关工作:

#ResNet #GoogleNet #ResNeXt #Network-engineering #Attention-mechanism

创新点:

贡献:

  • 提出CBAM
  • 验证了其有效性
  • 改善提高了以往模型的性能

代码:

python 复制代码
  
import torch  
from torch import nn  
  
  
class ChannelAttention(nn.Module):  
    def __init__(self, in_planes, ratio=16):  
        super(ChannelAttention, self).__init__()  
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  
        self.max_pool = nn.AdaptiveMaxPool2d(1)  
  
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)  
        self.relu1 = nn.ReLU()  
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))  
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))  
        out = avg_out + max_out  
        return self.sigmoid(out)  
  
  
class SpatialAttention(nn.Module):  
    def __init__(self, kernel_size=7):  
        super(SpatialAttention, self).__init__()  
  
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'  
        padding = 3 if kernel_size == 7 else 1  
  
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)  # 7,3     3,1  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        avg_out = torch.mean(x, dim=1, keepdim=True)  
        max_out, _ = torch.max(x, dim=1, keepdim=True)  
        x = torch.cat([avg_out, max_out], dim=1)  
        x = self.conv1(x)  
        return self.sigmoid(x)  
  
  
class CBAM(nn.Module):  
    def __init__(self, in_planes, ratio=16, kernel_size=7):  
        super(CBAM, self).__init__()  
        self.ca = ChannelAttention(in_planes, ratio)  
        self.sa = SpatialAttention(kernel_size)  
  
    def forward(self, x):  
        out = x * self.ca(x)  
        result = out * self.sa(out)  
        return result  
  
  
# 输入 N C H W,  输出 N C H Wif __name__ == '__main__':  
    block = CBAM(64)  
    input = torch.rand(3, 64, 32, 32)  
    output = block(input)  
    print(input.size(), output.size())
相关推荐
一棵开花的树,枝芽无限靠近你8 分钟前
【Pytorch】(一)使用 PyTorch 进行深度学习:60 分钟速成
人工智能·pytorch·深度学习
超龄超能程序猿9 分钟前
Docker常用中间件部署笔记:MongoDB、Redis、MySQL、Tomcat快速搭建
笔记·docker·中间件
奔波霸的伶俐虫14 分钟前
windows docker desktop 安装修改镜像学习
学习·docker·容器
时兮兮时15 分钟前
CALIPSO垂直特征掩膜(VFM)—使用python绘制主类型、气溶胶和云的子类型
笔记·python·calipso
时兮兮时18 分钟前
MODIS Land Cover (MCD12Q1 and MCD12C1) Product—官方文档的中文翻译
笔记·mcd12q1
时兮兮时33 分钟前
Linux 服务器后台任务生存指南
linux·服务器·笔记
BullSmall40 分钟前
《逍遥游》
学习
奔波霸的伶俐虫40 分钟前
spring boot集成kafka学习
spring boot·学习·kafka
CCPC不拿奖不改名41 分钟前
面向对象编程:继承与多态+面试习题
开发语言·数据结构·python·学习·面试·职场和发展
GHL28427109042 分钟前
通义千问的 Function Call - demo学习
学习·ai·ai编程