PyTorch模型转换ONNX 入门

目录

前言

什么是ONNX文件

ONNX文件简介

[ONNX 文件的主要特点](#ONNX 文件的主要特点)

[ONNX 文件的基本结构](#ONNX 文件的基本结构)

pytorch模型转ONNX模型

前期准备

cuda安装

pytorch安装

ONNX模块安装

pytorch模型导出ONNX模型

简单体验

查看ONNX文件

torch.onnx.export函数

加载ONNX模型


前言

什么是ONNX文件

ONNX文件简介

ONNX(Open Neural Network Exchange)是一种开放的文件格式,用于表示机器学习模型,旨在促进不同框架之间的互操作性。

ONNX 文件通常以 .onnx 为扩展名,能够存储神经网络的结构和权重,使得模型可以在不同的深度学习框架(如 TensorFlow、PyTorch、Caffe 等)之间进行转换和部署

ONNX 文件的主要特点

  1. 跨平台兼容性 :ONNX 支持在多个框架之间共享模型,用户可以在一个框架中训练模型,然后将其导出为 ONNX 格式,以便在另一个框架中进行推理。

  2. 开放标准:ONNX 是一个开放的标准,由多个行业合作伙伴共同开发和维护。它为机器学习社区提供了一个统一的格式。

  3. 高效性:ONNX 文件能够有效地存储模型的计算图、参数和操作,这样可以更高效地进行推理。

  4. 支持多种操作:ONNX 定义了一组标准操作符,支持多种神经网络架构,包括卷积神经网络(CNN)、循环神经网络(RNN)等。

  5. 工具支持:ONNX 提供了一系列工具和库,支持将模型从不同框架导出到 ONNX 格式,也支持从 ONNX 文件加载模型进行推理。

ONNX 文件的基本结构

一个典型的 ONNX 文件包含以下内容:

  • 计算图:描述了模型的结构,包括各层的连接关系。
  • 参数:模型的权重和偏置值。
  • 元数据:关于模型的一些附加信息,如输入输出的形状、数据类型等。

pytorch模型转ONNX模型

前期准备

用户需事先安装cuda、cudnn(可选)和pytorch

cuda安装

windows下cuda的安装见

windows安装cuda与cudnn-CSDN博客

linux下的cuda安装见

【CUDA】Ubuntu系统如何安装CUDA保姆级教程(2022年最新)_ubuntu安装cuda-CSDN博客

无论在哪个系统上安装cuda,只要输入以下命令时有信息输出即表示安装成功

bash 复制代码
nvcc -V

pytorch安装

当cuda安装成功后,输入nvcc -V命令查看cuda版本号,然后进入pytorch官网,下载对应cuda版本的pytorch即可

Previous PyTorch Versions | PyTorch

无论什么系统,只要在命令行输出以下结果,即表示pytorch安装成功

ONNX模块安装

python 复制代码
pip install onnx
python 复制代码
pip install onnxruntime

pytorch模型导出ONNX模型

简单体验

首先使用pytorch写一个简单的网络模型

python 复制代码
import torch
import torchvision
import numpy as np
 
devide=torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义一个简单的PyTorch 模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(64 * 8 * 8, 10)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        return x
 
# 创建模型实例
model = MyModel().to(devide)
 
# 指定模型输入尺寸
dummy_input = torch.randn(1, 3, 32, 32).to(devide)
 
# 将PyTorch模型转为ONNX模型
torch.onnx.export(model, dummy_input, 'mymodel.onnx',  do_constant_folding=False)

如上所示,我们手写了一个简单的网络模型,可以看到,转ONNX模型文件的代码只有最后一行

python 复制代码
torch.onnx.export(model, dummy_input, 'mymodel.onnx',  do_constant_folding=False)

事实上,torch.onnx.export函数就是将一个torch模型转换为ONNX文件的函数

查看ONNX文件

上述代码成功运行后,会在本地生成一个mymodel.onnx文件,该文件的打开需要使用netron,有关netron的安装见

netron安装(windows && linux)-CSDN博客

安装成功后,使用netron打开mymodel.onnx,如下所示

torch.onnx.export函数

上述我们写了一个小demo体验了torch模型转换为ONNX文件,并查看了ONNX文件到底是什么,接下来我们来看torch模型转换ONNX文件的核心函数中参数含义都是什么

python 复制代码
torch.onnx.export(
            model, 
            args, 
            f, 
            export_params=True, 
            opset_version=10, 
            do_constant_folding=True, 
            input_names=['input'], 
            output_names=['output'], 
            dynamic_axes=None, 
            verbose=False, 
            example_outputs=None, 
            keep_initializers_as_inputs=None)

参数详解

  1. model:

    • 类型: torch.nn.Module
    • 描述: 被转换的 PyTorch 模型
  2. args:

    • 类型: tupletorch.Tensor
    • 描述: torch模型的输入示例,可以是一个单一的张量或多个张量(以元组的形式)。这些输入数据用于执行模型,确定模型的输入形状。
  3. f:

    • 类型: strPath
    • 描述: 导出模型的目标文件路径或文件名,通常以 .onnx 作为扩展名
  4. export_params:

    • 类型: bool,默认: True
    • 描述: 是否将模型的参数(权重和偏置)也导出到 ONNX 文件中 。如果设置为 True,导出的模型会包含所有的参数。
  5. opset_version:

    • 类型: int,默认: 9
    • 描述: 指定要使用的 ONNX 操作集版本。不同版本可能支持不同的操作和功能。设置合适的版本可以确保兼容性。
  6. do_constant_folding:

    • 类型: bool,默认: True
    • 描述: 是否进行常量折叠优化。常量折叠会在导出过程中将一些常量计算提前,从而简化模型的计算图,提升推理效率。
  7. input_names:

    • 类型: list,默认: None
    • 描述: 输入张量的名称列表。可以用来指定导出模型输入的名称,便于后续在其他框架中识别。
  8. output_names:

    • 类型: list,默认: None
    • 描述: 输出张量的名称列表。类似于 input_names,用于指定导出模型输出的名称。
  9. dynamic_axes:

    • 类型: dictNone,默认: None

    • 描述: 允许动态维度的输入输出。在导出时,可以指定某些维度是动态的,这样在推理时输入的形状可以变化。例如:

      python 复制代码
      dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}

      这表示 inputoutput 的第一个维度是动态的(例如 batch size)。

  10. verbose:

    • 类型: bool,默认: False
    • 描述: 是否在导出时打印详细信息。如果设置为 True,会显示更多的调试信息,便于跟踪导出过程中的问题。
  11. example_outputs:

    • 类型: tupletorch.Tensor,默认: None
    • 描述: 用于指定模型的示例输出,这有助于 ONNX 在导出时进行类型推断。可以提供一个或多个输出张量,以便更好地推断输出的形状和类型。
  12. keep_initializers_as_inputs:

    • 类型: bool,默认: None
    • 描述: 是否将模型的初始值(权重)作为输入保存。如果设置为 True,则初始值将被视为模型的输入之一,而不是存储在模型的参数中。

下面是一个常用的模板

python 复制代码
import torch.onnx 
 
# 转为ONNX
def Convert_ONNX(model): 
 
    # 设置模型为推理模式
    model.eval() 
 
    # 设置模型输入的尺寸
    dummy_input = torch.randn(1, input_size, requires_grad=True)  
 
    # 导出ONNX模型  
    torch.onnx.export(model,         # model being run 
         dummy_input,       # model input (or a tuple for multiple inputs) 
         "xxx.onnx",       # where to save the model  
         export_params=True,  # store the trained parameter weights inside the model file 
         opset_version=10,    # the ONNX version to export the model to 
         do_constant_folding=True,  # whether to execute constant folding for optimization 
         input_names = ['modelInput'],   # the model's input names 
         output_names = ['modelOutput'], # the model's output names 
         dynamic_axes={'modelInput' : {0 : 'batch_size'},    # variable length axes 
                                'modelOutput' : {0 : 'batch_size'}}) 
    print(" ") 
    print('Model has been converted to ONNX')
 
 
if __name__ == "__main__": 
 
    # 构建模型并训练
    # xxxxxxxxxxxx
 
    # 测试模型精度
    #testAccuracy() 
 
    # 加载模型结构与权重
    model = Network() 
    path = "myFirstModel.pth" 
    model.load_state_dict(torch.load(path)) 
 
    # 转换为ONNX 
    Convert_ONNX(model)

加载ONNX模型

导出ONNX模型后,加载ONNX模型需要用到onnxruntime库,以下是一个导出ONNX模型的示例

python 复制代码
import onnxruntime as ort
 
# 加载 ONNX 模型
ort_session = ort.InferenceSession("model.onnx")
 
# 准备输入信息
input_info = ort_session.get_inputs()[0]
input_name = input_info.name
input_shape = input_info.shape
input_type = input_info.type
 
 
# 运行ONNX模型
outputs = ort_session.run(input_name, input_data)
 
# 获取输出信息
output_info = ort_session.get_outputs()[0]
output_name = output_info.name
output_shape = output_info.shape
output_data = outputs[0]
 
print("outputs:", outputs)
print("output_info :", output_info )
print("output_name :", output_name )
print("output_shape :", output_shape )
print("output_data :", output_data )

在以下案例中,我们首先将resnet-18模型导出为ONNX模型,然后再加载导出的ONNX模型,最后对比torch模型和ONNX模型的输出差异

python 复制代码
import torch
import torchvision.models as models
import onnx
import onnxruntime
 
# 加载 PyTorch 模型
model = models.resnet18(pretrained=True)
model.eval()
 
# 定义输入和输出张量的名称和形状
input_names = ["input"]
output_names = ["output"]
batch_size = 1
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)
 
# 将 PyTorch 模型转换为 ONNX 格式
torch.onnx.export(
    model,  # 要转换的 PyTorch 模型
    torch.randn(input_shape),  # 模型输入的随机张量
    "resnet18.onnx",  # 保存的 ONNX 模型的文件名
    input_names=input_names,  # 输入张量的名称
    output_names=output_names,  # 输出张量的名称
    dynamic_axes={input_names[0]: {0: "batch_size"}, output_names[0]: {0: "batch_size"}}  # 动态轴,即输入和输出张量可以具有不同的批次大小
)
 
# 加载 ONNX 模型
onnx_model = onnx.load("resnet18.onnx")
onnx_model_graph = onnx_model.graph
onnx_session = onnxruntime.InferenceSession(onnx_model.SerializeToString())
 
# 使用随机张量测试 ONNX 模型
x = torch.randn(input_shape).numpy()
onnx_output = onnx_session.run(output_names, {input_names[0]: x})[0]
 
print(f"PyTorch output: {model(torch.from_numpy(x)).detach().numpy()[0, :5]}")
print(f"ONNX output: {onnx_output[0, :5]}")

运行结果如下所示

python 复制代码
PyTorch output: [0.22972351 2.4930785  2.4462368  2.7443404  4.7080407 ]
ONNX output: [0.22972152 2.4930775  2.4462373  2.7443395  4.708042  ]
相关推荐
秦朝胖子得加钱19 分钟前
Flask
后端·python·flask
幽兰的天空23 分钟前
Python实现的简单时钟
开发语言·python
NCU_AI1 小时前
Python 网络爬虫快速入门
python·网络爬虫
幽兰的天空1 小时前
简单的Python爬虫实例
开发语言·爬虫·python
XinZong1 小时前
【AI开源项目】OneAPI -核心概念、特性、优缺点以及如何在本地和服务器上进行部署!
人工智能·开源
机器之心1 小时前
Runway CEO:AI公司的时代已经结束了
人工智能·后端
Kalika0-02 小时前
多层感知机从零开始实现
pytorch·学习
T0uken2 小时前
【机器学习】过拟合与欠拟合
人工智能·机器学习
IT·小灰灰2 小时前
Python——自动化发送邮件
运维·网络·后端·python·自动化
狼刀流2 小时前
(8) cuda分析工具
python·cuda