DenseNet详解,附模型代码(pytorch)

文章目录

前言

在当时的计算机视觉领域,从LeNet开始,卷积神经网络逐步开始成为最主流的方法。像AlexNet,VGG,GoogLeNet等等,大家都始终没有放弃去寻找一个最优的网络架构。尤其是当ResNet的出现,其成为了深度学习方向最主要的网络结构之一。 因为ResNet 可以训练出更深的 CNN 模型,其让深度学习也成为了可能,走向深度,从而实现更高的准确度。它的核心在于层与层之间的短路连接 (skip connection), skip connection 有助于训练过程中的梯度的反向传播,一定程度上减缓因为梯度消失导致网络训练不动,甚至效果下降的情况。关于ResNet的文章,我也讲解过,大家感兴趣的可以去看看。

DenseNet的提出

今天我们将要介绍DenseNet,其同样也是为了去探索最优的网络架构。

DenseNet模型,它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接(dense connection),它的名称也是由此而来。DenseNet的另一大特色是通过特征在channel上的连接来实现特征重用(feature reuse)。这些特点让DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能。

为什么建立前后层的连接很重要,尤其是在深层网络中?

**就是我们希望网络通过一系列非线性变换,抽象提取出其所蕴含的深层语义信息,其是符合我们所学习的数据的分布的,但是当网络深了之后,其提取的信息是有偏移的,即不是原来数据的分布。因为每一次的非线性变换其是会损失数据信息的,每一层的偏移会导致最后所学习到数据分布是有问题的,从而导致模型的表现不佳。**以下这图相信能更好的帮助理解。

与ResNet的对比

DenseNet和 ResNet、Inception 网络不同的是,DenseNet 并没有主要从网络的深度和宽度入手,DenseNet 的作者从 feature 入手,通过对 feature 的细腻操控,达到了更好的效果和更少的参数。DenseNet 的中心思想是与其多次学习冗余的特征,特征复用,是一种更好的特征提取方式

ResNet: 通过建立前面层和后面层的短路连接(skip connection), 帮助实现训练过程中更有效的反向传播,训练出更深的 CNN 网络;

DenseNet: 采用了比ResNet更极端的方法,通过密集连接机制,互相连接所有的层,每个层会将前面所有层的输出,在 channel 维度上进行 concat 操作,作为当前层的输入,进而实现特征重用。使用过该种方法不仅仅缓解了梯度消失的现象,也使得其在参数和计算量更少的情况下实现比 ResNet 更优的性能;

同时这里大家注意两者的Skip connection的方式是不同的,ResNet的是add的方式,而DenseNet是concat的方式。

那么二者有什么区别呢?

add

我们来看,以下是 keras 中对 add 的实现源码,pytorch的封装更复杂一些,不过原理都是一样的,看这个就行:

python 复制代码
def _merge_function(self, inputs):
    output = inputs[0]
    for i in range(1, len(inputs)):
        output += inputs[i]
    return output

其中 inputs 为待融合的特征图,inputs[0]、inputs[1]......等的通道数一样,且特征图宽与高也一样。

从代码中可以很容易地看出,add 方式有以下特点

  1. 做的是对应通道对应位置的值的相加,通道数不变
  2. 描述图像的特征个数不变,但是每个特征下的信息却增加了。

concat

同样的,我们通过阅读下面代码实例帮助理解 concat 的工作原理:

python 复制代码
import torch

# 创建两个张量
t1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
t2 = torch.tensor([[7, 8, 9], [10, 11, 12]])

# 沿第1维拼接
result_1 = torch.cat([t1, t2], dim=1)
print(result_1)
# 输出: tensor([[ 1,  2,  3,  7,  8,  9],
#               [ 4,  5,  6, 10, 11, 12]])

在模型网路当中,数据通常为 4 个维度,即 num×channels×height×width ,因此默认值 1 表示的是 channels 通道进行拼接。如:

python 复制代码
combine = torch.cat([d1, add1, add2, add3, add4], 1)

从代码中可以很容易地看出,concat 方式有以下特点:

  1. 做的是通道的合并,通道数变多了
  2. 描述图像的特征个数变多,但是每个特征下的信息却不变。

所以到这里,我们就能够很清晰的知道add操作和concat操作的不同了。

操作 描述 优点 缺点 补充
add - 相当于加了一种prior - 要求两路输入的对应通道特征图语义类似 - 计算量少 - 特征提取能力差 - 对应通道信息类似时,可融合多通道信息 - 尺度不一致时,小尺度特征可能被淹没
concat - 通过训练学习整合两个特征图通道之间的信息 - 特征提取能力强 - 计算量大(是add的2倍) - 能提取更合适的信息,效果更好

网络结构

其实Densenet的构建我们主要就是从三个方面入手的,首先我们要构建Densenet,其实也类似于Resnet的搭建,分块进行构建。我们要搭建Densenet,就要构建DenseBlock,然后要搭建DenseBlock,就要构建好里面的每个DenseLayer,最后我们需要连接不同的DenseBlock,又需要Transition部分。最后搭建Densenet,就将不同的部分按照一定的顺序连接起来即可。其实就是如下图所示的。

DenseLayer

DenseLayer的搭建还是很简单的,就是BN-ReLU-Conv三件套连着来两次。这里有个细节要注意的就是由于越到后面输入会越大,这里Densenet为了减少计算量,在第一个Conv的时候使用1x1卷积调整通道数到 b n s i z e ∗ g r o w t h r a t e bn_{size}*growth_{rate} bnsize∗growthrate,一般bn_size设置为4。从而能够降低特征数量,提升计算效率。

然后最后forward函数里面我们最后输出的是torch.cat([x,new_feature],1),将我们新生成的feature加入到输入中,然后作为下一个DenseLayer的输入,这样就实现了dense connection的思想。

python 复制代码
class _DenseLayer(nn.Module):
    def __init__(self,num_input_features,growth_rate,bn_size,drop_rate):
        super(_DenseLayer, self).__init__()
        self.norm1 = nn.BatchNorm2d(num_input_features)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(num_input_features,bn_size*growth_rate,kernel_size=1,stride=1,padding=0,bias=False)

        self.norm2 = nn.BatchNorm2d(bn_size*growth_rate)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(bn_size*growth_rate,growth_rate,kernel_size=3,stride=1,padding=1,bias=False)
        self.drop_rate = drop_rate


    def forward(self,x):
        new_feature = self.norm1(x)
        new_feature = self.relu1(new_feature)
        new_feature = self.conv1(new_feature)
        new_feature = self.norm2(new_feature)
        new_feature = self.relu2(new_feature)
        new_feature = self.conv2(new_feature)
        if self.drop_rate > 0:
            new_feature = F.dropout(new_feature, p=self.drop_rate, training=self.training)

        return torch.cat([x,new_feature],1)

DenseBlock

主要就是看每个block里面有多少个layer嘛,主要每个layer的输入channel就行了,是递增的,根据增长率growth_rate。

python 复制代码
class _DenseBlock(nn.ModuleDict):
    _version = 2

    def __init__(self,num_layers,num_input_features,growth_rate,bn_size,drop_rate):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate,growth_rate,bn_size,drop_rate)
            self.add_module('denselayer%d'%(i+1),layer)

    def forward(self,features):
        for name,layer in self.items():
            features = layer(features)
        return features

Transition

对于Transition层 ,它主要是连接两个相邻的DenseBlock,并且降低特征图大小。Transition层包括一个1x1的卷积和2x2的AvgPooling,结构为BN+ReLU+1x1 Conv+2x2 AvgPooling。另外,Transition层可以起到压缩模型的作用。

这里我们按照论文中所讲的压缩模型,就是AvgPool2d的设置 ,论文中讲述的压缩系数0.5,即设置2x2 AvgPooling即可。减少shape,起到下采样的作用。

python 复制代码
class _Transition(nn.Sequential):
    def __init__(self,num_input_features,num_output_features):
        super(_Transition, self).__init__()
        self.norm = nn.BatchNorm2d(num_input_features)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(num_input_features,num_output_features,kernel_size=1,stride=1,padding=0,bias=False)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)

Densenet

这部分就比较简单了,对于Densenet,首先经过一个初步的特征编码,7x7的卷积,BN层,激活层,池化层,然后就是四个DenseBlock层的搭建,并且每个DenseBlock使用Transition层进行连接,最后就是BN和全连接层了。这是整体的架构的搭建,在来看一些细节的处理,对于我们所搭建的网络模型,我们肯定是需要进行参数的初始化的,对于卷积层的参数采用凯明初始化,BN层的参数初始化为权重为1,偏置为0。

python 复制代码
class DenseNet(nn.Module):
    def __init__(self,growth_rate=32,block_config=(6,12,24,16),num_init_features=64,bn_size=4,drop_rate=0,num_classes=1000):
        super(DenseNet, self).__init__()
        self.features = nn.Sequential(
            OrderedDict([
                ("conv0",nn.Conv2d(3,num_init_features,kernel_size=7,stride=2,padding=3,bias=False)),
                ("norm0",nn.BatchNorm2d(num_init_features)),
                ("relu0",nn.ReLU(inplace=True)),
                ("pool0",nn.MaxPool2d(kernel_size=3, stride=2)),
            ])
        )
        num_features = num_init_features
        for i,num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers,num_features,growth_rate,bn_size,drop_rate)
            self.features.add_module('denseblock%d'%(i+1),block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config)-1:
                transition = _Transition(num_features,num_features // 2)
                self.features.add_module('transition%d'%(i+1),transition)
                num_features = num_features // 2

        self.features.add_module('norm5',nn.BatchNorm2d(num_features))

        self.classifier = nn.Linear(num_features,num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features,inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out

DenseNet121\161\169\201\264

然后不同深度的DenseNet网络,就是通过控制其是使用的block每层具体的layer数量来控制,所以我们可以搭建多个不同深度的ResNet模型。

python 复制代码
def densenet121(pretrained=True,**kwargs):
    model = DenseNet(growth_rate=32,block_config=(6,12,24,16),**kwargs)
    if pretrained:
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet121'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)

    return model


def densenet161(pretrained=True,**kwargs):
    model = DenseNet(growth_rate=48,block_config=(6, 12, 36, 24),**kwargs)
    if pretrained:
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet161'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model

def densenet169(pretrained=True,**kwargs):
    model = DenseNet(growth_rate=32,block_config=(6,12,32,32),**kwargs)
    if pretrained:
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet169'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model

def densenet201(pretrained=True,**kwargs):
    model = DenseNet(growth_rate=32,block_config=(6,12,48,32),**kwargs)
    if pretrained:
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet201'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model

def densenet264(**kwargs):
    model = DenseNet(growth_rate=32,block_config=(6,12,64,48),**kwargs)
    return model

参考资料

希望这篇文章能够给大家带来些思考,让大家能够有所收获,同时以上内容有所参考下列文章进行学习,并包含了我自己的思考。后续将会给大家带来使用DensenNet解决相关问题的实际应用部署。

DensetNet 介绍 - lucky_light - 博客园

相关推荐
Mryan20052 分钟前
✨ 使用 Flask 实现头像文件上传与加载功能
后端·python·flask
程序员是干活的8 分钟前
Java EE前端技术编程脚本语言JavaScript
java·大数据·前端·数据库·人工智能
chaofan98037 分钟前
ERNIE-4.5-0.3B 实战指南:文心一言 4.5 开源模型的轻量化部署与效能跃升
人工智能·开源·文心一言
hppyhjh39 分钟前
【昇腾CANN训练营】深入cann-ops仓算子编译出包流程
人工智能
飞凌嵌入式39 分钟前
飞凌嵌入式亮相第九届瑞芯微开发者大会:AIoT模型创新重做产品
人工智能·嵌入式硬件·嵌入式·飞凌嵌入式
大模型工程师40 分钟前
TongYiLingMa插件下Qwen3-Coder
人工智能
大模型工程师1 小时前
独立开发:高效集成大模型,看这篇就够了
人工智能
程序员的世界你不懂1 小时前
Jmeter的元件使用介绍:(四)前置处理器详解
开发语言·python·jmeter
倔强青铜三1 小时前
苦练Python第35天:数据结构挑战题,实战演练
人工智能·python·面试
你的电影很有趣1 小时前
lesson24:Python的logging模块
开发语言·python