pytorch-模型训练

目录

  • [1. 模型训练的基本步骤](#1. 模型训练的基本步骤)
    • [1.1 train、test数据下载](#1.1 train、test数据下载)
    • [1.2 train、test数据加载](#1.2 train、test数据加载)
    • [1.3 Lenet5实例化、初始化loss函数、初始化优化器](#1.3 Lenet5实例化、初始化loss函数、初始化优化器)
    • [1.4 开始train和test](#1.4 开始train和test)
  • [2. 完整代码](#2. 完整代码)

1. 模型训练的基本步骤

以cifar10和Lenet5为例

1.1 train、test数据下载

使用torchvision中的datasets可以方便下载cifar10数据

python 复制代码
cifar_train = datasets.CIFAR10('cifa', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)

transforms.Resize((32, 32)) 将数据图形数据resize为32x32,这里可不用因为cifar10本身就是32x32

transforms.ToTensor()是将numpy或者numpy数组或PIL图像)转换为PyTorch的Tensor格式,以便输入网络。

transforms.Normalize()根据指定的均值和标准差对每个颜色通道进行图像归一化,可以提高神经网络训练过程中的收敛速度

1.2 train、test数据加载

使用pytorch torch.utils.data中的DataLoader用来加载数据

python 复制代码
cifar_train = DataLoader(cifar_train, batch_size=batchz, shuffle=True)

batch_size=batchz: 这里batchz是一个变量,代表每个批次的样本数量。

shuffle=True: 这个参数设定为True意味着在每次训练循环(epoch)开始前,数据集中的样本会被随机打乱顺序。这样做可以增加训练过程中的随机性,帮助模型更好地泛化,避免过拟合特定的样本排列顺序。

1.3 Lenet5实例化、初始化loss函数、初始化优化器

python 复制代码
    device = torch.device('cuda')
    model = Lenet5().to(device)
    crition = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

注意:网络和模型一定要搬到GPU上

1.4 开始train和test

  • 循环epoch
  • 加载train数据、输入模型、计算loss、backward、调用优化器
  • 加载test数据、输入模型、计算prediction、计算正确率
  • 输出正确率
python 复制代码
 for epoch in range(1000):
        model.train()
        for batch, (x, label) in enumerate(cifar_train):
            x, label = x.to(device), label.to(device)
            logits = model(x)
            loss = crition(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # test
        model.eval()
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                x, label = x.to(device), label.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)

2. 完整代码

python 复制代码
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn, optim
import sys

sys.path.append('.')
from Lenet5 import Lenet5


def main():
    batchz = 128
    cifar_train = datasets.CIFAR10('cifa', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchz, shuffle=True)

    cifar_test = datasets.CIFAR10('cifa', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchz, shuffle=True)

    device = torch.device('cuda')
    model = Lenet5().to(device)
    crition = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(1000):
        model.train()
        for batch, (x, label) in enumerate(cifar_train):
            x, label = x.to(device), label.to(device)
            logits = model(x)
            loss = crition(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # test
        model.eval()
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                x, label = x.to(device), label.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)


if __name__ == '__main__':
    main()

model.train()和model.eval()的区别和作用

model.train()

作用:当调用模型的model.train()方法时,模型会进入训练模式。这意味着:

启用 Dropout层和BatchNorm层:在训练模式下,Dropout层会按照设定的概率随机"丢弃"一部分神经元以防止过拟合,而Batch Normalization(批规范化)层会根据当前批次的数据动态计算均值和方差进行归一化。

梯度计算:允许梯度计算,这是反向传播和权重更新的基础。

应用场景:在模型的训练循环中,每次迭代开始之前调用,以确保模型处于正确的训练状态。

model.eval()

作用:调用model.eval()方法后,模型会进入评估模式。此时:

禁用 Dropout层:Dropout层在评估时不发挥作用,所有的神经元都会被保留,以确保预测的确定性和可重复性。

固定 BatchNorm层:BatchNorm层使用训练过程中积累的统计量(全局均值和方差)进行归一化,而不是当前批次的统计量,这有助于模型输出更加稳定和一致。

应用场景:在验证或测试模型性能时使用,确保模型输出是确定性的,不受训练时特有的随机操作影响,以便于准确评估模型的泛化能力。

相关推荐
多米Domi0111 天前
0x3f 第35天 电脑硬盘坏了 +二叉树直径,将有序数组转换为二叉搜索树
java·数据结构·python·算法·leetcode·链表
猫天意1 天前
【深度学习小课堂】| torch | 升维打击还是原位拼接?深度解码 PyTorch 中 stack 与 cat 的几何奥义
开发语言·人工智能·pytorch·深度学习·神经网络·yolo·机器学习
cyyt1 天前
深度学习周报(1.12~1.18)
人工智能·算法·机器学习
摸鱼仙人~1 天前
深度对比:Prompt Tuning、P-tuning 与 Prefix Tuning 有何不同?
人工智能·prompt
塔能物联运维1 天前
隧道照明“智能进化”:PLC 通信 + AI 调光守护夜间通行生命线
大数据·人工智能
瑶光守护者1 天前
【AI经典论文解读】《Denoising Diffusion Implicit Models(去噪扩散隐式模型)》论文深度解读
人工智能
wwwzhouhui1 天前
2026年1月18日-Obsidian + AI,笔记效率提升10倍!一键生成Canvas和小红书风格笔记
人工智能·obsidian·skills
我星期八休息1 天前
MySQL数据可视化实战指南
数据库·人工智能·mysql·算法·信息可视化
UR的出不克1 天前
使用 Python 爬取 Bilibili 弹幕数据并导出 Excel
java·python·excel
wuk9981 天前
基于遗传算法优化BP神经网络实现非线性函数拟合
人工智能·深度学习·神经网络