一、说明
本篇介绍模型模型的参数,模型推理和使用,保存加载。
二、训练参数和模型
在本单元中,我们将了解如何加载模型及其持久参数状态和推理模型预测。为了加载模型,我们将定义模型类,其中包含用于训练模型的神经网络的状态和参数。
%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
def forward(self, x):
x = self.flatten(x)
logits = self.line
加载模型权重时,我们需要首先实例化模型类,因为该类定义了网络的结构。接下来,我们使用 load_state_dict() 方法加载参数。
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()
注意:请务必在推理之前调用 model.eval() 方法,以将 dropout 和批量归一化层设置为评估模式。如果不这样做,将产生不一致的推理结果。
三、模型推理
优化模型以在各种平台和编程语言上运行是很困难的。在所有不同的框架和硬件组合中最大限度地提高性能是非常耗时的。
开放式神经网络交换 (ONNX) 运行时为您提供了一种解决方案,只需训练一次,即可在任何硬件、云或边缘设备上加速推理。
ONNX 是许多供应商支持的一种通用格式,用于共享神经网络和其他机器学习模型。您可以使用 ONNX 格式在其他编程语言和框架(如 Java、JavaScript、C# 和 ML.NET)上对模型进行推理。
input_image = torch.zeros((1,28,28))
onnx_model = 'data/model.onnx'
onnx.export(model, input_image, onnx_model)
我们将使用测试数据集作为示例数据,以便从 ONNX 模型进行推理以进行预测。
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
classes = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
]
x, y = test_data[0][0], test_data[0][1]
我们需要使用 onnxruntime 创建一个推理会话。推理会话。为了推断 onnx 模型,我们使用 run 和 pass 输入要返回的输出列表(如果需要所有输出,请留空)和输入值映射。结果是一个输出列表:
session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
result = session.run([output_name], {input_name: x.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')
**四、**torch.utils.data.DataLoader 和torch.utils.data.Dataset
PyTorch有两个基元来处理数据:torch.utils.data.DataLoader 和torch.utils.data.Dataset 。数据集 存储样本及其相应的标签,DataLoader 围绕数据集包装一个可迭代对象。
ba
%matplotlib inline
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
PyTorch提供特定于领域的库,如TorchText,TorchVision和TorchAudio, 所有这些都包括数据集。在本教程中,我们将使用TorchVision数据集。
torchvision.datasets 模块包含许多真实世界视觉数据(如 CIFAR 和 COCO)的数据集 对象。在本教程中,我们将使用 FashionMNIST 数据集。每个TorchVision 数据集都包含两个参数:转换 和target_transform分别修改样本和标签。
ba
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
我们将数据集 作为参数传递给 DataLoader。这将在我们的数据集上包装一个可迭代对象,并支持自动批处理、采样、随机排序和多进程数据加载。这里我们定义一个 64 的批量大小,即 dataloader 迭代中的每个元素将返回一批 64 个特征和标签。
ba
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:
print("Shape of X [N, C, H, W]: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
break
# Display sample data
figure = plt.figure(figsize=(10, 8))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
idx = torch.randint(len(test_data), size=(1,)).item()
img, label = test_data[idx]
figure.add_subplot(rows, cols, i)
plt.title(label)
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
ba
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
五、创建模型
为了在 PyTorch 中定义神经网络,我们创建一个继承自 nn.Module 的类。我们在 init 函数中定义网络层,并在转发函数中指定数据如何通过网络。为了加速神经网络的运算,我们将其转移到 GPU(如果可用)。
ba
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
# Define model
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
print(model)
ba
Using cuda device
NeuralNetwork(
(flatten): Flatten()
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
(5): ReLU()
)
)
六、优化模型参数
为了训练模型,我们需要一个损失函数和一个优化器。我们将使用 nn。交叉熵损失 用于损失,随机梯度下降用于优化。
ba
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
在单个训练循环中,模型对训练数据集进行预测(批量馈送到它),并向后传播预测误差以调整模型的参数。
ba
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
我们还可以对照测试数据集检查模型的性能,以确保它正在学习。
ba
def test(dataloader, model):
size = len(dataloader.dataset)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= size
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
训练过程通过多次迭代(纪元)进行。在每个时期,模型学习参数以做出更好的预测。我们打印模型在每个时期的准确性和损失;我们希望看到精度随着每个时期的增加和损失的减少而减少。
ba
epochs = 15
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model)
print("Done!")
ba
Epoch 1
-------------------------------
loss: 2.295450 [ 0/60000]
loss: 2.293073 [ 6400/60000]
loss: 2.278504 [12800/60000]
loss: 2.282501 [19200/60000]
loss: 2.273211 [25600/60000]
loss: 2.258452 [32000/60000]
loss: 2.248237 [38400/60000]
loss: 2.228594 [44800/60000]
loss: 2.240276 [51200/60000]
loss: 2.221318 [57600/60000]
Test Error:
Accuracy: 51.8%, Avg loss: 0.034745
Epoch 2
-------------------------------
loss: 2.212354 [ 0/60000]
loss: 2.207739 [ 6400/60000]
loss: 2.160400 [12800/60000]
loss: 2.176181 [19200/60000]
loss: 2.168270 [25600/60000]
loss: 2.146453 [32000/60000]
loss: 2.119934 [38400/60000]
loss: 2.083791 [44800/60000]
loss: 2.126453 [51200/60000]
loss: 2.077550 [57600/60000]
Test Error:
Accuracy: 53.2%, Avg loss: 0.032452
Epoch 3
-------------------------------
loss: 2.082280 [ 0/60000]
loss: 2.068733 [ 6400/60000]
loss: 1.965958 [12800/60000]
loss: 1.997126 [19200/60000]
loss: 2.002057 [25600/60000]
loss: 1.967370 [32000/60000]
loss: 1.910595 [38400/60000]
loss: 1.849006 [44800/60000]
loss: 1.944741 [51200/60000]
loss: 1.861265 [57600/60000]
Test Error:
Accuracy: 51.6%, Avg loss: 0.028937
Epoch 4
-------------------------------
loss: 1.872628 [ 0/60000]
loss: 1.844543 [ 6400/60000]
loss: 1.710179 [12800/60000]
loss: 1.779804 [19200/60000]
loss: 1.737971 [25600/60000]
loss: 1.746953 [32000/60000]
loss: 1.624768 [38400/60000]
loss: 1.575720 [44800/60000]
loss: 1.742827 [51200/60000]
loss: 1.653375 [57600/60000]
Test Error:
Accuracy: 58.4%, Avg loss: 0.025570
Epoch 5
-------------------------------
loss: 1.662315 [ 0/60000]
loss: 1.636235 [ 6400/60000]
loss: 1.508407 [12800/60000]
loss: 1.606842 [19200/60000]
loss: 1.560728 [25600/60000]
loss: 1.606024 [32000/60000]
loss: 1.426900 [38400/60000]
loss: 1.406240 [44800/60000]
loss: 1.619918 [51200/60000]
loss: 1.521326 [57600/60000]
Test Error:
Accuracy: 61.2%, Avg loss: 0.023459
Epoch 6
-------------------------------
loss: 1.527535 [ 0/60000]
loss: 1.511209 [ 6400/60000]
loss: 1.377129 [12800/60000]
loss: 1.494889 [19200/60000]
loss: 1.457990 [25600/60000]
loss: 1.502333 [32000/60000]
loss: 1.291539 [38400/60000]
loss: 1.285098 [44800/60000]
loss: 1.484891 [51200/60000]
loss: 1.414015 [57600/60000]
Test Error:
Accuracy: 62.2%, Avg loss: 0.021480
Epoch 7
-------------------------------
loss: 1.376779 [ 0/60000]
loss: 1.384830 [ 6400/60000]
loss: 1.230116 [12800/60000]
loss: 1.382574 [19200/60000]
loss: 1.255630 [25600/60000]
loss: 1.396211 [32000/60000]
loss: 1.157718 [38400/60000]
loss: 1.186382 [44800/60000]
loss: 1.340606 [51200/60000]
loss: 1.321607 [57600/60000]
Test Error:
Accuracy: 62.8%, Avg loss: 0.019737
Epoch 8
-------------------------------
loss: 1.243344 [ 0/60000]
loss: 1.279124 [ 6400/60000]
loss: 1.121769 [12800/60000]
loss: 1.293069 [19200/60000]
loss: 1.128232 [25600/60000]
loss: 1.315465 [32000/60000]
loss: 1.069528 [38400/60000]
loss: 1.123324 [44800/60000]
loss: 1.243827 [51200/60000]
loss: 1.255190 [57600/60000]
Test Error:
Accuracy: 63.4%, Avg loss: 0.018518
Epoch 9
-------------------------------
loss: 1.154148 [ 0/60000]
loss: 1.205280 [ 6400/60000]
loss: 1.046463 [12800/60000]
loss: 1.229866 [19200/60000]
loss: 1.048813 [25600/60000]
loss: 1.254785 [32000/60000]
loss: 1.010614 [38400/60000]
loss: 1.077114 [44800/60000]
loss: 1.176766 [51200/60000]
loss: 1.206567 [57600/60000]
Test Error:
Accuracy: 64.3%, Avg loss: 0.017640
Epoch 10
-------------------------------
loss: 1.090360 [ 0/60000]
loss: 1.149150 [ 6400/60000]
loss: 0.990786 [12800/60000]
loss: 1.183704 [19200/60000]
loss: 0.997114 [25600/60000]
loss: 1.207199 [32000/60000]
loss: 0.967512 [38400/60000]
loss: 1.043431 [44800/60000]
loss: 1.127000 [51200/60000]
loss: 1.169639 [57600/60000]
Test Error:
Accuracy: 65.3%, Avg loss: 0.016974
Epoch 11
-------------------------------
loss: 1.041194 [ 0/60000]
loss: 1.104409 [ 6400/60000]
loss: 0.947670 [12800/60000]
loss: 1.149421 [19200/60000]
loss: 0.960403 [25600/60000]
loss: 1.169899 [32000/60000]
loss: 0.935149 [38400/60000]
loss: 1.018250 [44800/60000]
loss: 1.088222 [51200/60000]
loss: 1.139813 [57600/60000]
Test Error:
Accuracy: 66.2%, Avg loss: 0.016446
Epoch 12
-------------------------------
loss: 1.000646 [ 0/60000]
loss: 1.067356 [ 6400/60000]
loss: 0.912046 [12800/60000]
loss: 1.122742 [19200/60000]
loss: 0.932827 [25600/60000]
loss: 1.138785 [32000/60000]
loss: 0.910242 [38400/60000]
loss: 0.999010 [44800/60000]
loss: 1.056596 [51200/60000]
loss: 1.114582 [57600/60000]
Test Error:
Accuracy: 67.5%, Avg loss: 0.016011
Epoch 13
-------------------------------
loss: 0.966393 [ 0/60000]
loss: 1.035691 [ 6400/60000]
loss: 0.881672 [12800/60000]
loss: 1.100845 [19200/60000]
loss: 0.910265 [25600/60000]
loss: 1.112597 [32000/60000]
loss: 0.889558 [38400/60000]
loss: 0.982751 [44800/60000]
loss: 1.029199 [51200/60000]
loss: 1.092738 [57600/60000]
Test Error:
Accuracy: 68.5%, Avg loss: 0.015636
Epoch 14
-------------------------------
loss: 0.936334 [ 0/60000]
loss: 1.007734 [ 6400/60000]
loss: 0.854663 [12800/60000]
loss: 1.081601 [19200/60000]
loss: 0.890581 [25600/60000]
loss: 1.089641 [32000/60000]
loss: 0.872057 [38400/60000]
loss: 0.969192 [44800/60000]
loss: 1.005193 [51200/60000]
loss: 1.073098 [57600/60000]
Test Error:
Accuracy: 69.4%, Avg loss: 0.015304
Epoch 15
-------------------------------
loss: 0.908971 [ 0/60000]
loss: 0.982067 [ 6400/60000]
loss: 0.830095 [12800/60000]
loss: 1.064921 [19200/60000]
loss: 0.874204 [25600/60000]
loss: 1.069008 [32000/60000]
loss: 0.856447 [38400/60000]
loss: 0.957340 [44800/60000]
loss: 0.983547 [51200/60000]
loss: 1.055251 [57600/60000]
Test Error:
Accuracy: 70.3%, Avg loss: 0.015001
Done!
准确性最初不会很好(没关系!尝试运行循环以获取更多纪元 或将learning_rate调整为更大的数字。也可能是我们选择的模型配置可能不是此类问题的最佳配置。
七、保存模型
保存模型的常用方法是序列化内部状态字典(包含模型参数)。
ba
torch.save(model.state_dict(), "data/model.pth")
print("Saved PyTorch Model State to model.pth")
八、负载模型
加载模型的过程包括重新创建模型结构并将状态字典加载到其中。
ba
model = NeuralNetwork()
model.load_state_dict(torch.load("data/model.pth"))
此模型现在可用于进行预测。
ba
classes = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
]
model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
pred = model(x)
predicted, actual = classes[pred[0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')
ba
Predicted: "Ankle boot", Actual: "Ankle boot"
祝贺!您已经完成了 PyTorch 初学者教程!我们希望本教程能帮助您在 PyTorch 上开始深度学习。