CBAM-2018学习笔记

名称:

Convolutional Block Attention Module (CBAM)

来源:

CBAM: Convolutional Block Attention Module

相关工作:

#ResNet #GoogleNet #ResNeXt #Network-engineering #Attention-mechanism

创新点:

贡献:

  • 提出CBAM
  • 验证了其有效性
  • 改善提高了以往模型的性能

代码:

python 复制代码
  
import torch  
from torch import nn  
  
  
class ChannelAttention(nn.Module):  
    def __init__(self, in_planes, ratio=16):  
        super(ChannelAttention, self).__init__()  
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  
        self.max_pool = nn.AdaptiveMaxPool2d(1)  
  
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)  
        self.relu1 = nn.ReLU()  
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))  
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))  
        out = avg_out + max_out  
        return self.sigmoid(out)  
  
  
class SpatialAttention(nn.Module):  
    def __init__(self, kernel_size=7):  
        super(SpatialAttention, self).__init__()  
  
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'  
        padding = 3 if kernel_size == 7 else 1  
  
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)  # 7,3     3,1  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        avg_out = torch.mean(x, dim=1, keepdim=True)  
        max_out, _ = torch.max(x, dim=1, keepdim=True)  
        x = torch.cat([avg_out, max_out], dim=1)  
        x = self.conv1(x)  
        return self.sigmoid(x)  
  
  
class CBAM(nn.Module):  
    def __init__(self, in_planes, ratio=16, kernel_size=7):  
        super(CBAM, self).__init__()  
        self.ca = ChannelAttention(in_planes, ratio)  
        self.sa = SpatialAttention(kernel_size)  
  
    def forward(self, x):  
        out = x * self.ca(x)  
        result = out * self.sa(out)  
        return result  
  
  
# 输入 N C H W,  输出 N C H Wif __name__ == '__main__':  
    block = CBAM(64)  
    input = torch.rand(3, 64, 32, 32)  
    output = block(input)  
    print(input.size(), output.size())
相关推荐
悠哉悠哉愿意2 小时前
【电赛学习笔记】MaixCAM 的OCR图片文字识别
笔记·python·嵌入式硬件·学习·视觉检测·ocr
_Kayo_3 小时前
VUE2 学习笔记5 动态绑定class、条件渲染、列表过滤与排序
笔记·学习
waveee1233 小时前
学习嵌入式的第三十四天-数据结构-(2025.7.29)数据库
数据结构·数据库·学习
点云SLAM3 小时前
PyTorch中flatten()函数详解以及与view()和 reshape()的对比和实战代码示例
人工智能·pytorch·python·计算机视觉·3d深度学习·张量flatten操作·张量数据结构
爱分享的飘哥3 小时前
第三篇:VAE架构详解与PyTorch实现:从零构建AI的“视觉压缩引擎”
人工智能·pytorch·python·aigc·教程·生成模型·代码实战
charlie1145141915 小时前
设计自己的小传输协议 导论与概念
c++·笔记·qt·网络协议·设计·通信协议
xiaoxiaoxiaolll5 小时前
Adv. Sci. 前沿:非零高斯曲率3D结构可逆转换!液晶弹性体多级形变新策略
学习
xiaoli23276 小时前
课题学习笔记3——SBERT
笔记·学习·nlp·bert
缘友一世7 小时前
Agent常用搜索引擎Tavily使用学习
学习·搜索引擎·agent
超浪的晨8 小时前
JavaWeb 入门:JavaScript 基础与实战详解(Java 开发者视角)
java·开发语言·前端·javascript·后端·学习·个人开发