Day44 CBAM

CBAM = 通道注意力(SE Block) + 空间注意力(Spatial Attention)

  • 先做通道注意力:给每个特征通道打分(重视有用通道,忽略无用通道);

  • 再做空间注意力:给特征图每个像素位置打分(重视关键区域,忽略背景区域);

  • 两步串行执行,双重聚焦关键信息,比单独的 SE 注意力效果更好。

    python 复制代码
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    # ===================== 1. 通道注意力模块(Channel Attention) =====================
    class ChannelAttention(nn.Module):
        def __init__(self, in_channels, reduction=16):
            super(ChannelAttention, self).__init__()
            # 全局平均池化 + 全局最大池化(比SE只用人均池化更鲁棒)
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
            self.max_pool = nn.AdaptiveMaxPool2d(1)
            
            # 共享的两层全连接(压缩维度→恢复维度)
            self.fc = nn.Sequential(
                nn.Linear(in_channels, in_channels // reduction),
                nn.ReLU(inplace=True),
                nn.Linear(in_channels // reduction, in_channels)
            )
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x):
            batch, C, H, W = x.size()
            
            # 平均池化分支
            avg_out = self.avg_pool(x).view(batch, C)
            avg_out = self.fc(avg_out)
            
            # 最大池化分支
            max_out = self.max_pool(x).view(batch, C)
            max_out = self.fc(max_out)
            
            # 两个分支相加后激活,得到通道权重
            channel_weight = self.sigmoid(avg_out + max_out)
            channel_weight = channel_weight.view(batch, C, 1, 1)
            
            # 权重 × 原特征(广播机制)
            return x * channel_weight
    
    # ===================== 2. 空间注意力模块(Spatial Attention) =====================
    class SpatialAttention(nn.Module):
        def __init__(self, kernel_size=7):
            super(SpatialAttention, self).__init__()
            # 卷积核必须是奇数,保证中心对齐
            assert kernel_size in (3, 7), "kernel size must be 3 or 7"
            padding = kernel_size // 2  # 保持特征图尺寸不变
            
            # 卷积层学习空间权重(输入2通道→输出1通道)
            self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x):
            # 对通道维度做平均/最大池化,得到两个空间特征图
            avg_out = torch.mean(x, dim=1, keepdim=True)  # (B,1,H,W)
            max_out, _ = torch.max(x, dim=1, keepdim=True)  # (B,1,H,W)
            
            # 拼接两个特征图 → (B,2,H,W)
            spatial_feature = torch.cat([avg_out, max_out], dim=1)
            
            # 卷积学习空间权重 → (B,1,H,W)
            spatial_weight = self.sigmoid(self.conv(spatial_feature))
            
            # 权重 × 原特征(广播机制)
            return x * spatial_weight
    
    # ===================== 3. CBAM 整体模块(通道+空间串行) =====================
    class CBAM(nn.Module):
        def __init__(self, in_channels, reduction=16, spatial_kernel_size=7):
            super(CBAM, self).__init__()
            self.channel_attention = ChannelAttention(in_channels, reduction)
            self.spatial_attention = SpatialAttention(spatial_kernel_size)
    
        def forward(self, x):
            # 先通道注意力,再空间注意力
            x = self.channel_attention(x)
            x = self.spatial_attention(x)
            return x
    
    # ===================== 4. CBAM 集成到CNN(实战示例) =====================
    class CNN_WITH_CBAM(nn.Module):
        def __init__(self, num_classes=10):
            super(CNN_WITH_CBAM, self).__init__()
            # 基础卷积层
            self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
            self.bn1 = nn.BatchNorm2d(32)  # 加BN更稳定
            self.cbam1 = CBAM(32)  # 第一个CBAM
            
            self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
            self.bn2 = nn.BatchNorm2d(64)
            self.cbam2 = CBAM(64)  # 第二个CBAM
            
            self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
            self.bn3 = nn.BatchNorm2d(128)
            self.cbam3 = CBAM(128)  # 第三个CBAM
            
            # 池化+全连接
            self.pool = nn.MaxPool2d(2)
            self.fc1 = nn.Linear(128 * 4 * 4, 256)  # CIFAR10输入32×32,3次池化后4×4
            self.fc2 = nn.Linear(256, num_classes)
    
        def forward(self, x):
            # 卷积+BN+ReLU+池化+CBAM
            x = self.pool(F.relu(self.bn1(self.conv1(x))))
            x = self.cbam1(x)
            
            x = self.pool(F.relu(self.bn2(self.conv2(x))))
            x = self.cbam2(x)
            
            x = self.pool(F.relu(self.bn3(self.conv3(x))))
            x = self.cbam3(x)
            
            # 展平+全连接
            x = x.view(x.size(0), -1)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            
            return x
    
    # ===================== 5. 测试CBAM模块 =====================
    if __name__ == "__main__":
        # 1. 测试CBAM模块输出形状
        cbam = CBAM(in_channels=32)
        x = torch.randn(4, 32, 16, 16)  # (batch, channels, H, W)
        out = cbam(x)
        print(f"CBAM输入形状:{x.shape},输出形状:{out.shape}")  # 形状不变:(4,32,16,16)
        
        # 2. 测试集成CBAM的CNN
        model = CNN_WITH_CBAM(num_classes=10)
        x = torch.randn(8, 3, 32, 32)  # CIFAR10输入形状
        out = model(x)
        print(f"CNN+CBAM输入形状:{x.shape},输出形状:{out.shape}")  # (8,10)
        
        # 3. 打印模型结构(查看参数)
        print("\n模型结构:")
        print(model)

总结

  1. CBAM 核心优势

    • 双注意力机制:通道 + 空间,精准聚焦关键信息;
    • 轻量化:计算量增加极少,效果提升显著;
    • 即插即用:适配所有 CNN 模型(自定义 / 预训练)。
  2. 核心代码要点

    • 通道注意力:平均 + 最大池化双分支,全连接学习权重;
    • 空间注意力:通道维度池化 + 卷积学习空间权重;
    • 串行执行:先通道后空间,层层筛选特征。
  3. 实战建议

    • 小模型(如自定义 CNN):加 1-2 个 CBAM 在关键卷积块后;
    • 大模型(如 ResNet):每层卷积后加,配合预训练权重微调;
    • 超参数:reduction=16kernel_size=7 是最优默认值。
相关推荐
香蕉鼠片3 小时前
数字化图像的过程
人工智能·深度学习·计算机视觉
lqqjuly4 小时前
深度学习理论:从神经网络到Transformer—前馈网络、反向传播、注意力机制与训练
深度学习·神经网络·transformer
chsmiao4 小时前
张量(Tensor)
深度学习·ai编程
A_Sinon4 小时前
卷积神经网络
人工智能·神经网络·cnn
chsmiao5 小时前
深度学习之线性代数
人工智能·深度学习·线性代数
HyperAI超神经7 小时前
MiniCPM5-1B采用RL+OPD训练,多项复杂任务达SOTA;面向复杂医疗业务自动化:医疗智能体评测数据集 CHI-Bench
人工智能·深度学习·ai·计算化学
一个王同学7 小时前
从零到一 | CV转多模态大模型 | week12 | 整理 MiniLLaVA 工程与文档
人工智能·深度学习·算法·机器学习·计算机视觉
chsmiao7 小时前
深度学习之微积分
人工智能·深度学习
阳明山水7 小时前
LightGBM为何胜过Prophet做销量预测
人工智能·深度学习·机器学习·微信公众平台·微信开放平台
硅谷秋水7 小时前
世界模型:架构、方法、推理与应用的综述(下)
人工智能·机器学习·计算机视觉·语言模型·机器人