深度学习每周学习总结J8(Inception V1 算法实战与解析 - 猴痘识别)

目录

      • [0. 总结](#0. 总结)
      • [Inception V1 简介](#Inception V1 简介)
      • [1. 设置GPU](#1. 设置GPU)
      • [2. 导入数据及处理部分](#2. 导入数据及处理部分)
      • [3. 划分数据集](#3. 划分数据集)
      • [4. 模型构建部分](#4. 模型构建部分)
      • [5. 设置超参数:定义损失函数,学习率,以及根据学习率定义优化器等](#5. 设置超参数:定义损失函数,学习率,以及根据学习率定义优化器等)
      • [6. 训练函数](#6. 训练函数)
      • [7. 测试函数](#7. 测试函数)
      • [8. 正式训练](#8. 正式训练)
      • [9. 结果可视化](#9. 结果可视化)
      • [10. 模型的保存](#10. 模型的保存)
      • 11.使用训练好的模型进行预测

0. 总结

数据导入及处理部分:本次数据导入没有使用torchvision自带的数据集,需要将原始数据进行处理包括数据导入,查看数据分类情况,定义transforms,进行数据类型转换等操作。

划分数据集:划定训练集测试集后,再使用torch.utils.data中的DataLoader()分别加载上一步处理好的训练及测试数据,查看批处理维度.

模型构建部分:Inception V1

设置超参数:在这之前需要定义损失函数,学习率(动态学习率),以及根据学习率定义优化器(例如SGD随机梯度下降),用来在训练中更新参数,最小化损失函数。

定义训练函数:函数的传入的参数有四个,分别是设置好的DataLoader(),定义好的模型,损失函数,优化器。函数内部初始化损失准确率为0,接着开始循环,使用DataLoader()获取一个批次的数据,对这个批次的数据带入模型得到预测值,然后使用损失函数计算得到损失值。接下来就是进行反向传播以及使用优化器优化参数,梯度清零放在反向传播之前或者是使用优化器优化之后都是可以的,一般是默认放在反向传播之前。

定义测试函数:函数传入的参数相比训练函数少了优化器,只需传入设置好的DataLoader(),定义好的模型,损失函数。此外除了处理批次数据时无需再设置梯度清零、返向传播以及优化器优化参数,其余部分均和训练函数保持一致。

训练过程:定义训练次数,有几次就使用整个数据集进行几次训练,初始化四个空list分别存储每次训练及测试的准确率及损失。使用model.train()开启训练模式,调用训练函数得到准确率及损失。使用model.eval()将模型设置为评估模式,调用测试函数得到准确率及损失。接着就是将得到的训练及测试的准确率及损失存储到相应list中并合并打印出来,得到每一次整体训练后的准确率及损失。

结果可视化

模型的保存,调取及使用。在PyTorch中,通常使用 torch.save(model.state_dict(), 'model.pth') 保存模型的参数,使用 model.load_state_dict(torch.load('model.pth')) 加载参数。

需要改进优化的地方:确保模型和数据的一致性,都存到GPU或者CPU;注意numclasses不要直接用默认的1000,需要根据实际数据集改进;实例化模型也要注意numclasses这个参数;此外注意测试模型需要用(3,224,224)3表示通道数,这和tensorflow定义的顺序是不用的(224,224,3),做代码转换时需要注意。

关于调优(十分重要) :本次将测试集准确率提升到了96.03%(随机种子设置为42)

1:使用多卡不一定比单卡效果好,需要继续调优

2:本次微调参数主要调整了两点一是初始学习率从1e-4 增大为了3e-4;其次是原来图片预处理只加入了随机水平翻转,本次加入了小角度的随机翻转,随机缩放剪裁,光照变化等,发现有更好的效果。测试集准确率有了很大的提升。从训练后的准确率图像也可以看到,训练准确率和测试准确率很接近甚至能够超过。之前没有做这个改进之前,都是训练准确率远大于测试准确率。

关键代码示例:

python 复制代码
import torchvision.transforms as transforms
 
# 定义猴痘识别的 transforms
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),            # 统一图片尺寸
    transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
    transforms.RandomRotation(degrees=15),   # 小角度随机旋转
    transforms.RandomResizedCrop(size=224, scale=(0.8, 1.2)),  # 随机缩放裁剪
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),  # 光照变化
    transforms.ToTensor(),                   # 转换为 Tensor 格式
    transforms.Normalize(                    # 标准化
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

Inception V1 简介

什么是Inception V1?

Inception V1,也被称为GoogLeNet,是Google在2014年ILSVRC比赛中提出的一种卷积神经网络(CNN)架构,并且在比赛中获得了冠军。与当时流行的VGGNet相比,Inception V1在保持相似性能的同时,显著减少了参数数量,从而提高了计算效率。

Inception Module的核心思想

Inception V1的核心是Inception Module,它通过并行的卷积操作在同一层提取不同尺度的特征。这种设计不仅增加了网络的深度,还有效地捕捉了多种特征信息。

具体来说,一个Inception Module通常包含以下几个分支:

  1. 1x1卷积分支:用于降低输入特征图的通道数,减少计算量。
  2. 1x1卷积后接3x3卷积分支:先用1x1卷积降维,再进行3x3卷积提取特征。
  3. 1x1卷积后接5x5卷积分支:类似于3x3分支,但使用更大的卷积核以捕捉更大范围的特征。
  4. 3x3最大池化后接1x1卷积分支:先进行池化操作,再用1x1卷积进行特征整合。

通过将这些分支的输出在通道维度上拼接,Inception Module能够在同一层次上整合多种尺度的信息,提升模型的表达能力。

1x1卷积的作用

1x1卷积主要用于降维,即减少特征图的通道数。这不仅降低了网络的参数量和计算量,还间接增加了网络的深度,有助于提升模型性能。例如:

  • 原始输入:100x100x128
  • 经过1x1卷积降维到32通道,再进行5x5卷积,输出仍为100x100x256
  • 参数量由原来的约8.192×10⁹降低到2.048×10⁹

辅助分类器

Inception V1还引入了辅助分类器,主要有两个作用:

  1. 缓解梯度消失:通过在中间层添加分类器,帮助梯度更好地传播。
  2. 模型融合:将中间层的输出用于分类,增强模型的泛化能力。

不过,在实际应用中,这些辅助分类器通常在训练过程中使用,推理时会被去掉。

python 复制代码
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import torchvision.models as models
import torch.nn.functional as F
from collections import OrderedDict 


import os,PIL,pathlib
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings('ignore') # 忽略警告信息

plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False   # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 # 分辨率

1. 设置GPU

python 复制代码
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
复制代码
device(type='cuda')

2. 导入数据及处理部分

python 复制代码
# 获取数据分布情况
path_dir = './data/mpox_recognize/'
path_dir = pathlib.Path(path_dir)

paths = list(path_dir.glob('*'))
# classNames = [str(path).split("\\")[-1] for path in paths] # ['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
classNames = [path.parts[-1] for path in paths]
classNames
复制代码
['Monkeypox', 'Others']
python 复制代码
# 定义transforms 并处理数据
# train_transforms = transforms.Compose([
#     transforms.Resize([224,224]),      # 将输入图片resize成统一尺寸
#     transforms.RandomHorizontalFlip(), # 随机水平翻转
#     transforms.ToTensor(),             # 将PIL Image 或 numpy.ndarray 装换为tensor,并归一化到[0,1]之间
#     transforms.Normalize(              # 标准化处理 --> 转换为标准正太分布(高斯分布),使模型更容易收敛
#         mean = [0.485,0.456,0.406],    # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
#         std = [0.229,0.224,0.225]
#     )
# ])

# 定义猴痘识别的 transforms 并处理数据
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),            # 统一图片尺寸
    transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
    transforms.RandomRotation(degrees=15),   # 小角度随机旋转
    transforms.RandomResizedCrop(size=224, scale=(0.8, 1.2)),  # 随机缩放裁剪
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),  # 光照变化
    transforms.ToTensor(),                   # 转换为 Tensor 格式
    transforms.Normalize(                    # 标准化
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

test_transforms = transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ToTensor(),
    transforms.Normalize(
        mean = [0.485,0.456,0.406],
        std = [0.229,0.224,0.225]
    )
])
total_data = datasets.ImageFolder('./data/mpox_recognize/',transform = train_transforms)
total_data
复制代码
Dataset ImageFolder
    Number of datapoints: 2142
    Root location: ./data/mpox_recognize/
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               RandomRotation(degrees=[-15.0, 15.0], interpolation=nearest, expand=False, fill=0)
               RandomResizedCrop(size=(224, 224), scale=(0.8, 1.2), ratio=(0.75, 1.3333), interpolation=bilinear, antialias=True)
               ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.9, 1.1), hue=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

3. 划分数据集

python 复制代码
# 设置随机种子
torch.manual_seed(42)

# 划分数据集
train_size = int(len(total_data) * 0.8)
test_size = len(total_data) - train_size

train_dataset,test_dataset = torch.utils.data.random_split(total_data,[train_size,test_size])
train_dataset,test_dataset
复制代码
(<torch.utils.data.dataset.Subset at 0x1b854727580>,
 <torch.utils.data.dataset.Subset at 0x1b854727c40>)
python 复制代码
# 定义DataLoader用于数据集的加载

batch_size = 32 # 如使用多显卡,请确保 batch_size 是显卡数量的倍数。

train_dl = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = batch_size,
    shuffle = True,
    num_workers = 1
)
test_dl = torch.utils.data.DataLoader(
    test_dataset,
    batch_size = batch_size,
    shuffle = True,
    num_workers = 1
)
python 复制代码
# 观察数据维度
for X,y in test_dl:
    print("Shape of X [N,C,H,W]: ",X.shape)
    print("Shape of y: ", y.shape,y.dtype)
    break
复制代码
Shape of X [N,C,H,W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

4. 模型构建部分

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class inception_block(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(inception_block, self).__init__()

        # 1x1 conv branch
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, ch1x1, kernel_size=1),
            nn.BatchNorm2d(ch1x1),
            nn.ReLU(inplace=True)
        )

        # 1x1 conv -> 3x3 conv branch
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, ch3x3red, kernel_size=1),
            nn.BatchNorm2d(ch3x3red),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch3x3red, ch3x3, kernel_size=3, padding=1),
            nn.BatchNorm2d(ch3x3),
            nn.ReLU(inplace=True)
        )

        # 1x1 conv -> 5x5 conv branch
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, ch5x5red, kernel_size=1),
            nn.BatchNorm2d(ch5x5red),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch5x5red, ch5x5, kernel_size=5, padding=2),
            nn.BatchNorm2d(ch5x5),
            nn.ReLU(inplace=True)
        )

        # 3x3 max pooling -> 1x1 conv branch
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, pool_proj, kernel_size=1),
            nn.BatchNorm2d(pool_proj),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Compute forward pass through all branches and concatenate the output feature maps
        branch1_output = self.branch1(x)
        branch2_output = self.branch2(x)
        branch3_output = self.branch3(x)
        branch4_output = self.branch4(x)

        outputs = [branch1_output, branch2_output, branch3_output, branch4_output]
        return torch.cat(outputs, 1)
    
class InceptionV1(nn.Module):
    def __init__(self, num_classes=1000):
        super(InceptionV1, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0)
        self.conv3 = nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3    = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4    = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
        

        self.inception5b=nn.Sequential(
            inception_block(832, 384, 192, 384, 48, 128, 128),
            nn.AvgPool2d(kernel_size=7,stride=1,padding=0),
            nn.Dropout(0.4)
        )
        
        # 全连接层前的池化层: 在Inception V1中,最后一个Inception模块后通常会有一个全局平均池化层,
        # 以减少特征维度。你可以在inception5b后添加:
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # 全连接网络层,用于分类
        self.classifier = nn.Sequential(
            nn.Linear(in_features=1024, out_features=1024),
            nn.ReLU(),
            nn.Linear(in_features=1024, out_features=num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)
        
        x = self.inception4a(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        x = self.inception4e(x)
        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)
        
        x = self.avgpool(x) # 全连接层前的池化层
        
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        
        return x
python 复制代码
model = InceptionV1(num_classes=len(classNames)).to(device)
model
复制代码
InceptionV1(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv3): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (inception3a): inception_block(
    (branch1): Sequential(
      (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (inception3b): inception_block(
    (branch1): Sequential(
      (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (maxpool3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (inception4a): inception_block(
    (branch1): Sequential(
      (0): Conv2d(480, 192, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(480, 96, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(96, 208, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(208, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(480, 16, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(16, 48, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(480, 64, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (inception4b): inception_block(
    (branch1): Sequential(
      (0): Conv2d(512, 160, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(112, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (inception4c): inception_block(
    (branch1): Sequential(
      (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (inception4d): inception_block(
    (branch1): Sequential(
      (0): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(512, 144, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(144, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (inception4e): inception_block(
    (branch1): Sequential(
      (0): Conv2d(528, 256, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(528, 160, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(528, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(528, 128, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (maxpool4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (inception5a): inception_block(
    (branch1): Sequential(
      (0): Conv2d(832, 256, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(832, 160, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(832, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (inception5b): Sequential(
    (0): inception_block(
      (branch1): Sequential(
        (0): Conv2d(832, 384, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (branch2): Sequential(
        (0): Conv2d(832, 192, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
      (branch3): Sequential(
        (0): Conv2d(832, 48, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(48, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
      (branch4): Sequential(
        (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
        (1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1))
        (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
      )
    )
    (1): AvgPool2d(kernel_size=7, stride=1, padding=0)
    (2): Dropout(p=0.4, inplace=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=2, bias=True)
    (3): Softmax(dim=1)
  )
)
python 复制代码
# 查看模型详情
import torchsummary as summary
summary.summary(model,(3,224,224))
复制代码
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,472
         MaxPool2d-2           [-1, 64, 56, 56]               0
            Conv2d-3           [-1, 64, 56, 56]           4,160
            Conv2d-4          [-1, 192, 56, 56]         110,784
         MaxPool2d-5          [-1, 192, 28, 28]               0
            Conv2d-6           [-1, 64, 28, 28]          12,352
       BatchNorm2d-7           [-1, 64, 28, 28]             128
              ReLU-8           [-1, 64, 28, 28]               0
            Conv2d-9           [-1, 96, 28, 28]          18,528
      BatchNorm2d-10           [-1, 96, 28, 28]             192
             ReLU-11           [-1, 96, 28, 28]               0
           Conv2d-12          [-1, 128, 28, 28]         110,720
      BatchNorm2d-13          [-1, 128, 28, 28]             256
             ReLU-14          [-1, 128, 28, 28]               0
           Conv2d-15           [-1, 16, 28, 28]           3,088
      BatchNorm2d-16           [-1, 16, 28, 28]              32
             ReLU-17           [-1, 16, 28, 28]               0
           Conv2d-18           [-1, 32, 28, 28]          12,832
      BatchNorm2d-19           [-1, 32, 28, 28]              64
             ReLU-20           [-1, 32, 28, 28]               0
        MaxPool2d-21          [-1, 192, 28, 28]               0
           Conv2d-22           [-1, 32, 28, 28]           6,176
      BatchNorm2d-23           [-1, 32, 28, 28]              64
             ReLU-24           [-1, 32, 28, 28]               0
  inception_block-25          [-1, 256, 28, 28]               0
           Conv2d-26          [-1, 128, 28, 28]          32,896
      BatchNorm2d-27          [-1, 128, 28, 28]             256
             ReLU-28          [-1, 128, 28, 28]               0
           Conv2d-29          [-1, 128, 28, 28]          32,896
      BatchNorm2d-30          [-1, 128, 28, 28]             256
             ReLU-31          [-1, 128, 28, 28]               0
           Conv2d-32          [-1, 192, 28, 28]         221,376
      BatchNorm2d-33          [-1, 192, 28, 28]             384
             ReLU-34          [-1, 192, 28, 28]               0
           Conv2d-35           [-1, 32, 28, 28]           8,224
      BatchNorm2d-36           [-1, 32, 28, 28]              64
             ReLU-37           [-1, 32, 28, 28]               0
           Conv2d-38           [-1, 96, 28, 28]          76,896
      BatchNorm2d-39           [-1, 96, 28, 28]             192
             ReLU-40           [-1, 96, 28, 28]               0
        MaxPool2d-41          [-1, 256, 28, 28]               0
           Conv2d-42           [-1, 64, 28, 28]          16,448
      BatchNorm2d-43           [-1, 64, 28, 28]             128
             ReLU-44           [-1, 64, 28, 28]               0
  inception_block-45          [-1, 480, 28, 28]               0
        MaxPool2d-46          [-1, 480, 14, 14]               0
           Conv2d-47          [-1, 192, 14, 14]          92,352
      BatchNorm2d-48          [-1, 192, 14, 14]             384
             ReLU-49          [-1, 192, 14, 14]               0
           Conv2d-50           [-1, 96, 14, 14]          46,176
      BatchNorm2d-51           [-1, 96, 14, 14]             192
             ReLU-52           [-1, 96, 14, 14]               0
           Conv2d-53          [-1, 208, 14, 14]         179,920
      BatchNorm2d-54          [-1, 208, 14, 14]             416
             ReLU-55          [-1, 208, 14, 14]               0
           Conv2d-56           [-1, 16, 14, 14]           7,696
      BatchNorm2d-57           [-1, 16, 14, 14]              32
             ReLU-58           [-1, 16, 14, 14]               0
           Conv2d-59           [-1, 48, 14, 14]          19,248
      BatchNorm2d-60           [-1, 48, 14, 14]              96
             ReLU-61           [-1, 48, 14, 14]               0
        MaxPool2d-62          [-1, 480, 14, 14]               0
           Conv2d-63           [-1, 64, 14, 14]          30,784
      BatchNorm2d-64           [-1, 64, 14, 14]             128
             ReLU-65           [-1, 64, 14, 14]               0
  inception_block-66          [-1, 512, 14, 14]               0
           Conv2d-67          [-1, 160, 14, 14]          82,080
      BatchNorm2d-68          [-1, 160, 14, 14]             320
             ReLU-69          [-1, 160, 14, 14]               0
           Conv2d-70          [-1, 112, 14, 14]          57,456
      BatchNorm2d-71          [-1, 112, 14, 14]             224
             ReLU-72          [-1, 112, 14, 14]               0
           Conv2d-73          [-1, 224, 14, 14]         226,016
      BatchNorm2d-74          [-1, 224, 14, 14]             448
             ReLU-75          [-1, 224, 14, 14]               0
           Conv2d-76           [-1, 24, 14, 14]          12,312
      BatchNorm2d-77           [-1, 24, 14, 14]              48
             ReLU-78           [-1, 24, 14, 14]               0
           Conv2d-79           [-1, 64, 14, 14]          38,464
      BatchNorm2d-80           [-1, 64, 14, 14]             128
             ReLU-81           [-1, 64, 14, 14]               0
        MaxPool2d-82          [-1, 512, 14, 14]               0
           Conv2d-83           [-1, 64, 14, 14]          32,832
      BatchNorm2d-84           [-1, 64, 14, 14]             128
             ReLU-85           [-1, 64, 14, 14]               0
  inception_block-86          [-1, 512, 14, 14]               0
           Conv2d-87          [-1, 128, 14, 14]          65,664
      BatchNorm2d-88          [-1, 128, 14, 14]             256
             ReLU-89          [-1, 128, 14, 14]               0
           Conv2d-90          [-1, 128, 14, 14]          65,664
      BatchNorm2d-91          [-1, 128, 14, 14]             256
             ReLU-92          [-1, 128, 14, 14]               0
           Conv2d-93          [-1, 256, 14, 14]         295,168
      BatchNorm2d-94          [-1, 256, 14, 14]             512
             ReLU-95          [-1, 256, 14, 14]               0
           Conv2d-96           [-1, 24, 14, 14]          12,312
      BatchNorm2d-97           [-1, 24, 14, 14]              48
             ReLU-98           [-1, 24, 14, 14]               0
           Conv2d-99           [-1, 64, 14, 14]          38,464
     BatchNorm2d-100           [-1, 64, 14, 14]             128
            ReLU-101           [-1, 64, 14, 14]               0
       MaxPool2d-102          [-1, 512, 14, 14]               0
          Conv2d-103           [-1, 64, 14, 14]          32,832
     BatchNorm2d-104           [-1, 64, 14, 14]             128
            ReLU-105           [-1, 64, 14, 14]               0
 inception_block-106          [-1, 512, 14, 14]               0
          Conv2d-107          [-1, 112, 14, 14]          57,456
     BatchNorm2d-108          [-1, 112, 14, 14]             224
            ReLU-109          [-1, 112, 14, 14]               0
          Conv2d-110          [-1, 144, 14, 14]          73,872
     BatchNorm2d-111          [-1, 144, 14, 14]             288
            ReLU-112          [-1, 144, 14, 14]               0
          Conv2d-113          [-1, 288, 14, 14]         373,536
     BatchNorm2d-114          [-1, 288, 14, 14]             576
            ReLU-115          [-1, 288, 14, 14]               0
          Conv2d-116           [-1, 32, 14, 14]          16,416
     BatchNorm2d-117           [-1, 32, 14, 14]              64
            ReLU-118           [-1, 32, 14, 14]               0
          Conv2d-119           [-1, 64, 14, 14]          51,264
     BatchNorm2d-120           [-1, 64, 14, 14]             128
            ReLU-121           [-1, 64, 14, 14]               0
       MaxPool2d-122          [-1, 512, 14, 14]               0
          Conv2d-123           [-1, 64, 14, 14]          32,832
     BatchNorm2d-124           [-1, 64, 14, 14]             128
            ReLU-125           [-1, 64, 14, 14]               0
 inception_block-126          [-1, 528, 14, 14]               0
          Conv2d-127          [-1, 256, 14, 14]         135,424
     BatchNorm2d-128          [-1, 256, 14, 14]             512
            ReLU-129          [-1, 256, 14, 14]               0
          Conv2d-130          [-1, 160, 14, 14]          84,640
     BatchNorm2d-131          [-1, 160, 14, 14]             320
            ReLU-132          [-1, 160, 14, 14]               0
          Conv2d-133          [-1, 320, 14, 14]         461,120
     BatchNorm2d-134          [-1, 320, 14, 14]             640
            ReLU-135          [-1, 320, 14, 14]               0
          Conv2d-136           [-1, 32, 14, 14]          16,928
     BatchNorm2d-137           [-1, 32, 14, 14]              64
            ReLU-138           [-1, 32, 14, 14]               0
          Conv2d-139          [-1, 128, 14, 14]         102,528
     BatchNorm2d-140          [-1, 128, 14, 14]             256
            ReLU-141          [-1, 128, 14, 14]               0
       MaxPool2d-142          [-1, 528, 14, 14]               0
          Conv2d-143          [-1, 128, 14, 14]          67,712
     BatchNorm2d-144          [-1, 128, 14, 14]             256
            ReLU-145          [-1, 128, 14, 14]               0
 inception_block-146          [-1, 832, 14, 14]               0
       MaxPool2d-147            [-1, 832, 7, 7]               0
          Conv2d-148            [-1, 256, 7, 7]         213,248
     BatchNorm2d-149            [-1, 256, 7, 7]             512
            ReLU-150            [-1, 256, 7, 7]               0
          Conv2d-151            [-1, 160, 7, 7]         133,280
     BatchNorm2d-152            [-1, 160, 7, 7]             320
            ReLU-153            [-1, 160, 7, 7]               0
          Conv2d-154            [-1, 320, 7, 7]         461,120
     BatchNorm2d-155            [-1, 320, 7, 7]             640
            ReLU-156            [-1, 320, 7, 7]               0
          Conv2d-157             [-1, 32, 7, 7]          26,656
     BatchNorm2d-158             [-1, 32, 7, 7]              64
            ReLU-159             [-1, 32, 7, 7]               0
          Conv2d-160            [-1, 128, 7, 7]         102,528
     BatchNorm2d-161            [-1, 128, 7, 7]             256
            ReLU-162            [-1, 128, 7, 7]               0
       MaxPool2d-163            [-1, 832, 7, 7]               0
          Conv2d-164            [-1, 128, 7, 7]         106,624
     BatchNorm2d-165            [-1, 128, 7, 7]             256
            ReLU-166            [-1, 128, 7, 7]               0
 inception_block-167            [-1, 832, 7, 7]               0
          Conv2d-168            [-1, 384, 7, 7]         319,872
     BatchNorm2d-169            [-1, 384, 7, 7]             768
            ReLU-170            [-1, 384, 7, 7]               0
          Conv2d-171            [-1, 192, 7, 7]         159,936
     BatchNorm2d-172            [-1, 192, 7, 7]             384
            ReLU-173            [-1, 192, 7, 7]               0
          Conv2d-174            [-1, 384, 7, 7]         663,936
     BatchNorm2d-175            [-1, 384, 7, 7]             768
            ReLU-176            [-1, 384, 7, 7]               0
          Conv2d-177             [-1, 48, 7, 7]          39,984
     BatchNorm2d-178             [-1, 48, 7, 7]              96
            ReLU-179             [-1, 48, 7, 7]               0
          Conv2d-180            [-1, 128, 7, 7]         153,728
     BatchNorm2d-181            [-1, 128, 7, 7]             256
            ReLU-182            [-1, 128, 7, 7]               0
       MaxPool2d-183            [-1, 832, 7, 7]               0
          Conv2d-184            [-1, 128, 7, 7]         106,624
     BatchNorm2d-185            [-1, 128, 7, 7]             256
            ReLU-186            [-1, 128, 7, 7]               0
 inception_block-187           [-1, 1024, 7, 7]               0
       AvgPool2d-188           [-1, 1024, 1, 1]               0
         Dropout-189           [-1, 1024, 1, 1]               0
AdaptiveAvgPool2d-190           [-1, 1024, 1, 1]               0
          Linear-191                 [-1, 1024]       1,049,600
            ReLU-192                 [-1, 1024]               0
          Linear-193                    [-1, 2]           2,050
         Softmax-194                    [-1, 2]               0
================================================================
Total params: 7,039,122
Trainable params: 7,039,122
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 69.62
Params size (MB): 26.85
Estimated Total Size (MB): 97.05
----------------------------------------------------------------

5. 设置超参数:定义损失函数,学习率,以及根据学习率定义优化器等

python 复制代码
# loss_fn = nn.CrossEntropyLoss() # 创建损失函数

# learn_rate = 1e-3 # 初始学习率
# def adjust_learning_rate(optimizer,epoch,start_lr):
#     # 每两个epoch 衰减到原来的0.98
#     lr = start_lr * (0.92 ** (epoch//2))
#     for param_group in optimizer.param_groups:
#         param_group['lr'] = lr
        
# optimizer = torch.optim.Adam(model.parameters(),lr=learn_rate)
python 复制代码
# 调用官方接口示例
loss_fn = nn.CrossEntropyLoss()

# learn_rate = 1e-4  
learn_rate = 3e-4
lambda1 = lambda epoch:(0.92**(epoch//2))

optimizer = torch.optim.Adam(model.parameters(),lr = learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1) # 选定调整方法

6. 训练函数

python 复制代码
# 训练函数
def train(dataloader,model,loss_fn,optimizer):
    size = len(dataloader.dataset) # 训练集大小
    num_batches = len(dataloader) # 批次数目
    
    train_loss,train_acc = 0,0
    
    for X,y in dataloader:
        X,y = X.to(device),y.to(device)
        
        # 计算预测误差
        pred = model(X)
        loss = loss_fn(pred,y)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 记录acc与loss
        train_acc += (pred.argmax(1)==y).type(torch.float).sum().item()
        train_loss += loss.item()
        
    train_acc /= size
    train_loss /= num_batches
    
    return train_acc,train_loss

7. 测试函数

python 复制代码
# 测试函数
def test(dataloader,model,loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    
    test_acc,test_loss = 0,0
    
    with torch.no_grad():
        for X,y in dataloader:
            X,y = X.to(device),y.to(device)
            
            # 计算loss
            pred = model(X)
            loss = loss_fn(pred,y)
            
            test_acc += (pred.argmax(1)==y).type(torch.float).sum().item()
            test_loss += loss.item()
            
    test_acc /= size
    test_loss /= num_batches
    
    return test_acc,test_loss

8. 正式训练

python 复制代码
import copy

epochs = 60

train_acc = []
train_loss = []
test_acc = []
test_loss = []

best_acc = 0.0

# 检查 GPU 可用性并打印设备信息
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"Initial Memory Allocated: {torch.cuda.memory_allocated(i)/1024**2:.2f} MB")
        print(f"Initial Memory Cached: {torch.cuda.memory_reserved(i)/1024**2:.2f} MB")
else:
    print("No GPU available. Using CPU.")

# 多显卡设置 当前使用的是使用 PyTorch 自带的 DataParallel,后续如有需要可以设置为DistributedDataParallel,这是更加高效的方式
# 且多卡不一定比单卡效果就好,需要调整优化
# if torch.cuda.device_count() > 1:
#     print(f"Using {torch.cuda.device_count()} GPUs")
#     model = nn.DataParallel(model)
# model = model.to('cuda')

for epoch in range(epochs):
    
    # 更新学习率------使用自定义学习率时使用
    # adjust_learning_rate(optimizer,epoch,learn_rate)
    
    model.train()
    epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,optimizer)
    scheduler.step() # 更新学习率------调用官方动态学习率时使用
    
    model.eval()
    epoch_test_acc,epoch_test_loss = test(test_dl,model,loss_fn)
    
    # 保存最佳模型到 best_model
    if epoch_test_acc > best_acc:
        best_acc = epoch_test_acc
        best_model = copy.deepcopy(model)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    # 获取当前学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    
    template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')
    print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))

    # 实时监控 GPU 状态
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i} Usage:")
            print(f"  Memory Allocated: {torch.cuda.memory_allocated(i)/1024**2:.2f} MB")
            print(f"  Memory Cached: {torch.cuda.memory_reserved(i)/1024**2:.2f} MB")
            print(f"  Max Memory Allocated: {torch.cuda.max_memory_allocated(i)/1024**2:.2f} MB")
            print(f"  Max Memory Cached: {torch.cuda.max_memory_reserved(i)/1024**2:.2f} MB")

print('Done','best_acc: ',best_acc)
复制代码
GPU 0: NVIDIA GeForce RTX 4070 Laptop GPU
Initial Memory Allocated: 335.65 MB
Initial Memory Cached: 586.00 MB
Epoch: 1,Train_acc:63.3%,Train_loss:0.645,Test_acc:64.8%,Test_loss:0.634,Lr:3.00E-04
GPU 0 Usage:
  Memory Allocated: 455.01 MB
  Memory Cached: 2072.00 MB
  Max Memory Allocated: 1845.06 MB
  Max Memory Cached: 2072.00 MB
Epoch: 2,Train_acc:63.6%,Train_loss:0.638,Test_acc:62.5%,Test_loss:0.670,Lr:2.76E-04
GPU 0 Usage:
  Memory Allocated: 454.59 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1872.34 MB
  Max Memory Cached: 2086.00 MB
Epoch: 3,Train_acc:67.1%,Train_loss:0.622,Test_acc:62.9%,Test_loss:0.651,Lr:2.76E-04
GPU 0 Usage:
  Memory Allocated: 454.37 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1874.41 MB
  Max Memory Cached: 2086.00 MB
Epoch: 4,Train_acc:66.1%,Train_loss:0.627,Test_acc:67.1%,Test_loss:0.621,Lr:2.54E-04
GPU 0 Usage:
  Memory Allocated: 453.73 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1874.41 MB
  Max Memory Cached: 2086.00 MB
Epoch: 5,Train_acc:68.6%,Train_loss:0.616,Test_acc:60.8%,Test_loss:0.683,Lr:2.54E-04
GPU 0 Usage:
  Memory Allocated: 453.00 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1874.41 MB
  Max Memory Cached: 2086.00 MB
Epoch: 6,Train_acc:68.5%,Train_loss:0.601,Test_acc:69.5%,Test_loss:0.602,Lr:2.34E-04
GPU 0 Usage:
  Memory Allocated: 454.46 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1874.41 MB
  Max Memory Cached: 2086.00 MB
Epoch: 7,Train_acc:72.2%,Train_loss:0.583,Test_acc:70.4%,Test_loss:0.601,Lr:2.34E-04
GPU 0 Usage:
  Memory Allocated: 454.09 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1874.41 MB
  Max Memory Cached: 2086.00 MB
Epoch: 8,Train_acc:72.6%,Train_loss:0.572,Test_acc:69.9%,Test_loss:0.607,Lr:2.15E-04
GPU 0 Usage:
  Memory Allocated: 453.72 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.13 MB
  Max Memory Cached: 2086.00 MB
Epoch: 9,Train_acc:75.8%,Train_loss:0.545,Test_acc:73.7%,Test_loss:0.567,Lr:2.15E-04
GPU 0 Usage:
  Memory Allocated: 454.52 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.13 MB
  Max Memory Cached: 2086.00 MB
Epoch:10,Train_acc:75.8%,Train_loss:0.544,Test_acc:72.0%,Test_loss:0.584,Lr:1.98E-04
GPU 0 Usage:
  Memory Allocated: 454.52 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.13 MB
  Max Memory Cached: 2086.00 MB
Epoch:11,Train_acc:76.6%,Train_loss:0.539,Test_acc:75.3%,Test_loss:0.542,Lr:1.98E-04
GPU 0 Usage:
  Memory Allocated: 455.12 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.13 MB
  Max Memory Cached: 2086.00 MB
Epoch:12,Train_acc:78.6%,Train_loss:0.517,Test_acc:72.7%,Test_loss:0.574,Lr:1.82E-04
GPU 0 Usage:
  Memory Allocated: 455.12 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.13 MB
  Max Memory Cached: 2086.00 MB
Epoch:13,Train_acc:78.2%,Train_loss:0.521,Test_acc:74.1%,Test_loss:0.569,Lr:1.82E-04
GPU 0 Usage:
  Memory Allocated: 455.12 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.13 MB
  Max Memory Cached: 2086.00 MB
Epoch:14,Train_acc:78.1%,Train_loss:0.525,Test_acc:79.3%,Test_loss:0.509,Lr:1.67E-04
GPU 0 Usage:
  Memory Allocated: 454.52 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.13 MB
  Max Memory Cached: 2086.00 MB
Epoch:15,Train_acc:83.0%,Train_loss:0.483,Test_acc:72.7%,Test_loss:0.575,Lr:1.67E-04
GPU 0 Usage:
  Memory Allocated: 454.52 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.13 MB
  Max Memory Cached: 2086.00 MB
Epoch:16,Train_acc:82.6%,Train_loss:0.482,Test_acc:75.3%,Test_loss:0.545,Lr:1.54E-04
GPU 0 Usage:
  Memory Allocated: 455.53 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.22 MB
  Max Memory Cached: 2086.00 MB
Epoch:17,Train_acc:83.1%,Train_loss:0.476,Test_acc:79.5%,Test_loss:0.506,Lr:1.54E-04
GPU 0 Usage:
  Memory Allocated: 454.47 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.22 MB
  Max Memory Cached: 2086.00 MB
Epoch:18,Train_acc:84.8%,Train_loss:0.457,Test_acc:83.4%,Test_loss:0.471,Lr:1.42E-04
GPU 0 Usage:
  Memory Allocated: 454.93 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.22 MB
  Max Memory Cached: 2086.00 MB
Epoch:19,Train_acc:84.5%,Train_loss:0.467,Test_acc:81.8%,Test_loss:0.495,Lr:1.42E-04
GPU 0 Usage:
  Memory Allocated: 455.60 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:20,Train_acc:85.2%,Train_loss:0.457,Test_acc:83.2%,Test_loss:0.467,Lr:1.30E-04
GPU 0 Usage:
  Memory Allocated: 455.02 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:21,Train_acc:86.0%,Train_loss:0.445,Test_acc:79.7%,Test_loss:0.503,Lr:1.30E-04
GPU 0 Usage:
  Memory Allocated: 455.60 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:22,Train_acc:86.2%,Train_loss:0.444,Test_acc:86.0%,Test_loss:0.454,Lr:1.20E-04
GPU 0 Usage:
  Memory Allocated: 453.39 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:23,Train_acc:87.0%,Train_loss:0.437,Test_acc:85.5%,Test_loss:0.452,Lr:1.20E-04
GPU 0 Usage:
  Memory Allocated: 453.02 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:24,Train_acc:87.9%,Train_loss:0.432,Test_acc:88.8%,Test_loss:0.427,Lr:1.10E-04
GPU 0 Usage:
  Memory Allocated: 454.34 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:25,Train_acc:88.5%,Train_loss:0.423,Test_acc:86.2%,Test_loss:0.435,Lr:1.10E-04
GPU 0 Usage:
  Memory Allocated: 454.33 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:26,Train_acc:89.3%,Train_loss:0.421,Test_acc:86.7%,Test_loss:0.436,Lr:1.01E-04
GPU 0 Usage:
  Memory Allocated: 454.33 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:27,Train_acc:90.0%,Train_loss:0.411,Test_acc:87.2%,Test_loss:0.435,Lr:1.01E-04
GPU 0 Usage:
  Memory Allocated: 454.33 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:28,Train_acc:90.4%,Train_loss:0.404,Test_acc:89.3%,Test_loss:0.424,Lr:9.34E-05
GPU 0 Usage:
  Memory Allocated: 453.29 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:29,Train_acc:90.3%,Train_loss:0.405,Test_acc:89.7%,Test_loss:0.411,Lr:9.34E-05
GPU 0 Usage:
  Memory Allocated: 455.36 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:30,Train_acc:89.6%,Train_loss:0.411,Test_acc:89.0%,Test_loss:0.424,Lr:8.59E-05
GPU 0 Usage:
  Memory Allocated: 455.31 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:31,Train_acc:91.9%,Train_loss:0.392,Test_acc:90.7%,Test_loss:0.412,Lr:8.59E-05
GPU 0 Usage:
  Memory Allocated: 453.24 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:32,Train_acc:91.8%,Train_loss:0.392,Test_acc:89.5%,Test_loss:0.420,Lr:7.90E-05
GPU 0 Usage:
  Memory Allocated: 453.27 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:33,Train_acc:91.7%,Train_loss:0.392,Test_acc:91.8%,Test_loss:0.387,Lr:7.90E-05
GPU 0 Usage:
  Memory Allocated: 455.28 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:34,Train_acc:91.0%,Train_loss:0.401,Test_acc:89.7%,Test_loss:0.410,Lr:7.27E-05
GPU 0 Usage:
  Memory Allocated: 454.91 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:35,Train_acc:91.6%,Train_loss:0.392,Test_acc:92.5%,Test_loss:0.389,Lr:7.27E-05
GPU 0 Usage:
  Memory Allocated: 453.79 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:36,Train_acc:92.8%,Train_loss:0.386,Test_acc:92.1%,Test_loss:0.387,Lr:6.69E-05
GPU 0 Usage:
  Memory Allocated: 453.79 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:37,Train_acc:91.9%,Train_loss:0.392,Test_acc:88.8%,Test_loss:0.422,Lr:6.69E-05
GPU 0 Usage:
  Memory Allocated: 453.79 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:38,Train_acc:93.2%,Train_loss:0.382,Test_acc:90.9%,Test_loss:0.405,Lr:6.15E-05
GPU 0 Usage:
  Memory Allocated: 453.79 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:39,Train_acc:93.1%,Train_loss:0.382,Test_acc:93.0%,Test_loss:0.380,Lr:6.15E-05
GPU 0 Usage:
  Memory Allocated: 455.30 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:40,Train_acc:93.1%,Train_loss:0.381,Test_acc:92.8%,Test_loss:0.386,Lr:5.66E-05
GPU 0 Usage:
  Memory Allocated: 455.30 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:41,Train_acc:93.5%,Train_loss:0.377,Test_acc:92.8%,Test_loss:0.387,Lr:5.66E-05
GPU 0 Usage:
  Memory Allocated: 455.30 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:42,Train_acc:93.8%,Train_loss:0.373,Test_acc:93.9%,Test_loss:0.374,Lr:5.21E-05
GPU 0 Usage:
  Memory Allocated: 453.05 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:43,Train_acc:94.3%,Train_loss:0.370,Test_acc:92.8%,Test_loss:0.381,Lr:5.21E-05
GPU 0 Usage:
  Memory Allocated: 453.05 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:44,Train_acc:93.8%,Train_loss:0.373,Test_acc:92.3%,Test_loss:0.394,Lr:4.79E-05
GPU 0 Usage:
  Memory Allocated: 453.05 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:45,Train_acc:94.5%,Train_loss:0.368,Test_acc:93.9%,Test_loss:0.367,Lr:4.79E-05
GPU 0 Usage:
  Memory Allocated: 453.05 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:46,Train_acc:94.3%,Train_loss:0.370,Test_acc:93.0%,Test_loss:0.385,Lr:4.41E-05
GPU 0 Usage:
  Memory Allocated: 453.05 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:47,Train_acc:94.6%,Train_loss:0.364,Test_acc:93.0%,Test_loss:0.378,Lr:4.41E-05
GPU 0 Usage:
  Memory Allocated: 453.05 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:48,Train_acc:93.2%,Train_loss:0.380,Test_acc:93.0%,Test_loss:0.383,Lr:4.06E-05
GPU 0 Usage:
  Memory Allocated: 453.05 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:49,Train_acc:93.9%,Train_loss:0.371,Test_acc:95.1%,Test_loss:0.361,Lr:4.06E-05
GPU 0 Usage:
  Memory Allocated: 454.08 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:50,Train_acc:94.6%,Train_loss:0.368,Test_acc:94.6%,Test_loss:0.367,Lr:3.73E-05
GPU 0 Usage:
  Memory Allocated: 454.08 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:51,Train_acc:94.3%,Train_loss:0.368,Test_acc:94.6%,Test_loss:0.365,Lr:3.73E-05
GPU 0 Usage:
  Memory Allocated: 454.08 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:52,Train_acc:94.9%,Train_loss:0.363,Test_acc:93.2%,Test_loss:0.376,Lr:3.43E-05
GPU 0 Usage:
  Memory Allocated: 454.08 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:53,Train_acc:95.2%,Train_loss:0.362,Test_acc:94.4%,Test_loss:0.362,Lr:3.43E-05
GPU 0 Usage:
  Memory Allocated: 454.08 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:54,Train_acc:96.4%,Train_loss:0.348,Test_acc:94.4%,Test_loss:0.373,Lr:3.16E-05
GPU 0 Usage:
  Memory Allocated: 454.08 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:55,Train_acc:96.5%,Train_loss:0.347,Test_acc:93.7%,Test_loss:0.371,Lr:3.16E-05
GPU 0 Usage:
  Memory Allocated: 454.08 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:56,Train_acc:95.6%,Train_loss:0.355,Test_acc:95.8%,Test_loss:0.355,Lr:2.91E-05
GPU 0 Usage:
  Memory Allocated: 453.93 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:57,Train_acc:95.4%,Train_loss:0.358,Test_acc:94.6%,Test_loss:0.363,Lr:2.91E-05
GPU 0 Usage:
  Memory Allocated: 453.93 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:58,Train_acc:95.6%,Train_loss:0.355,Test_acc:94.4%,Test_loss:0.369,Lr:2.67E-05
GPU 0 Usage:
  Memory Allocated: 453.93 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:59,Train_acc:96.5%,Train_loss:0.348,Test_acc:93.9%,Test_loss:0.372,Lr:2.67E-05
GPU 0 Usage:
  Memory Allocated: 453.93 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Epoch:60,Train_acc:96.8%,Train_loss:0.344,Test_acc:96.0%,Test_loss:0.355,Lr:2.46E-05
GPU 0 Usage:
  Memory Allocated: 453.93 MB
  Memory Cached: 2086.00 MB
  Max Memory Allocated: 1875.26 MB
  Max Memory Cached: 2086.00 MB
Done best_acc:  0.9603729603729604

9. 结果可视化

python 复制代码
epochs_range = range(epochs)

plt.figure(figsize = (12,3))

plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label = 'Training Accuracy')
plt.plot(epochs_range,test_acc,label = 'Test Accuracy')
plt.legend(loc = 'lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label = 'Test Accuracy')
plt.plot(epochs_range,test_loss,label = 'Test Loss')
plt.legend(loc = 'lower right')
plt.title('Training and validation Loss')
plt.show()

10. 模型的保存

python 复制代码
# 自定义模型保存
# 状态字典保存
torch.save(model.state_dict(),'./模型参数/J8_InceptionV1_model_state_dict.pth') # 仅保存状态字典

# 定义模型用来加载参数
best_model = InceptionV1(num_classes=len(classNames)).to(device)

best_model.load_state_dict(torch.load('./模型参数/J8_InceptionV1_model_state_dict.pth')) # 加载状态字典到模型
复制代码
<All keys matched successfully>

11.使用训练好的模型进行预测

python 复制代码
# 指定路径图片预测
from PIL import Image
import torchvision.transforms as transforms

classes = list(total_data.class_to_idx) # classes = list(total_data.class_to_idx)

def predict_one_image(image_path,model,transform,classes):
    
    test_img = Image.open(image_path).convert('RGB')
    # plt.imshow(test_img) # 展示待预测的图片
    
    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)
    
    model.eval()
    output = model(img)
    print(output) # 观察模型预测结果的输出数据
    
    _,pred = torch.max(output,1)
    pred_class = classes[pred]
    print(f'预测结果是:{pred_class}')
python 复制代码
# 预测训练集中的某张照片
predict_one_image(image_path='./data/mpox_recognize/Monkeypox/M01_01_04.jpg',
                 model = model,
                 transform = test_transforms,
                 classes = classes
                 )
复制代码
tensor([[0.0015, 0.9985]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
预测结果是:Others
python 复制代码
classes
复制代码
['Monkeypox', 'Others']
python 复制代码
python 复制代码
相关推荐
风象南20 分钟前
Claude Code这个隐藏技能,让我告别PPT焦虑
人工智能·后端
曲幽42 分钟前
FastAPI压力测试实战:Locust模拟真实用户并发及优化建议
python·fastapi·web·locust·asyncio·test·uvicorn·workers
Mintopia1 小时前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮2 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬2 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia2 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区2 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两5 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
敏编程5 小时前
一天一个Python库:jsonschema - JSON 数据验证利器
python