一、前置知识
1、知识总结
在 J4 周 ResDenseNet(融合 ResNet 残差连接 + DenseNet 密集连接)的基础上,
引入 SE-Net 通道注意力机制 ,构建 SE-ResDenseNet,进一步提升模型对关键特征通道的关注能力。
SE-ResDenseNet 创新点
|--------|-----------------|---------------------------|
| 维度 | J4: ResDenseNet | J5: SE-ResDenseNet |
| 通道注意力 | 无 | SE 模块自适应重标定通道权重 |
| DRB 内部 | 密集连接 + 残差连接 | 密集连接 + SE 通道加权 + 残差连接 |
| 特征选择 | 所有通道等权重 | 重要通道增强,次要通道抑制 |
| 压缩层后 | 直接残差相加 | SE 加权后再残差相加 |
SE 模块嵌入位置(关键设计)
遵循论文建议,在 Addition 前 对分支特征进行 SE 重标定:
DenseLayer x L → Concat → Compress 1x1 → [SE Attention] → + (Residual Add)
↑
Squeeze: Global AvgPool
Excitation: FC→ReLU→FC→Sigmoid
Scale: 通道加权
改进思路迁移
SE 注意力机制是一种通用的通道注意力方法,可以迁移到:
- CNN 系列: 在任何 ResNet/Inception/MobileNet 的 block 中嵌入
- Transformer: 类似思想已演化为 Squeeze-and-Excitation 的注意力头
- 目标检测/分割: 在 FPN、特征金字塔中添加 SE 模块提升特征选择能力
模型架构对比图
查看 ResDenseNet 与 SE-ResDenseNet 的核心模块对比。

二、代码实现
1、准备工作
1.1 设置GPU
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
import os, copy, warnings
import numpy as np
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
1.2 导入数据(增强策略)
沿用 J4 周的数据增强策略:随机水平翻转 + 颜色抖动 + 标准化
data_dir = './data/day01'
train_transforms = transforms.Compose([
transforms.Resize([224, 224]),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transforms = transforms.Compose([
transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
total_data = datasets.ImageFolder(data_dir, transform=train_transforms)
total_data
Dataset ImageFolder
Number of datapoints: 1661
Root location: ./data/day01
StandardTransform
Transform: Compose(
Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
RandomHorizontalFlip(p=0.5)
ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.9, 1.1), hue=None)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
total_data.class_to_idx
{'0Normal': 0, '2Mild': 1, '4Severe': 2}
1.3 划分数据集
train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
# 测试集使用无增强 transform
test_dataset.dataset = datasets.ImageFolder(data_dir, transform=test_transforms)
train_dataset.dataset = datasets.ImageFolder(data_dir, transform=train_transforms)
batch_size = 8
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=0)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
num_workers=0)
for X, y in test_dl:
print("Shape of X [N, C, H, W]: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
break
Shape of X [N, C, H, W]: torch.Size([8, 3, 224, 224])
Shape of y: torch.Size([8]) torch.int64
2、搭建 SE-ResDenseNet 模型
SE-ResDenseNet 设计理念
在 J4 周 ResDenseNet 的 Dense Residual Block (DRB) 中嵌入 SE 注意力模块:
J4 周 DRB 流程:
DenseLayer x L → Concat → 1x1 Compress → (+ Residual Add)
J5 周 SE-DRB 流程 (新增 SE 模块):
DenseLayer x L → Concat → 1x1 Compress → [SE Attention] → (+ Residual Add)
↓
Squeeze: Global AvgPool → C个数值
Excitation: FC(C→C/r)→ReLU→FC(C/r→C)→Sigmoid
Scale: 逐通道加权 (重要通道增强)
SE 嵌入位置选择依据(参考 SE 论文):
- 放在 Addition 前 对分支特征重标定 → 避免主支 0~1 scale 导致梯度消失
- 放在 Compress 后 → 对压缩后的特征进行通道选择,参数更少
- reduction=16 → 第一个 FC 层将通道数降为 C/16,平衡精度与参数量
2.1 SE 注意力模块
SE (Squeeze-and-Excitation) 模块实现通道注意力:
- Squeeze: 全局平均池化,将每个通道压缩为一个标量
- Excitation: 两个全连接层学习通道间依赖关系
-
Scale: 将学到的权重乘回原特征
class SEAttention(nn.Module):
"""SE (Squeeze-and-Excitation) 通道注意力模块论文: https://arxiv.org/abs/1709.01507 流程: 1. Squeeze: 全局平均池化 -> 每个 channel 压缩为 1 个数值 2. Excitation: FC(C→C/r) -> ReLU -> FC(C/r→C) -> Sigmoid 3. Scale: 将权重乘回原特征图 """ def __init__(self, channel, reduction=16): super(SEAttention, self).__init__() # Squeeze: 自适应平均池化 self.avg_pool = nn.AdaptiveAvgPool2d(1) # Excitation: 两个全连接层 (bottleneck 结构) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() # Squeeze y = self.avg_pool(x).view(b, c) # Excitation y = self.fc(y).view(b, c, 1, 1) # Scale return x * y.expand_as(x)验证 SE 模块
if name == 'main':
x = torch.randn(8, 512, 32, 32)
se = SEAttention(channel=512, reduction=16)
out = se(x)
print(f'SE module test: input {x.shape} -> output {out.shape}')
print(f'SE params: {sum(p.numel() for p in se.parameters()):,}')SE module test: input torch.Size([8, 512, 32, 32]) -> output torch.Size([8, 512, 32, 32])
SE params: 32,768
2.2 完整 SE-ResDenseNet 模型
class DenseLayer(nn.Module):
"""密集层:Pre-Activation 瓶颈结构 (BN->ReLU->1x1Conv->BN->ReLU->3x3Conv)
借鉴 ResNetV2 的 Pre-Activation 设计,将 BN 和 ReLU 放在卷积之前。
瓶颈结构先用 1x1 卷积降维(bn_size * growth_rate),再用 3x3 卷积产生新特征。
"""
def __init__(self, in_channels, growth_rate, bn_size=4):
super(DenseLayer, self).__init__()
# Pre-activation: BN -> ReLU -> Conv
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = nn.Conv2d(in_channels, bn_size * growth_rate,
kernel_size=1, stride=1, bias=False)
self.bn2 = nn.BatchNorm2d(bn_size * growth_rate)
self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
return torch.cat([x, out], dim=1) # DenseNet 风格: 拼接
class SE_DenseResidualBlock(nn.Module):
"""SE-密集残差块 (SE-DRB): 密集连接 + SE通道注意力 + 残差连接
在 J4 周 DRB 基础上新增 SE 注意力模块:
1. 内部: 多个 DenseLayer 密集拼接 -> 特征复用 (DenseNet)
2. 压缩: 1x1 Conv 将高维特征映射回原始通道数
3. SE注意力: 对压缩后的特征进行通道重标定 (SE-Net) ← NEW
4. 外部: SE加权结果与输入相加 -> 梯度直通 (ResNet)
"""
def __init__(self, in_channels, num_layers, growth_rate, bn_size=4, se_reduction=16):
super(SE_DenseResidualBlock, self).__init__()
# 密集层
self.dense_layers = nn.ModuleList()
for i in range(num_layers):
self.dense_layers.append(
DenseLayer(in_channels + i * growth_rate, growth_rate, bn_size)
)
# 压缩层: 将密集拼接的高维特征映射回原始通道数
total_channels = in_channels + num_layers * growth_rate
self.compress = nn.Sequential(
nn.BatchNorm2d(total_channels),
nn.ReLU(inplace=True),
nn.Conv2d(total_channels, in_channels, kernel_size=1, stride=1, bias=False)
)
# SE 通道注意力模块 (NEW)
self.se = SEAttention(channel=in_channels, reduction=se_reduction)
def forward(self, x):
# 密集连接路径
features = x
for layer in self.dense_layers:
features = layer(features)
# 压缩
out = self.compress(features)
# SE 通道注意力加权 (NEW)
out = self.se(out)
# 残差连接
return out + x # ResNet 风格: 残差相加
class Transition(nn.Sequential):
"""过渡层: BN->ReLU->1x1Conv->AvgPool"""
def __init__(self, in_channels, out_channels):
super(Transition, self).__init__()
self.add_module('bn', nn.BatchNorm2d(in_channels))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
class SEResDenseNet(nn.Module):
"""SE-ResDenseNet: 融合 ResNet残差连接 + DenseNet密集连接 + SE通道注意力
架构: Stem -> [SE-DRB -> Transition] x 3 -> SE-DRB -> BN -> ReLU -> GAP -> FC
相比 J4 周 ResDenseNet:
- 每个 DRB 升级为 SE-DRB,在压缩后、残差相加前加入 SE 注意力
- SE 模块自适应学习每个通道的重要程度,增强关键特征、抑制冗余特征
"""
def __init__(self, growth_rate=32, block_config=(3, 4, 6, 4),
num_init_features=64, bn_size=4, se_reduction=16, num_classes=1000):
super(SEResDenseNet, self).__init__()
# ===== Stem =====
self.stem = nn.Sequential(
nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
padding=3, bias=False),
nn.BatchNorm2d(num_init_features),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# ===== Stages =====
self.stages = nn.ModuleList()
self.transitions = nn.ModuleList()
num_features = num_init_features
for i, num_layers in enumerate(block_config):
# SE-Dense Residual Block
self.stages.append(
SE_DenseResidualBlock(num_features, num_layers, growth_rate, bn_size, se_reduction)
)
# Transition (最后一个 stage 后不加)
if i != len(block_config) - 1:
next_features = num_features * 2
self.transitions.append(
Transition(num_features, next_features)
)
num_features = next_features
# ===== Final BN + Classifier =====
self.final_bn = nn.BatchNorm2d(num_features)
self.classifier = nn.Linear(num_features, num_classes)
# ===== 权重初始化 (Kaiming) =====
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.stem(x)
for i, stage in enumerate(self.stages):
x = stage(x)
if i < len(self.transitions):
x = self.transitions[i](x)
x = F.relu(self.final_bn(x), inplace=True)
x = F.adaptive_avg_pool2d(x, (1, 1))
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
model = SEResDenseNet(num_classes=3).to(device)
model
SEResDenseNet(
(stem): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
(stages): ModuleList(
(0): SE_DenseResidualBlock(
(dense_layers): ModuleList(
(0): DenseLayer(
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(1): DenseLayer(
(bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(2): DenseLayer(
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
(compress): Sequential(
(0): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): ReLU(inplace=True)
(2): Conv2d(160, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(se): SEAttention(
(avg_pool): AdaptiveAvgPool2d(output_size=1)
(fc): Sequential(
(0): Linear(in_features=64, out_features=4, bias=False)
(1): ReLU(inplace=True)
(2): Linear(in_features=4, out_features=64, bias=False)
(3): Sigmoid()
)
)
)
(1): SE_DenseResidualBlock(
(dense_layers): ModuleList(
(0): DenseLayer(
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(1): DenseLayer(
(bn1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(2): DenseLayer(
(bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(3): DenseLayer(
(bn1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
(compress): Sequential(
(0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): ReLU(inplace=True)
(2): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(se): SEAttention(
(avg_pool): AdaptiveAvgPool2d(output_size=1)
(fc): Sequential(
(0): Linear(in_features=128, out_features=8, bias=False)
(1): ReLU(inplace=True)
(2): Linear(in_features=8, out_features=128, bias=False)
(3): Sigmoid()
)
)
)
(2): SE_DenseResidualBlock(
(dense_layers): ModuleList(
(0): DenseLayer(
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(1): DenseLayer(
(bn1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(2): DenseLayer(
(bn1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(3): DenseLayer(
(bn1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(4): DenseLayer(
(bn1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(5): DenseLayer(
(bn1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
(compress): Sequential(
(0): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): ReLU(inplace=True)
(2): Conv2d(448, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(se): SEAttention(
(avg_pool): AdaptiveAvgPool2d(output_size=1)
(fc): Sequential(
(0): Linear(in_features=256, out_features=16, bias=False)
(1): ReLU(inplace=True)
(2): Linear(in_features=16, out_features=256, bias=False)
(3): Sigmoid()
)
)
)
(3): SE_DenseResidualBlock(
(dense_layers): ModuleList(
(0): DenseLayer(
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(1): DenseLayer(
(bn1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(2): DenseLayer(
(bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(3): DenseLayer(
(bn1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
(compress): Sequential(
(0): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): ReLU(inplace=True)
(2): Conv2d(640, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(se): SEAttention(
(avg_pool): AdaptiveAvgPool2d(output_size=1)
(fc): Sequential(
(0): Linear(in_features=512, out_features=32, bias=False)
(1): ReLU(inplace=True)
(2): Linear(in_features=32, out_features=512, bias=False)
(3): Sigmoid()
)
)
)
)
(transitions): ModuleList(
(0): Transition(
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(1): Transition(
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(2): Transition(
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
)
(final_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(classifier): Linear(in_features=512, out_features=3, bias=True)
)
2.3 查看模型详情
import torchsummary as summary
summary.summary(model, (3, 224, 224))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
BatchNorm2d-5 [-1, 64, 56, 56] 128
Conv2d-6 [-1, 128, 56, 56] 8,192
BatchNorm2d-7 [-1, 128, 56, 56] 256
Conv2d-8 [-1, 32, 56, 56] 36,864
DenseLayer-9 [-1, 96, 56, 56] 0
BatchNorm2d-10 [-1, 96, 56, 56] 192
Conv2d-11 [-1, 128, 56, 56] 12,288
BatchNorm2d-12 [-1, 128, 56, 56] 256
Conv2d-13 [-1, 32, 56, 56] 36,864
DenseLayer-14 [-1, 128, 56, 56] 0
BatchNorm2d-15 [-1, 128, 56, 56] 256
Conv2d-16 [-1, 128, 56, 56] 16,384
BatchNorm2d-17 [-1, 128, 56, 56] 256
Conv2d-18 [-1, 32, 56, 56] 36,864
DenseLayer-19 [-1, 160, 56, 56] 0
BatchNorm2d-20 [-1, 160, 56, 56] 320
ReLU-21 [-1, 160, 56, 56] 0
Conv2d-22 [-1, 64, 56, 56] 10,240
AdaptiveAvgPool2d-23 [-1, 64, 1, 1] 0
Linear-24 [-1, 4] 256
ReLU-25 [-1, 4] 0
Linear-26 [-1, 64] 256
Sigmoid-27 [-1, 64] 0
SEAttention-28 [-1, 64, 56, 56] 0
SE_DenseResidualBlock-29 [-1, 64, 56, 56] 0
BatchNorm2d-30 [-1, 64, 56, 56] 128
ReLU-31 [-1, 64, 56, 56] 0
Conv2d-32 [-1, 128, 56, 56] 8,192
AvgPool2d-33 [-1, 128, 28, 28] 0
BatchNorm2d-34 [-1, 128, 28, 28] 256
Conv2d-35 [-1, 128, 28, 28] 16,384
BatchNorm2d-36 [-1, 128, 28, 28] 256
Conv2d-37 [-1, 32, 28, 28] 36,864
DenseLayer-38 [-1, 160, 28, 28] 0
BatchNorm2d-39 [-1, 160, 28, 28] 320
Conv2d-40 [-1, 128, 28, 28] 20,480
BatchNorm2d-41 [-1, 128, 28, 28] 256
Conv2d-42 [-1, 32, 28, 28] 36,864
DenseLayer-43 [-1, 192, 28, 28] 0
BatchNorm2d-44 [-1, 192, 28, 28] 384
Conv2d-45 [-1, 128, 28, 28] 24,576
BatchNorm2d-46 [-1, 128, 28, 28] 256
Conv2d-47 [-1, 32, 28, 28] 36,864
DenseLayer-48 [-1, 224, 28, 28] 0
BatchNorm2d-49 [-1, 224, 28, 28] 448
Conv2d-50 [-1, 128, 28, 28] 28,672
BatchNorm2d-51 [-1, 128, 28, 28] 256
Conv2d-52 [-1, 32, 28, 28] 36,864
DenseLayer-53 [-1, 256, 28, 28] 0
BatchNorm2d-54 [-1, 256, 28, 28] 512
ReLU-55 [-1, 256, 28, 28] 0
Conv2d-56 [-1, 128, 28, 28] 32,768
AdaptiveAvgPool2d-57 [-1, 128, 1, 1] 0
Linear-58 [-1, 8] 1,024
ReLU-59 [-1, 8] 0
Linear-60 [-1, 128] 1,024
Sigmoid-61 [-1, 128] 0
SEAttention-62 [-1, 128, 28, 28] 0
SE_DenseResidualBlock-63 [-1, 128, 28, 28] 0
BatchNorm2d-64 [-1, 128, 28, 28] 256
ReLU-65 [-1, 128, 28, 28] 0
Conv2d-66 [-1, 256, 28, 28] 32,768
AvgPool2d-67 [-1, 256, 14, 14] 0
BatchNorm2d-68 [-1, 256, 14, 14] 512
Conv2d-69 [-1, 128, 14, 14] 32,768
BatchNorm2d-70 [-1, 128, 14, 14] 256
Conv2d-71 [-1, 32, 14, 14] 36,864
DenseLayer-72 [-1, 288, 14, 14] 0
BatchNorm2d-73 [-1, 288, 14, 14] 576
Conv2d-74 [-1, 128, 14, 14] 36,864
BatchNorm2d-75 [-1, 128, 14, 14] 256
Conv2d-76 [-1, 32, 14, 14] 36,864
DenseLayer-77 [-1, 320, 14, 14] 0
BatchNorm2d-78 [-1, 320, 14, 14] 640
Conv2d-79 [-1, 128, 14, 14] 40,960
BatchNorm2d-80 [-1, 128, 14, 14] 256
Conv2d-81 [-1, 32, 14, 14] 36,864
DenseLayer-82 [-1, 352, 14, 14] 0
BatchNorm2d-83 [-1, 352, 14, 14] 704
Conv2d-84 [-1, 128, 14, 14] 45,056
BatchNorm2d-85 [-1, 128, 14, 14] 256
Conv2d-86 [-1, 32, 14, 14] 36,864
DenseLayer-87 [-1, 384, 14, 14] 0
BatchNorm2d-88 [-1, 384, 14, 14] 768
Conv2d-89 [-1, 128, 14, 14] 49,152
BatchNorm2d-90 [-1, 128, 14, 14] 256
Conv2d-91 [-1, 32, 14, 14] 36,864
DenseLayer-92 [-1, 416, 14, 14] 0
BatchNorm2d-93 [-1, 416, 14, 14] 832
Conv2d-94 [-1, 128, 14, 14] 53,248
BatchNorm2d-95 [-1, 128, 14, 14] 256
Conv2d-96 [-1, 32, 14, 14] 36,864
DenseLayer-97 [-1, 448, 14, 14] 0
BatchNorm2d-98 [-1, 448, 14, 14] 896
ReLU-99 [-1, 448, 14, 14] 0
Conv2d-100 [-1, 256, 14, 14] 114,688
AdaptiveAvgPool2d-101 [-1, 256, 1, 1] 0
Linear-102 [-1, 16] 4,096
ReLU-103 [-1, 16] 0
Linear-104 [-1, 256] 4,096
Sigmoid-105 [-1, 256] 0
SEAttention-106 [-1, 256, 14, 14] 0
SE_DenseResidualBlock-107 [-1, 256, 14, 14] 0
BatchNorm2d-108 [-1, 256, 14, 14] 512
ReLU-109 [-1, 256, 14, 14] 0
Conv2d-110 [-1, 512, 14, 14] 131,072
AvgPool2d-111 [-1, 512, 7, 7] 0
BatchNorm2d-112 [-1, 512, 7, 7] 1,024
Conv2d-113 [-1, 128, 7, 7] 65,536
BatchNorm2d-114 [-1, 128, 7, 7] 256
Conv2d-115 [-1, 32, 7, 7] 36,864
DenseLayer-116 [-1, 544, 7, 7] 0
BatchNorm2d-117 [-1, 544, 7, 7] 1,088
Conv2d-118 [-1, 128, 7, 7] 69,632
BatchNorm2d-119 [-1, 128, 7, 7] 256
Conv2d-120 [-1, 32, 7, 7] 36,864
DenseLayer-121 [-1, 576, 7, 7] 0
BatchNorm2d-122 [-1, 576, 7, 7] 1,152
Conv2d-123 [-1, 128, 7, 7] 73,728
BatchNorm2d-124 [-1, 128, 7, 7] 256
Conv2d-125 [-1, 32, 7, 7] 36,864
DenseLayer-126 [-1, 608, 7, 7] 0
BatchNorm2d-127 [-1, 608, 7, 7] 1,216
Conv2d-128 [-1, 128, 7, 7] 77,824
BatchNorm2d-129 [-1, 128, 7, 7] 256
Conv2d-130 [-1, 32, 7, 7] 36,864
DenseLayer-131 [-1, 640, 7, 7] 0
BatchNorm2d-132 [-1, 640, 7, 7] 1,280
ReLU-133 [-1, 640, 7, 7] 0
Conv2d-134 [-1, 512, 7, 7] 327,680
AdaptiveAvgPool2d-135 [-1, 512, 1, 1] 0
Linear-136 [-1, 32] 16,384
ReLU-137 [-1, 32] 0
Linear-138 [-1, 512] 16,384
Sigmoid-139 [-1, 512] 0
SEAttention-140 [-1, 512, 7, 7] 0
SE_DenseResidualBlock-141 [-1, 512, 7, 7] 0
BatchNorm2d-142 [-1, 512, 7, 7] 1,024
Linear-143 [-1, 3] 1,539
================================================================
Total params: 2,030,211
Trainable params: 2,030,211
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 117.21
Params size (MB): 7.74
Estimated Total Size (MB): 125.53
----------------------------------------------------------------
2.4 对比 ResDenseNet 与 SE-ResDenseNet 参数量
# 统计 SE-ResDenseNet 参数量
se_params = sum(p.numel() for p in model.parameters())
print(f'SE-ResDenseNet 总参数量: {se_params:,} ({se_params/1e6:.2f}M)')
# 统计 SE 模块新增参数量
se_only_params = 0
for name, module in model.named_modules():
if isinstance(module, SEAttention):
se_only_params += sum(p.numel() for p in module.parameters())
print(f'SE 注意力模块新增参数量: {se_only_params:,} ({se_only_params/1e3:.2f}K)')
print(f'SE 模块参数占比: {se_only_params/se_params*100:.2f}%')
SE-ResDenseNet 总参数量: 2,030,211 (2.03M)
SE 注意力模块新增参数量: 43,520 (43.52K)
SE 模块参数占比: 2.14%
3、训练模型
3.1 编写训练函数
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
num_batches = len(dataloader)
train_loss, train_acc = 0, 0
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss += loss.item()
train_acc /= size
train_loss /= num_batches
return train_acc, train_loss
3.2 编写测试函数
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, test_acc = 0, 0
with torch.no_grad():
for imgs, target in dataloader:
imgs, target = imgs.to(device), target.to(device)
target_pred = model(imgs)
loss = loss_fn(target_pred, target)
test_loss += loss.item()
test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()
test_acc /= size
test_loss /= num_batches
return test_acc, test_loss
3.3 正式训练
训练策略(沿用 J4 周最优实践):
- AdamW 优化器 + 权重衰减 (weight_decay=1e-4)
- 标签平滑 (label_smoothing=0.1) 防止过拟合
- 余弦退火 学习率调度 (CosineAnnealingLR)
-
Kaiming 权重初始化
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []best_acc = 0
for epoch in range(epochs):
model.train()
epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)scheduler.step() # 余弦退火更新学习率 model.eval() epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn) if epoch_test_acc > best_acc: best_acc = epoch_test_acc best_model = copy.deepcopy(model) train_acc.append(epoch_train_acc) train_loss.append(epoch_train_loss) test_acc.append(epoch_test_acc) test_loss.append(epoch_test_loss) lr = optimizer.state_dict()['param_groups'][0]['lr'] template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, ' 'Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}') print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))保存最佳模型
PATH = './model/J5_se_resdensenet_best_model.pth'
os.makedirs(os.path.dirname(PATH), exist_ok=True)
torch.save(best_model.state_dict(), PATH)print('Done')
Epoch: 1, Train_acc:64.8%, Train_loss:0.910, Test_acc:62.5%, Test_loss:1.009, Lr:9.76E-04
Epoch: 2, Train_acc:67.5%, Train_loss:0.859, Test_acc:59.5%, Test_loss:2.214, Lr:9.05E-04
Epoch: 3, Train_acc:70.8%, Train_loss:0.835, Test_acc:79.0%, Test_loss:0.715, Lr:7.94E-04
Epoch: 4, Train_acc:73.6%, Train_loss:0.771, Test_acc:70.6%, Test_loss:0.965, Lr:6.55E-04
Epoch: 5, Train_acc:77.1%, Train_loss:0.718, Test_acc:67.3%, Test_loss:1.010, Lr:5.01E-04
Epoch: 6, Train_acc:79.0%, Train_loss:0.691, Test_acc:87.7%, Test_loss:0.555, Lr:3.46E-04
Epoch: 7, Train_acc:82.0%, Train_loss:0.649, Test_acc:82.3%, Test_loss:0.609, Lr:2.07E-04
Epoch: 8, Train_acc:82.6%, Train_loss:0.612, Test_acc:87.4%, Test_loss:0.516, Lr:9.64E-05
Epoch: 9, Train_acc:83.3%, Train_loss:0.600, Test_acc:88.3%, Test_loss:0.534, Lr:2.54E-05
Epoch:10, Train_acc:84.4%, Train_loss:0.584, Test_acc:89.2%, Test_loss:0.500, Lr:1.00E-06
Done
4、结果可视化
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100
from datetime import datetime
current_time = datetime.now()
epochs_range = range(epochs)
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('SE-ResDenseNet - Training and Validation Accuracy')
plt.xlabel(current_time)
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('SE-ResDenseNet - Training and Validation Loss')
plt.show()

5、模型评估
best_model.load_state_dict(torch.load(PATH, map_location=device, weights_only=True))
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
print(f'SE-ResDenseNet Best Test Accuracy: {epoch_test_acc*100:.1f}%')
print(f'SE-ResDenseNet Best Test Loss: {epoch_test_loss:.4f}')
SE-ResDenseNet Best Test Accuracy: 89.2%
SE-ResDenseNet Best Test Loss: 0.5000