CNN算法实战系列05 | SE注意力机制改造 ResDenseNet

一、前置知识

1、知识总结

在 J4 周 ResDenseNet(融合 ResNet 残差连接 + DenseNet 密集连接)的基础上,

引入 SE-Net 通道注意力机制 ,构建 SE-ResDenseNet,进一步提升模型对关键特征通道的关注能力。

SE-ResDenseNet 创新点

|--------|-----------------|---------------------------|
| 维度 | J4: ResDenseNet | J5: SE-ResDenseNet |
| 通道注意力 | 无 | SE 模块自适应重标定通道权重 |
| DRB 内部 | 密集连接 + 残差连接 | 密集连接 + SE 通道加权 + 残差连接 |
| 特征选择 | 所有通道等权重 | 重要通道增强,次要通道抑制 |
| 压缩层后 | 直接残差相加 | SE 加权后再残差相加 |

SE 模块嵌入位置(关键设计)

遵循论文建议,在 Addition 前 对分支特征进行 SE 重标定:

复制代码
DenseLayer x L → Concat → Compress 1x1 → [SE Attention] → + (Residual Add)
                                              ↑
                                    Squeeze: Global AvgPool
                                    Excitation: FC→ReLU→FC→Sigmoid
                                    Scale: 通道加权

改进思路迁移

SE 注意力机制是一种通用的通道注意力方法,可以迁移到:

  • CNN 系列: 在任何 ResNet/Inception/MobileNet 的 block 中嵌入
  • Transformer: 类似思想已演化为 Squeeze-and-Excitation 的注意力头
  • 目标检测/分割: 在 FPN、特征金字塔中添加 SE 模块提升特征选择能力

模型架构对比图

查看 ResDenseNet 与 SE-ResDenseNet 的核心模块对比。

二、代码实现

1、准备工作

1.1 设置GPU

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
import os, copy, warnings
import numpy as np
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

1.2 导入数据(增强策略)

沿用 J4 周的数据增强策略:随机水平翻转 + 颜色抖动 + 标准化

复制代码
data_dir = './data/day01'

train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

total_data = datasets.ImageFolder(data_dir, transform=train_transforms)
total_data

Dataset ImageFolder
    Number of datapoints: 1661
    Root location: ./data/day01
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.9, 1.1), hue=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

total_data.class_to_idx

{'0Normal': 0, '2Mild': 1, '4Severe': 2}

1.3 划分数据集

复制代码
train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])

# 测试集使用无增强 transform
test_dataset.dataset = datasets.ImageFolder(data_dir, transform=test_transforms)
train_dataset.dataset = datasets.ImageFolder(data_dir, transform=train_transforms)

batch_size = 8

train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                       shuffle=True, num_workers=0)
test_dl  = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                      num_workers=0)

for X, y in test_dl:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

Shape of X [N, C, H, W]:  torch.Size([8, 3, 224, 224])
Shape of y:  torch.Size([8]) torch.int64

2、搭建 SE-ResDenseNet 模型

SE-ResDenseNet 设计理念

在 J4 周 ResDenseNet 的 Dense Residual Block (DRB) 中嵌入 SE 注意力模块:

复制代码
J4 周 DRB 流程:
  DenseLayer x L → Concat → 1x1 Compress → (+ Residual Add)

J5 周 SE-DRB 流程 (新增 SE 模块):
  DenseLayer x L → Concat → 1x1 Compress → [SE Attention] → (+ Residual Add)
                                                 ↓
                                    Squeeze: Global AvgPool → C个数值
                                    Excitation: FC(C→C/r)→ReLU→FC(C/r→C)→Sigmoid
                                    Scale: 逐通道加权 (重要通道增强)

SE 嵌入位置选择依据(参考 SE 论文):

  • 放在 Addition 前 对分支特征重标定 → 避免主支 0~1 scale 导致梯度消失
  • 放在 Compress 后 → 对压缩后的特征进行通道选择,参数更少
  • reduction=16 → 第一个 FC 层将通道数降为 C/16,平衡精度与参数量

2.1 SE 注意力模块

SE (Squeeze-and-Excitation) 模块实现通道注意力:

  1. Squeeze: 全局平均池化,将每个通道压缩为一个标量
  1. Excitation: 两个全连接层学习通道间依赖关系
  1. Scale: 将学到的权重乘回原特征

    class SEAttention(nn.Module):
    """SE (Squeeze-and-Excitation) 通道注意力模块

    复制代码
     论文: https://arxiv.org/abs/1709.01507
     
     流程:
     1. Squeeze: 全局平均池化 -> 每个 channel 压缩为 1 个数值
     2. Excitation: FC(C→C/r) -> ReLU -> FC(C/r→C) -> Sigmoid
     3. Scale: 将权重乘回原特征图
     """
     def __init__(self, channel, reduction=16):
         super(SEAttention, self).__init__()
         # Squeeze: 自适应平均池化
         self.avg_pool = nn.AdaptiveAvgPool2d(1)
         # Excitation: 两个全连接层 (bottleneck 结构)
         self.fc = nn.Sequential(
             nn.Linear(channel, channel // reduction, bias=False),
             nn.ReLU(inplace=True),
             nn.Linear(channel // reduction, channel, bias=False),
             nn.Sigmoid()
         )
    
     def forward(self, x):
         b, c, _, _ = x.size()
         # Squeeze
         y = self.avg_pool(x).view(b, c)
         # Excitation
         y = self.fc(y).view(b, c, 1, 1)
         # Scale
         return x * y.expand_as(x)

    验证 SE 模块

    if name == 'main':
    x = torch.randn(8, 512, 32, 32)
    se = SEAttention(channel=512, reduction=16)
    out = se(x)
    print(f'SE module test: input {x.shape} -> output {out.shape}')
    print(f'SE params: {sum(p.numel() for p in se.parameters()):,}')

    SE module test: input torch.Size([8, 512, 32, 32]) -> output torch.Size([8, 512, 32, 32])
    SE params: 32,768

2.2 完整 SE-ResDenseNet 模型

复制代码
class DenseLayer(nn.Module):
    """密集层:Pre-Activation 瓶颈结构 (BN->ReLU->1x1Conv->BN->ReLU->3x3Conv)
    
    借鉴 ResNetV2 的 Pre-Activation 设计,将 BN 和 ReLU 放在卷积之前。
    瓶颈结构先用 1x1 卷积降维(bn_size * growth_rate),再用 3x3 卷积产生新特征。
    """
    def __init__(self, in_channels, growth_rate, bn_size=4):
        super(DenseLayer, self).__init__()
        # Pre-activation: BN -> ReLU -> Conv
        self.bn1   = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, bn_size * growth_rate,
                               kernel_size=1, stride=1, bias=False)
        self.bn2   = nn.BatchNorm2d(bn_size * growth_rate)
        self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate,
                               kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = self.conv2(F.relu(self.bn2(out)))
        return torch.cat([x, out], dim=1)  # DenseNet 风格: 拼接


class SE_DenseResidualBlock(nn.Module):
    """SE-密集残差块 (SE-DRB): 密集连接 + SE通道注意力 + 残差连接
    
    在 J4 周 DRB 基础上新增 SE 注意力模块:
    1. 内部: 多个 DenseLayer 密集拼接 -> 特征复用 (DenseNet)
    2. 压缩: 1x1 Conv 将高维特征映射回原始通道数
    3. SE注意力: 对压缩后的特征进行通道重标定 (SE-Net)  ← NEW
    4. 外部: SE加权结果与输入相加 -> 梯度直通 (ResNet)
    """
    def __init__(self, in_channels, num_layers, growth_rate, bn_size=4, se_reduction=16):
        super(SE_DenseResidualBlock, self).__init__()

        # 密集层
        self.dense_layers = nn.ModuleList()
        for i in range(num_layers):
            self.dense_layers.append(
                DenseLayer(in_channels + i * growth_rate, growth_rate, bn_size)
            )

        # 压缩层: 将密集拼接的高维特征映射回原始通道数
        total_channels = in_channels + num_layers * growth_rate
        self.compress = nn.Sequential(
            nn.BatchNorm2d(total_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(total_channels, in_channels, kernel_size=1, stride=1, bias=False)
        )

        # SE 通道注意力模块 (NEW)
        self.se = SEAttention(channel=in_channels, reduction=se_reduction)

    def forward(self, x):
        # 密集连接路径
        features = x
        for layer in self.dense_layers:
            features = layer(features)
        # 压缩
        out = self.compress(features)
        # SE 通道注意力加权 (NEW)
        out = self.se(out)
        # 残差连接
        return out + x   # ResNet 风格: 残差相加


class Transition(nn.Sequential):
    """过渡层: BN->ReLU->1x1Conv->AvgPool"""
    def __init__(self, in_channels, out_channels):
        super(Transition, self).__init__()
        self.add_module('bn',   nn.BatchNorm2d(in_channels))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class SEResDenseNet(nn.Module):
    """SE-ResDenseNet: 融合 ResNet残差连接 + DenseNet密集连接 + SE通道注意力
    
    架构: Stem -> [SE-DRB -> Transition] x 3 -> SE-DRB -> BN -> ReLU -> GAP -> FC
    
    相比 J4 周 ResDenseNet:
    - 每个 DRB 升级为 SE-DRB,在压缩后、残差相加前加入 SE 注意力
    - SE 模块自适应学习每个通道的重要程度,增强关键特征、抑制冗余特征
    """
    def __init__(self, growth_rate=32, block_config=(3, 4, 6, 4),
                 num_init_features=64, bn_size=4, se_reduction=16, num_classes=1000):
        super(SEResDenseNet, self).__init__()

        # ===== Stem =====
        self.stem = nn.Sequential(
            nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
                      padding=3, bias=False),
            nn.BatchNorm2d(num_init_features),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # ===== Stages =====
        self.stages = nn.ModuleList()
        self.transitions = nn.ModuleList()

        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            # SE-Dense Residual Block
            self.stages.append(
                SE_DenseResidualBlock(num_features, num_layers, growth_rate, bn_size, se_reduction)
            )
            # Transition (最后一个 stage 后不加)
            if i != len(block_config) - 1:
                next_features = num_features * 2
                self.transitions.append(
                    Transition(num_features, next_features)
                )
                num_features = next_features

        # ===== Final BN + Classifier =====
        self.final_bn   = nn.BatchNorm2d(num_features)
        self.classifier = nn.Linear(num_features, num_classes)

        # ===== 权重初始化 (Kaiming) =====
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.stem(x)
        for i, stage in enumerate(self.stages):
            x = stage(x)
            if i < len(self.transitions):
                x = self.transitions[i](x)
        x = F.relu(self.final_bn(x), inplace=True)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

model = SEResDenseNet(num_classes=3).to(device)
model

SEResDenseNet(
  (stem): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (stages): ModuleList(
    (0): SE_DenseResidualBlock(
      (dense_layers): ModuleList(
        (0): DenseLayer(
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): DenseLayer(
          (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): DenseLayer(
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (compress): Sequential(
        (0): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU(inplace=True)
        (2): Conv2d(160, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (se): SEAttention(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (fc): Sequential(
          (0): Linear(in_features=64, out_features=4, bias=False)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=4, out_features=64, bias=False)
          (3): Sigmoid()
        )
      )
    )
    (1): SE_DenseResidualBlock(
      (dense_layers): ModuleList(
        (0): DenseLayer(
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): DenseLayer(
          (bn1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): DenseLayer(
          (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (3): DenseLayer(
          (bn1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (compress): Sequential(
        (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU(inplace=True)
        (2): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (se): SEAttention(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (fc): Sequential(
          (0): Linear(in_features=128, out_features=8, bias=False)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=8, out_features=128, bias=False)
          (3): Sigmoid()
        )
      )
    )
    (2): SE_DenseResidualBlock(
      (dense_layers): ModuleList(
        (0): DenseLayer(
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): DenseLayer(
          (bn1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): DenseLayer(
          (bn1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (3): DenseLayer(
          (bn1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (4): DenseLayer(
          (bn1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (5): DenseLayer(
          (bn1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (compress): Sequential(
        (0): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU(inplace=True)
        (2): Conv2d(448, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (se): SEAttention(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (fc): Sequential(
          (0): Linear(in_features=256, out_features=16, bias=False)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=16, out_features=256, bias=False)
          (3): Sigmoid()
        )
      )
    )
    (3): SE_DenseResidualBlock(
      (dense_layers): ModuleList(
        (0): DenseLayer(
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): DenseLayer(
          (bn1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): DenseLayer(
          (bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (3): DenseLayer(
          (bn1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (compress): Sequential(
        (0): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU(inplace=True)
        (2): Conv2d(640, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (se): SEAttention(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (fc): Sequential(
          (0): Linear(in_features=512, out_features=32, bias=False)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=32, out_features=512, bias=False)
          (3): Sigmoid()
        )
      )
    )
  )
  (transitions): ModuleList(
    (0): Transition(
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
    (1): Transition(
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
    (2): Transition(
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
  )
  (final_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (classifier): Linear(in_features=512, out_features=3, bias=True)
)

2.3 查看模型详情

复制代码
import torchsummary as summary
summary.summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
       BatchNorm2d-5           [-1, 64, 56, 56]             128
            Conv2d-6          [-1, 128, 56, 56]           8,192
       BatchNorm2d-7          [-1, 128, 56, 56]             256
            Conv2d-8           [-1, 32, 56, 56]          36,864
        DenseLayer-9           [-1, 96, 56, 56]               0
      BatchNorm2d-10           [-1, 96, 56, 56]             192
           Conv2d-11          [-1, 128, 56, 56]          12,288
      BatchNorm2d-12          [-1, 128, 56, 56]             256
           Conv2d-13           [-1, 32, 56, 56]          36,864
       DenseLayer-14          [-1, 128, 56, 56]               0
      BatchNorm2d-15          [-1, 128, 56, 56]             256
           Conv2d-16          [-1, 128, 56, 56]          16,384
      BatchNorm2d-17          [-1, 128, 56, 56]             256
           Conv2d-18           [-1, 32, 56, 56]          36,864
       DenseLayer-19          [-1, 160, 56, 56]               0
      BatchNorm2d-20          [-1, 160, 56, 56]             320
             ReLU-21          [-1, 160, 56, 56]               0
           Conv2d-22           [-1, 64, 56, 56]          10,240
AdaptiveAvgPool2d-23             [-1, 64, 1, 1]               0
           Linear-24                    [-1, 4]             256
             ReLU-25                    [-1, 4]               0
           Linear-26                   [-1, 64]             256
          Sigmoid-27                   [-1, 64]               0
      SEAttention-28           [-1, 64, 56, 56]               0
SE_DenseResidualBlock-29           [-1, 64, 56, 56]               0
      BatchNorm2d-30           [-1, 64, 56, 56]             128
             ReLU-31           [-1, 64, 56, 56]               0
           Conv2d-32          [-1, 128, 56, 56]           8,192
        AvgPool2d-33          [-1, 128, 28, 28]               0
      BatchNorm2d-34          [-1, 128, 28, 28]             256
           Conv2d-35          [-1, 128, 28, 28]          16,384
      BatchNorm2d-36          [-1, 128, 28, 28]             256
           Conv2d-37           [-1, 32, 28, 28]          36,864
       DenseLayer-38          [-1, 160, 28, 28]               0
      BatchNorm2d-39          [-1, 160, 28, 28]             320
           Conv2d-40          [-1, 128, 28, 28]          20,480
      BatchNorm2d-41          [-1, 128, 28, 28]             256
           Conv2d-42           [-1, 32, 28, 28]          36,864
       DenseLayer-43          [-1, 192, 28, 28]               0
      BatchNorm2d-44          [-1, 192, 28, 28]             384
           Conv2d-45          [-1, 128, 28, 28]          24,576
      BatchNorm2d-46          [-1, 128, 28, 28]             256
           Conv2d-47           [-1, 32, 28, 28]          36,864
       DenseLayer-48          [-1, 224, 28, 28]               0
      BatchNorm2d-49          [-1, 224, 28, 28]             448
           Conv2d-50          [-1, 128, 28, 28]          28,672
      BatchNorm2d-51          [-1, 128, 28, 28]             256
           Conv2d-52           [-1, 32, 28, 28]          36,864
       DenseLayer-53          [-1, 256, 28, 28]               0
      BatchNorm2d-54          [-1, 256, 28, 28]             512
             ReLU-55          [-1, 256, 28, 28]               0
           Conv2d-56          [-1, 128, 28, 28]          32,768
AdaptiveAvgPool2d-57            [-1, 128, 1, 1]               0
           Linear-58                    [-1, 8]           1,024
             ReLU-59                    [-1, 8]               0
           Linear-60                  [-1, 128]           1,024
          Sigmoid-61                  [-1, 128]               0
      SEAttention-62          [-1, 128, 28, 28]               0
SE_DenseResidualBlock-63          [-1, 128, 28, 28]               0
      BatchNorm2d-64          [-1, 128, 28, 28]             256
             ReLU-65          [-1, 128, 28, 28]               0
           Conv2d-66          [-1, 256, 28, 28]          32,768
        AvgPool2d-67          [-1, 256, 14, 14]               0
      BatchNorm2d-68          [-1, 256, 14, 14]             512
           Conv2d-69          [-1, 128, 14, 14]          32,768
      BatchNorm2d-70          [-1, 128, 14, 14]             256
           Conv2d-71           [-1, 32, 14, 14]          36,864
       DenseLayer-72          [-1, 288, 14, 14]               0
      BatchNorm2d-73          [-1, 288, 14, 14]             576
           Conv2d-74          [-1, 128, 14, 14]          36,864
      BatchNorm2d-75          [-1, 128, 14, 14]             256
           Conv2d-76           [-1, 32, 14, 14]          36,864
       DenseLayer-77          [-1, 320, 14, 14]               0
      BatchNorm2d-78          [-1, 320, 14, 14]             640
           Conv2d-79          [-1, 128, 14, 14]          40,960
      BatchNorm2d-80          [-1, 128, 14, 14]             256
           Conv2d-81           [-1, 32, 14, 14]          36,864
       DenseLayer-82          [-1, 352, 14, 14]               0
      BatchNorm2d-83          [-1, 352, 14, 14]             704
           Conv2d-84          [-1, 128, 14, 14]          45,056
      BatchNorm2d-85          [-1, 128, 14, 14]             256
           Conv2d-86           [-1, 32, 14, 14]          36,864
       DenseLayer-87          [-1, 384, 14, 14]               0
      BatchNorm2d-88          [-1, 384, 14, 14]             768
           Conv2d-89          [-1, 128, 14, 14]          49,152
      BatchNorm2d-90          [-1, 128, 14, 14]             256
           Conv2d-91           [-1, 32, 14, 14]          36,864
       DenseLayer-92          [-1, 416, 14, 14]               0
      BatchNorm2d-93          [-1, 416, 14, 14]             832
           Conv2d-94          [-1, 128, 14, 14]          53,248
      BatchNorm2d-95          [-1, 128, 14, 14]             256
           Conv2d-96           [-1, 32, 14, 14]          36,864
       DenseLayer-97          [-1, 448, 14, 14]               0
      BatchNorm2d-98          [-1, 448, 14, 14]             896
             ReLU-99          [-1, 448, 14, 14]               0
          Conv2d-100          [-1, 256, 14, 14]         114,688
AdaptiveAvgPool2d-101            [-1, 256, 1, 1]               0
          Linear-102                   [-1, 16]           4,096
            ReLU-103                   [-1, 16]               0
          Linear-104                  [-1, 256]           4,096
         Sigmoid-105                  [-1, 256]               0
     SEAttention-106          [-1, 256, 14, 14]               0
SE_DenseResidualBlock-107          [-1, 256, 14, 14]               0
     BatchNorm2d-108          [-1, 256, 14, 14]             512
            ReLU-109          [-1, 256, 14, 14]               0
          Conv2d-110          [-1, 512, 14, 14]         131,072
       AvgPool2d-111            [-1, 512, 7, 7]               0
     BatchNorm2d-112            [-1, 512, 7, 7]           1,024
          Conv2d-113            [-1, 128, 7, 7]          65,536
     BatchNorm2d-114            [-1, 128, 7, 7]             256
          Conv2d-115             [-1, 32, 7, 7]          36,864
      DenseLayer-116            [-1, 544, 7, 7]               0
     BatchNorm2d-117            [-1, 544, 7, 7]           1,088
          Conv2d-118            [-1, 128, 7, 7]          69,632
     BatchNorm2d-119            [-1, 128, 7, 7]             256
          Conv2d-120             [-1, 32, 7, 7]          36,864
      DenseLayer-121            [-1, 576, 7, 7]               0
     BatchNorm2d-122            [-1, 576, 7, 7]           1,152
          Conv2d-123            [-1, 128, 7, 7]          73,728
     BatchNorm2d-124            [-1, 128, 7, 7]             256
          Conv2d-125             [-1, 32, 7, 7]          36,864
      DenseLayer-126            [-1, 608, 7, 7]               0
     BatchNorm2d-127            [-1, 608, 7, 7]           1,216
          Conv2d-128            [-1, 128, 7, 7]          77,824
     BatchNorm2d-129            [-1, 128, 7, 7]             256
          Conv2d-130             [-1, 32, 7, 7]          36,864
      DenseLayer-131            [-1, 640, 7, 7]               0
     BatchNorm2d-132            [-1, 640, 7, 7]           1,280
            ReLU-133            [-1, 640, 7, 7]               0
          Conv2d-134            [-1, 512, 7, 7]         327,680
AdaptiveAvgPool2d-135            [-1, 512, 1, 1]               0
          Linear-136                   [-1, 32]          16,384
            ReLU-137                   [-1, 32]               0
          Linear-138                  [-1, 512]          16,384
         Sigmoid-139                  [-1, 512]               0
     SEAttention-140            [-1, 512, 7, 7]               0
SE_DenseResidualBlock-141            [-1, 512, 7, 7]               0
     BatchNorm2d-142            [-1, 512, 7, 7]           1,024
          Linear-143                    [-1, 3]           1,539
================================================================
Total params: 2,030,211
Trainable params: 2,030,211
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 117.21
Params size (MB): 7.74
Estimated Total Size (MB): 125.53
----------------------------------------------------------------

2.4 对比 ResDenseNet 与 SE-ResDenseNet 参数量

复制代码
# 统计 SE-ResDenseNet 参数量
se_params = sum(p.numel() for p in model.parameters())
print(f'SE-ResDenseNet 总参数量: {se_params:,} ({se_params/1e6:.2f}M)')

# 统计 SE 模块新增参数量
se_only_params = 0
for name, module in model.named_modules():
    if isinstance(module, SEAttention):
        se_only_params += sum(p.numel() for p in module.parameters())
print(f'SE 注意力模块新增参数量: {se_only_params:,} ({se_only_params/1e3:.2f}K)')
print(f'SE 模块参数占比: {se_only_params/se_params*100:.2f}%')

SE-ResDenseNet 总参数量: 2,030,211 (2.03M)
SE 注意力模块新增参数量: 43,520 (43.52K)
SE 模块参数占比: 2.14%

3、训练模型

3.1 编写训练函数

复制代码
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    train_loss, train_acc = 0, 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc  /= size
    train_loss /= num_batches

    return train_acc, train_loss

3.2 编写测试函数

复制代码
def test(dataloader, model, loss_fn):
    size        = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, test_acc = 0, 0

    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc  /= size
    test_loss /= num_batches

    return test_acc, test_loss

3.3 正式训练

训练策略(沿用 J4 周最优实践):

  • AdamW 优化器 + 权重衰减 (weight_decay=1e-4)
  • 标签平滑 (label_smoothing=0.1) 防止过拟合
  • 余弦退火 学习率调度 (CosineAnnealingLR)
  • Kaiming 权重初始化

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)

    epochs = 10

    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    best_acc = 0

    for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)

    复制代码
      scheduler.step()  # 余弦退火更新学习率
    
      model.eval()
      epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
      if epoch_test_acc > best_acc:
          best_acc   = epoch_test_acc
          best_model = copy.deepcopy(model)
    
      train_acc.append(epoch_train_acc)
      train_loss.append(epoch_train_loss)
      test_acc.append(epoch_test_acc)
      test_loss.append(epoch_test_loss)
    
      lr = optimizer.state_dict()['param_groups'][0]['lr']
    
      template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, '
                  'Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
      print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,
                            epoch_test_acc*100, epoch_test_loss, lr))

    保存最佳模型

    PATH = './model/J5_se_resdensenet_best_model.pth'
    os.makedirs(os.path.dirname(PATH), exist_ok=True)
    torch.save(best_model.state_dict(), PATH)

    print('Done')

    Epoch: 1, Train_acc:64.8%, Train_loss:0.910, Test_acc:62.5%, Test_loss:1.009, Lr:9.76E-04
    Epoch: 2, Train_acc:67.5%, Train_loss:0.859, Test_acc:59.5%, Test_loss:2.214, Lr:9.05E-04
    Epoch: 3, Train_acc:70.8%, Train_loss:0.835, Test_acc:79.0%, Test_loss:0.715, Lr:7.94E-04
    Epoch: 4, Train_acc:73.6%, Train_loss:0.771, Test_acc:70.6%, Test_loss:0.965, Lr:6.55E-04
    Epoch: 5, Train_acc:77.1%, Train_loss:0.718, Test_acc:67.3%, Test_loss:1.010, Lr:5.01E-04
    Epoch: 6, Train_acc:79.0%, Train_loss:0.691, Test_acc:87.7%, Test_loss:0.555, Lr:3.46E-04
    Epoch: 7, Train_acc:82.0%, Train_loss:0.649, Test_acc:82.3%, Test_loss:0.609, Lr:2.07E-04
    Epoch: 8, Train_acc:82.6%, Train_loss:0.612, Test_acc:87.4%, Test_loss:0.516, Lr:9.64E-05
    Epoch: 9, Train_acc:83.3%, Train_loss:0.600, Test_acc:88.3%, Test_loss:0.534, Lr:2.54E-05
    Epoch:10, Train_acc:84.4%, Train_loss:0.584, Test_acc:89.2%, Test_loss:0.500, Lr:1.00E-06
    Done

4、结果可视化

复制代码
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']    = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi']         = 100

from datetime import datetime
current_time = datetime.now()

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc,  label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('SE-ResDenseNet - Training and Validation Accuracy')
plt.xlabel(current_time)

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss,  label='Test Loss')
plt.legend(loc='upper right')
plt.title('SE-ResDenseNet - Training and Validation Loss')
plt.show()

5、模型评估

复制代码
best_model.load_state_dict(torch.load(PATH, map_location=device, weights_only=True))
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
print(f'SE-ResDenseNet Best Test Accuracy: {epoch_test_acc*100:.1f}%')
print(f'SE-ResDenseNet Best Test Loss:     {epoch_test_loss:.4f}')

SE-ResDenseNet Best Test Accuracy: 89.2%
SE-ResDenseNet Best Test Loss:     0.5000
相关推荐
金融RPA机器人丨实在智能1 小时前
实在Agent的下单和部署流程复杂吗?2026全流程解析:从分钟级交付到企业级AI智能体规模化落地
人工智能·ai
IvorySQL2 小时前
【HOW 2026 分论坛演讲】PG/IvorySQL私有云中实践
数据库·人工智能·sql·postgresql
小橙讲编程2 小时前
一键给 AI Agent 装上「互联网眼睛」:Agent Reach 深度解析与实战指南
人工智能·开源·github·ai编程
志栋智能2 小时前
超自动化巡检:在混合云时代更显其必要性
大数据·运维·网络·人工智能·自动化
zyl837212 小时前
Python 概率论:概率、数学期望、方差
人工智能·机器学习
来自于狂人2 小时前
GPU架构全对比
人工智能·架构
武子康2 小时前
调查研究-155 Open-LLM-VTuber 本地部署与互动实战指南
人工智能·python·深度学习·ai·数字人
weixin_397574092 小时前
工业AI数字化转型地图:工业企业AI改造的全景路径
人工智能
十正2 小时前
Claude code源码精读之蜂群模式
javascript·人工智能·agent·claude code