Pytorch-day09-模型微调-checkpoint

模型微调(fine-tune)-迁移学习

  • torchvision微调
  • timm微调
  • 半精度训练

起源:

  • 1、随着深度学习的发展,模型的参数越来越大,许多开源模型都是在较大数据集上进行训练的,比如Imagenet-1k,Imagenet-11k等
  • 2、如果数据集可能只有几千张,训练几千万参数的大模型,过拟合无法避免
  • 3、如果我们想从零开始训练一个大模型,那么我们的解决办法是收集更多的数据。然而,收集和标注数据会花费大量的时间和资⾦,成本无法承受

解决方案:

  • 应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上
  • 比如:ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成
  • 模型微调(finetune):就是先找到一个同类的别人训练好的模型,基于已经训练好的模型换成自己的数据,通过训练调整一下参数

不同数据集下使用微调:

  • 数据集1 - 数据量少,但数据相似度非常高 - 在这种情况下,我们所做的只是修改最后几层或最终的softmax图层的输出类别。

  • 数据集2 - 数据量少,数据相似度低 - 在这种情况下,我们可以冻结预训练模型的初始层(比如k层),并再次训练剩余的(n-k)层。由于新数据集的相似度较低,因此根据新数据集对较高层进行重新训练具有重要意义。

  • 数据集3 - 数据量大,数据相似度低 - 在这种情况下,由于我们有一个大的数据集,我们的神经网络训练将会很有效。但是,由于我们的数据与用于训练我们的预训练模型的数据相比有很大不同。使用预训练模型进行的预测不会有效。因此,最好根据你的数据从头开始训练神经网络(Training from scatch)

  • 数据集4 - 数据量大,数据相似度高 - 这是理想情况。在这种情况下,预训练模型应该是最有效的。使用模型的最好方法是保留模型的体系结构和模型的初始权重。然后,我们可以使用在预先训练的模型中的权重来重新训练该模型。

微调的是什么?

  • 换数据源
  • 针对K层进行重新训练
  • K层的权重&shape调整

1、模型微调(fine-tune)一般流程:

  • 1、在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型
  • 2、创建一个新的神经网络模型,即目标模型,它复制了源模型上除了输出层外的所有模型设计及其参数
  • 3、为目标模型添加一个输出⼤小为⽬标数据集类别个数的输出层,并随机初始化该层的模型参数
  • 4、在目标数据集上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-x43rJPAE-1692785269018)(attachment:image.png)]

2、torchvision微调

2.1 实例化Model

python 复制代码
import torchvision.models as models
resnet34 = models.resnet34(pretrained=True)

pretrained参数说明:

  • 1、通过True或者False来决定是否使用预训练好的权重,在默认状态下pretrained = False,意味着我们不使用预训练得到的权重
  • 2、当pretrained = True,意味着我们将使用在一些数据集上预训练得到的权重

注意:如果中途强行停止下载的话,一定要去对应路径下将权重文件删除干净,否则会报错。

2.2 训练特定层

如果我们正在提取特征并且只想为新初始化的层计算梯度,其他参数不进行改变。那我们就需要通过设置requires_grad = False来冻结部分层

python 复制代码
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

2.3 实例

  • 使用resnet34为例的将1000类改为10类,但是仅改变最后一层的模型参数
  • 我们先冻结模型参数的梯度,再对模型输出部分的全连接层进行修改
python 复制代码
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import StepLR
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torchvision.models as models
from torchinfo import summary
python 复制代码
#超参数定义
# 批次的大小
batch_size = 16 #可选32、64、128
# 优化器的学习率
lr = 1e-4
#运行epoch
max_epochs = 2
# 方案二:使用"device",后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 
python 复制代码
# 数据读取
#cifar10数据集为例给出构建Dataset类的方式
from torchvision import datasets

#"data_transform"可以对图像进行一定的变换,如翻转、裁剪、归一化等操作,可自己定义
data_transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
                   ])


train_cifar_dataset = datasets.CIFAR10('cifar10',train=True, download=False,transform=data_transform)
test_cifar_dataset = datasets.CIFAR10('cifar10',train=False, download=False,transform=data_transform)

#构建好Dataset后,就可以使用DataLoader来按批次读入数据了
train_loader = torch.utils.data.DataLoader(train_cifar_dataset, 
                                           batch_size=batch_size, num_workers=4, 
                                           shuffle=True, drop_last=True)

test_loader = torch.utils.data.DataLoader(test_cifar_dataset, 
                                         batch_size=batch_size, num_workers=4, 
                                         shuffle=False)
python 复制代码
# 下载预训练模型 restnet50
resnet34 = models.resnet34(pretrained=True)
print(resnet34)
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to C:\Users\xulele/.cache\torch\hub\checkpoints\resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:10<00:00, 8.57MB/s]

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (4): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
python 复制代码
#查看模型结构
summary(resnet34, (1, 3, 224, 224)) 
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│    └─BasicBlock: 2-2                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-7                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-8             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-9                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-10                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-11            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-12                   [1, 64, 56, 56]           --
│    └─BasicBlock: 2-3                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-13                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-14            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-15                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-16                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-17            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-18                   [1, 64, 56, 56]           --
├─Sequential: 1-6                        [1, 128, 28, 28]          --
│    └─BasicBlock: 2-4                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-19                 [1, 128, 28, 28]          73,728
│    │    └─BatchNorm2d: 3-20            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-21                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-22                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-23            [1, 128, 28, 28]          256
│    │    └─Sequential: 3-24             [1, 128, 28, 28]          8,448
│    │    └─ReLU: 3-25                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-5                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-26                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-27            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-28                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-29                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-30            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-31                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-6                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-32                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-33            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-34                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-35                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-36            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-37                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-7                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-38                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-39            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-40                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-41                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-42            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-43                   [1, 128, 28, 28]          --
├─Sequential: 1-7                        [1, 256, 14, 14]          --
│    └─BasicBlock: 2-8                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-44                 [1, 256, 14, 14]          294,912
│    │    └─BatchNorm2d: 3-45            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-46                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-47                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-48            [1, 256, 14, 14]          512
│    │    └─Sequential: 3-49             [1, 256, 14, 14]          33,280
│    │    └─ReLU: 3-50                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-9                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-51                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-52            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-53                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-54                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-55            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-56                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-10                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-57                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-58            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-59                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-60                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-61            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-62                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-11                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-63                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-64            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-65                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-66                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-67            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-68                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-12                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-69                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-70            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-71                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-72                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-73            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-74                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-13                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-75                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-76            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-77                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-78                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-79            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-80                   [1, 256, 14, 14]          --
├─Sequential: 1-8                        [1, 512, 7, 7]            --
│    └─BasicBlock: 2-14                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-81                 [1, 512, 7, 7]            1,179,648
│    │    └─BatchNorm2d: 3-82            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-83                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-84                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-85            [1, 512, 7, 7]            1,024
│    │    └─Sequential: 3-86             [1, 512, 7, 7]            132,096
│    │    └─ReLU: 3-87                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-15                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-88                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-89            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-90                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-91                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-92            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-93                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-16                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-94                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-95            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-96                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-97                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-98            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-99                   [1, 512, 7, 7]            --
├─AdaptiveAvgPool2d: 1-9                 [1, 512, 1, 1]            --
├─Linear: 1-10                           [1, 1000]                 513,000
==========================================================================================
Total params: 21,797,672
Trainable params: 21,797,672
Non-trainable params: 0
Total mult-adds (G): 3.66
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 59.82
Params size (MB): 87.19
Estimated Total Size (MB): 147.61
==========================================================================================
python 复制代码
#检测 模型准确率
def cal_predict_correct(model):
    test_total_correct = 0
    for iter,(images,labels) in enumerate(test_loader):
        images = images.to(device)
        labels = labels.to(device)
    
        outputs = model(images)
        test_total_correct += (outputs.argmax(1) == labels).sum().item()
#     print("test_total_correct: "+ str(test_total_correct))
    return test_total_correct
python 复制代码
total_correct = cal_predict_correct(resnet34)
print("test_total_correct: "+ str(test_total_correct / 10000))
test_total_correct: 0.1
python 复制代码
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
            

# 冻结参数的梯度
feature_extract = True
new_model = resnet34
set_parameter_requires_grad(new_model, feature_extract)

# 修改模型
#训练过程中,model仍会进行梯度回传,但是参数更新则只会发生在fc层
num_ftrs = new_model.fc.in_features
new_model.fc = nn.Linear(in_features=num_ftrs, out_features=10, bias=True)
python 复制代码
summary(new_model, (1, 3, 224, 224)) 
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 10]                   --
├─Conv2d: 1-1                            [1, 64, 112, 112]         (9,408)
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         (128)
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│    └─BasicBlock: 2-2                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-7                  [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-8             [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-9                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-10                 [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-11            [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-12                   [1, 64, 56, 56]           --
│    └─BasicBlock: 2-3                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-13                 [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-14            [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-15                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-16                 [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-17            [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-18                   [1, 64, 56, 56]           --
├─Sequential: 1-6                        [1, 128, 28, 28]          --
│    └─BasicBlock: 2-4                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-19                 [1, 128, 28, 28]          (73,728)
│    │    └─BatchNorm2d: 3-20            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-21                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-22                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-23            [1, 128, 28, 28]          (256)
│    │    └─Sequential: 3-24             [1, 128, 28, 28]          (8,448)
│    │    └─ReLU: 3-25                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-5                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-26                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-27            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-28                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-29                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-30            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-31                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-6                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-32                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-33            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-34                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-35                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-36            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-37                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-7                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-38                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-39            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-40                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-41                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-42            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-43                   [1, 128, 28, 28]          --
├─Sequential: 1-7                        [1, 256, 14, 14]          --
│    └─BasicBlock: 2-8                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-44                 [1, 256, 14, 14]          (294,912)
│    │    └─BatchNorm2d: 3-45            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-46                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-47                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-48            [1, 256, 14, 14]          (512)
│    │    └─Sequential: 3-49             [1, 256, 14, 14]          (33,280)
│    │    └─ReLU: 3-50                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-9                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-51                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-52            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-53                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-54                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-55            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-56                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-10                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-57                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-58            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-59                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-60                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-61            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-62                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-11                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-63                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-64            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-65                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-66                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-67            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-68                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-12                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-69                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-70            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-71                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-72                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-73            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-74                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-13                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-75                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-76            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-77                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-78                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-79            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-80                   [1, 256, 14, 14]          --
├─Sequential: 1-8                        [1, 512, 7, 7]            --
│    └─BasicBlock: 2-14                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-81                 [1, 512, 7, 7]            (1,179,648)
│    │    └─BatchNorm2d: 3-82            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-83                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-84                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-85            [1, 512, 7, 7]            (1,024)
│    │    └─Sequential: 3-86             [1, 512, 7, 7]            (132,096)
│    │    └─ReLU: 3-87                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-15                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-88                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-89            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-90                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-91                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-92            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-93                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-16                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-94                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-95            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-96                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-97                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-98            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-99                   [1, 512, 7, 7]            --
├─AdaptiveAvgPool2d: 1-9                 [1, 512, 1, 1]            --
├─Linear: 1-10                           [1, 10]                   5,130
==========================================================================================
Total params: 21,289,802
Trainable params: 5,130
Non-trainable params: 21,284,672
Total mult-adds (G): 3.66
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 59.81
Params size (MB): 85.16
Estimated Total Size (MB): 145.57
==========================================================================================
python 复制代码
#训练&验证
Resnet34_new = new_model.to(device)
# 定义损失函数和优化器
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 损失函数:自定义损失函数
criterion = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.Adam(Resnet50_new.parameters(), lr=lr)
epoch = max_epochs

total_step = len(train_loader)
train_all_loss = []
test_all_loss = []

for i in range(epoch):
    Resnet34_new.train()
    train_total_loss = 0
    train_total_num = 0
    train_total_correct = 0

    for iter, (images,labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = Resnet34_new(images)
        loss = criterion(outputs,labels)
        train_total_correct += (outputs.argmax(1) == labels).sum().item()
        
        #backword
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_total_num += labels.shape[0]
        train_total_loss += loss.item()
        print("Epoch [{}/{}], Iter [{}/{}], train_loss:{:4f}".format(i+1,epoch,iter+1,total_step,loss.item()/labels.shape[0]))
    
    Resnet34_new.eval()
    test_total_loss = 0
    test_total_correct = 0
    test_total_num = 0
    for iter,(images,labels) in enumerate(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = Resnet34_new(images)
        loss = criterion(outputs,labels)
        test_total_correct += (outputs.argmax(1) == labels).sum().item()
        test_total_loss += loss.item()
        test_total_num += labels.shape[0]
    print("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(
        i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100
    
    ))
    train_all_loss.append(np.round(train_total_loss / train_total_num,4))
    test_all_loss.append(np.round(test_total_loss / test_total_num,4))
Epoch [1/2], Iter [1481/3125], train_loss:0.17220

半精度训练

问题:

GPU的性能主要分为两部分:算力和显存。

前者决定了显卡计算的速度,后者则决定了显卡可以同时放入多少数据用于计算

在可以使用的显存数量一定的情况下,每次训练能够加载的数据更多(也就是batch size更大),则也可以提高训练效率

定义:

PyTorch默认的浮点数存储方式用的是torch.float32,小数点后位数更多固然能保证数据的精确性

但绝大多数场景其实并不需要这么精确,只保留一半的信息也不会影响结果,也就是使用torch.float16格式。由于数位减了一半,因此被称为"半精度"

image.png

显然半精度能够减少显存占用,使得显卡可以同时加载更多数据进行计算

3.1、半精度训练的设置

1、引入 from torch.cuda.amp import autocast

2、forward函数指定 autocast 装饰器

3、训练过程: 只需在将数据输入模型及其之后的部分放入"with autocast():"

4、半精度训练主要适用于数据本身的size比较大(比如说3D图像、视频等)

引入

from torch.cuda.amp import autocast

python 复制代码
# forward指定装饰器
@autocast()   
def forward(self, x):
    ...
    return x

python 复制代码
# 指定with autocast 
 for x in train_loader:
    x = x.cuda()
    with autocast():
            output = model(x)
        ...

半精度训练案例

from torch.cuda.amp import autocast

半精度模型

class DemoModel(nn.Module):

def init (self):

super(DemoModel, self).init ()

self.conv1 = nn.Conv2d(3, 6, 5)

self.pool = nn.MaxPool2d(2, 2)

self.conv2 = nn.Conv2d(6, 16, 5)

self.fc1 = nn.Linear(16 * 5 * 5, 120)

self.fc2 = nn.Linear(120, 84)

self.fc3 = nn.Linear(84, 10)

@autocast() 
def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

#训练&验证

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

half_model = DemoModel().to(device)

损失函数:自定义损失函数

criterion = nn.CrossEntropyLoss()

优化器

optimizer = torch.optim.Adam(Resnet50_new.parameters(), lr=lr)

epoch = max_epochs

total_step = len(train_loader)

train_all_loss = []

test_all_loss = []

for i in range(epoch):

half_model.train()

train_total_loss = 0

train_total_num = 0

train_total_correct = 0

for iter, (images,labels) in enumerate(train_loader):

images = images.to(device)

labels = labels.to(device)

with autocast():

outputs = half_model(images)

loss = criterion(outputs,labels)

train_total_correct += (outputs.argmax(1) == labels).sum().item()

#backword

optimizer.zero_grad()

loss.backward()

optimizer.step()

        train_total_num += labels.shape[0]
        train_total_loss += loss.item()
        print("Epoch [{}/{}], Iter [{}/{}], train_loss:{:4f}".format(i+1,epoch,iter+1,total_step,loss.item()/labels.shape[0]))

half_model.eval()
test_total_loss = 0
test_total_correct = 0
test_total_num = 0
for iter,(images,labels) in enumerate(test_loader):
    images = images.to(device)
    labels = labels.to(device)
    with autocast():
        outputs = half_model(images)
        loss = criterion(outputs,labels)
        test_total_correct += (outputs.argmax(1) == labels).sum().item()
        test_total_loss += loss.item()
        test_total_num += labels.shape[0]
        print("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(
            i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100

))
train_all_loss.append(np.round(train_total_loss / train_total_num,4))
test_all_loss.append(np.round(test_total_loss / test_total_num,4))

相关推荐
豆本-豆豆奶11 分钟前
最全面的Python重点知识汇总,建议收藏!
开发语言·数据库·python·oracle
Bosenya1214 分钟前
【信号处理】绘制IQ信号时域图、星座图、功率谱
开发语言·python·信号处理
再不会python就不礼貌了18 分钟前
Ollama 0.4 发布!支持 Llama 3.2 Vision,实现多模态 RAG
人工智能·学习·机器学习·ai·开源·产品经理·llama
大大大反派21 分钟前
ONLYOFFICE 8.2深度测评:集成PDF编辑、数据可视化与AI功能的强大办公套件
人工智能·信息可视化·pdf
AI原吾32 分钟前
探索PyAV:Python中的多媒体处理利器
开发语言·python·ai·pyav
DK2215136 分钟前
机器学习系列-----主成分分析(PCA)
人工智能·算法·机器学习
oliveira-time42 分钟前
爬虫学习8
开发语言·javascript·爬虫·python·算法
正义的彬彬侠1 小时前
XGBoost算法Python代码实现
python·决策树·机器学习·numpy·集成学习·boosting·xgboost
昨天今天明天好多天1 小时前
【Python】解析 XML
xml·python
侯喵喵2 小时前
从零开始 blender插件开发
python·blender