【深度学习PyTorch简介】7.Load and run model predictions 加载和运行模型预测

Load and run model predictions 加载和运行模型预测

Load the model 加载模型

在本单元中,我们将了解如何加载模型及其持久参数状态和推理模型预测。

python 复制代码
%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor

为了加载模型,我们将定义模型类,其中包含用于训练模型的神经网络的状态和参数。

python 复制代码
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

加载模型权重时,我们需要首先实例化模型类,因为该类定义了网络的结构。接下来,我们使用 load_state_dict() 方法加载参数。

python 复制代码
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()
复制代码
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
    (5): ReLU()
  )
)

**注意:**请务必在推理之前调用 model.eval() 方法,以将 dropout 和批量归一化层设置为评估模式。否则,您将看到不一致的推理结果。

Model Inference 模型推理

优化模型以在各种平台和编程语言上运行是很困难的。在所有不同的框架和硬件组合中最大限度地提高性能非常耗时。Open Neural Network Exchange (ONNX) 开放神经网络交换运行时为您提供了一种解决方案,可在任何硬件、云或边缘设备上进行一次训练并加速推理。

ONNX 是许多供应商支持的通用格式,用于共享神经网络和其他机器学习模型。您可以使用 ONNX 格式在其他编程语言(Java, JavaScript, C# 和 ML.NET)和框架上对模型进行推理。

Exporting the model to ONNX 将模型导出到 ONNX

PyTorch 还具有本机 ONNX 导出支持。然而,考虑到 PyTorch 执行图的动态特性,导出过程必须遍历执行图以生成持久的 ONNX 模型。因此,应将适当大小的测试变量传递到导出例程中(在我们的例子中,我们将创建正确大小的虚拟零张量。您可以从训练数据集的shape函数中获取大小:tensor.shape):

python 复制代码
input_image = torch.zeros((1,28,28))
onnx_model = 'data/model.onnx'
onnx.export(model, input_image, onnx_model)

我们将使用测试数据集作为示例数据,从 ONNX 模型进行推理以进行预测。

python 复制代码
test_data = datasets.FashionMNIST(
    root = "data",
    train = False,
    download = True,
    transform = ToTensor()
)

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

x, y = test_data[0][0], test_data[0][1]

我们使用 onnxruntime.InferenceSession 创建推理会话。要推断 ONNX 模型,请调用 run 并传入您想要返回的输出列表(如果您需要所有输出,请保留为空)和输入值的映射。结果是输出列表。

python 复制代码
session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

result = session.run([output_name], {input_name:x.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: {actual}')
复制代码
Predicted: "Ankle boot", Actual: Ankle boot

ONNX 模型使您能够在不同平台上以不同编程语言运行推理。

知识检查

什么是 PyTorch 模型 state_dict?

它是模型的内部状态字典,用于存储已学习的参数。

相关推荐
澳鹏Appen2 小时前
数据集月度精选 | 高质量具身智能数据集:打开机器人“感知-决策-动作”闭环的钥匙
人工智能·机器人·具身智能
q***71014 小时前
开源模型应用落地-工具使用篇-Spring AI-Function Call(八)
人工智能·spring·开源
极限实验室4 小时前
Coco AI 参选 Gitee 2025 最受欢迎开源软件!您的每一票,都是对中国开源的硬核支持
人工智能·开源
secondyoung4 小时前
Mermaid流程图高效转换为图片方案
c语言·人工智能·windows·vscode·python·docker·流程图
iFlow_AI4 小时前
iFlow CLI Hooks 「从入门到实战」应用指南
开发语言·前端·javascript·人工智能·ai·iflow·iflow cli
Shang180989357264 小时前
THC63LVD1027D一款10位双链路LVDS信号中继器芯片,支持WUXGA分辨率视频数据传输THC63LVD1027支持30位数据通道方案
人工智能·考研·信息与通信·信号处理·thc63lvd1027d·thc63lvd1027
飞哥数智坊4 小时前
项目太大,AI无法理解?试试这3种思路
人工智能·ai编程
桜吹雪5 小时前
手搓一个简易Agent
前端·人工智能·后端
数字时代全景窗5 小时前
从App时代到智能体时代,如何打破“三堵墙”
人工智能·软件工程
weixin_469163695 小时前
金融科技项目管理方式在AI加持下发展方向之,需求分析精准化减少业务与技术偏差
人工智能·科技·金融·项目管理·需求管理