《动手学深度学习(PyTorch版)》笔记7.7

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter7 Modern Convolutional Neural Networks

7.7 Densely Connected Networks(DenseNet)

稠密连接网络(DenseNet)在某种程度上是ResNet的逻辑扩展。ResNet将函数展开为
f ( x ) = x + g ( x ) . f(\mathbf{x}) = \mathbf{x} + g(\mathbf{x}). f(x)=x+g(x).也就是说,ResNet将 f f f分解为一个简单的线性项和一个复杂的非线性项。如果想将 f f f拓展成超过两部分的信息,一种方案便是DenseNet。ResNet和DenseNet的关键区别在于,DenseNe的输出是连接 (用 [ , ] [,] [,]表示),而不是ResNet的简单相加(如下图所示),因此我们可以执行从 x \mathbf{x} x到其展开式的映射:

x → [ x , f 1 ( x ) , f 2 ( [ x , f 1 ( x ) ] ) , f 3 ( [ x , f 1 ( x ) , f 2 ( [ x , f 1 ( x ) ] ) ] ) , ... ] . \mathbf{x} \to \left[ \mathbf{x}, f_1(\mathbf{x}), f_2([\mathbf{x}, f_1(\mathbf{x})]), f_3([\mathbf{x}, f_1(\mathbf{x}), f_2([\mathbf{x}, f_1(\mathbf{x})])]), \ldots\right]. x→[x,f1(x),f2([x,f1(x)]),f3([x,f1(x),f2([x,f1(x)])]),...].

最后,将这些展开式结合到多层感知机中,再次减少特征的数量。

DenseNet这个名字由变量之间的"稠密连接"而得来,稠密连接如下图所示。稠密网络主要由2部分构成:稠密块 (dense block)和过渡层(transition layer)。前者定义如何连接输入和输出,而后者则控制通道数量,使其不会太复杂。

复制代码
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

def conv_block(input_channels, num_channels):
    return nn.Sequential(
        nn.BatchNorm2d(input_channels), 
        nn.ReLU(),
        nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1))
    
class DenseBlock(nn.Module):
    def __init__(self, num_convs, input_channels, num_channels):
        super(DenseBlock, self).__init__()
        layer = []
        for i in range(num_convs):
            layer.append(conv_block(num_channels * i + input_channels, num_channels))
        self.net = nn.Sequential(*layer)

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            # 连接通道维度上每个块的输入和输出
            X = torch.cat((X, Y), dim=1)
        return X

#在下面的例子中,我们定义一个有2个输出通道数为10的DenseBlock。
#使用通道数为3的输入时,我们会得到通道数为3+2x10=23的输出。
#卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为*增长率*(growth rate)。

blk = DenseBlock(2, 3, 10)
X = torch.randn(4, 3, 8, 8)
Y = blk(X)
print(Y.shape)

#由于每个稠密块都会带来通道数的增加,过渡层可以用来控制模型复杂度。
#过渡层通过1x1卷积层来减小通道数,并使用步幅为2的平均汇聚层减半高和宽,以降低模型复杂度。

def transition_block(input_channels, num_channels):#过渡层
    return nn.Sequential(
        nn.BatchNorm2d(input_channels), nn.ReLU(),
        nn.Conv2d(input_channels, num_channels, kernel_size=1),
        nn.AvgPool2d(kernel_size=2, stride=2))

blk = transition_block(23, 10)
print(blk(Y).shape)

b1 = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

#与ResNet类似,我们可以设置每个稠密块使用多少个卷积层。这里我们设成4,从而与ResNet-18保持一致。 
#稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。
#num_channels为当前的通道数
num_channels, growth_rate = 64, 32
num_convs_in_dense_blocks = [4, 4, 4, 4]#num_convs_in_dense_blocks表示每个稠密块中包含的卷积层的数量。在这里,有4个稠密块,每个稠密块中包含4个卷积层。
blks = []
for i, num_convs in enumerate(num_convs_in_dense_blocks):#enumerate()函数用于同时遍历列表元素及其索引
    blks.append(DenseBlock(num_convs, num_channels, growth_rate))
    # 上一个稠密块的输出通道数
    num_channels += num_convs * growth_rate
    # 在稠密块之间添加一个转换层,使通道数量减半
    if i != len(num_convs_in_dense_blocks) - 1:
        blks.append(transition_block(num_channels, num_channels // 2))
        num_channels = num_channels // 2
        
net = nn.Sequential(
    b1, *blks,
    nn.BatchNorm2d(num_channels), nn.ReLU(),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(num_channels, 10))

lr, num_epochs, batch_size = 0.1, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

训练结果:

相关推荐
小尧嵌入式1 分钟前
Linux的shell命令
linux·运维·服务器·数据库·c++·windows·算法
码界奇点2 分钟前
基于Django与Ansible的自动化运维管理系统设计与实现
运维·python·django·毕业设计·ansible·源代码管理
Jeremy爱编码4 分钟前
leetcode热题路径总和 III
算法·leetcode·职场和发展
Warren2Lynch5 分钟前
解锁 UML 潜力:Visual Paradigm AI 如何革新用例、活动图和顺序图的设计
人工智能·uml
阿杰学AI6 分钟前
AI核心知识54——大语言模型之Structured CoT(简洁且通俗易懂版)
人工智能·ai·语言模型·prompt·pe·结构化提示词·structured cot
CoovallyAIHub7 分钟前
滑雪季又来了!如何用计算机视觉帮雪场解决最头疼的问题
深度学习·算法·计算机视觉
爱笑的眼睛1110 分钟前
超越 `assert`:深入 Pytest 的高级测试哲学与实践
java·人工智能·python·ai
爱笑的眼睛1110 分钟前
超越静态图表:Bokeh可视化API的实时数据流与交互式应用开发深度解析
java·人工智能·python·ai
lxmyzzs11 分钟前
X-AnyLabeling 自动数据标注保姆级教程:从安装到格式转换全流程
人工智能·数据标注
承渊政道14 分钟前
C++学习之旅【实战全面解析C++类和对象】
c++·笔记·学习