ResNet与DenseNet结合探索:构建新模型笔记

一、创新的思路

创新思路:基于双通道(Dual Path)的 ResNet 与 DenseNet 融合架构。

在前几周的学习中,我认识到 ResNet 的 Element-wise Add 机制能够有效缓解梯度消失,使得网络可以极深;而 DenseNet 的 Concatenation 机制能够实现特征重用,使得参数利用率极高。

为了结合两者的优势,我设计了一个 ResDense Block (残差-密集混合块)。

具体做法是:在特征提取后,将输出的通道(Channels)分为两部分。一部分与输入特征进行相加(Add)以保证特征和梯度的顺畅流通;另一部分与输入特征进行拼接(Concat)以探索新的特征并实现特征重用。这样,一个模块内既有残差的高速直达,又有密集连接的特征丰富度,实现 1+1>2 的效果。


二、前期准备

python 复制代码
import torch
import torch.nn as nn

三、定义创新核心组件 (ResDense Block)

python 复制代码
class ResDenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super(ResDenseBlock, self).__init__()
        # 为了能让特征相加,提取特征的后半部分通道数必须等于 in_channels
        # 所以这个卷积层输出的总通道数是:in_channels (用于相加) + growth_rate (用于拼接)
        out_channels = in_channels + growth_rate
        
        # 预激活 (Pre-activation) 思想,融合了 J2 和 J3 的精华
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        
        self.in_channels = in_channels

    def forward(self, x):
        # 1. 提取新特征
        out = self.conv1(self.relu(self.bn1(x)))
        
        # 2. 使用 split 函数切分通道
        # 一部分切出来用来做残差(Res),一部分用来做密集拼接(Dense)
        out_res, out_dense = torch.split(out, [self.in_channels, out.size(1) - self.in_channels], dim=1)
        
        # 3. ResNet 的特性:相加 (Add)
        res_features = x + out_res
        
        # 4. DenseNet 的特性:拼接 (Concat)
        # 把相加后的结果,和刚提取出的新生特征拼接在一起
        final_out = torch.cat([res_features, out_dense], dim=1)
        
        return final_out

四、组装全局融合网络 (HybridNet)

python 复制代码
class HybridNet(nn.Module):
    def __init__(self, num_classes=2):
        super(HybridNet, self).__init__()
        
        # 初始的特征提取
        self.init_conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # 使用创新模块构建网络
        # 第1个混合块:输入64,额外生长32个通道 -> 输出 64+32 = 96
        self.block1 = ResDenseBlock(in_channels=64, growth_rate=32)
        # 第2个混合块:输入96,额外生长32个通道 -> 输出 96+32 = 128
        self.block2 = ResDenseBlock(in_channels=96, growth_rate=32)
        
        # 全局池化和分类器
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.init_conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.global_pool(x)
        x = torch.flatten(x, 1)
        out = self.fc(x)
        return out

五、模型结构验证与维度测试

python 复制代码
if __name__ == "__main__":
    print("正在组装 ResNet+DenseNet 融合架构")
    model = HybridNet(num_classes=2)
    
    print("正在生成一张 224x224 的虚拟医学 X 光片送入模型测试")
    dummy_input = torch.randn(2, 3, 224, 224) 
    
    output = model(dummy_input)
    
    print("融合模型构建大成功")
    print(f"模型的最终输出形状为: {output.shape} ")
相关推荐
CC大煊11 分钟前
一个Javaer的AI转型笔记(1):入坑LangChain,我的第一个hello world
笔记·langchain
元气少女小圆丶2 小时前
SenseGlove Nova 2+Unity开发笔记1
笔记·学习·unity
冰暮流星3 小时前
javascript之history对象介绍
前端·笔记
jialiguo4 小时前
博客摘录「 尚硅谷Vue3入门到实战,最新版Vue3+TypeScript前端开发教程」2024年8月7日
笔记
風清掦5 小时前
【STM32学习笔记-14】WDG看门狗 - 14.2 WWDG窗口看门狗
笔记·stm32·单片机·嵌入式硬件·学习·fpga开发
晓梦林5 小时前
bughush靶场学习笔记
笔记·学习
sakiko_6 小时前
Swift学习笔记34-MVC架构,SwiftUI与UIkit混编练习
笔记·学习·swiftui·mvc·swift
Afans_fire6 小时前
多渠道广告归因:3种逻辑解决效果分配难题
笔记·内容运营·广告投放·广告营销·徐州巨量星河
泉飒6 小时前
qt软件无法打开编译
笔记·工业视觉
穗余7 小时前
2026 AI x Web3 School共学营笔记-Day10-Women Builders in AI × Web3
人工智能·笔记·web3