【深度学习PyTorch简介】7.Load and run model predictions 加载和运行模型预测

Load and run model predictions 加载和运行模型预测

Load the model 加载模型

在本单元中,我们将了解如何加载模型及其持久参数状态和推理模型预测。

python 复制代码
%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor

为了加载模型,我们将定义模型类,其中包含用于训练模型的神经网络的状态和参数。

python 复制代码
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

加载模型权重时,我们需要首先实例化模型类,因为该类定义了网络的结构。接下来,我们使用 load_state_dict() 方法加载参数。

python 复制代码
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()
复制代码
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
    (5): ReLU()
  )
)

**注意:**请务必在推理之前调用 model.eval() 方法,以将 dropout 和批量归一化层设置为评估模式。否则,您将看到不一致的推理结果。

Model Inference 模型推理

优化模型以在各种平台和编程语言上运行是很困难的。在所有不同的框架和硬件组合中最大限度地提高性能非常耗时。Open Neural Network Exchange (ONNX) 开放神经网络交换运行时为您提供了一种解决方案,可在任何硬件、云或边缘设备上进行一次训练并加速推理。

ONNX 是许多供应商支持的通用格式,用于共享神经网络和其他机器学习模型。您可以使用 ONNX 格式在其他编程语言(Java, JavaScript, C# 和 ML.NET)和框架上对模型进行推理。

Exporting the model to ONNX 将模型导出到 ONNX

PyTorch 还具有本机 ONNX 导出支持。然而,考虑到 PyTorch 执行图的动态特性,导出过程必须遍历执行图以生成持久的 ONNX 模型。因此,应将适当大小的测试变量传递到导出例程中(在我们的例子中,我们将创建正确大小的虚拟零张量。您可以从训练数据集的shape函数中获取大小:tensor.shape):

python 复制代码
input_image = torch.zeros((1,28,28))
onnx_model = 'data/model.onnx'
onnx.export(model, input_image, onnx_model)

我们将使用测试数据集作为示例数据,从 ONNX 模型进行推理以进行预测。

python 复制代码
test_data = datasets.FashionMNIST(
    root = "data",
    train = False,
    download = True,
    transform = ToTensor()
)

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

x, y = test_data[0][0], test_data[0][1]

我们使用 onnxruntime.InferenceSession 创建推理会话。要推断 ONNX 模型,请调用 run 并传入您想要返回的输出列表(如果您需要所有输出,请保留为空)和输入值的映射。结果是输出列表。

python 复制代码
session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

result = session.run([output_name], {input_name:x.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: {actual}')
复制代码
Predicted: "Ankle boot", Actual: Ankle boot

ONNX 模型使您能够在不同平台上以不同编程语言运行推理。

知识检查

什么是 PyTorch 模型 state_dict?

它是模型的内部状态字典,用于存储已学习的参数。

相关推荐
用户2704272838121 分钟前
OpenClaw,Token费用控制
人工智能
唯创知音7 分钟前
Stickerbox儿童AI贴纸打印机国产替代方案
人工智能
AI袋鼠帝13 分钟前
劲爆!个人微信官方接入龙虾了【喂饭级教程】
人工智能·微信
柯儿的天空15 分钟前
【OpenClaw 全面解析:从零到精通】第 016 篇:OpenClaw 实战案例——代码开发助手,从代码生成到部署自动化的全流程
运维·人工智能·ai作画·自动化·aigc·ai写作
我科绝伦(Huanhuan Zhou)19 分钟前
从自动化到自主化—AI Agent引领的运维范式变革
运维·人工智能·自动化
foenix6626 分钟前
我的第一个 Vibe Coding 项目:我做了一个能自动剪视频、写字幕、配音、生成文案的 AI 工作流
人工智能·音视频
新缸中之脑1 小时前
Unsloth Studio:一键微调LLM
人工智能
2301_766558651 小时前
本地部署+云端优化:矩阵跃动龙虾机器人,实现7×24小时AI获客无人值守
人工智能·矩阵·机器人
动物园猫1 小时前
蜜蜂目标检测数据集(7000张图片已标注划分)AI训练适用于目标检测任务
人工智能·目标检测·计算机视觉
未来之窗软件服务1 小时前
阿里云 page-agent 核心逻辑梳理[AI人工智能(六十一)]—东方仙盟
人工智能·阿里云·云计算·仙盟创梦ide·东方仙盟