使用ONNX Runtime在Python中进行模型推理

ONNX Runtime(ORT)是一种高性能的推理引擎,支持多种深度学习框架,如PyTorch、TensorFlow和scikit-learn。以下是如何在Python中安装和使用ONNX Runtime的简要指南。

安装ONNX Runtime

ONNX Runtime有两个主要的Python包:CPU版本和GPU版本。根据你的硬件选择合适的版本。

  • CPU版本:适用于Arm-based CPU和macOS。

    bash 复制代码
    pip install onnxruntime
  • GPU版本(CUDA 12.x):默认支持CUDA 12.x。

    bash 复制代码
    pip install onnxruntime-gpu
  • GPU版本(CUDA 11.8):需要从Azure DevOps Feed安装。

    bash 复制代码
    pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-11/pypi/simple/

安装ONNX导出工具

PyTorch

PyTorch自带ONNX支持,直接安装PyTorch即可。

bash 复制代码
pip install torch

TensorFlow

需要额外安装tf2onnx来支持ONNX导出。

bash 复制代码
pip install tf2onnx

scikit-learn

需要安装skl2onnx来支持ONNX导出。

bash 复制代码
pip install skl2onnx

快速开始示例

PyTorch CV示例

  1. 导出模型

    python 复制代码
    import torch
    import torch.onnx as onnx
    
    # 假设model是你的PyTorch模型,device是设备(如GPU或CPU)
    model = ...  # 初始化你的模型
    device = ...  # 设备
    
    input_data = torch.randn(1, 28, 28).to(device)
    onnx.export(model, input_data, "fashion_mnist_model.onnx", 
                input_names=['input'], output_names=['output'])
  2. 加载ONNX模型并进行推理

    python 复制代码
    import onnx
    import onnxruntime as ort
    import numpy as np
    
    # 加载ONNX模型
    onnx_model = onnx.load("fashion_mnist_model.onnx")
    onnx.checker.check_model(onnx_model)
    
    # 创建推理会话
    ort_sess = ort.InferenceSession('fashion_mnist_model.onnx')
    
    # 假设x是输入数据
    x = ...  # 初始化输入数据
    
    outputs = ort_sess.run(None, {'input': x.numpy()})

PyTorch NLP示例

  1. 导出模型

    python 复制代码
    import torch
    import torch.onnx as onnx
    
    # 假设model是你的PyTorch NLP模型
    model = ...  # 初始化你的模型
    
    text = "示例文本"
    text_tensor = torch.tensor(text_pipeline(text))  # 假设text_pipeline是文本预处理函数
    offsets = torch.tensor([0])
    
    onnx.export(model, (text_tensor, offsets), "ag_news_model.onnx", 
                input_names=['input', 'offsets'], output_names=['output'],
                dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
  2. 加载ONNX模型并进行推理

    python 复制代码
    import onnx
    import onnxruntime as ort
    import numpy as np
    
    # 加载ONNX模型
    onnx_model = onnx.load("ag_news_model.onnx")
    onnx.checker.check_model(onnx_model)
    
    # 创建推理会话
    ort_sess = ort.InferenceSession('ag_news_model.onnx')
    
    # 假设text是输入文本,offsets是偏移量
    text = ...  # 初始化输入文本
    offsets = torch.tensor([0])
    
    outputs = ort_sess.run(None, {'input': text.numpy(), 'offsets': offsets.numpy()})

TensorFlow CV示例

  1. 导出模型

    python 复制代码
    import tensorflow as tf
    from tensorflow.keras.applications import ResNet50
    import tf2onnx
    
    # 加载预训练模型
    model = ResNet50(weights='imagenet')
    
    # 转换为ONNX模型
    spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),)
    output_path = model.name + ".onnx"
    model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=output_path)
  2. 加载ONNX模型并进行推理

    python 复制代码
    import onnxruntime as rt
    
    # 创建推理会话
    providers = ['CPUExecutionProvider']
    m = rt.InferenceSession(output_path, providers=providers)
    
    # 假设x是输入数据
    x = ...  # 初始化输入数据
    
    output_names = [n.name for n in model_proto.graph.output]
    onnx_pred = m.run(output_names, {"input": x})

scikit-learn CV示例

  1. 导出模型

    python 复制代码
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    from skl2onnx import convert_sklearn
    
    # 加载iris数据集
    iris = load_iris()
    X, y = iris.data, iris.target
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    
    # 训练逻辑回归模型
    clr = LogisticRegression()
    clr.fit(X_train, y_train)
    
    # 转换为ONNX模型
    initial_type = [('float_input', FloatTensorType([None, 4]))]
    onx = convert_sklearn(clr, initial_types=initial_type)
    with open("logreg_iris.onnx", "wb") as f:
        f.write(onx.SerializeToString())
  2. 加载ONNX模型并进行推理

    python 复制代码
    import numpy
    import onnxruntime as rt
    
    # 创建推理会话
    sess = rt.InferenceSession("logreg_iris.onnx")
    input_name = sess.get_inputs()[0].name
    
    # 假设X_test是测试数据
    pred_onx = sess.run(None, {input_name: X_test.astype(numpy.float32)})[0]

通过这些示例,你可以轻松地将不同框架的模型导出为ONNX格式,并使用ONNX Runtime进行高效的推理。

相关推荐
蒟蒻小袁15 分钟前
力扣面试150题--除法求值
算法·leetcode·面试
保持学习ing24 分钟前
SpringBoot前后台交互 -- 登录功能实现(拦截器+异常捕获器)
java·spring boot·后端·ssm·交互·拦截器·异常捕获器
gadiaola30 分钟前
【JVM面试篇】高频八股汇总——类加载和类加载器
java·jvm·面试
十年老菜鸟1 小时前
spring boot源码和lib分开打包
spring boot·后端·maven
白宇横流学长2 小时前
基于SpringBoot实现的课程答疑系统设计与实现【源码+文档】
java·spring boot·后端
加瓦点灯3 小时前
什么?工作五年还不了解SafePoint?
后端
飞翔的猪猪3 小时前
GitHub Recovery Codes - 用于 GitHub Two-factor authentication (2FA) 凭据丢失时登录账号
前端·git·github
他日若遂凌云志3 小时前
Lua 模块系统的前世今生:从 module () 到 local _M 的迭代
后端
bnnnnnnnn3 小时前
看完就懂、懂完就敢讲的「原型与原型链」终极八卦!
前端·javascript·面试
David爱编程3 小时前
Docker 安全全揭秘:防逃逸、防漏洞、防越权,一篇学会容器防御!
后端·docker·容器