从PyTorch到ONNX:模型部署性能提升

在深度学习模型部署过程中,推理性能优化是一个关键环节。ONNX(Open Neural Network Exchange)作为一个开放的神经网络交换格式,能够在不同框架之间实现模型的无缝转换,并通过优化的运行时环境显著提升推理速度。

环境准备

首先安装必要的依赖包:

bash 复制代码
# 安装PyTorch生态
pip install torch torchvision torchaudio

# 安装ONNX相关包
pip install onnx onnxruntime-gpu  # GPU版本
# pip install onnxruntime          # CPU版本

ONNX常用方法汇总

模型导出 (PyTorch → ONNX)

python 复制代码
import torch
import torch.onnx

# 加载PyTorch模型
model = torch.load('model.pth')
model.eval()

# 创建虚拟输入
dummy_input = torch.randn(1, 3, 224, 224)

# 导出ONNX
torch.onnx.export(
    model,                      # 模型
    dummy_input,               # 虚拟输入
    'model.onnx',              # 输出路径
    input_names=['input'],     # 输入名称
    output_names=['output'],   # 输出名称
    dynamic_axes={'input': {0: 'batch_size'}},  # 动态维度
    opset_version=17
)

模型加载与推理

python 复制代码
import onnxruntime as ort
import numpy as np

# 加载ONNX模型
session = ort.InferenceSession('model.onnx', 
    providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])

# 获取模型信息
print("输入:", [input.name for input in session.get_inputs()])
print("输出:", [output.name for output in session.get_outputs()])
print("执行提供商:", session.get_providers())

# 准备输入数据
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
input_dict = {session.get_inputs()[0].name: input_data}

# 推理
outputs = session.run(None, input_dict)
result = outputs[0]

模型信息查看

python 复制代码
import onnx

# 加载ONNX模型
model = onnx.load('model.onnx')

# 检查模型
onnx.checker.check_model(model)

# 打印模型信息
print("Graph inputs:", [input.name for input in model.graph.input])
print("Graph outputs:", [output.name for output in model.graph.output])

性能优化设置

python 复制代码
# 会话选项配置
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 4        # 线程数
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

session = ort.InferenceSession('model.onnx', sess_options)

完整实现代码

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import onnxruntime as ort
from torchvision.models import resnet50, ResNet50_Weights

# ================== 基础配置 ==================
device = 'cuda' if torch.cuda.is_available() else 'cpu'
C, H, W = 3, 224, 224  # 输入图像尺寸
PTH_PATH = "resnet50.pth"
ONNX_PATH = "resnet50.onnx"

print(f"使用设备: {device}")
print(f"输入尺寸: {C}x{H}x{W}")

# ================== 模型下载与导出 ==================
def download_and_export():
    """下载预训练模型并导出为ONNX格式"""
    print("\n开始下载预训练模型...")
    
    # 加载预训练的ResNet50模型
    pt_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).to(device)
    pt_model.eval()  # 设置为评估模式
    
    # 保存PyTorch模型权重
    torch.save(pt_model.state_dict(), PTH_PATH)
    print(f"PyTorch模型已保存至: {PTH_PATH}")
    
    # 创建虚拟输入用于导出
    dummy_input = torch.randn(1, C, H, W, device=device)
    
    # 导出为ONNX格式
    print("正在导出ONNX模型...")
    torch.onnx.export(
        pt_model,                    # 要导出的模型
        dummy_input,                 # 虚拟输入张量
        ONNX_PATH,                   # 输出文件路径
        input_names=['input'],       # 输入节点名称
        output_names=['output'],     # 输出节点名称
        dynamic_axes={               # 动态轴配置
            "input": {0: "batch_size"},   # batch维度可动态变化
            "output": {0: "batch_size"},
        },
        opset_version=17,            # ONNX算子集版本
        do_constant_folding=True,    # 常量折叠优化
        dynamo=False                 # 禁用TorchDynamo
    )
    print(f"ONNX模型已导出至: {ONNX_PATH}")

# ================== 模型加载 ==================
def load_models():
    """加载PyTorch和ONNX模型"""
    print("\n正在加载模型...")
    
    # 加载PyTorch模型
    model_pt = resnet50(weights=None).to(device)  # 不重新下载权重
    model_pt.load_state_dict(torch.load(PTH_PATH, map_location=device))
    model_pt.eval()
    
    # 加载ONNX模型
    providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
    sess = ort.InferenceSession(ONNX_PATH, providers=providers)
    
    print(f"PyTorch模型已加载")
    print(f"ONNX模型已加载,使用提供商: {sess.get_providers()}")
    
    return model_pt, sess

# ================== 性能测试函数 ==================
def benchmark_pytorch(model, data, num_batches=50):
    """测试PyTorch模型推理性能"""
    print(f"\n开始PyTorch性能测试 (批次数: {num_batches})")
    
    with torch.no_grad():
        # 预热阶段
        for _ in range(10):
            _ = model(data)
        
        # 同步GPU操作
        if device == 'cuda':
            torch.cuda.synchronize()
        
        # 正式计时
        start = time.time()
        for _ in range(num_batches):
            _ = model(data)
        
        if device == 'cuda':
            torch.cuda.synchronize()
        end = time.time()
    
    avg_time = (end - start) / num_batches
    print(f"PyTorch平均推理时间: {avg_time:.6f}秒")
    return avg_time

def benchmark_onnx(sess, data, num_batches=50):
    """测试ONNX模型推理性能"""
    print(f"\n开始ONNX性能测试 (批次数: {num_batches})")
    
    # 准备ONNX输入格式
    onnx_input = {sess.get_inputs()[0].name: data.cpu().numpy()}
    
    # 预热阶段
    for _ in range(10):
        _ = sess.run(None, onnx_input)
    
    # 正式计时
    start = time.time()
    for _ in range(num_batches):
        _ = sess.run(None, onnx_input)
    end = time.time()
    
    avg_time = (end - start) / num_batches
    print(f"ONNX平均推理时间: {avg_time:.6f}秒")
    return avg_time

# ================== 精度验证 ==================
def verify_correctness(model_pt, sess):
    """验证PyTorch和ONNX模型输出一致性"""
    print("\n开始精度验证...")
    
    # 创建测试输入
    input_tensor = torch.randn(1, C, H, W, device=device)
    
    # PyTorch推理
    with torch.no_grad():
        pytorch_out = model_pt(input_tensor).detach().cpu().numpy()
    
    # ONNX推理
    onnx_input = {sess.get_inputs()[0].name: input_tensor.cpu().numpy()}
    onnx_out = sess.run(None, onnx_input)[0]
    
    # 比较结果
    is_close = np.allclose(pytorch_out, onnx_out, atol=1e-3)
    max_diff = np.max(np.abs(pytorch_out - onnx_out))
    
    print(f"输出一致性检查: {'通过' if is_close else '失败'}")
    print(f" 最大差异: {max_diff:.6f}")
    
    return is_close, max_diff

# ================== 主函数 ==================
def main():
    """主执行函数"""
    print("ONNX模型导出与性能对比教程")
    print("=" * 50)
    
    # 1. 下载并导出模型
    download_and_export()
    
    # 2. 加载模型
    model_pt, sess = load_models()
    
    # 3. 准备测试数据
    batch_size = 64
    test_data = torch.randn(batch_size, C, H, W, device=device)
    print(f"\n测试数据形状: {test_data.shape}")
    
    # 4. 性能测试
    print("\n" + "="*30 + " 性能测试 " + "="*30)
    
    num_batches = 20
    pt_time = benchmark_pytorch(model_pt, test_data, num_batches)
    onnx_time = benchmark_onnx(sess, test_data, num_batches)
    
    # 5. 结果分析
    print("\n" + "="*30 + " 测试结果 " + "="*30)
    print(f"PyTorch推理时间:  {pt_time:.6f}秒")
    print(f"ONNX推理时间:     {onnx_time:.6f}秒")
    
    if onnx_time < pt_time:
        speedup = pt_time / onnx_time
        print(f"ONNX加速比:       {speedup:.2f}x")
        print(f"性能提升:         {((pt_time - onnx_time) / pt_time * 100):.1f}%")
    else:
        slowdown = onnx_time / pt_time
        print(f"ONNX较慢:         {slowdown:.2f}x")
    
    # 6. 精度验证
    print("\n" + "="*30 + " 精度验证 " + "="*30)
    is_accurate, max_diff = verify_correctness(model_pt, sess)
    
if __name__ == "__main__":
    main()
相关推荐
xcnn_2 小时前
深度学习基础概念回顾(Pytorch架构)
人工智能·pytorch·深度学习
蒋星熠3 小时前
Flutter跨平台工程实践与原理透视:从渲染引擎到高质产物
开发语言·python·算法·flutter·设计模式·性能优化·硬件工程
attitude.x3 小时前
PyTorch 动态图的灵活性与实用技巧
前端·人工智能·深度学习
骥龙3 小时前
XX汽集团数字化转型:全生命周期网络安全、数据合规与AI工业物联网融合实践
人工智能·物联网·web安全
zskj_qcxjqr3 小时前
告别传统繁琐!七彩喜艾灸机器人:一键开启智能养生新时代
大数据·人工智能·科技·机器人
Ven%3 小时前
第一章 神经网络的复习
人工智能·深度学习·神经网络
爬虫程序猿3 小时前
《京东商品详情爬取实战指南》
爬虫·python
胡耀超3 小时前
4、Python面向对象编程与模块化设计
开发语言·python·ai·大模型·conda·anaconda
研梦非凡4 小时前
CVPR 2025|基于视觉语言模型的零样本3D视觉定位
人工智能·深度学习·计算机视觉·3d·ai·语言模型·自然语言处理