- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
一、创新的思路
创新思路:基于双通道(Dual Path)的 ResNet 与 DenseNet 融合架构。
在前几周的学习中,我认识到 ResNet 的 Element-wise Add 机制能够有效缓解梯度消失,使得网络可以极深;而 DenseNet 的 Concatenation 机制能够实现特征重用,使得参数利用率极高。
为了结合两者的优势,我设计了一个 ResDense Block (残差-密集混合块)。
具体做法是:在特征提取后,将输出的通道(Channels)分为两部分。一部分与输入特征进行相加(Add)以保证特征和梯度的顺畅流通;另一部分与输入特征进行拼接(Concat)以探索新的特征并实现特征重用。这样,一个模块内既有残差的高速直达,又有密集连接的特征丰富度,实现 1+1>2 的效果。


二、前期准备
python
import torch
import torch.nn as nn
三、定义创新核心组件 (ResDense Block)
python
class ResDenseBlock(nn.Module):
def __init__(self, in_channels, growth_rate):
super(ResDenseBlock, self).__init__()
# 为了能让特征相加,提取特征的后半部分通道数必须等于 in_channels
# 所以这个卷积层输出的总通道数是:in_channels (用于相加) + growth_rate (用于拼接)
out_channels = in_channels + growth_rate
# 预激活 (Pre-activation) 思想,融合了 J2 和 J3 的精华
self.bn1 = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.in_channels = in_channels
def forward(self, x):
# 1. 提取新特征
out = self.conv1(self.relu(self.bn1(x)))
# 2. 使用 split 函数切分通道
# 一部分切出来用来做残差(Res),一部分用来做密集拼接(Dense)
out_res, out_dense = torch.split(out, [self.in_channels, out.size(1) - self.in_channels], dim=1)
# 3. ResNet 的特性:相加 (Add)
res_features = x + out_res
# 4. DenseNet 的特性:拼接 (Concat)
# 把相加后的结果,和刚提取出的新生特征拼接在一起
final_out = torch.cat([res_features, out_dense], dim=1)
return final_out
四、组装全局融合网络 (HybridNet)
python
class HybridNet(nn.Module):
def __init__(self, num_classes=2):
super(HybridNet, self).__init__()
# 初始的特征提取
self.init_conv = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# 使用创新模块构建网络
# 第1个混合块:输入64,额外生长32个通道 -> 输出 64+32 = 96
self.block1 = ResDenseBlock(in_channels=64, growth_rate=32)
# 第2个混合块:输入96,额外生长32个通道 -> 输出 96+32 = 128
self.block2 = ResDenseBlock(in_channels=96, growth_rate=32)
# 全局池化和分类器
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
x = self.init_conv(x)
x = self.block1(x)
x = self.block2(x)
x = self.global_pool(x)
x = torch.flatten(x, 1)
out = self.fc(x)
return out
五、模型结构验证与维度测试
python
if __name__ == "__main__":
print("正在组装 ResNet+DenseNet 融合架构")
model = HybridNet(num_classes=2)
print("正在生成一张 224x224 的虚拟医学 X 光片送入模型测试")
dummy_input = torch.randn(2, 3, 224, 224)
output = model(dummy_input)
print("融合模型构建大成功")
print(f"模型的最终输出形状为: {output.shape} ")
