【深度学习PyTorch简介】7.Load and run model predictions 加载和运行模型预测

Load and run model predictions 加载和运行模型预测

Load the model 加载模型

在本单元中,我们将了解如何加载模型及其持久参数状态和推理模型预测。

python 复制代码
%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor

为了加载模型,我们将定义模型类,其中包含用于训练模型的神经网络的状态和参数。

python 复制代码
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

加载模型权重时,我们需要首先实例化模型类,因为该类定义了网络的结构。接下来,我们使用 load_state_dict() 方法加载参数。

python 复制代码
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()
复制代码
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
    (5): ReLU()
  )
)

**注意:**请务必在推理之前调用 model.eval() 方法,以将 dropout 和批量归一化层设置为评估模式。否则,您将看到不一致的推理结果。

Model Inference 模型推理

优化模型以在各种平台和编程语言上运行是很困难的。在所有不同的框架和硬件组合中最大限度地提高性能非常耗时。Open Neural Network Exchange (ONNX) 开放神经网络交换运行时为您提供了一种解决方案,可在任何硬件、云或边缘设备上进行一次训练并加速推理。

ONNX 是许多供应商支持的通用格式,用于共享神经网络和其他机器学习模型。您可以使用 ONNX 格式在其他编程语言(Java, JavaScript, C# 和 ML.NET)和框架上对模型进行推理。

Exporting the model to ONNX 将模型导出到 ONNX

PyTorch 还具有本机 ONNX 导出支持。然而,考虑到 PyTorch 执行图的动态特性,导出过程必须遍历执行图以生成持久的 ONNX 模型。因此,应将适当大小的测试变量传递到导出例程中(在我们的例子中,我们将创建正确大小的虚拟零张量。您可以从训练数据集的shape函数中获取大小:tensor.shape):

python 复制代码
input_image = torch.zeros((1,28,28))
onnx_model = 'data/model.onnx'
onnx.export(model, input_image, onnx_model)

我们将使用测试数据集作为示例数据,从 ONNX 模型进行推理以进行预测。

python 复制代码
test_data = datasets.FashionMNIST(
    root = "data",
    train = False,
    download = True,
    transform = ToTensor()
)

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

x, y = test_data[0][0], test_data[0][1]

我们使用 onnxruntime.InferenceSession 创建推理会话。要推断 ONNX 模型,请调用 run 并传入您想要返回的输出列表(如果您需要所有输出,请保留为空)和输入值的映射。结果是输出列表。

python 复制代码
session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

result = session.run([output_name], {input_name:x.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: {actual}')
复制代码
Predicted: "Ankle boot", Actual: Ankle boot

ONNX 模型使您能够在不同平台上以不同编程语言运行推理。

知识检查

什么是 PyTorch 模型 state_dict?

它是模型的内部状态字典,用于存储已学习的参数。

相关推荐
ACP广源盛139246256732 小时前
(ACP广源盛)GSV1175---- MIPI/LVDS 转 Type-C/DisplayPort 1.2 转换器产品说明及功能分享
人工智能·音视频
胡耀超2 小时前
隐私计算技术全景:从联邦学习到可信执行环境的实战指南—数据安全——隐私计算 联邦学习 多方安全计算 可信执行环境 差分隐私
人工智能·安全·数据安全·tee·联邦学习·差分隐私·隐私计算
停停的茶3 小时前
深度学习(目标检测)
人工智能·深度学习·目标检测
Y200309163 小时前
基于 CIFAR10 数据集的卷积神经网络(CNN)模型训练与集成学习
人工智能·cnn·集成学习
老兵发新帖3 小时前
主流神经网络快速应用指南
人工智能·深度学习·神经网络
AI量化投资实验室4 小时前
15年122倍,年化43.58%,回撤才20%,Optuna机器学习多目标调参backtrader,附python代码
人工智能·python·机器学习
java_logo4 小时前
vllm-openai Docker 部署手册
运维·人工智能·docker·ai·容器
倔强青铜三4 小时前
苦练Python第67天:光速读取任意行,linecache模块解锁文件处理新姿势
人工智能·python·面试
算家计算4 小时前
重磅突破!全球首个真实物理环境机器人基准测试正式发布,具身智能迎来 “ImageNet 时刻”
人工智能·资讯
新智元5 小时前
苹果 M5「夜袭」高通英特尔!AI 算力狂飙 400%,Pro 三剑客火速上新
人工智能·openai