使用ONNX Runtime在Python中进行模型推理

ONNX Runtime(ORT)是一种高性能的推理引擎,支持多种深度学习框架,如PyTorch、TensorFlow和scikit-learn。以下是如何在Python中安装和使用ONNX Runtime的简要指南。

安装ONNX Runtime

ONNX Runtime有两个主要的Python包:CPU版本和GPU版本。根据你的硬件选择合适的版本。

  • CPU版本:适用于Arm-based CPU和macOS。

    bash 复制代码
    pip install onnxruntime
  • GPU版本(CUDA 12.x):默认支持CUDA 12.x。

    bash 复制代码
    pip install onnxruntime-gpu
  • GPU版本(CUDA 11.8):需要从Azure DevOps Feed安装。

    bash 复制代码
    pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-11/pypi/simple/

安装ONNX导出工具

PyTorch

PyTorch自带ONNX支持,直接安装PyTorch即可。

bash 复制代码
pip install torch

TensorFlow

需要额外安装tf2onnx来支持ONNX导出。

bash 复制代码
pip install tf2onnx

scikit-learn

需要安装skl2onnx来支持ONNX导出。

bash 复制代码
pip install skl2onnx

快速开始示例

PyTorch CV示例

  1. 导出模型

    python 复制代码
    import torch
    import torch.onnx as onnx
    
    # 假设model是你的PyTorch模型,device是设备(如GPU或CPU)
    model = ...  # 初始化你的模型
    device = ...  # 设备
    
    input_data = torch.randn(1, 28, 28).to(device)
    onnx.export(model, input_data, "fashion_mnist_model.onnx", 
                input_names=['input'], output_names=['output'])
  2. 加载ONNX模型并进行推理

    python 复制代码
    import onnx
    import onnxruntime as ort
    import numpy as np
    
    # 加载ONNX模型
    onnx_model = onnx.load("fashion_mnist_model.onnx")
    onnx.checker.check_model(onnx_model)
    
    # 创建推理会话
    ort_sess = ort.InferenceSession('fashion_mnist_model.onnx')
    
    # 假设x是输入数据
    x = ...  # 初始化输入数据
    
    outputs = ort_sess.run(None, {'input': x.numpy()})

PyTorch NLP示例

  1. 导出模型

    python 复制代码
    import torch
    import torch.onnx as onnx
    
    # 假设model是你的PyTorch NLP模型
    model = ...  # 初始化你的模型
    
    text = "示例文本"
    text_tensor = torch.tensor(text_pipeline(text))  # 假设text_pipeline是文本预处理函数
    offsets = torch.tensor([0])
    
    onnx.export(model, (text_tensor, offsets), "ag_news_model.onnx", 
                input_names=['input', 'offsets'], output_names=['output'],
                dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
  2. 加载ONNX模型并进行推理

    python 复制代码
    import onnx
    import onnxruntime as ort
    import numpy as np
    
    # 加载ONNX模型
    onnx_model = onnx.load("ag_news_model.onnx")
    onnx.checker.check_model(onnx_model)
    
    # 创建推理会话
    ort_sess = ort.InferenceSession('ag_news_model.onnx')
    
    # 假设text是输入文本,offsets是偏移量
    text = ...  # 初始化输入文本
    offsets = torch.tensor([0])
    
    outputs = ort_sess.run(None, {'input': text.numpy(), 'offsets': offsets.numpy()})

TensorFlow CV示例

  1. 导出模型

    python 复制代码
    import tensorflow as tf
    from tensorflow.keras.applications import ResNet50
    import tf2onnx
    
    # 加载预训练模型
    model = ResNet50(weights='imagenet')
    
    # 转换为ONNX模型
    spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),)
    output_path = model.name + ".onnx"
    model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=output_path)
  2. 加载ONNX模型并进行推理

    python 复制代码
    import onnxruntime as rt
    
    # 创建推理会话
    providers = ['CPUExecutionProvider']
    m = rt.InferenceSession(output_path, providers=providers)
    
    # 假设x是输入数据
    x = ...  # 初始化输入数据
    
    output_names = [n.name for n in model_proto.graph.output]
    onnx_pred = m.run(output_names, {"input": x})

scikit-learn CV示例

  1. 导出模型

    python 复制代码
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    from skl2onnx import convert_sklearn
    
    # 加载iris数据集
    iris = load_iris()
    X, y = iris.data, iris.target
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    
    # 训练逻辑回归模型
    clr = LogisticRegression()
    clr.fit(X_train, y_train)
    
    # 转换为ONNX模型
    initial_type = [('float_input', FloatTensorType([None, 4]))]
    onx = convert_sklearn(clr, initial_types=initial_type)
    with open("logreg_iris.onnx", "wb") as f:
        f.write(onx.SerializeToString())
  2. 加载ONNX模型并进行推理

    python 复制代码
    import numpy
    import onnxruntime as rt
    
    # 创建推理会话
    sess = rt.InferenceSession("logreg_iris.onnx")
    input_name = sess.get_inputs()[0].name
    
    # 假设X_test是测试数据
    pred_onx = sess.run(None, {input_name: X_test.astype(numpy.float32)})[0]

通过这些示例,你可以轻松地将不同框架的模型导出为ONNX格式,并使用ONNX Runtime进行高效的推理。

相关推荐
yhole3 分钟前
springboot 修复 Spring Framework 特定条件下目录遍历漏洞(CVE-2024-38819)
spring boot·后端·spring
BingoGo8 分钟前
Laravel 13 正式发布 使用 Laravel AI 无缝平滑升级
后端·php
l软件定制开发工作室24 分钟前
Spring开发系列教程(34)——打包Spring Boot应用
java·spring boot·后端·spring·springboot
随风,奔跑27 分钟前
Spring MVC
java·后端·spring
美团技术团队1 小时前
美团 BI 在指标平台和分析引擎上的探索和实践
后端
JimmtButler1 小时前
我用 Claude Code 给 Claude Code 做了一个 DevTools
后端·claude
风止何安啊1 小时前
数字太长看花眼?一招教它排好队:千分位处理的实现
前端·javascript·面试
Java水解1 小时前
Java 中实现多租户架构:数据隔离策略与实践指南
java·后端
Master_Azur2 小时前
Java面向对象之多态与重写
后端
Java面试题总结2 小时前
2026Java面试八股文合集(持续更新)
java·spring·面试·职场和发展·java面试·java八股文