【PyTorch 攻略】(6-7/7)

一、说明

本篇介绍模型模型的参数,模型推理和使用,保存加载。

二、训练参数和模型

在本单元中,我们将了解如何加载模型及其持久参数状态和推理模型预测。为了加载模型,我们将定义模型类,其中包含用于训练模型的神经网络的状态和参数。

复制代码
%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor
复制代码
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.line

加载模型权重时,我们需要首先实例化模型类,因为该类定义了网络的结构。接下来,我们使用 load_state_dict() 方法加载参数。

复制代码
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()

注意:请务必在推理之前调用 model.eval() 方法,以将 dropout 和批量归一化层设置为评估模式。如果不这样做,将产生不一致的推理结果。

三、模型推理

优化模型以在各种平台和编程语言上运行是很困难的。在所有不同的框架和硬件组合中最大限度地提高性能是非常耗时的。
开放式神经网络交换 (ONNX) 运行时为您提供了一种解决方案,只需训练一次,即可在任何硬件、云或边缘设备上加速推理。

ONNX 是许多供应商支持的一种通用格式,用于共享神经网络和其他机器学习模型。您可以使用 ONNX 格式在其他编程语言和框架(如 Java、JavaScript、C# 和 ML.NET)上对模型进行推理。

复制代码
input_image = torch.zeros((1,28,28))
onnx_model = 'data/model.onnx'
onnx.export(model, input_image, onnx_model)

我们将使用测试数据集作为示例数据,以便从 ONNX 模型进行推理以进行预测。

复制代码
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]
x, y = test_data[0][0], test_data[0][1]

我们需要使用 onnxruntime 创建一个推理会话。推理会话。为了推断 onnx 模型,我们使用 run 和 pass 输入要返回的输出列表(如果需要所有输出,请留空)和输入值映射。结果是一个输出列表:

复制代码
session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

result = session.run([output_name], {input_name: x.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')

**四、**torch.utils.data.DataLoader 和torch.utils.data.Dataset

PyTorch有两个基元来处理数据:torch.utils.data.DataLoader 和torch.utils.data.Dataset数据集 存储样本及其相应的标签,DataLoader 围绕数据集包装一个可迭代对象。

ba 复制代码
%matplotlib inline
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt

PyTorch提供特定于领域的库,如TorchText,TorchVision和TorchAudio, 所有这些都包括数据集。在本教程中,我们将使用TorchVision数据集。

torchvision.datasets 模块包含许多真实世界视觉数据(如 CIFAR 和 COCO)的数据集 对象。在本教程中,我们将使用 FashionMNIST 数据集。每个TorchVision 数据集都包含两个参数:转换target_transform分别修改样本和标签。

ba 复制代码
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

我们将数据集 作为参数传递给 DataLoader。这将在我们的数据集上包装一个可迭代对象,并支持自动批处理、采样、随机排序和多进程数据加载。这里我们定义一个 64 的批量大小,即 dataloader 迭代中的每个元素将返回一批 64 个特征和标签。

ba 复制代码
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break
    
# Display sample data
figure = plt.figure(figsize=(10, 8))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
    idx = torch.randint(len(test_data), size=(1,)).item()
    img, label = test_data[idx]
    figure.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()
ba 复制代码
Shape of X [N, C, H, W]:  torch.Size([64, 1, 28, 28])
Shape of y:  torch.Size([64]) torch.int64

五、创建模型

为了在 PyTorch 中定义神经网络,我们创建一个继承自 nn.Module 的类。我们在 init 函数中定义网络层,并在转发函数中指定数据如何通过网络。为了加速神经网络的运算,我们将其转移到 GPU(如果可用)。

ba 复制代码
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)
ba 复制代码
Using cuda device
NeuralNetwork(
  (flatten): Flatten()
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
    (5): ReLU()
  )
)

六、优化模型参数

为了训练模型,我们需要一个损失函数和一个优化器。我们将使用 nn。交叉熵损失 用于损失,随机梯度下降用于优化。

ba 复制代码
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

在单个训练循环中,模型对训练数据集进行预测(批量馈送到它),并向后传播预测误差以调整模型的参数。

ba 复制代码
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)

我们还可以对照测试数据集检查模型的性能,以确保它正在学习。

ba 复制代码
def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程通过多次迭代(纪元)进行。在每个时期,模型学习参数以做出更好的预测。我们打印模型在每个时期的准确性和损失;我们希望看到精度随着每个时期的增加和损失的减少而减少。

ba 复制代码
epochs = 15
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model)
print("Done!")
ba 复制代码
Epoch 1
-------------------------------
loss: 2.295450  [    0/60000]
loss: 2.293073  [ 6400/60000]
loss: 2.278504  [12800/60000]
loss: 2.282501  [19200/60000]
loss: 2.273211  [25600/60000]
loss: 2.258452  [32000/60000]
loss: 2.248237  [38400/60000]
loss: 2.228594  [44800/60000]
loss: 2.240276  [51200/60000]
loss: 2.221318  [57600/60000]
Test Error: 
 Accuracy: 51.8%, Avg loss: 0.034745 

Epoch 2
-------------------------------
loss: 2.212354  [    0/60000]
loss: 2.207739  [ 6400/60000]
loss: 2.160400  [12800/60000]
loss: 2.176181  [19200/60000]
loss: 2.168270  [25600/60000]
loss: 2.146453  [32000/60000]
loss: 2.119934  [38400/60000]
loss: 2.083791  [44800/60000]
loss: 2.126453  [51200/60000]
loss: 2.077550  [57600/60000]
Test Error: 
 Accuracy: 53.2%, Avg loss: 0.032452 

Epoch 3
-------------------------------
loss: 2.082280  [    0/60000]
loss: 2.068733  [ 6400/60000]
loss: 1.965958  [12800/60000]
loss: 1.997126  [19200/60000]
loss: 2.002057  [25600/60000]
loss: 1.967370  [32000/60000]
loss: 1.910595  [38400/60000]
loss: 1.849006  [44800/60000]
loss: 1.944741  [51200/60000]
loss: 1.861265  [57600/60000]
Test Error: 
 Accuracy: 51.6%, Avg loss: 0.028937 

Epoch 4
-------------------------------
loss: 1.872628  [    0/60000]
loss: 1.844543  [ 6400/60000]
loss: 1.710179  [12800/60000]
loss: 1.779804  [19200/60000]
loss: 1.737971  [25600/60000]
loss: 1.746953  [32000/60000]
loss: 1.624768  [38400/60000]
loss: 1.575720  [44800/60000]
loss: 1.742827  [51200/60000]
loss: 1.653375  [57600/60000]
Test Error: 
 Accuracy: 58.4%, Avg loss: 0.025570 

Epoch 5
-------------------------------
loss: 1.662315  [    0/60000]
loss: 1.636235  [ 6400/60000]
loss: 1.508407  [12800/60000]
loss: 1.606842  [19200/60000]
loss: 1.560728  [25600/60000]
loss: 1.606024  [32000/60000]
loss: 1.426900  [38400/60000]
loss: 1.406240  [44800/60000]
loss: 1.619918  [51200/60000]
loss: 1.521326  [57600/60000]
Test Error: 
 Accuracy: 61.2%, Avg loss: 0.023459 

Epoch 6
-------------------------------
loss: 1.527535  [    0/60000]
loss: 1.511209  [ 6400/60000]
loss: 1.377129  [12800/60000]
loss: 1.494889  [19200/60000]
loss: 1.457990  [25600/60000]
loss: 1.502333  [32000/60000]
loss: 1.291539  [38400/60000]
loss: 1.285098  [44800/60000]
loss: 1.484891  [51200/60000]
loss: 1.414015  [57600/60000]
Test Error: 
 Accuracy: 62.2%, Avg loss: 0.021480 

Epoch 7
-------------------------------
loss: 1.376779  [    0/60000]
loss: 1.384830  [ 6400/60000]
loss: 1.230116  [12800/60000]
loss: 1.382574  [19200/60000]
loss: 1.255630  [25600/60000]
loss: 1.396211  [32000/60000]
loss: 1.157718  [38400/60000]
loss: 1.186382  [44800/60000]
loss: 1.340606  [51200/60000]
loss: 1.321607  [57600/60000]
Test Error: 
 Accuracy: 62.8%, Avg loss: 0.019737 

Epoch 8
-------------------------------
loss: 1.243344  [    0/60000]
loss: 1.279124  [ 6400/60000]
loss: 1.121769  [12800/60000]
loss: 1.293069  [19200/60000]
loss: 1.128232  [25600/60000]
loss: 1.315465  [32000/60000]
loss: 1.069528  [38400/60000]
loss: 1.123324  [44800/60000]
loss: 1.243827  [51200/60000]
loss: 1.255190  [57600/60000]
Test Error: 
 Accuracy: 63.4%, Avg loss: 0.018518 

Epoch 9
-------------------------------
loss: 1.154148  [    0/60000]
loss: 1.205280  [ 6400/60000]
loss: 1.046463  [12800/60000]
loss: 1.229866  [19200/60000]
loss: 1.048813  [25600/60000]
loss: 1.254785  [32000/60000]
loss: 1.010614  [38400/60000]
loss: 1.077114  [44800/60000]
loss: 1.176766  [51200/60000]
loss: 1.206567  [57600/60000]
Test Error: 
 Accuracy: 64.3%, Avg loss: 0.017640 

Epoch 10
-------------------------------
loss: 1.090360  [    0/60000]
loss: 1.149150  [ 6400/60000]
loss: 0.990786  [12800/60000]
loss: 1.183704  [19200/60000]
loss: 0.997114  [25600/60000]
loss: 1.207199  [32000/60000]
loss: 0.967512  [38400/60000]
loss: 1.043431  [44800/60000]
loss: 1.127000  [51200/60000]
loss: 1.169639  [57600/60000]
Test Error: 
 Accuracy: 65.3%, Avg loss: 0.016974 

Epoch 11
-------------------------------
loss: 1.041194  [    0/60000]
loss: 1.104409  [ 6400/60000]
loss: 0.947670  [12800/60000]
loss: 1.149421  [19200/60000]
loss: 0.960403  [25600/60000]
loss: 1.169899  [32000/60000]
loss: 0.935149  [38400/60000]
loss: 1.018250  [44800/60000]
loss: 1.088222  [51200/60000]
loss: 1.139813  [57600/60000]
Test Error: 
 Accuracy: 66.2%, Avg loss: 0.016446 

Epoch 12
-------------------------------
loss: 1.000646  [    0/60000]
loss: 1.067356  [ 6400/60000]
loss: 0.912046  [12800/60000]
loss: 1.122742  [19200/60000]
loss: 0.932827  [25600/60000]
loss: 1.138785  [32000/60000]
loss: 0.910242  [38400/60000]
loss: 0.999010  [44800/60000]
loss: 1.056596  [51200/60000]
loss: 1.114582  [57600/60000]
Test Error: 
 Accuracy: 67.5%, Avg loss: 0.016011 

Epoch 13
-------------------------------
loss: 0.966393  [    0/60000]
loss: 1.035691  [ 6400/60000]
loss: 0.881672  [12800/60000]
loss: 1.100845  [19200/60000]
loss: 0.910265  [25600/60000]
loss: 1.112597  [32000/60000]
loss: 0.889558  [38400/60000]
loss: 0.982751  [44800/60000]
loss: 1.029199  [51200/60000]
loss: 1.092738  [57600/60000]
Test Error: 
 Accuracy: 68.5%, Avg loss: 0.015636 

Epoch 14
-------------------------------
loss: 0.936334  [    0/60000]
loss: 1.007734  [ 6400/60000]
loss: 0.854663  [12800/60000]
loss: 1.081601  [19200/60000]
loss: 0.890581  [25600/60000]
loss: 1.089641  [32000/60000]
loss: 0.872057  [38400/60000]
loss: 0.969192  [44800/60000]
loss: 1.005193  [51200/60000]
loss: 1.073098  [57600/60000]
Test Error: 
 Accuracy: 69.4%, Avg loss: 0.015304 

Epoch 15
-------------------------------
loss: 0.908971  [    0/60000]
loss: 0.982067  [ 6400/60000]
loss: 0.830095  [12800/60000]
loss: 1.064921  [19200/60000]
loss: 0.874204  [25600/60000]
loss: 1.069008  [32000/60000]
loss: 0.856447  [38400/60000]
loss: 0.957340  [44800/60000]
loss: 0.983547  [51200/60000]
loss: 1.055251  [57600/60000]
Test Error: 
 Accuracy: 70.3%, Avg loss: 0.015001 

Done!

准确性最初不会很好(没关系!尝试运行循环以获取更多纪元 或将learning_rate调整为更大的数字。也可能是我们选择的模型配置可能不是此类问题的最佳配置。

七、保存模型

保存模型的常用方法是序列化内部状态字典(包含模型参数)。

ba 复制代码
torch.save(model.state_dict(), "data/model.pth")
print("Saved PyTorch Model State to model.pth")

八、负载模型

加载模型的过程包括重新创建模型结构并将状态字典加载到其中。

ba 复制代码
model = NeuralNetwork()
model.load_state_dict(torch.load("data/model.pth"))

此模型现在可用于进行预测。

ba 复制代码
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')
ba 复制代码
Predicted: "Ankle boot", Actual: "Ankle boot"

祝贺!您已经完成了 PyTorch 初学者教程!我们希望本教程能帮助您在 PyTorch 上开始深度学习。

相关推荐
aneasystone本尊23 分钟前
学习 Coze Studio 的工作流执行逻辑
人工智能
aneasystone本尊31 分钟前
再学 Coze Studio 的智能体执行逻辑
人工智能
苏婳66632 分钟前
【最新版】怎么下载mysqlclient并成功安装?
数据库·python·mysql
xuanwuziyou34 分钟前
LangChain 多任务应用开发
人工智能·langchain
新智元1 小时前
一句话,性能暴涨 49%!马里兰 MIT 等力作:Prompt 才是大模型终极武器
人工智能·openai
猫头虎1 小时前
猫头虎AI分享|一款Coze、Dify类开源AI应用超级智能体Agent快速构建工具:FastbuildAI
人工智能·开源·github·aigc·ai编程·ai写作·ai-native
0wioiw01 小时前
Python基础(Flask①)
后端·python·flask
新智元1 小时前
AI 版华尔街之狼!o3-mini 靠「神之押注」狂赚 9 倍,DeepSeek R1 最特立独行
人工智能·openai
天下弈星~1 小时前
GANs生成对抗网络生成手写数字的Pytorch实现
人工智能·pytorch·深度学习·神经网络·生成对抗网络·gans
飞翔的佩奇1 小时前
【完整源码+数据集+部署教程】食品分类与实例分割系统源码和数据集:改进yolo11-AggregatedAttention
python·yolo·计算机视觉·数据集·yolo11·食品分类与实例分割