【深度学习PyTorch简介】7.Load and run model predictions 加载和运行模型预测

Load and run model predictions 加载和运行模型预测

Load the model 加载模型

在本单元中,我们将了解如何加载模型及其持久参数状态和推理模型预测。

python 复制代码
%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor

为了加载模型,我们将定义模型类,其中包含用于训练模型的神经网络的状态和参数。

python 复制代码
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

加载模型权重时,我们需要首先实例化模型类,因为该类定义了网络的结构。接下来,我们使用 load_state_dict() 方法加载参数。

python 复制代码
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()
复制代码
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
    (5): ReLU()
  )
)

**注意:**请务必在推理之前调用 model.eval() 方法,以将 dropout 和批量归一化层设置为评估模式。否则,您将看到不一致的推理结果。

Model Inference 模型推理

优化模型以在各种平台和编程语言上运行是很困难的。在所有不同的框架和硬件组合中最大限度地提高性能非常耗时。Open Neural Network Exchange (ONNX) 开放神经网络交换运行时为您提供了一种解决方案,可在任何硬件、云或边缘设备上进行一次训练并加速推理。

ONNX 是许多供应商支持的通用格式,用于共享神经网络和其他机器学习模型。您可以使用 ONNX 格式在其他编程语言(Java, JavaScript, C# 和 ML.NET)和框架上对模型进行推理。

Exporting the model to ONNX 将模型导出到 ONNX

PyTorch 还具有本机 ONNX 导出支持。然而,考虑到 PyTorch 执行图的动态特性,导出过程必须遍历执行图以生成持久的 ONNX 模型。因此,应将适当大小的测试变量传递到导出例程中(在我们的例子中,我们将创建正确大小的虚拟零张量。您可以从训练数据集的shape函数中获取大小:tensor.shape):

python 复制代码
input_image = torch.zeros((1,28,28))
onnx_model = 'data/model.onnx'
onnx.export(model, input_image, onnx_model)

我们将使用测试数据集作为示例数据,从 ONNX 模型进行推理以进行预测。

python 复制代码
test_data = datasets.FashionMNIST(
    root = "data",
    train = False,
    download = True,
    transform = ToTensor()
)

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

x, y = test_data[0][0], test_data[0][1]

我们使用 onnxruntime.InferenceSession 创建推理会话。要推断 ONNX 模型,请调用 run 并传入您想要返回的输出列表(如果您需要所有输出,请保留为空)和输入值的映射。结果是输出列表。

python 复制代码
session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

result = session.run([output_name], {input_name:x.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: {actual}')
复制代码
Predicted: "Ankle boot", Actual: Ankle boot

ONNX 模型使您能够在不同平台上以不同编程语言运行推理。

知识检查

什么是 PyTorch 模型 state_dict?

它是模型的内部状态字典,用于存储已学习的参数。

相关推荐
昨日之日200617 分钟前
Wan2.2-S2V - 音频驱动图像生成电影级质量的数字人视频 ComfyUI工作流 支持50系显卡 一键整合包下载
人工智能·音视频
SEO_juper3 小时前
大型语言模型SEO(LLM SEO)完全手册:驾驭搜索新范式
人工智能·语言模型·自然语言处理·chatgpt·llm·seo·数字营销
攻城狮7号3 小时前
腾讯混元翻译模型Hunyuan-MT-7B开源,先前拿了30个冠军
人工智能·hunyuan-mt-7b·腾讯混元翻译模型·30个冠军
zezexihaha4 小时前
从“帮写文案”到“管生活”:个人AI工具的边界在哪?
人工智能
算家云4 小时前
nano banana官方最强Prompt模板来了!六大场景模板详解
人工智能·谷歌·ai大模型·算家云·ai生图·租算力,到算家云·nano banana 提示词
暴躁的大熊4 小时前
AI助力决策:告别生活与工作中的纠结,明析抉择引领明智选择
人工智能
Gyoku Mint4 小时前
提示词工程(Prompt Engineering)的崛起——为什么“会写Prompt”成了新技能?
人工智能·pytorch·深度学习·神经网络·语言模型·自然语言处理·nlp
梁小憨憨4 小时前
zotero扩容
人工智能·笔记
大数据张老师4 小时前
AI架构师的思维方式与架构设计原则
人工智能·架构师·ai架构·后端架构
AKAMAI4 小时前
Entity Digital Sports 降低成本并快速扩展
人工智能·云计算