深度学习每周学习总结J4(ResDenseNet 算法探索实践 - 鸟类识别)

目录

    • [一:回顾与总结: 三种神经网络模型对比研究及尝试构成新的网络结构模型](#一:回顾与总结: 三种神经网络模型对比研究及尝试构成新的网络结构模型)
      • 卷积计算过程
      • [ResNet-50 模型](#ResNet-50 模型)
        • [1. 关于残差](#1. 关于残差)
        • [2. 关于ConvBlock的思考](#2. 关于ConvBlock的思考)
        • [3. 关于IdentityBlock的思考](#3. 关于IdentityBlock的思考)
      • [ResNet-50v2 模型](#ResNet-50v2 模型)
        • 1.与ResNet-50的区别
        • [2. 各模块分解](#2. 各模块分解)
          • [**Block2 类**](#Block2 类)
          • [**Stack2 类**](#Stack2 类)
          • [**ResNet50V2 类**](#ResNet50V2 类)
      • [DenseNet 模型](#DenseNet 模型)
        • [1. 设计理念](#1. 设计理念)
        • [2. 网络结构](#2. 网络结构)
        • [3. 与CNN和ResNet的对比](#3. 与CNN和ResNet的对比)
        • [4. 代码及关键点解析](#4. 代码及关键点解析)
          • [1. 模块分解](#1. 模块分解)
            • [**_DenseLayer 类**](#_DenseLayer 类)
            • [**_DenseBlock 类**](#_DenseBlock 类)
            • [**_Transition 类**](#_Transition 类)
            • [**DenseNet 类**](#DenseNet 类)
          • [2. DenseNet 关键特点分析](#2. DenseNet 关键特点分析)
            • [特点 1: 特征重用机制](#特点 1: 特征重用机制)
            • [特点 2: 瓶颈设计](#特点 2: 瓶颈设计)
            • [特点 3: 过渡层](#特点 3: 过渡层)
            • [特点 4: Dropout 机制](#特点 4: Dropout 机制)
            • [特点 5: Globally Average Pooling](#特点 5: Globally Average Pooling)
          • [3. 总结](#3. 总结)
      • 新的模型框架
        • [1. 理解 ResNet 和 DenseNet 的关键特点](#1. 理解 ResNet 和 DenseNet 的关键特点)
          • [ResNet 的关键特点](#ResNet 的关键特点)
          • [DenseNet 的关键特点](#DenseNet 的关键特点)
        • [2. 设计新的模型框架](#2. 设计新的模型框架)
        • [3. 新模型框架的构建](#3. 新模型框架的构建)
        • [4. 关键点和总结](#4. 关键点和总结)
        • [5. 未来的方向](#5. 未来的方向)
    • 二:代码流程
      • [0. 总结](#0. 总结)
      • [1. 设置GPU](#1. 设置GPU)
      • [2. 导入数据及处理部分](#2. 导入数据及处理部分)
      • [3. 划分数据集](#3. 划分数据集)
      • [4. 模型构建部分](#4. 模型构建部分)
      • [5. 设置超参数:定义损失函数,学习率,以及根据学习率定义优化器等](#5. 设置超参数:定义损失函数,学习率,以及根据学习率定义优化器等)
      • [6. 训练函数](#6. 训练函数)
      • [7. 测试函数](#7. 测试函数)
      • [8. 正式训练](#8. 正式训练)
      • [9. 结果可视化](#9. 结果可视化)
      • [10. 模型的保存](#10. 模型的保存)
      • [11. 使用训练好的模型进行预测](#11. 使用训练好的模型进行预测)

一:回顾与总结: 三种神经网络模型对比研究及尝试构成新的网络结构模型

卷积计算过程

对于卷积新的理解:

我之前的误解是错误以为有几个卷积核,就有几个权重矩阵。比如输入通道数为3,输出通道数为2,我误以为只有两个权重矩阵。但其实对于一个卷积层来说,卷积核的数量等同于输出通道数(也称为输出的特征图数)。如果输入通道数为3,输出通道数为2,那么根据卷积操作的定义,每个输出通道对应一个单独的卷积核组。每个卷积核组由数量等于输入通道数的权重矩阵组成,因此在这个例子中,实际上有 3 × 2 = 6 3×2=6 3×2=6 个权重矩阵(即每个输出通道有3个权重矩阵,与输入通道数相对应)。

此外,在训练神经网络时,除了卷积核的参数矩阵(也称为权重矩阵)以外,偏置参数 $ b$ 也是会随着训练过程而改变的。偏置参数 b b b 是每个神经元的一个独立参数。需要注意这个偏置是与整个卷积核相关的,而不是与单个权重矩阵相关,一个卷积核组可能有复数个权重矩阵。因此,对于每个输出通道,模型在计算完成卷积后,会加上各自的偏置来调整输出,它在每次反向传播过程中通过梯度下降或其他优化算法进行更新。

这样的更新机制是为了让神经网络模型能够更好地拟合训练数据。偏置参数 b b b的存在使得每个神经元在没有任何输入时就可以有一个非零输出,从而增加了模型的灵活性和表达能力。

Source: http://cs231n.github.io/convolutional-networks/

ResNet-50 模型

1. 关于残差

目前残差结构主要有两种:

左边的单元为 ResNet 两层的残差单元,两层的残差单元包含两个相同输出的通道数的 3x3 卷积,只是用于较浅的 ResNet 网络,对较深的网络主要使用三层的残差单元。三层的残差单元又称为 bottleneck 结构,先用一个 1x1 卷积进行降维,然后 3x3 卷积,最后用 1x1 升维恢复原有的维度。另外,如果有输入输出维度不同的情况,可以对输入做一个线性映射变换维度,再连接后面的层。三层的残差单元对于相同数量的层又减少了参数量,因此可以拓展更深的模型。通过残差单元的组合有经典的 ResNet-50,ResNet-101 等网络结构。

2. 关于ConvBlock的思考

关于convblock的参数设置例如:[64,64,256]:

主要基于参数选择的经验性:第一个是常用的通道数64是常见的开始值,第二个值为64是为了保持相同的特征数保留和加强特征,最后一个256会大一些是为了增加特征维度和网络模型的表达能力

残差层(layer1 到 layer4)

每个残差层通过 _build_layer 函数构建,每个层的输出通道配置如 [64, 64, 256], [128, 128, 512], [256, 256, 1024], [512, 512, 2048],仍然是基于经验法则和文献中的设计,确保了网络在进行深度学习时能够高效提取特征。

3. 关于IdentityBlock的思考
Identity Block 的特征

在经典的残差网络实现中,一个 Identity Block 通常具有以下特征:

无尺寸或通道变化:在 Identity Block 中,输入和输出的通道数相同,并且通常步幅也为 1。

直接使用输入作为快捷路径:在没有下采样和通道变换的情况下,快捷连接就直接是输入。

相关代码解析
python 复制代码
self.shortcut = nn.Sequential()
if in_channels != filters3 or strides != 1:
    self.shortcut = nn.Sequential(
        nn.Conv2d(in_channels, filters3, kernel_size=1, stride=strides),
        nn.BatchNorm2d(filters3)
    )
  1. shortcut的定义
  • self.shortcut 是一个用于储存快捷连接的神经网络模块(Sequential)。
  • 当输入层in_channels 与输出层 filters3 不同,或者 strides 不等于1(即需要下采样)时,表示需要调整输入特征图的尺寸或通道数。
    • 这种情况下,shortcut 被定义为一个卷积层(Conv2d)和一个批量归一化层(BatchNorm2d)的组合。
  • 如果通道数和步幅满足条件(即相等且为1),那么 shortcut 直接使用输入 x,实际上不会发生任何变化。
  1. Identity Block 的特性
  • 身份连接 :在残差网络中,身份连接使得输入能直接回馈到输出。当 in_channelsfilters3 相等,且步幅为1时,shortcut 就会等于输入 x。这正是 Identity Block 的特性,即不对输入进行任何修改。
完整前向传播过程

forward 方法中,这段代码实现了身份连接:

python 复制代码
def forward(self, x):
    shortcut = self.shortcut(x)  # 使用 shortcut 将 x 进行改动或保持不变

    x = F.relu(self.bn1(self.conv1(x)))  # 卷积操作和激活
    x = F.relu(self.bn2(self.conv2(x)))  # 另一个卷积操作和激活
    x = self.bn3(self.conv3(x))          # 最后的卷积操作

    x += shortcut  # 将调整(或不调整)的 shortcut 加到输出 x 上
    x = F.relu(x)  # 再次进行激活
    return x
  • 加法操作x += shortcut):这一行实现了残差连接。根据前面定义的逻辑,如果 shortcut 是通过卷积操作得到的结果,那么输入信号经过了变换,以下采样的方式对信号进行匹配。
  • 如果 shortcut 是直接使用的输入 x,那么就实现了身份连接,保持了输入特征图的维度和内容。

总结:

你的 ConvBlock 实现中已经巧妙地综合了 ConvBlockIdentity Block 的特性。通过对 shortcut 的定义和处理,你的代码在有需要时能够对信号进行变换,而在不需要时则能够保持信号不变。这种灵活性使得你的实现既可以实现残差学习的功能,也符合对身份块的定义。因此,不需要独立声明一个单独的 Identity Block。你的实现已经具备了这种能力。

ResNet-50v2 模型

1.与ResNet-50的区别

关键区别

  1. 预激活结构

    • ResNet50V2 使用了预激活方式(pre-activation block),在执行卷积操作之前先进行 Batch Normalization 和激活函数处理。这种结构使得网络在训练时更平稳,梯度传播更有效。
  2. 快捷连接的实现

    • Block2 中,允许通过卷积层进行快捷连接(conv_shortcut),而不是简单地保持输入不变。这种方式在输出通道数需要变化时能够更好地对齐特征图。
  3. 堆叠的结构

    • Stack2 类负责管理多个 Block2 的堆叠,每个堆叠都在调整通道数和步幅时保持更清晰的结构。这在模型构建时,使得模块更易于组织和理解。
  4. 模块化和可重用性

    • ResNet50V2 更加模块化,各模块之间具有良好的隔离性,这样有助于模型的可读性和重用性。
  5. 整体设计

    • 整体的通道调整和层数设定使得 ResNet50V2 能够在特征提取上更加灵活,适应不同的输入特征。

总结

ResNet50V2 改进了原始 ResNet50 的结构,使之在处理深度学习任务时更加高效。通过预激活、灵活的快捷连接和模块化设计,使得网络能够更好地训练和泛化。这些改进为构建更深更复杂的神经网络打下了基础,也进一步解放了网络结构设计的限制。

2. 各模块分解
Block2 类
  • 初始化方法 __init__

    • 设定了 conv_shortcutstride 等参数,用于控制是否使用卷积快捷连接以及步幅的设置。
    • 预激活(Pre-activation):使用了 Batch Normalization 和 ReLU 激活函数。预激活是一种改进,有助于缓解深层网络的梯度消失问题。
    • 快捷路径(Shortcut Path)
      • 如果 conv_shortcut 为真,通过 1x1 卷积调整通道数(4 * filters),以实现跨层连接。
      • 如果步幅不为 1 而 conv_shortcut 为假,则使用最大池化。
    • 该模块的三个卷积层结构沿用了标准的 ResNet 架构。
  • 前向传播 forward

    • 先进行预激活处理(Batch Normalization + ReLU)。
    • 根据设定的 conv_shortcut 决定快捷连接的实现。
    • 将经过卷积的特征图与快捷路径相加,然后返回结果。
Stack2 类
  • 在该类中,多个 Block2 被堆叠在一起。
  • 第一层使用卷积快捷连接,后续层保持输入和输出通道一致。
ResNet50V2 类
  • 初始化方法 __init__

    • 模型的输入从 3(RGB 图像)开始,经过一系列卷积层和堆栈结构。
    • stack1stack4 各自包含不同数量的 Block2,每层都对应通道数和块数的调整。
    • 使用了标准的卷积后,增加了一个 Batch Normalization 层和 ReLU 激活。
  • 前向传播 forward

    • 输入数据经过了一系列处理,包括卷积、池化、堆叠处理、全局平均池化和最终的全连接层输出。

DenseNet 模型

它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接(dense connection),它的名称也是由此而来。DenseNet的另一大特色是通过特征在channel上的连接来实现特征重用(feature reuse)。这些特点让DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能

1. 设计理念

DenseNet的核心设计理念是"密集连接"(Dense Connectivity)。其目标是通过在每一层与所有前面层的特征图进行连接,来更有效地利用特征,缓解深度网络中的梯度消失问题。具体来说,DenseNet的主要优点包括:

  • 特征重用:每一层都直接连接到前面的所有层,使得特征的重用变得更加高效,提高了信息流动性。
  • 降低参数数量:通过密集连接,DenseNet显著减少了需要学习的参数数量,从而降低了过拟合的风险。
  • 改善梯度流动:更好的梯度流动使得网络训练更加稳定,即使在极深的网络结构中也能保持良好的性能。
2. 网络结构

DenseNet的基本结构由多个稠密块(Dense Block)组成,每个稠密块内的每一层都与前面所有层进行连接。DenseNet的结构通常包含以下几个部分:

  • 输入层:输入原始图像。
  • 卷积层:初始卷积层用于提取特征。
  • 稠密块:由多个卷积层组成,每个卷积层的输出都会被连接到下一层。
  • 过渡层(Transition Layer):通常用于降低特征图的尺寸,采用1x1卷积和平均池化。
  • 分类层:最终的全局平均池化层和全连接层进行分类。

以下是DenseNet的一个简单结构示意图:


3. 与CNN和ResNet的对比
  • CNN(卷积神经网络):传统的CNN采用每层单独连接的结构,深度较大时会面临梯度消失和特征梯度消失问题。DenseNet通过密集连接,有效地解决了这些问题。

  • ResNet(残差网络):ResNet引入了残差学习,通过跳跃连接(skip connections)来缓解深度网络中的梯度消失问题。与ResNet不同,DenseNet将每一层都连接起来,这意味着每层的信息都可以从所有前面层中获得。DenseNet通常比同深度的ResNet具有更少的参数,从而提高了计算效率和分类精度。

以下是一个简单的对比图示:

DenseNet通过密集连接层间的特征图,有效地利用了信息流,降低了对参数的需求,并且改善了梯度流动。与传统的CNN和ResNet相比,DenseNet在许多应用上能够提供更好的性能和更高的效率。

4. 代码及关键点解析

让我们逐步分析这段 DenseNet 的代码实现,并详细解析模型的关键特点在代码中的体现。

1. 模块分解
_DenseLayer 类
  • 初始化方法 __init__

    • 创建一个基本的稠密层(Dense Layer),实现了瓶颈结构,主要由两个卷积层和两个 Batch Normalization 层组成。
    • 首先将输入特征通过 Batch Normalization 和 ReLU 激活,然后进行 1x1 卷积,接着经过 Batch Normalization 和 ReLU,再进行 3x3 的卷积。
    • drop_rate 是用来随机失活(Dropout)的概率。
  • 前向传播 forward

    • 先经过 BatchNormReLU,再执行 Conv2d 操作。
    • 如果 drop_rate 大于 0,则对新特征进行随机失活。
    • 最后,使用 torch.cat 将原始输入 x 和新产生的特征 new_features 在通道维度上连接。这是 DenseNet 的核心思想之一:每层都可以访问前面所有层的特征。
_DenseBlock 类
  • 初始化方法 __init__
    • 创建一个稠密块(Dense Block),由多个 _DenseLayer 组成。
    • 根据传入的 num_layers 为每层创建一个 _DenseLayer,并将其添加到该块中。
_Transition 类
  • 初始化方法 __init__
    • 实现了稠密块之间的过渡层(Transition Layer),主要用于减少特征图的通道数和进行空间下采样。
    • 包括 Batch Normalization、ReLU、1x1 卷积和 2x2 的平均池化。
DenseNet 类
  • 初始化方法 __init__

    • 初始化 DenseNet 模型的各个部分。
    • 卷积层和最大池化 :从输入的 RGB 图片开始,然后进行 7x7 的卷积和 3x3 的最大池化,生成初始特征。
    • DenseBlocks :根据 block_config 指定的结构,多个 _DenseBlock 被堆叠在一起。在每个 Block 之后,如果不是最后一个 Block,会加入一个过渡层。
    • 最终的 Batch Normalization 和 ReLU 处理经过所有 DenseBlocks 的特征。
    • 通过一个全连接层进行分类。
  • 参数初始化:使用 Kaiming 初始化方法初始化卷积层,并将 Batch Normalization 和全连接层的偏置和权重进行初始化。

  • 前向传播 forward

    • 将输入 x 传递通过 features 串联的各层。
    • 在最后进行全局平均池化(7x7),将特征展平,并通过分类层输出最终的类概率。
2. DenseNet 关键特点分析
特点 1: 特征重用机制
  • DenseNet 通过 torch.cat 将所有前面层的特征连接起来,允许每一层直接访问所有之前层的输出。这种设计能够有效减轻深层网络的退化问题,并增强特征重用的机制。

代码片段:

python 复制代码
return torch.cat([x, new_features], 1)  # 在通道维度连接所有特征
特点 2: 瓶颈设计
  • _DenseLayer 中使用 1x1 卷积作为瓶颈层,这减少了在特征图上进一步运算需要的计算量和空间复杂度。
python 复制代码
self.add_module("conv1", nn.Conv2d(num_input_features, bn_size*growth_rate, kernel_size=1, stride=1, bias=False))
特点 3: 过渡层
  • 使用 _Transition 类来控制模型的复杂度和特征图的尺寸,使网络更具灵活性。通过平均池化来降低特征图尺寸,通过卷积减少通道数,防止过拟合。
python 复制代码
self.add_module("pool", nn.AvgPool2d(2, stride=2))  # 下采样
特点 4: Dropout 机制
  • 在每个稠密层之后可以使用 Dropout 来减少过拟合风险,增加了模型的泛化能力。
python 复制代码
if self.drop_rate > 0:
    new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
特点 5: Globally Average Pooling
  • 最终使用全局平均池化,这有助于减少模型的参数并提取全局特征,特别是在分类任务中表现更好。
3. 总结

DenseNet 在结构设计和数据流动态上与传统的卷积神经网络(如 ResNet)有显著差异,具体包括:

  • 特征重用:通过所有层的特征连接,增强了特征的流动和重用。
  • 瓶颈结构:减小了每个 DenseLayer 的计算负担,提高了计算效率。
  • 稠密块和过渡层:通过更细致地控制特征图的生长,保持了模型的紧凑和高效。
  • Dropout 和全局平均池化:增强了模型的泛化能力并减少了参数数量。

这些特性使 DenseNet 在处理复杂视觉任务时具有更好的性能,并在一定程度上克服了深度网络常见的退化和过拟合问题。

新的模型框架

探索将 ResNet 和 DenseNet 结合的可能性,是深度学习模型设计中的一个有趣方向。两者各自在不同的特性和优势,为我们提供了独特的构建块。我们可以尝试创建一个新的模型框架,结合这两者的优点,提升模型的性能。

1. 理解 ResNet 和 DenseNet 的关键特点
ResNet 的关键特点
  • 残差连接:通过直接连接输入和输出,允许梯度流过深层网络,缓解梯度消失问题。
  • 简化结构:在架构上相对简单,通过较少的层数实现深度,便于训练。
DenseNet 的关键特点
  • 特征重用:每一层利用之前所有层的特征图,使得网络更加高效和有效,增强了特征的传递能力。
  • 密集连接:与前一层直接相连,促进梯度流动和信息流动。
2. 设计新的模型框架

结合 ResNet 和 DenseNet 的特性,我们可以考虑以下几个关键设计原则:

  1. 引入残差连接:在 DenseNet 的每个层之间保留残差连接,以使训练过程更稳定。
  2. 特征重用与密集连接相结合:在模型内部使用 DenseNet 的特征重用机制,同时保持 ResNet 的残差连接。
  3. 选择合适的转换块:在 Dense Block 和 Residual Block 之间逐步过渡,以便控制特征图的维度和复杂度。
3. 新模型框架的构建

下面是一个新的混合模型框架的代码实现示例,结合了 ResNet 和 DenseNet 的特性。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

class DenseResLayer(nn.Module):
    """A layer that combines DenseNet and ResNet features."""
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(DenseResLayer, self).__init__()
        # Pre-activation Block
        self.norm1 = nn.BatchNorm2d(num_input_features)
        self.relu1 = nn.ReLU(inplace=True)
        
        # Dense Block
        self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, bias=False)
        self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
        self.relu2 = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
        self.drop_rate = drop_rate

    def forward(self, x):
        # First part of the block (DenseNet)
        out = self.norm1(x)
        out = self.relu1(out)
        
        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu2(out)
        
        out = self.conv2(out)
        
        # Adding Residual Connection
        out += x
        
        if self.drop_rate > 0:
            out = F.dropout(out, p=self.drop_rate, training=self.training)

        return out

class DenseResNet(nn.Module):
    """A new Dense-Residual network model."""
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
                 bn_size=4, drop_rate=0, num_classes=1000):
        super(DenseResNet, self).__init__()

        # Initial Conv Layer
        self.features = nn.Sequential(OrderedDict([
            ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ("norm0", nn.BatchNorm2d(num_init_features)),
            ("relu0", nn.ReLU(inplace=True)),
            ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        ]))

        num_features = num_init_features
        
        for i, num_layers in enumerate(block_config):
            for j in range(num_layers):
                layer = DenseResLayer(num_features, growth_rate, bn_size, drop_rate)
                self.features.add_module("denselayer%d-%d" % (i + 1, j + 1), layer)
                num_features += growth_rate  # Update feature number

        # Final BatchNorm + ReLU
        self.features.add_module("norm_final", nn.BatchNorm2d(num_features))
        self.features.add_module("relu_final", nn.ReLU(inplace=True))

        # Classification Layer
        self.classifier = nn.Linear(num_features, num_classes)

    def forward(self, x):
        features = self.features(x)
        out = F.avg_pool2d(features, features.size(2)).view(features.size(0), -1)  # Global Average Pooling
        out = self.classifier(out)
        return out
4. 关键点和总结
关键点
  1. DenseResLayer:

    • 结合了 DenseNet 的特征重用机制(通过添加特征)和 ResNet 的残差连接。
    • 使用 Batch Normalization 和 ReLU 激活确保梯度流动和训练稳定性。
  2. DenseResNet:

    • 在网络构建中,使用多层稠密残差层,将 ResNet 的残差连接与 Dense Block 的特征连接相结合。
    • 通过堆叠多个这种层,模型集成了 DenseNet 和 ResNet 的优势。
  3. 灵活性和扩展性

    • 可以通过调整 growth_rateblock_config 和其他超参数来创建不同深度和宽度的网络,来适应不同任务和数据集的需求。
5. 未来的方向

通过结合 ResNet 和 DenseNet 的特性,我们获得了一种新的模型框架,理论上会在不同的任务上表现良好。在实际应用中,可以通过以下方式进一步优化和评估:

  • 超参数调整:根据具体任务,在训练前进行超参数优化,以获得更好的性能。
  • 扩展到更深的网络:在具有足够计算资源的情况下,尝试构建更深的网络,评估深度增加对性能的影响。
  • 多任务学习或迁移学习:评估该框架在多个计算机视觉任务上的有效性,或在预训练模型的基础上进行微调。

结合 ResNet 和 DenseNet 的方法,是一种深层网络结构设计的探索,能够更好地利用特征并提升模型表现。希望这个新的模型框架对你的研究和应用有所帮助!

二:代码流程

0. 总结

数据导入及处理部分:本次数据导入没有使用torchvision自带的数据集,需要将原始数据进行处理包括数据导入,查看数据分类情况,定义transforms,进行数据类型转换等操作。

划分数据集:划定训练集测试集后,再使用torch.utils.data中的DataLoader()分别加载上一步处理好的训练及测试数据,查看批处理维度.

模型构建部分:resdesnet

设置超参数:在这之前需要定义损失函数,学习率(动态学习率),以及根据学习率定义优化器(例如SGD随机梯度下降),用来在训练中更新参数,最小化损失函数。

定义训练函数:函数的传入的参数有四个,分别是设置好的DataLoader(),定义好的模型,损失函数,优化器。函数内部初始化损失准确率为0,接着开始循环,使用DataLoader()获取一个批次的数据,对这个批次的数据带入模型得到预测值,然后使用损失函数计算得到损失值。接下来就是进行反向传播以及使用优化器优化参数,梯度清零放在反向传播之前或者是使用优化器优化之后都是可以的,一般是默认放在反向传播之前。

定义测试函数:函数传入的参数相比训练函数少了优化器,只需传入设置好的DataLoader(),定义好的模型,损失函数。此外除了处理批次数据时无需再设置梯度清零、返向传播以及优化器优化参数,其余部分均和训练函数保持一致。

训练过程:定义训练次数,有几次就使用整个数据集进行几次训练,初始化四个空list分别存储每次训练及测试的准确率及损失。使用model.train()开启训练模式,调用训练函数得到准确率及损失。使用model.eval()将模型设置为评估模式,调用测试函数得到准确率及损失。接着就是将得到的训练及测试的准确率及损失存储到相应list中并合并打印出来,得到每一次整体训练后的准确率及损失。

结果可视化

模型的保存,调取及使用。在PyTorch中,通常使用 torch.save(model.state_dict(), 'model.pth') 保存模型的参数,使用 model.load_state_dict(torch.load('model.pth')) 加载参数。

需要改进优化的地方:确保模型和数据的一致性,都存到GPU或者CPU;注意numclasses不要直接用默认的1000,需要根据实际数据集改进;实例化模型也要注意numclasses这个参数;此外注意测试模型需要用(3,224,224)3表示通道数,这和tensorflow定义的顺序是不用的(224,224,3),做代码转换时需要注意。

python 复制代码
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import torchvision.models as models
import torch.nn.functional as F
from collections import OrderedDict 


import os,PIL,pathlib
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings('ignore') # 忽略警告信息

plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False   # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 # 分辨率

1. 设置GPU

python 复制代码
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')

2. 导入数据及处理部分

python 复制代码
# 获取数据分布情况
path_dir = './data/bird_photos/'
path_dir = pathlib.Path(path_dir)

paths = list(path_dir.glob('*'))
# classNames = [str(path).split("\\")[-1] for path in paths] # ['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
classNames = [path.parts[-1] for path in paths]
classNames
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
python 复制代码
# 定义transforms 并处理数据
train_transforms = transforms.Compose([
    transforms.Resize([224,224]),      # 将输入图片resize成统一尺寸
    transforms.RandomHorizontalFlip(), # 随机水平翻转
    transforms.ToTensor(),             # 将PIL Image 或 numpy.ndarray 装换为tensor,并归一化到[0,1]之间
    transforms.Normalize(              # 标准化处理 --> 转换为标准正太分布(高斯分布),使模型更容易收敛
        mean = [0.485,0.456,0.406],    # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
        std = [0.229,0.224,0.225]
    )
])
test_transforms = transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ToTensor(),
    transforms.Normalize(
        mean = [0.485,0.456,0.406],
        std = [0.229,0.224,0.225]
    )
])
total_data = datasets.ImageFolder('./data/bird_photos/',transform = train_transforms)
total_data
Dataset ImageFolder
    Number of datapoints: 565
    Root location: ./data/bird_photos/
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )
python 复制代码
total_data.class_to_idx
{'Bananaquit': 0,
 'Black Skimmer': 1,
 'Black Throated Bushtiti': 2,
 'Cockatoo': 3}

3. 划分数据集

python 复制代码
# 划分数据集
train_size = int(len(total_data) * 0.8)
test_size = len(total_data) - train_size

train_dataset,test_dataset = torch.utils.data.random_split(total_data,[train_size,test_size])
train_dataset,test_dataset
(<torch.utils.data.dataset.Subset at 0x24b5c8a54e0>,
 <torch.utils.data.dataset.Subset at 0x24b5c8a5720>)
python 复制代码
# 定义DataLoader用于数据集的加载

batch_size = 32

train_dl = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = batch_size,
    shuffle = True,
    num_workers = 1
)
test_dl = torch.utils.data.DataLoader(
    test_dataset,
    batch_size = batch_size,
    shuffle = True,
    num_workers = 1
)
python 复制代码
# 观察数据维度
for X,y in test_dl:
    print("Shape of X [N,C,H,W]: ",X.shape)
    print("Shape of y: ", y.shape,y.dtype)
    break
Shape of X [N,C,H,W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

4. 模型构建部分

python 复制代码
class _DenseLayer(nn.Sequential):
    """Basic unit of DenseBlock (using bottleneck layer) """
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()
        self.add_module("norm1", nn.BatchNorm2d(num_input_features))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv1", nn.Conv2d(num_input_features, bn_size*growth_rate,
                                           kernel_size=1, stride=1, bias=False))
        self.add_module("norm2", nn.BatchNorm2d(bn_size*growth_rate))
        self.add_module("relu2", nn.ReLU(inplace=True))
        self.add_module("conv2", nn.Conv2d(bn_size*growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1, bias=False))
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)
    
class _DenseBlock(nn.Sequential):
    """DenseBlock"""
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features+i*growth_rate, growth_rate, bn_size,
                                drop_rate)
            self.add_module("denselayer%d" % (i+1,), layer)
            
class _Transition(nn.Sequential):
    """Transition layer between two adjacent DenseBlock"""
    def __init__(self, num_input_feature, num_output_features):
        super(_Transition, self).__init__()
        self.add_module("norm", nn.BatchNorm2d(num_input_feature))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(num_input_feature, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module("pool", nn.AvgPool2d(2, stride=2))

        
class DenseNet(nn.Module):
    "DenseNet-BC model"
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
                 bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=1000):
        """
        :param growth_rate: (int) number of filters used in DenseLayer, `k` in the paper
        :param block_config: (list of 4 ints) number of layers in each DenseBlock
        :param num_init_features: (int) number of filters in the first Conv2d
        :param bn_size: (int) the factor using in the bottleneck layer
        :param compression_rate: (float) the compression rate used in Transition Layer
        :param drop_rate: (float) the drop rate after each DenseLayer
        :param num_classes: (int) number of classes for classification
        """
        super(DenseNet, self).__init__()
        # first Conv2d
        self.features = nn.Sequential(OrderedDict([
            ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ("norm0", nn.BatchNorm2d(num_init_features)),
            ("relu0", nn.ReLU(inplace=True)),
            ("pool0", nn.MaxPool2d(3, stride=2, padding=1))
        ]))

        # DenseBlock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)
            self.features.add_module("denseblock%d" % (i + 1), block)
            num_features += num_layers*growth_rate
            if i != len(block_config) - 1:
                transition = _Transition(num_features, int(num_features*compression_rate))
                self.features.add_module("transition%d" % (i + 1), transition)
                num_features = int(num_features * compression_rate)

        # final bn+ReLU
        self.features.add_module("norm5", nn.BatchNorm2d(num_features))
        self.features.add_module("relu5", nn.ReLU(inplace=True))

        # classification layer
        self.classifier = nn.Linear(num_features, num_classes)

        # params initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.avg_pool2d(features, 7, stride=1).view(features.size(0), -1)
        out = self.classifier(out)
        return out
python 复制代码
# Now, instantiate and use the model
densenet121 = DenseNet(num_init_features=64, # init_channel=64,
                       growth_rate=32,
                       block_config=(6,12,24,16),
                       num_classes=len(classNames))  

model = densenet121.to(device)
model
DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer3): _DenseLayer(
        (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer4): _DenseLayer(
        (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer5): _DenseLayer(
        (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer6): _DenseLayer(
        (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (transition1): _Transition(
      (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
    (denseblock2): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer3): _DenseLayer(
        (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer4): _DenseLayer(
        (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer5): _DenseLayer(
        (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer6): _DenseLayer(
        (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer7): _DenseLayer(
        (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer8): _DenseLayer(
        (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer9): _DenseLayer(
        (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer10): _DenseLayer(
        (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer11): _DenseLayer(
        (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer12): _DenseLayer(
        (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (transition2): _Transition(
      (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
    (denseblock3): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer3): _DenseLayer(
        (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer4): _DenseLayer(
        (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer5): _DenseLayer(
        (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer6): _DenseLayer(
        (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer7): _DenseLayer(
        (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer8): _DenseLayer(
        (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer9): _DenseLayer(
        (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer10): _DenseLayer(
        (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer11): _DenseLayer(
        (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer12): _DenseLayer(
        (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer13): _DenseLayer(
        (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer14): _DenseLayer(
        (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer15): _DenseLayer(
        (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer16): _DenseLayer(
        (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer17): _DenseLayer(
        (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer18): _DenseLayer(
        (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer19): _DenseLayer(
        (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer20): _DenseLayer(
        (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer21): _DenseLayer(
        (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer22): _DenseLayer(
        (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer23): _DenseLayer(
        (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer24): _DenseLayer(
        (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (transition3): _Transition(
      (norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
    (denseblock4): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer3): _DenseLayer(
        (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer4): _DenseLayer(
        (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer5): _DenseLayer(
        (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer6): _DenseLayer(
        (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer7): _DenseLayer(
        (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer8): _DenseLayer(
        (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer9): _DenseLayer(
        (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer10): _DenseLayer(
        (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer11): _DenseLayer(
        (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer12): _DenseLayer(
        (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer13): _DenseLayer(
        (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer14): _DenseLayer(
        (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer15): _DenseLayer(
        (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer16): _DenseLayer(
        (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu5): ReLU(inplace=True)
  )
  (classifier): Linear(in_features=1024, out_features=4, bias=True)
)
python 复制代码
# 查看模型详情
import torchsummary as summary
summary.summary(model,(3,224,224))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
       BatchNorm2d-5           [-1, 64, 56, 56]             128
              ReLU-6           [-1, 64, 56, 56]               0
            Conv2d-7          [-1, 128, 56, 56]           8,192
       BatchNorm2d-8          [-1, 128, 56, 56]             256
              ReLU-9          [-1, 128, 56, 56]               0
           Conv2d-10           [-1, 32, 56, 56]          36,864
      BatchNorm2d-11           [-1, 96, 56, 56]             192
             ReLU-12           [-1, 96, 56, 56]               0
           Conv2d-13          [-1, 128, 56, 56]          12,288
      BatchNorm2d-14          [-1, 128, 56, 56]             256
             ReLU-15          [-1, 128, 56, 56]               0
           Conv2d-16           [-1, 32, 56, 56]          36,864
      BatchNorm2d-17          [-1, 128, 56, 56]             256
             ReLU-18          [-1, 128, 56, 56]               0
           Conv2d-19          [-1, 128, 56, 56]          16,384
      BatchNorm2d-20          [-1, 128, 56, 56]             256
             ReLU-21          [-1, 128, 56, 56]               0
           Conv2d-22           [-1, 32, 56, 56]          36,864
      BatchNorm2d-23          [-1, 160, 56, 56]             320
             ReLU-24          [-1, 160, 56, 56]               0
           Conv2d-25          [-1, 128, 56, 56]          20,480
      BatchNorm2d-26          [-1, 128, 56, 56]             256
             ReLU-27          [-1, 128, 56, 56]               0
           Conv2d-28           [-1, 32, 56, 56]          36,864
      BatchNorm2d-29          [-1, 192, 56, 56]             384
             ReLU-30          [-1, 192, 56, 56]               0
           Conv2d-31          [-1, 128, 56, 56]          24,576
      BatchNorm2d-32          [-1, 128, 56, 56]             256
             ReLU-33          [-1, 128, 56, 56]               0
           Conv2d-34           [-1, 32, 56, 56]          36,864
      BatchNorm2d-35          [-1, 224, 56, 56]             448
             ReLU-36          [-1, 224, 56, 56]               0
           Conv2d-37          [-1, 128, 56, 56]          28,672
      BatchNorm2d-38          [-1, 128, 56, 56]             256
             ReLU-39          [-1, 128, 56, 56]               0
           Conv2d-40           [-1, 32, 56, 56]          36,864
      BatchNorm2d-41          [-1, 256, 56, 56]             512
             ReLU-42          [-1, 256, 56, 56]               0
           Conv2d-43          [-1, 128, 56, 56]          32,768
        AvgPool2d-44          [-1, 128, 28, 28]               0
      BatchNorm2d-45          [-1, 128, 28, 28]             256
             ReLU-46          [-1, 128, 28, 28]               0
           Conv2d-47          [-1, 128, 28, 28]          16,384
      BatchNorm2d-48          [-1, 128, 28, 28]             256
             ReLU-49          [-1, 128, 28, 28]               0
           Conv2d-50           [-1, 32, 28, 28]          36,864
      BatchNorm2d-51          [-1, 160, 28, 28]             320
             ReLU-52          [-1, 160, 28, 28]               0
           Conv2d-53          [-1, 128, 28, 28]          20,480
      BatchNorm2d-54          [-1, 128, 28, 28]             256
             ReLU-55          [-1, 128, 28, 28]               0
           Conv2d-56           [-1, 32, 28, 28]          36,864
      BatchNorm2d-57          [-1, 192, 28, 28]             384
             ReLU-58          [-1, 192, 28, 28]               0
           Conv2d-59          [-1, 128, 28, 28]          24,576
      BatchNorm2d-60          [-1, 128, 28, 28]             256
             ReLU-61          [-1, 128, 28, 28]               0
           Conv2d-62           [-1, 32, 28, 28]          36,864
      BatchNorm2d-63          [-1, 224, 28, 28]             448
             ReLU-64          [-1, 224, 28, 28]               0
           Conv2d-65          [-1, 128, 28, 28]          28,672
      BatchNorm2d-66          [-1, 128, 28, 28]             256
             ReLU-67          [-1, 128, 28, 28]               0
           Conv2d-68           [-1, 32, 28, 28]          36,864
      BatchNorm2d-69          [-1, 256, 28, 28]             512
             ReLU-70          [-1, 256, 28, 28]               0
           Conv2d-71          [-1, 128, 28, 28]          32,768
      BatchNorm2d-72          [-1, 128, 28, 28]             256
             ReLU-73          [-1, 128, 28, 28]               0
           Conv2d-74           [-1, 32, 28, 28]          36,864
      BatchNorm2d-75          [-1, 288, 28, 28]             576
             ReLU-76          [-1, 288, 28, 28]               0
           Conv2d-77          [-1, 128, 28, 28]          36,864
      BatchNorm2d-78          [-1, 128, 28, 28]             256
             ReLU-79          [-1, 128, 28, 28]               0
           Conv2d-80           [-1, 32, 28, 28]          36,864
      BatchNorm2d-81          [-1, 320, 28, 28]             640
             ReLU-82          [-1, 320, 28, 28]               0
           Conv2d-83          [-1, 128, 28, 28]          40,960
      BatchNorm2d-84          [-1, 128, 28, 28]             256
             ReLU-85          [-1, 128, 28, 28]               0
           Conv2d-86           [-1, 32, 28, 28]          36,864
      BatchNorm2d-87          [-1, 352, 28, 28]             704
             ReLU-88          [-1, 352, 28, 28]               0
           Conv2d-89          [-1, 128, 28, 28]          45,056
      BatchNorm2d-90          [-1, 128, 28, 28]             256
             ReLU-91          [-1, 128, 28, 28]               0
           Conv2d-92           [-1, 32, 28, 28]          36,864
      BatchNorm2d-93          [-1, 384, 28, 28]             768
             ReLU-94          [-1, 384, 28, 28]               0
           Conv2d-95          [-1, 128, 28, 28]          49,152
      BatchNorm2d-96          [-1, 128, 28, 28]             256
             ReLU-97          [-1, 128, 28, 28]               0
           Conv2d-98           [-1, 32, 28, 28]          36,864
      BatchNorm2d-99          [-1, 416, 28, 28]             832
            ReLU-100          [-1, 416, 28, 28]               0
          Conv2d-101          [-1, 128, 28, 28]          53,248
     BatchNorm2d-102          [-1, 128, 28, 28]             256
            ReLU-103          [-1, 128, 28, 28]               0
          Conv2d-104           [-1, 32, 28, 28]          36,864
     BatchNorm2d-105          [-1, 448, 28, 28]             896
            ReLU-106          [-1, 448, 28, 28]               0
          Conv2d-107          [-1, 128, 28, 28]          57,344
     BatchNorm2d-108          [-1, 128, 28, 28]             256
            ReLU-109          [-1, 128, 28, 28]               0
          Conv2d-110           [-1, 32, 28, 28]          36,864
     BatchNorm2d-111          [-1, 480, 28, 28]             960
            ReLU-112          [-1, 480, 28, 28]               0
          Conv2d-113          [-1, 128, 28, 28]          61,440
     BatchNorm2d-114          [-1, 128, 28, 28]             256
            ReLU-115          [-1, 128, 28, 28]               0
          Conv2d-116           [-1, 32, 28, 28]          36,864
     BatchNorm2d-117          [-1, 512, 28, 28]           1,024
            ReLU-118          [-1, 512, 28, 28]               0
          Conv2d-119          [-1, 256, 28, 28]         131,072
       AvgPool2d-120          [-1, 256, 14, 14]               0
     BatchNorm2d-121          [-1, 256, 14, 14]             512
            ReLU-122          [-1, 256, 14, 14]               0
          Conv2d-123          [-1, 128, 14, 14]          32,768
     BatchNorm2d-124          [-1, 128, 14, 14]             256
            ReLU-125          [-1, 128, 14, 14]               0
          Conv2d-126           [-1, 32, 14, 14]          36,864
     BatchNorm2d-127          [-1, 288, 14, 14]             576
            ReLU-128          [-1, 288, 14, 14]               0
          Conv2d-129          [-1, 128, 14, 14]          36,864
     BatchNorm2d-130          [-1, 128, 14, 14]             256
            ReLU-131          [-1, 128, 14, 14]               0
          Conv2d-132           [-1, 32, 14, 14]          36,864
     BatchNorm2d-133          [-1, 320, 14, 14]             640
            ReLU-134          [-1, 320, 14, 14]               0
          Conv2d-135          [-1, 128, 14, 14]          40,960
     BatchNorm2d-136          [-1, 128, 14, 14]             256
            ReLU-137          [-1, 128, 14, 14]               0
          Conv2d-138           [-1, 32, 14, 14]          36,864
     BatchNorm2d-139          [-1, 352, 14, 14]             704
            ReLU-140          [-1, 352, 14, 14]               0
          Conv2d-141          [-1, 128, 14, 14]          45,056
     BatchNorm2d-142          [-1, 128, 14, 14]             256
            ReLU-143          [-1, 128, 14, 14]               0
          Conv2d-144           [-1, 32, 14, 14]          36,864
     BatchNorm2d-145          [-1, 384, 14, 14]             768
            ReLU-146          [-1, 384, 14, 14]               0
          Conv2d-147          [-1, 128, 14, 14]          49,152
     BatchNorm2d-148          [-1, 128, 14, 14]             256
            ReLU-149          [-1, 128, 14, 14]               0
          Conv2d-150           [-1, 32, 14, 14]          36,864
     BatchNorm2d-151          [-1, 416, 14, 14]             832
            ReLU-152          [-1, 416, 14, 14]               0
          Conv2d-153          [-1, 128, 14, 14]          53,248
     BatchNorm2d-154          [-1, 128, 14, 14]             256
            ReLU-155          [-1, 128, 14, 14]               0
          Conv2d-156           [-1, 32, 14, 14]          36,864
     BatchNorm2d-157          [-1, 448, 14, 14]             896
            ReLU-158          [-1, 448, 14, 14]               0
          Conv2d-159          [-1, 128, 14, 14]          57,344
     BatchNorm2d-160          [-1, 128, 14, 14]             256
            ReLU-161          [-1, 128, 14, 14]               0
          Conv2d-162           [-1, 32, 14, 14]          36,864
     BatchNorm2d-163          [-1, 480, 14, 14]             960
            ReLU-164          [-1, 480, 14, 14]               0
          Conv2d-165          [-1, 128, 14, 14]          61,440
     BatchNorm2d-166          [-1, 128, 14, 14]             256
            ReLU-167          [-1, 128, 14, 14]               0
          Conv2d-168           [-1, 32, 14, 14]          36,864
     BatchNorm2d-169          [-1, 512, 14, 14]           1,024
            ReLU-170          [-1, 512, 14, 14]               0
          Conv2d-171          [-1, 128, 14, 14]          65,536
     BatchNorm2d-172          [-1, 128, 14, 14]             256
            ReLU-173          [-1, 128, 14, 14]               0
          Conv2d-174           [-1, 32, 14, 14]          36,864
     BatchNorm2d-175          [-1, 544, 14, 14]           1,088
            ReLU-176          [-1, 544, 14, 14]               0
          Conv2d-177          [-1, 128, 14, 14]          69,632
     BatchNorm2d-178          [-1, 128, 14, 14]             256
            ReLU-179          [-1, 128, 14, 14]               0
          Conv2d-180           [-1, 32, 14, 14]          36,864
     BatchNorm2d-181          [-1, 576, 14, 14]           1,152
            ReLU-182          [-1, 576, 14, 14]               0
          Conv2d-183          [-1, 128, 14, 14]          73,728
     BatchNorm2d-184          [-1, 128, 14, 14]             256
            ReLU-185          [-1, 128, 14, 14]               0
          Conv2d-186           [-1, 32, 14, 14]          36,864
     BatchNorm2d-187          [-1, 608, 14, 14]           1,216
            ReLU-188          [-1, 608, 14, 14]               0
          Conv2d-189          [-1, 128, 14, 14]          77,824
     BatchNorm2d-190          [-1, 128, 14, 14]             256
            ReLU-191          [-1, 128, 14, 14]               0
          Conv2d-192           [-1, 32, 14, 14]          36,864
     BatchNorm2d-193          [-1, 640, 14, 14]           1,280
            ReLU-194          [-1, 640, 14, 14]               0
          Conv2d-195          [-1, 128, 14, 14]          81,920
     BatchNorm2d-196          [-1, 128, 14, 14]             256
            ReLU-197          [-1, 128, 14, 14]               0
          Conv2d-198           [-1, 32, 14, 14]          36,864
     BatchNorm2d-199          [-1, 672, 14, 14]           1,344
            ReLU-200          [-1, 672, 14, 14]               0
          Conv2d-201          [-1, 128, 14, 14]          86,016
     BatchNorm2d-202          [-1, 128, 14, 14]             256
            ReLU-203          [-1, 128, 14, 14]               0
          Conv2d-204           [-1, 32, 14, 14]          36,864
     BatchNorm2d-205          [-1, 704, 14, 14]           1,408
            ReLU-206          [-1, 704, 14, 14]               0
          Conv2d-207          [-1, 128, 14, 14]          90,112
     BatchNorm2d-208          [-1, 128, 14, 14]             256
            ReLU-209          [-1, 128, 14, 14]               0
          Conv2d-210           [-1, 32, 14, 14]          36,864
     BatchNorm2d-211          [-1, 736, 14, 14]           1,472
            ReLU-212          [-1, 736, 14, 14]               0
          Conv2d-213          [-1, 128, 14, 14]          94,208
     BatchNorm2d-214          [-1, 128, 14, 14]             256
            ReLU-215          [-1, 128, 14, 14]               0
          Conv2d-216           [-1, 32, 14, 14]          36,864
     BatchNorm2d-217          [-1, 768, 14, 14]           1,536
            ReLU-218          [-1, 768, 14, 14]               0
          Conv2d-219          [-1, 128, 14, 14]          98,304
     BatchNorm2d-220          [-1, 128, 14, 14]             256
            ReLU-221          [-1, 128, 14, 14]               0
          Conv2d-222           [-1, 32, 14, 14]          36,864
     BatchNorm2d-223          [-1, 800, 14, 14]           1,600
            ReLU-224          [-1, 800, 14, 14]               0
          Conv2d-225          [-1, 128, 14, 14]         102,400
     BatchNorm2d-226          [-1, 128, 14, 14]             256
            ReLU-227          [-1, 128, 14, 14]               0
          Conv2d-228           [-1, 32, 14, 14]          36,864
     BatchNorm2d-229          [-1, 832, 14, 14]           1,664
            ReLU-230          [-1, 832, 14, 14]               0
          Conv2d-231          [-1, 128, 14, 14]         106,496
     BatchNorm2d-232          [-1, 128, 14, 14]             256
            ReLU-233          [-1, 128, 14, 14]               0
          Conv2d-234           [-1, 32, 14, 14]          36,864
     BatchNorm2d-235          [-1, 864, 14, 14]           1,728
            ReLU-236          [-1, 864, 14, 14]               0
          Conv2d-237          [-1, 128, 14, 14]         110,592
     BatchNorm2d-238          [-1, 128, 14, 14]             256
            ReLU-239          [-1, 128, 14, 14]               0
          Conv2d-240           [-1, 32, 14, 14]          36,864
     BatchNorm2d-241          [-1, 896, 14, 14]           1,792
            ReLU-242          [-1, 896, 14, 14]               0
          Conv2d-243          [-1, 128, 14, 14]         114,688
     BatchNorm2d-244          [-1, 128, 14, 14]             256
            ReLU-245          [-1, 128, 14, 14]               0
          Conv2d-246           [-1, 32, 14, 14]          36,864
     BatchNorm2d-247          [-1, 928, 14, 14]           1,856
            ReLU-248          [-1, 928, 14, 14]               0
          Conv2d-249          [-1, 128, 14, 14]         118,784
     BatchNorm2d-250          [-1, 128, 14, 14]             256
            ReLU-251          [-1, 128, 14, 14]               0
          Conv2d-252           [-1, 32, 14, 14]          36,864
     BatchNorm2d-253          [-1, 960, 14, 14]           1,920
            ReLU-254          [-1, 960, 14, 14]               0
          Conv2d-255          [-1, 128, 14, 14]         122,880
     BatchNorm2d-256          [-1, 128, 14, 14]             256
            ReLU-257          [-1, 128, 14, 14]               0
          Conv2d-258           [-1, 32, 14, 14]          36,864
     BatchNorm2d-259          [-1, 992, 14, 14]           1,984
            ReLU-260          [-1, 992, 14, 14]               0
          Conv2d-261          [-1, 128, 14, 14]         126,976
     BatchNorm2d-262          [-1, 128, 14, 14]             256
            ReLU-263          [-1, 128, 14, 14]               0
          Conv2d-264           [-1, 32, 14, 14]          36,864
     BatchNorm2d-265         [-1, 1024, 14, 14]           2,048
            ReLU-266         [-1, 1024, 14, 14]               0
          Conv2d-267          [-1, 512, 14, 14]         524,288
       AvgPool2d-268            [-1, 512, 7, 7]               0
     BatchNorm2d-269            [-1, 512, 7, 7]           1,024
            ReLU-270            [-1, 512, 7, 7]               0
          Conv2d-271            [-1, 128, 7, 7]          65,536
     BatchNorm2d-272            [-1, 128, 7, 7]             256
            ReLU-273            [-1, 128, 7, 7]               0
          Conv2d-274             [-1, 32, 7, 7]          36,864
     BatchNorm2d-275            [-1, 544, 7, 7]           1,088
            ReLU-276            [-1, 544, 7, 7]               0
          Conv2d-277            [-1, 128, 7, 7]          69,632
     BatchNorm2d-278            [-1, 128, 7, 7]             256
            ReLU-279            [-1, 128, 7, 7]               0
          Conv2d-280             [-1, 32, 7, 7]          36,864
     BatchNorm2d-281            [-1, 576, 7, 7]           1,152
            ReLU-282            [-1, 576, 7, 7]               0
          Conv2d-283            [-1, 128, 7, 7]          73,728
     BatchNorm2d-284            [-1, 128, 7, 7]             256
            ReLU-285            [-1, 128, 7, 7]               0
          Conv2d-286             [-1, 32, 7, 7]          36,864
     BatchNorm2d-287            [-1, 608, 7, 7]           1,216
            ReLU-288            [-1, 608, 7, 7]               0
          Conv2d-289            [-1, 128, 7, 7]          77,824
     BatchNorm2d-290            [-1, 128, 7, 7]             256
            ReLU-291            [-1, 128, 7, 7]               0
          Conv2d-292             [-1, 32, 7, 7]          36,864
     BatchNorm2d-293            [-1, 640, 7, 7]           1,280
            ReLU-294            [-1, 640, 7, 7]               0
          Conv2d-295            [-1, 128, 7, 7]          81,920
     BatchNorm2d-296            [-1, 128, 7, 7]             256
            ReLU-297            [-1, 128, 7, 7]               0
          Conv2d-298             [-1, 32, 7, 7]          36,864
     BatchNorm2d-299            [-1, 672, 7, 7]           1,344
            ReLU-300            [-1, 672, 7, 7]               0
          Conv2d-301            [-1, 128, 7, 7]          86,016
     BatchNorm2d-302            [-1, 128, 7, 7]             256
            ReLU-303            [-1, 128, 7, 7]               0
          Conv2d-304             [-1, 32, 7, 7]          36,864
     BatchNorm2d-305            [-1, 704, 7, 7]           1,408
            ReLU-306            [-1, 704, 7, 7]               0
          Conv2d-307            [-1, 128, 7, 7]          90,112
     BatchNorm2d-308            [-1, 128, 7, 7]             256
            ReLU-309            [-1, 128, 7, 7]               0
          Conv2d-310             [-1, 32, 7, 7]          36,864
     BatchNorm2d-311            [-1, 736, 7, 7]           1,472
            ReLU-312            [-1, 736, 7, 7]               0
          Conv2d-313            [-1, 128, 7, 7]          94,208
     BatchNorm2d-314            [-1, 128, 7, 7]             256
            ReLU-315            [-1, 128, 7, 7]               0
          Conv2d-316             [-1, 32, 7, 7]          36,864
     BatchNorm2d-317            [-1, 768, 7, 7]           1,536
            ReLU-318            [-1, 768, 7, 7]               0
          Conv2d-319            [-1, 128, 7, 7]          98,304
     BatchNorm2d-320            [-1, 128, 7, 7]             256
            ReLU-321            [-1, 128, 7, 7]               0
          Conv2d-322             [-1, 32, 7, 7]          36,864
     BatchNorm2d-323            [-1, 800, 7, 7]           1,600
            ReLU-324            [-1, 800, 7, 7]               0
          Conv2d-325            [-1, 128, 7, 7]         102,400
     BatchNorm2d-326            [-1, 128, 7, 7]             256
            ReLU-327            [-1, 128, 7, 7]               0
          Conv2d-328             [-1, 32, 7, 7]          36,864
     BatchNorm2d-329            [-1, 832, 7, 7]           1,664
            ReLU-330            [-1, 832, 7, 7]               0
          Conv2d-331            [-1, 128, 7, 7]         106,496
     BatchNorm2d-332            [-1, 128, 7, 7]             256
            ReLU-333            [-1, 128, 7, 7]               0
          Conv2d-334             [-1, 32, 7, 7]          36,864
     BatchNorm2d-335            [-1, 864, 7, 7]           1,728
            ReLU-336            [-1, 864, 7, 7]               0
          Conv2d-337            [-1, 128, 7, 7]         110,592
     BatchNorm2d-338            [-1, 128, 7, 7]             256
            ReLU-339            [-1, 128, 7, 7]               0
          Conv2d-340             [-1, 32, 7, 7]          36,864
     BatchNorm2d-341            [-1, 896, 7, 7]           1,792
            ReLU-342            [-1, 896, 7, 7]               0
          Conv2d-343            [-1, 128, 7, 7]         114,688
     BatchNorm2d-344            [-1, 128, 7, 7]             256
            ReLU-345            [-1, 128, 7, 7]               0
          Conv2d-346             [-1, 32, 7, 7]          36,864
     BatchNorm2d-347            [-1, 928, 7, 7]           1,856
            ReLU-348            [-1, 928, 7, 7]               0
          Conv2d-349            [-1, 128, 7, 7]         118,784
     BatchNorm2d-350            [-1, 128, 7, 7]             256
            ReLU-351            [-1, 128, 7, 7]               0
          Conv2d-352             [-1, 32, 7, 7]          36,864
     BatchNorm2d-353            [-1, 960, 7, 7]           1,920
            ReLU-354            [-1, 960, 7, 7]               0
          Conv2d-355            [-1, 128, 7, 7]         122,880
     BatchNorm2d-356            [-1, 128, 7, 7]             256
            ReLU-357            [-1, 128, 7, 7]               0
          Conv2d-358             [-1, 32, 7, 7]          36,864
     BatchNorm2d-359            [-1, 992, 7, 7]           1,984
            ReLU-360            [-1, 992, 7, 7]               0
          Conv2d-361            [-1, 128, 7, 7]         126,976
     BatchNorm2d-362            [-1, 128, 7, 7]             256
            ReLU-363            [-1, 128, 7, 7]               0
          Conv2d-364             [-1, 32, 7, 7]          36,864
     BatchNorm2d-365           [-1, 1024, 7, 7]           2,048
            ReLU-366           [-1, 1024, 7, 7]               0
          Linear-367                    [-1, 4]           4,100
================================================================
Total params: 6,957,956
Trainable params: 6,957,956
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 294.57
Params size (MB): 26.54
Estimated Total Size (MB): 321.69
----------------------------------------------------------------

5. 设置超参数:定义损失函数,学习率,以及根据学习率定义优化器等

python 复制代码
# loss_fn = nn.CrossEntropyLoss() # 创建损失函数

# learn_rate = 1e-3 # 初始学习率
# def adjust_learning_rate(optimizer,epoch,start_lr):
#     # 每两个epoch 衰减到原来的0.98
#     lr = start_lr * (0.92 ** (epoch//2))
#     for param_group in optimizer.param_groups:
#         param_group['lr'] = lr
        
# optimizer = torch.optim.Adam(model.parameters(),lr=learn_rate)
python 复制代码
# 调用官方接口示例
loss_fn = nn.CrossEntropyLoss()

learn_rate = 1e-4
lambda1 = lambda epoch:(0.92**(epoch//2))

optimizer = torch.optim.Adam(model.parameters(),lr = learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1) # 选定调整方法

6. 训练函数

python 复制代码
# 训练函数
def train(dataloader,model,loss_fn,optimizer):
    size = len(dataloader.dataset) # 训练集大小
    num_batches = len(dataloader) # 批次数目
    
    train_loss,train_acc = 0,0
    
    for X,y in dataloader:
        X,y = X.to(device),y.to(device)
        
        # 计算预测误差
        pred = model(X)
        loss = loss_fn(pred,y)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 记录acc与loss
        train_acc += (pred.argmax(1)==y).type(torch.float).sum().item()
        train_loss += loss.item()
        
    train_acc /= size
    train_loss /= num_batches
    
    return train_acc,train_loss

7. 测试函数

python 复制代码
# 测试函数
def test(dataloader,model,loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    
    test_acc,test_loss = 0,0
    
    with torch.no_grad():
        for X,y in dataloader:
            X,y = X.to(device),y.to(device)
            
            # 计算loss
            pred = model(X)
            loss = loss_fn(pred,y)
            
            test_acc += (pred.argmax(1)==y).type(torch.float).sum().item()
            test_loss += loss.item()
            
    test_acc /= size
    test_loss /= num_batches
    
    return test_acc,test_loss

8. 正式训练

python 复制代码
import copy

epochs = 40

train_acc = []
train_loss = []
test_acc = []
test_loss = []

best_acc = 0.0

for epoch in range(epochs):
    
    # 更新学习率------使用自定义学习率时使用
    # adjust_learning_rate(optimizer,epoch,learn_rate)
    
    model.train()
    epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,optimizer)
    scheduler.step() # 更新学习率------调用官方动态学习率时使用
    
    model.eval()
    epoch_test_acc,epoch_test_loss = test(test_dl,model,loss_fn)
    
    # 保存最佳模型到 best_model
    if epoch_test_acc > best_acc:
        best_acc = epoch_test_acc
        best_model = copy.deepcopy(model)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    # 获取当前学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    
    template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')
    print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))

print('Done')
Epoch: 1,Train_acc:57.5%,Train_loss:1.125,Test_acc:67.3%,Test_loss:0.976,Lr:1.00E-04
Epoch: 2,Train_acc:74.8%,Train_loss:0.812,Test_acc:78.8%,Test_loss:0.694,Lr:9.20E-05
Epoch: 3,Train_acc:81.9%,Train_loss:0.677,Test_acc:78.8%,Test_loss:0.528,Lr:9.20E-05
Epoch: 4,Train_acc:86.5%,Train_loss:0.463,Test_acc:84.1%,Test_loss:0.457,Lr:8.46E-05
Epoch: 5,Train_acc:86.5%,Train_loss:0.427,Test_acc:90.3%,Test_loss:0.311,Lr:8.46E-05
Epoch: 6,Train_acc:90.5%,Train_loss:0.350,Test_acc:89.4%,Test_loss:0.417,Lr:7.79E-05
Epoch: 7,Train_acc:92.7%,Train_loss:0.287,Test_acc:85.8%,Test_loss:0.303,Lr:7.79E-05
Epoch: 8,Train_acc:92.0%,Train_loss:0.339,Test_acc:90.3%,Test_loss:0.313,Lr:7.16E-05
Epoch: 9,Train_acc:93.1%,Train_loss:0.231,Test_acc:92.0%,Test_loss:0.243,Lr:7.16E-05
Epoch:10,Train_acc:94.9%,Train_loss:0.286,Test_acc:90.3%,Test_loss:0.285,Lr:6.59E-05
Epoch:11,Train_acc:95.4%,Train_loss:0.185,Test_acc:88.5%,Test_loss:0.334,Lr:6.59E-05
Epoch:12,Train_acc:95.6%,Train_loss:0.151,Test_acc:92.9%,Test_loss:0.226,Lr:6.06E-05
Epoch:13,Train_acc:96.9%,Train_loss:0.118,Test_acc:91.2%,Test_loss:0.263,Lr:6.06E-05
Epoch:14,Train_acc:98.5%,Train_loss:0.114,Test_acc:94.7%,Test_loss:0.177,Lr:5.58E-05
Epoch:15,Train_acc:98.0%,Train_loss:0.145,Test_acc:92.9%,Test_loss:0.244,Lr:5.58E-05
Epoch:16,Train_acc:96.9%,Train_loss:0.126,Test_acc:94.7%,Test_loss:0.200,Lr:5.13E-05
Epoch:17,Train_acc:97.8%,Train_loss:0.196,Test_acc:93.8%,Test_loss:0.166,Lr:5.13E-05
Epoch:18,Train_acc:97.6%,Train_loss:0.167,Test_acc:93.8%,Test_loss:0.236,Lr:4.72E-05
Epoch:19,Train_acc:98.0%,Train_loss:0.102,Test_acc:92.9%,Test_loss:0.258,Lr:4.72E-05
Epoch:20,Train_acc:98.7%,Train_loss:0.093,Test_acc:93.8%,Test_loss:0.207,Lr:4.34E-05
Epoch:21,Train_acc:98.9%,Train_loss:0.070,Test_acc:96.5%,Test_loss:0.180,Lr:4.34E-05
Epoch:22,Train_acc:99.1%,Train_loss:0.115,Test_acc:93.8%,Test_loss:0.213,Lr:4.00E-05
Epoch:23,Train_acc:99.3%,Train_loss:0.051,Test_acc:92.9%,Test_loss:0.241,Lr:4.00E-05
Epoch:24,Train_acc:99.1%,Train_loss:0.064,Test_acc:92.9%,Test_loss:0.201,Lr:3.68E-05
Epoch:25,Train_acc:99.3%,Train_loss:0.070,Test_acc:91.2%,Test_loss:0.313,Lr:3.68E-05
Epoch:26,Train_acc:98.5%,Train_loss:0.083,Test_acc:92.0%,Test_loss:0.288,Lr:3.38E-05
Epoch:27,Train_acc:99.1%,Train_loss:0.125,Test_acc:97.3%,Test_loss:0.146,Lr:3.38E-05
Epoch:28,Train_acc:99.3%,Train_loss:0.109,Test_acc:94.7%,Test_loss:0.259,Lr:3.11E-05
Epoch:29,Train_acc:98.7%,Train_loss:0.073,Test_acc:92.0%,Test_loss:0.268,Lr:3.11E-05
Epoch:30,Train_acc:99.6%,Train_loss:0.097,Test_acc:91.2%,Test_loss:0.254,Lr:2.86E-05
Epoch:31,Train_acc:98.9%,Train_loss:0.083,Test_acc:90.3%,Test_loss:0.255,Lr:2.86E-05
Epoch:32,Train_acc:99.1%,Train_loss:0.058,Test_acc:90.3%,Test_loss:0.326,Lr:2.63E-05
Epoch:33,Train_acc:99.6%,Train_loss:0.042,Test_acc:92.9%,Test_loss:0.155,Lr:2.63E-05
Epoch:34,Train_acc:100.0%,Train_loss:0.037,Test_acc:94.7%,Test_loss:0.157,Lr:2.42E-05
Epoch:35,Train_acc:99.1%,Train_loss:0.140,Test_acc:97.3%,Test_loss:0.144,Lr:2.42E-05
Epoch:36,Train_acc:99.3%,Train_loss:0.059,Test_acc:93.8%,Test_loss:0.174,Lr:2.23E-05
Epoch:37,Train_acc:99.8%,Train_loss:0.038,Test_acc:91.2%,Test_loss:0.215,Lr:2.23E-05
Epoch:38,Train_acc:99.8%,Train_loss:0.045,Test_acc:92.9%,Test_loss:0.163,Lr:2.05E-05
Epoch:39,Train_acc:99.1%,Train_loss:0.076,Test_acc:93.8%,Test_loss:0.185,Lr:2.05E-05
Epoch:40,Train_acc:98.9%,Train_loss:0.090,Test_acc:93.8%,Test_loss:0.176,Lr:1.89E-05
Done

9. 结果可视化

python 复制代码
epochs_range = range(epochs)

plt.figure(figsize = (12,3))

plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label = 'Training Accuracy')
plt.plot(epochs_range,test_acc,label = 'Test Accuracy')
plt.legend(loc = 'lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label = 'Test Accuracy')
plt.plot(epochs_range,test_loss,label = 'Test Loss')
plt.legend(loc = 'lower right')
plt.title('Training and validation Loss')
plt.show()

10. 模型的保存

python 复制代码
# 自定义模型保存
# 状态字典保存
torch.save(model.state_dict(),'./模型参数/J3_densenet121_model_state_dict.pth') # 仅保存状态字典

# 定义模型用来加载参数

best_model = DenseNet(num_init_features=64, # init_channel=64,
                       growth_rate=32,
                       block_config=(6,12,24,16),
                       num_classes=len(classNames)).to(device)

best_model.load_state_dict(torch.load('./模型参数/J3_densenet121_model_state_dict.pth')) # 加载状态字典到模型
<All keys matched successfully>

11. 使用训练好的模型进行预测

python 复制代码
# 指定路径图片预测
from PIL import Image
import torchvision.transforms as transforms

classes = list(total_data.class_to_idx) # classes = list(total_data.class_to_idx)

def predict_one_image(image_path,model,transform,classes):
    
    test_img = Image.open(image_path).convert('RGB')
    # plt.imshow(test_img) # 展示待预测的图片
    
    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)
    
    model.eval()
    output = model(img)
    print(output) # 观察模型预测结果的输出数据
    
    _,pred = torch.max(output,1)
    pred_class = classes[pred]
    print(f'预测结果是:{pred_class}')
python 复制代码
# 预测训练集中的某张照片
predict_one_image(image_path='./data/bird_photos/Bananaquit/007.jpg',
                 model = model,
                 transform = test_transforms,
                 classes = classes
                 )
tensor([[ 5.3812, -3.3856, -0.2259, -2.2854]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
预测结果是:Bananaquit
python 复制代码
classes
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
python 复制代码
相关推荐
ahadee18 分钟前
蓝桥杯每日真题 - 第18天
c语言·vscode·算法·蓝桥杯
思通数科多模态大模型23 分钟前
10大核心应用场景,解锁AI检测系统的智能安全之道
人工智能·深度学习·安全·目标检测·计算机视觉·自然语言处理·数据挖掘
数据岛27 分钟前
数据集论文:面向深度学习的土地利用场景分类与变化检测
人工智能·深度学习
地平线开发者1 小时前
CPU& 内存加压工具 stress-ng 介绍
算法·自动驾驶
XY.散人1 小时前
初识算法 · 分治(2)
算法
DanielYQ1 小时前
LCR 001 两数相除
开发语言·python·算法
学不会lostfound1 小时前
三、计算机视觉_05MTCNN人脸检测
pytorch·深度学习·计算机视觉·mtcnn·p-net·r-net·o-net
红色的山茶花1 小时前
YOLOv8-ultralytics-8.2.103部分代码阅读笔记-block.py
笔记·深度学习·yolo
龙的爹23331 小时前
论文翻译 | RECITATION-AUGMENTED LANGUAGE MODELS
人工智能·语言模型·自然语言处理·prompt·gpu算力
白光白光1 小时前
凸函数与深度学习调参
人工智能·深度学习