使用ONNX Runtime在Python中进行模型推理

ONNX Runtime(ORT)是一种高性能的推理引擎,支持多种深度学习框架,如PyTorch、TensorFlow和scikit-learn。以下是如何在Python中安装和使用ONNX Runtime的简要指南。

安装ONNX Runtime

ONNX Runtime有两个主要的Python包:CPU版本和GPU版本。根据你的硬件选择合适的版本。

  • CPU版本:适用于Arm-based CPU和macOS。

    bash 复制代码
    pip install onnxruntime
  • GPU版本(CUDA 12.x):默认支持CUDA 12.x。

    bash 复制代码
    pip install onnxruntime-gpu
  • GPU版本(CUDA 11.8):需要从Azure DevOps Feed安装。

    bash 复制代码
    pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-11/pypi/simple/

安装ONNX导出工具

PyTorch

PyTorch自带ONNX支持,直接安装PyTorch即可。

bash 复制代码
pip install torch

TensorFlow

需要额外安装tf2onnx来支持ONNX导出。

bash 复制代码
pip install tf2onnx

scikit-learn

需要安装skl2onnx来支持ONNX导出。

bash 复制代码
pip install skl2onnx

快速开始示例

PyTorch CV示例

  1. 导出模型

    python 复制代码
    import torch
    import torch.onnx as onnx
    
    # 假设model是你的PyTorch模型,device是设备(如GPU或CPU)
    model = ...  # 初始化你的模型
    device = ...  # 设备
    
    input_data = torch.randn(1, 28, 28).to(device)
    onnx.export(model, input_data, "fashion_mnist_model.onnx", 
                input_names=['input'], output_names=['output'])
  2. 加载ONNX模型并进行推理

    python 复制代码
    import onnx
    import onnxruntime as ort
    import numpy as np
    
    # 加载ONNX模型
    onnx_model = onnx.load("fashion_mnist_model.onnx")
    onnx.checker.check_model(onnx_model)
    
    # 创建推理会话
    ort_sess = ort.InferenceSession('fashion_mnist_model.onnx')
    
    # 假设x是输入数据
    x = ...  # 初始化输入数据
    
    outputs = ort_sess.run(None, {'input': x.numpy()})

PyTorch NLP示例

  1. 导出模型

    python 复制代码
    import torch
    import torch.onnx as onnx
    
    # 假设model是你的PyTorch NLP模型
    model = ...  # 初始化你的模型
    
    text = "示例文本"
    text_tensor = torch.tensor(text_pipeline(text))  # 假设text_pipeline是文本预处理函数
    offsets = torch.tensor([0])
    
    onnx.export(model, (text_tensor, offsets), "ag_news_model.onnx", 
                input_names=['input', 'offsets'], output_names=['output'],
                dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
  2. 加载ONNX模型并进行推理

    python 复制代码
    import onnx
    import onnxruntime as ort
    import numpy as np
    
    # 加载ONNX模型
    onnx_model = onnx.load("ag_news_model.onnx")
    onnx.checker.check_model(onnx_model)
    
    # 创建推理会话
    ort_sess = ort.InferenceSession('ag_news_model.onnx')
    
    # 假设text是输入文本,offsets是偏移量
    text = ...  # 初始化输入文本
    offsets = torch.tensor([0])
    
    outputs = ort_sess.run(None, {'input': text.numpy(), 'offsets': offsets.numpy()})

TensorFlow CV示例

  1. 导出模型

    python 复制代码
    import tensorflow as tf
    from tensorflow.keras.applications import ResNet50
    import tf2onnx
    
    # 加载预训练模型
    model = ResNet50(weights='imagenet')
    
    # 转换为ONNX模型
    spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),)
    output_path = model.name + ".onnx"
    model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=output_path)
  2. 加载ONNX模型并进行推理

    python 复制代码
    import onnxruntime as rt
    
    # 创建推理会话
    providers = ['CPUExecutionProvider']
    m = rt.InferenceSession(output_path, providers=providers)
    
    # 假设x是输入数据
    x = ...  # 初始化输入数据
    
    output_names = [n.name for n in model_proto.graph.output]
    onnx_pred = m.run(output_names, {"input": x})

scikit-learn CV示例

  1. 导出模型

    python 复制代码
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    from skl2onnx import convert_sklearn
    
    # 加载iris数据集
    iris = load_iris()
    X, y = iris.data, iris.target
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    
    # 训练逻辑回归模型
    clr = LogisticRegression()
    clr.fit(X_train, y_train)
    
    # 转换为ONNX模型
    initial_type = [('float_input', FloatTensorType([None, 4]))]
    onx = convert_sklearn(clr, initial_types=initial_type)
    with open("logreg_iris.onnx", "wb") as f:
        f.write(onx.SerializeToString())
  2. 加载ONNX模型并进行推理

    python 复制代码
    import numpy
    import onnxruntime as rt
    
    # 创建推理会话
    sess = rt.InferenceSession("logreg_iris.onnx")
    input_name = sess.get_inputs()[0].name
    
    # 假设X_test是测试数据
    pred_onx = sess.run(None, {input_name: X_test.astype(numpy.float32)})[0]

通过这些示例,你可以轻松地将不同框架的模型导出为ONNX格式,并使用ONNX Runtime进行高效的推理。

相关推荐
magic334165632 分钟前
Springboot整合MinIO文件服务(windows版本)
windows·spring boot·后端·minio·文件对象存储
222you5 分钟前
Git仓库推送到GitHub
git·github
开心-开心急了12 分钟前
Flask入门教程——李辉 第一、二章关键知识梳理(更新一次)
后端·python·flask
掘金码甲哥23 分钟前
调试grpc的哼哈二将,你值得拥有
后端
小学鸡!37 分钟前
Spring Boot实现日志链路追踪
java·spring boot·后端
用户21411832636022 小时前
OpenSpec 实战:用规范驱动开发破解 AI 编程协作难题
后端
你的人类朋友3 小时前
hotfix分支的使用
git·gitlab·github
Olrookie3 小时前
若依前后端分离版学习笔记(二十)——实现滑块验证码(vue3)
java·前端·笔记·后端·学习·vue·ruoyi
LucianaiB3 小时前
招聘可以AI面试,那么我制作了一个AI面试教练不过分吧
后端
无奈何杨4 小时前
CoolGuard更新,ip2region升级、名单增加过期时间
后端