pytorch-模型训练

目录

  • [1. 模型训练的基本步骤](#1. 模型训练的基本步骤)
    • [1.1 train、test数据下载](#1.1 train、test数据下载)
    • [1.2 train、test数据加载](#1.2 train、test数据加载)
    • [1.3 Lenet5实例化、初始化loss函数、初始化优化器](#1.3 Lenet5实例化、初始化loss函数、初始化优化器)
    • [1.4 开始train和test](#1.4 开始train和test)
  • [2. 完整代码](#2. 完整代码)

1. 模型训练的基本步骤

以cifar10和Lenet5为例

1.1 train、test数据下载

使用torchvision中的datasets可以方便下载cifar10数据

python 复制代码
cifar_train = datasets.CIFAR10('cifa', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)

transforms.Resize((32, 32)) 将数据图形数据resize为32x32,这里可不用因为cifar10本身就是32x32

transforms.ToTensor()是将numpy或者numpy数组或PIL图像)转换为PyTorch的Tensor格式,以便输入网络。

transforms.Normalize()根据指定的均值和标准差对每个颜色通道进行图像归一化,可以提高神经网络训练过程中的收敛速度

1.2 train、test数据加载

使用pytorch torch.utils.data中的DataLoader用来加载数据

python 复制代码
cifar_train = DataLoader(cifar_train, batch_size=batchz, shuffle=True)

batch_size=batchz: 这里batchz是一个变量,代表每个批次的样本数量。

shuffle=True: 这个参数设定为True意味着在每次训练循环(epoch)开始前,数据集中的样本会被随机打乱顺序。这样做可以增加训练过程中的随机性,帮助模型更好地泛化,避免过拟合特定的样本排列顺序。

1.3 Lenet5实例化、初始化loss函数、初始化优化器

python 复制代码
    device = torch.device('cuda')
    model = Lenet5().to(device)
    crition = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

注意:网络和模型一定要搬到GPU上

1.4 开始train和test

  • 循环epoch
  • 加载train数据、输入模型、计算loss、backward、调用优化器
  • 加载test数据、输入模型、计算prediction、计算正确率
  • 输出正确率
python 复制代码
 for epoch in range(1000):
        model.train()
        for batch, (x, label) in enumerate(cifar_train):
            x, label = x.to(device), label.to(device)
            logits = model(x)
            loss = crition(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # test
        model.eval()
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                x, label = x.to(device), label.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)

2. 完整代码

python 复制代码
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn, optim
import sys

sys.path.append('.')
from Lenet5 import Lenet5


def main():
    batchz = 128
    cifar_train = datasets.CIFAR10('cifa', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchz, shuffle=True)

    cifar_test = datasets.CIFAR10('cifa', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchz, shuffle=True)

    device = torch.device('cuda')
    model = Lenet5().to(device)
    crition = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(1000):
        model.train()
        for batch, (x, label) in enumerate(cifar_train):
            x, label = x.to(device), label.to(device)
            logits = model(x)
            loss = crition(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # test
        model.eval()
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                x, label = x.to(device), label.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)


if __name__ == '__main__':
    main()

model.train()和model.eval()的区别和作用

model.train()

作用:当调用模型的model.train()方法时,模型会进入训练模式。这意味着:

启用 Dropout层和BatchNorm层:在训练模式下,Dropout层会按照设定的概率随机"丢弃"一部分神经元以防止过拟合,而Batch Normalization(批规范化)层会根据当前批次的数据动态计算均值和方差进行归一化。

梯度计算:允许梯度计算,这是反向传播和权重更新的基础。

应用场景:在模型的训练循环中,每次迭代开始之前调用,以确保模型处于正确的训练状态。

model.eval()

作用:调用model.eval()方法后,模型会进入评估模式。此时:

禁用 Dropout层:Dropout层在评估时不发挥作用,所有的神经元都会被保留,以确保预测的确定性和可重复性。

固定 BatchNorm层:BatchNorm层使用训练过程中积累的统计量(全局均值和方差)进行归一化,而不是当前批次的统计量,这有助于模型输出更加稳定和一致。

应用场景:在验证或测试模型性能时使用,确保模型输出是确定性的,不受训练时特有的随机操作影响,以便于准确评估模型的泛化能力。

相关推荐
无人装备硬件开发爱好者5 小时前
Python + Blender 5.0 几何节点全栈实战教程1
开发语言·python·blender
witAI5 小时前
**AI漫剧制作工具2025推荐,解锁高效创作新体验**
人工智能·python
凉忆-5 小时前
llama-factory训练大模型
python·pip·llama
80530单词突击赢5 小时前
MPPI算法:ROS下的智能控制实战
开发语言·python
qinyia5 小时前
如何在服务器上查看网络连接数并进行综合分析
linux·运维·服务器·开发语言·人工智能·php
m0_706653235 小时前
自然语言处理(NLP)入门:使用NLTK和Spacy
jvm·数据库·python
新缸中之脑5 小时前
构建一个论文学习AI助手
人工智能·学习
说私域5 小时前
私域流量生态重构:链动2+1模式S2B2C商城小程序的流量整合与价值创造
人工智能·小程序·流量运营·私域运营
圆奋奋5 小时前
让“小爱音箱PRO”智能起来:接入豆包AI
人工智能
m0_736919105 小时前
Python游戏中的碰撞检测实现
jvm·数据库·python