CBAM = 通道注意力(SE Block) + 空间注意力(Spatial Attention)
-
先做通道注意力:给每个特征通道打分(重视有用通道,忽略无用通道);
-
再做空间注意力:给特征图每个像素位置打分(重视关键区域,忽略背景区域);
-
两步串行执行,双重聚焦关键信息,比单独的 SE 注意力效果更好。
pythonimport torch import torch.nn as nn import torch.nn.functional as F # ===================== 1. 通道注意力模块(Channel Attention) ===================== class ChannelAttention(nn.Module): def __init__(self, in_channels, reduction=16): super(ChannelAttention, self).__init__() # 全局平均池化 + 全局最大池化(比SE只用人均池化更鲁棒) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) # 共享的两层全连接(压缩维度→恢复维度) self.fc = nn.Sequential( nn.Linear(in_channels, in_channels // reduction), nn.ReLU(inplace=True), nn.Linear(in_channels // reduction, in_channels) ) self.sigmoid = nn.Sigmoid() def forward(self, x): batch, C, H, W = x.size() # 平均池化分支 avg_out = self.avg_pool(x).view(batch, C) avg_out = self.fc(avg_out) # 最大池化分支 max_out = self.max_pool(x).view(batch, C) max_out = self.fc(max_out) # 两个分支相加后激活,得到通道权重 channel_weight = self.sigmoid(avg_out + max_out) channel_weight = channel_weight.view(batch, C, 1, 1) # 权重 × 原特征(广播机制) return x * channel_weight # ===================== 2. 空间注意力模块(Spatial Attention) ===================== class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() # 卷积核必须是奇数,保证中心对齐 assert kernel_size in (3, 7), "kernel size must be 3 or 7" padding = kernel_size // 2 # 保持特征图尺寸不变 # 卷积层学习空间权重(输入2通道→输出1通道) self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): # 对通道维度做平均/最大池化,得到两个空间特征图 avg_out = torch.mean(x, dim=1, keepdim=True) # (B,1,H,W) max_out, _ = torch.max(x, dim=1, keepdim=True) # (B,1,H,W) # 拼接两个特征图 → (B,2,H,W) spatial_feature = torch.cat([avg_out, max_out], dim=1) # 卷积学习空间权重 → (B,1,H,W) spatial_weight = self.sigmoid(self.conv(spatial_feature)) # 权重 × 原特征(广播机制) return x * spatial_weight # ===================== 3. CBAM 整体模块(通道+空间串行) ===================== class CBAM(nn.Module): def __init__(self, in_channels, reduction=16, spatial_kernel_size=7): super(CBAM, self).__init__() self.channel_attention = ChannelAttention(in_channels, reduction) self.spatial_attention = SpatialAttention(spatial_kernel_size) def forward(self, x): # 先通道注意力,再空间注意力 x = self.channel_attention(x) x = self.spatial_attention(x) return x # ===================== 4. CBAM 集成到CNN(实战示例) ===================== class CNN_WITH_CBAM(nn.Module): def __init__(self, num_classes=10): super(CNN_WITH_CBAM, self).__init__() # 基础卷积层 self.conv1 = nn.Conv2d(3, 32, 3, padding=1) self.bn1 = nn.BatchNorm2d(32) # 加BN更稳定 self.cbam1 = CBAM(32) # 第一个CBAM self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.cbam2 = CBAM(64) # 第二个CBAM self.conv3 = nn.Conv2d(64, 128, 3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.cbam3 = CBAM(128) # 第三个CBAM # 池化+全连接 self.pool = nn.MaxPool2d(2) self.fc1 = nn.Linear(128 * 4 * 4, 256) # CIFAR10输入32×32,3次池化后4×4 self.fc2 = nn.Linear(256, num_classes) def forward(self, x): # 卷积+BN+ReLU+池化+CBAM x = self.pool(F.relu(self.bn1(self.conv1(x)))) x = self.cbam1(x) x = self.pool(F.relu(self.bn2(self.conv2(x)))) x = self.cbam2(x) x = self.pool(F.relu(self.bn3(self.conv3(x)))) x = self.cbam3(x) # 展平+全连接 x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x # ===================== 5. 测试CBAM模块 ===================== if __name__ == "__main__": # 1. 测试CBAM模块输出形状 cbam = CBAM(in_channels=32) x = torch.randn(4, 32, 16, 16) # (batch, channels, H, W) out = cbam(x) print(f"CBAM输入形状:{x.shape},输出形状:{out.shape}") # 形状不变:(4,32,16,16) # 2. 测试集成CBAM的CNN model = CNN_WITH_CBAM(num_classes=10) x = torch.randn(8, 3, 32, 32) # CIFAR10输入形状 out = model(x) print(f"CNN+CBAM输入形状:{x.shape},输出形状:{out.shape}") # (8,10) # 3. 打印模型结构(查看参数) print("\n模型结构:") print(model)
总结
-
CBAM 核心优势:
- 双注意力机制:通道 + 空间,精准聚焦关键信息;
- 轻量化:计算量增加极少,效果提升显著;
- 即插即用:适配所有 CNN 模型(自定义 / 预训练)。
-
核心代码要点:
- 通道注意力:平均 + 最大池化双分支,全连接学习权重;
- 空间注意力:通道维度池化 + 卷积学习空间权重;
- 串行执行:先通道后空间,层层筛选特征。
-
实战建议:
- 小模型(如自定义 CNN):加 1-2 个 CBAM 在关键卷积块后;
- 大模型(如 ResNet):每层卷积后加,配合预训练权重微调;
- 超参数:
reduction=16,kernel_size=7是最优默认值。