CBAM-2018学习笔记

名称:

Convolutional Block Attention Module (CBAM)

来源:

CBAM: Convolutional Block Attention Module

相关工作:

#ResNet #GoogleNet #ResNeXt #Network-engineering #Attention-mechanism

创新点:

贡献:

  • 提出CBAM
  • 验证了其有效性
  • 改善提高了以往模型的性能

代码:

python 复制代码
  
import torch  
from torch import nn  
  
  
class ChannelAttention(nn.Module):  
    def __init__(self, in_planes, ratio=16):  
        super(ChannelAttention, self).__init__()  
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  
        self.max_pool = nn.AdaptiveMaxPool2d(1)  
  
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)  
        self.relu1 = nn.ReLU()  
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))  
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))  
        out = avg_out + max_out  
        return self.sigmoid(out)  
  
  
class SpatialAttention(nn.Module):  
    def __init__(self, kernel_size=7):  
        super(SpatialAttention, self).__init__()  
  
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'  
        padding = 3 if kernel_size == 7 else 1  
  
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)  # 7,3     3,1  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        avg_out = torch.mean(x, dim=1, keepdim=True)  
        max_out, _ = torch.max(x, dim=1, keepdim=True)  
        x = torch.cat([avg_out, max_out], dim=1)  
        x = self.conv1(x)  
        return self.sigmoid(x)  
  
  
class CBAM(nn.Module):  
    def __init__(self, in_planes, ratio=16, kernel_size=7):  
        super(CBAM, self).__init__()  
        self.ca = ChannelAttention(in_planes, ratio)  
        self.sa = SpatialAttention(kernel_size)  
  
    def forward(self, x):  
        out = x * self.ca(x)  
        result = out * self.sa(out)  
        return result  
  
  
# 输入 N C H W,  输出 N C H Wif __name__ == '__main__':  
    block = CBAM(64)  
    input = torch.rand(3, 64, 32, 32)  
    output = block(input)  
    print(input.size(), output.size())
相关推荐
无心水10 分钟前
【Java面试笔记:基础】8.对比Vector、ArrayList、LinkedList有何区别?
java·笔记·面试·vector·arraylist·linkedlist
卡皮巴拉爱吃小蛋糕18 分钟前
MySQL的MVCC【学习笔记】
数据库·笔记·mysql
清流君28 分钟前
【MySQL】数据库 Navicat 可视化工具与 MySQL 命令行基本操作
数据库·人工智能·笔记·mysql·ue5·数字孪生
Angindem1 小时前
SpringClound 微服务分布式Nacos学习笔记
分布式·学习·微服务
虾球xz2 小时前
游戏引擎学习第244天: 完成异步纹理下载
c++·学习·游戏引擎
BOB-wangbaohai2 小时前
Flowable7.x学习笔记(十四)查看部署流程Bpmn2.0-xml
xml·笔记·学习
先生沉默先2 小时前
c#接口_抽象类_多态学习
开发语言·学习·c#
豆芽8192 小时前
图解YOLO(You Only Look Once)目标检测(v1-v5)
人工智能·深度学习·学习·yolo·目标检测·计算机视觉
友善啊,朋友2 小时前
《普通逻辑》学习记录——性质命题及其推理
学习·逻辑学
北上ing2 小时前
从FP32到BF16,再到混合精度的全景解析
人工智能·pytorch·深度学习·计算机视觉·stable diffusion