Day44 CBAM

CBAM = 通道注意力(SE Block) + 空间注意力(Spatial Attention)

  • 先做通道注意力:给每个特征通道打分(重视有用通道,忽略无用通道);

  • 再做空间注意力:给特征图每个像素位置打分(重视关键区域,忽略背景区域);

  • 两步串行执行,双重聚焦关键信息,比单独的 SE 注意力效果更好。

    python 复制代码
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    # ===================== 1. 通道注意力模块(Channel Attention) =====================
    class ChannelAttention(nn.Module):
        def __init__(self, in_channels, reduction=16):
            super(ChannelAttention, self).__init__()
            # 全局平均池化 + 全局最大池化(比SE只用人均池化更鲁棒)
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
            self.max_pool = nn.AdaptiveMaxPool2d(1)
            
            # 共享的两层全连接(压缩维度→恢复维度)
            self.fc = nn.Sequential(
                nn.Linear(in_channels, in_channels // reduction),
                nn.ReLU(inplace=True),
                nn.Linear(in_channels // reduction, in_channels)
            )
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x):
            batch, C, H, W = x.size()
            
            # 平均池化分支
            avg_out = self.avg_pool(x).view(batch, C)
            avg_out = self.fc(avg_out)
            
            # 最大池化分支
            max_out = self.max_pool(x).view(batch, C)
            max_out = self.fc(max_out)
            
            # 两个分支相加后激活,得到通道权重
            channel_weight = self.sigmoid(avg_out + max_out)
            channel_weight = channel_weight.view(batch, C, 1, 1)
            
            # 权重 × 原特征(广播机制)
            return x * channel_weight
    
    # ===================== 2. 空间注意力模块(Spatial Attention) =====================
    class SpatialAttention(nn.Module):
        def __init__(self, kernel_size=7):
            super(SpatialAttention, self).__init__()
            # 卷积核必须是奇数,保证中心对齐
            assert kernel_size in (3, 7), "kernel size must be 3 or 7"
            padding = kernel_size // 2  # 保持特征图尺寸不变
            
            # 卷积层学习空间权重(输入2通道→输出1通道)
            self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x):
            # 对通道维度做平均/最大池化,得到两个空间特征图
            avg_out = torch.mean(x, dim=1, keepdim=True)  # (B,1,H,W)
            max_out, _ = torch.max(x, dim=1, keepdim=True)  # (B,1,H,W)
            
            # 拼接两个特征图 → (B,2,H,W)
            spatial_feature = torch.cat([avg_out, max_out], dim=1)
            
            # 卷积学习空间权重 → (B,1,H,W)
            spatial_weight = self.sigmoid(self.conv(spatial_feature))
            
            # 权重 × 原特征(广播机制)
            return x * spatial_weight
    
    # ===================== 3. CBAM 整体模块(通道+空间串行) =====================
    class CBAM(nn.Module):
        def __init__(self, in_channels, reduction=16, spatial_kernel_size=7):
            super(CBAM, self).__init__()
            self.channel_attention = ChannelAttention(in_channels, reduction)
            self.spatial_attention = SpatialAttention(spatial_kernel_size)
    
        def forward(self, x):
            # 先通道注意力,再空间注意力
            x = self.channel_attention(x)
            x = self.spatial_attention(x)
            return x
    
    # ===================== 4. CBAM 集成到CNN(实战示例) =====================
    class CNN_WITH_CBAM(nn.Module):
        def __init__(self, num_classes=10):
            super(CNN_WITH_CBAM, self).__init__()
            # 基础卷积层
            self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
            self.bn1 = nn.BatchNorm2d(32)  # 加BN更稳定
            self.cbam1 = CBAM(32)  # 第一个CBAM
            
            self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
            self.bn2 = nn.BatchNorm2d(64)
            self.cbam2 = CBAM(64)  # 第二个CBAM
            
            self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
            self.bn3 = nn.BatchNorm2d(128)
            self.cbam3 = CBAM(128)  # 第三个CBAM
            
            # 池化+全连接
            self.pool = nn.MaxPool2d(2)
            self.fc1 = nn.Linear(128 * 4 * 4, 256)  # CIFAR10输入32×32,3次池化后4×4
            self.fc2 = nn.Linear(256, num_classes)
    
        def forward(self, x):
            # 卷积+BN+ReLU+池化+CBAM
            x = self.pool(F.relu(self.bn1(self.conv1(x))))
            x = self.cbam1(x)
            
            x = self.pool(F.relu(self.bn2(self.conv2(x))))
            x = self.cbam2(x)
            
            x = self.pool(F.relu(self.bn3(self.conv3(x))))
            x = self.cbam3(x)
            
            # 展平+全连接
            x = x.view(x.size(0), -1)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            
            return x
    
    # ===================== 5. 测试CBAM模块 =====================
    if __name__ == "__main__":
        # 1. 测试CBAM模块输出形状
        cbam = CBAM(in_channels=32)
        x = torch.randn(4, 32, 16, 16)  # (batch, channels, H, W)
        out = cbam(x)
        print(f"CBAM输入形状:{x.shape},输出形状:{out.shape}")  # 形状不变:(4,32,16,16)
        
        # 2. 测试集成CBAM的CNN
        model = CNN_WITH_CBAM(num_classes=10)
        x = torch.randn(8, 3, 32, 32)  # CIFAR10输入形状
        out = model(x)
        print(f"CNN+CBAM输入形状:{x.shape},输出形状:{out.shape}")  # (8,10)
        
        # 3. 打印模型结构(查看参数)
        print("\n模型结构:")
        print(model)

总结

  1. CBAM 核心优势

    • 双注意力机制:通道 + 空间,精准聚焦关键信息;
    • 轻量化:计算量增加极少,效果提升显著;
    • 即插即用:适配所有 CNN 模型(自定义 / 预训练)。
  2. 核心代码要点

    • 通道注意力:平均 + 最大池化双分支,全连接学习权重;
    • 空间注意力:通道维度池化 + 卷积学习空间权重;
    • 串行执行:先通道后空间,层层筛选特征。
  3. 实战建议

    • 小模型(如自定义 CNN):加 1-2 个 CBAM 在关键卷积块后;
    • 大模型(如 ResNet):每层卷积后加,配合预训练权重微调;
    • 超参数:reduction=16kernel_size=7 是最优默认值。
相关推荐
芯片-嵌入式1 小时前
具身智能(2):OpenExplorer下的模型量化
人工智能·深度学习·算法
崔高杰2 小时前
训练数据选择又有新方法了?——两篇文章的阅读笔记 Less is Enough和 OPUS
人工智能·笔记·机器学习
Westward-sun.2 小时前
Python argparse 模块:命令行参数解析实战全攻略
python·opencv·机器学习·rpc
机器学习之心2 小时前
GRU锂电池剩余寿命预测,NASA数据集(5号电池训练6号电池测试),MATLAB代码
深度学习·matlab·gru·gru锂电池剩余寿命预测
zh路西法2 小时前
【宇树机器人强化学习】(二):ActorCritic网络和ActorCriticRecurrent网络的python实现与解析
开发语言·python·深度学习·机器学习·机器人
AI科技星3 小时前
基于v≡c空间光速螺旋量子几何归一化统一场论第一性原理的时间势差本源理论
人工智能·线性代数·算法·机器学习·平面
AC赳赳老秦3 小时前
智能协同新纪元:DeepSeek驱动的跨岗位、跨工具多智能体实操体系展望(2026)
大数据·运维·人工智能·深度学习·机器学习·ai-native·deepseek
fanxianshi3 小时前
2026 年 3 月行业动态与开源生态全景报告
人工智能·深度学习·神经网络·机器学习·计算机视觉·开源·语音识别
zadyd3 小时前
Langgraph开发:先有Graph还是先有State
人工智能·机器学习