pytorch-模型训练

目录

  • [1. 模型训练的基本步骤](#1. 模型训练的基本步骤)
    • [1.1 train、test数据下载](#1.1 train、test数据下载)
    • [1.2 train、test数据加载](#1.2 train、test数据加载)
    • [1.3 Lenet5实例化、初始化loss函数、初始化优化器](#1.3 Lenet5实例化、初始化loss函数、初始化优化器)
    • [1.4 开始train和test](#1.4 开始train和test)
  • [2. 完整代码](#2. 完整代码)

1. 模型训练的基本步骤

以cifar10和Lenet5为例

1.1 train、test数据下载

使用torchvision中的datasets可以方便下载cifar10数据

python 复制代码
cifar_train = datasets.CIFAR10('cifa', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)

transforms.Resize((32, 32)) 将数据图形数据resize为32x32,这里可不用因为cifar10本身就是32x32

transforms.ToTensor()是将numpy或者numpy数组或PIL图像)转换为PyTorch的Tensor格式,以便输入网络。

transforms.Normalize()根据指定的均值和标准差对每个颜色通道进行图像归一化,可以提高神经网络训练过程中的收敛速度

1.2 train、test数据加载

使用pytorch torch.utils.data中的DataLoader用来加载数据

python 复制代码
cifar_train = DataLoader(cifar_train, batch_size=batchz, shuffle=True)

batch_size=batchz: 这里batchz是一个变量,代表每个批次的样本数量。

shuffle=True: 这个参数设定为True意味着在每次训练循环(epoch)开始前,数据集中的样本会被随机打乱顺序。这样做可以增加训练过程中的随机性,帮助模型更好地泛化,避免过拟合特定的样本排列顺序。

1.3 Lenet5实例化、初始化loss函数、初始化优化器

python 复制代码
    device = torch.device('cuda')
    model = Lenet5().to(device)
    crition = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

注意:网络和模型一定要搬到GPU上

1.4 开始train和test

  • 循环epoch
  • 加载train数据、输入模型、计算loss、backward、调用优化器
  • 加载test数据、输入模型、计算prediction、计算正确率
  • 输出正确率
python 复制代码
 for epoch in range(1000):
        model.train()
        for batch, (x, label) in enumerate(cifar_train):
            x, label = x.to(device), label.to(device)
            logits = model(x)
            loss = crition(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # test
        model.eval()
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                x, label = x.to(device), label.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)

2. 完整代码

python 复制代码
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn, optim
import sys

sys.path.append('.')
from Lenet5 import Lenet5


def main():
    batchz = 128
    cifar_train = datasets.CIFAR10('cifa', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchz, shuffle=True)

    cifar_test = datasets.CIFAR10('cifa', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchz, shuffle=True)

    device = torch.device('cuda')
    model = Lenet5().to(device)
    crition = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(1000):
        model.train()
        for batch, (x, label) in enumerate(cifar_train):
            x, label = x.to(device), label.to(device)
            logits = model(x)
            loss = crition(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # test
        model.eval()
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                x, label = x.to(device), label.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)


if __name__ == '__main__':
    main()

model.train()和model.eval()的区别和作用

model.train()

作用:当调用模型的model.train()方法时,模型会进入训练模式。这意味着:

启用 Dropout层和BatchNorm层:在训练模式下,Dropout层会按照设定的概率随机"丢弃"一部分神经元以防止过拟合,而Batch Normalization(批规范化)层会根据当前批次的数据动态计算均值和方差进行归一化。

梯度计算:允许梯度计算,这是反向传播和权重更新的基础。

应用场景:在模型的训练循环中,每次迭代开始之前调用,以确保模型处于正确的训练状态。

model.eval()

作用:调用model.eval()方法后,模型会进入评估模式。此时:

禁用 Dropout层:Dropout层在评估时不发挥作用,所有的神经元都会被保留,以确保预测的确定性和可重复性。

固定 BatchNorm层:BatchNorm层使用训练过程中积累的统计量(全局均值和方差)进行归一化,而不是当前批次的统计量,这有助于模型输出更加稳定和一致。

应用场景:在验证或测试模型性能时使用,确保模型输出是确定性的,不受训练时特有的随机操作影响,以便于准确评估模型的泛化能力。

相关推荐
dundunmm几秒前
【每天一个知识点】本体论
人工智能·rag·本体论
nimadan122 分钟前
**免费有声书配音软件2025推荐,高拟真度AI配音与多场景
人工智能·python
jkyy20148 分钟前
汽车×大健康融合:智慧健康监测座舱成车企新赛道核心布局
大数据·人工智能·物联网·汽车·健康医疗
可触的未来,发芽的智生9 分钟前
完全原生态思考:从零学习的本质探索→刻石头
javascript·人工智能·python·神经网络·程序人生
凤希AI伴侣10 分钟前
重构与远见:凤希AI伴侣的看图升级与P2P算力共享蓝图-凤希AI伴侣-2026年1月12日
人工智能·重构·凤希ai伴侣
叫我:松哥11 分钟前
基于Flask+ECharts+Bootstrap构建的微博智能数据分析大屏
人工智能·python·信息可视化·数据分析·flask·bootstrap·echarts
倔强的石头10619 分钟前
什么是机器学习?—— 用 “买西瓜” 讲透核心逻辑
人工智能·机器学习
美团技术团队20 分钟前
KuiTest:基于大模型通识的UI交互遍历测试
人工智能
Study99620 分钟前
大语言模型的详解与训练
人工智能·ai·语言模型·自然语言处理·大模型·llm·agent
Pyeako20 分钟前
Opencv计算机视觉--边界填充&图像形态学
人工智能·python·opencv·计算机视觉·pycharm·图像形态学·边缘填充