从PyTorch到ONNX:模型部署性能提升

在深度学习模型部署过程中,推理性能优化是一个关键环节。ONNX(Open Neural Network Exchange)作为一个开放的神经网络交换格式,能够在不同框架之间实现模型的无缝转换,并通过优化的运行时环境显著提升推理速度。

环境准备

首先安装必要的依赖包:

bash 复制代码
# 安装PyTorch生态
pip install torch torchvision torchaudio

# 安装ONNX相关包
pip install onnx onnxruntime-gpu  # GPU版本
# pip install onnxruntime          # CPU版本

ONNX常用方法汇总

模型导出 (PyTorch → ONNX)

python 复制代码
import torch
import torch.onnx

# 加载PyTorch模型
model = torch.load('model.pth')
model.eval()

# 创建虚拟输入
dummy_input = torch.randn(1, 3, 224, 224)

# 导出ONNX
torch.onnx.export(
    model,                      # 模型
    dummy_input,               # 虚拟输入
    'model.onnx',              # 输出路径
    input_names=['input'],     # 输入名称
    output_names=['output'],   # 输出名称
    dynamic_axes={'input': {0: 'batch_size'}},  # 动态维度
    opset_version=17
)

模型加载与推理

python 复制代码
import onnxruntime as ort
import numpy as np

# 加载ONNX模型
session = ort.InferenceSession('model.onnx', 
    providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])

# 获取模型信息
print("输入:", [input.name for input in session.get_inputs()])
print("输出:", [output.name for output in session.get_outputs()])
print("执行提供商:", session.get_providers())

# 准备输入数据
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
input_dict = {session.get_inputs()[0].name: input_data}

# 推理
outputs = session.run(None, input_dict)
result = outputs[0]

模型信息查看

python 复制代码
import onnx

# 加载ONNX模型
model = onnx.load('model.onnx')

# 检查模型
onnx.checker.check_model(model)

# 打印模型信息
print("Graph inputs:", [input.name for input in model.graph.input])
print("Graph outputs:", [output.name for output in model.graph.output])

性能优化设置

python 复制代码
# 会话选项配置
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 4        # 线程数
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

session = ort.InferenceSession('model.onnx', sess_options)

完整实现代码

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import onnxruntime as ort
from torchvision.models import resnet50, ResNet50_Weights

# ================== 基础配置 ==================
device = 'cuda' if torch.cuda.is_available() else 'cpu'
C, H, W = 3, 224, 224  # 输入图像尺寸
PTH_PATH = "resnet50.pth"
ONNX_PATH = "resnet50.onnx"

print(f"使用设备: {device}")
print(f"输入尺寸: {C}x{H}x{W}")

# ================== 模型下载与导出 ==================
def download_and_export():
    """下载预训练模型并导出为ONNX格式"""
    print("\n开始下载预训练模型...")
    
    # 加载预训练的ResNet50模型
    pt_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).to(device)
    pt_model.eval()  # 设置为评估模式
    
    # 保存PyTorch模型权重
    torch.save(pt_model.state_dict(), PTH_PATH)
    print(f"PyTorch模型已保存至: {PTH_PATH}")
    
    # 创建虚拟输入用于导出
    dummy_input = torch.randn(1, C, H, W, device=device)
    
    # 导出为ONNX格式
    print("正在导出ONNX模型...")
    torch.onnx.export(
        pt_model,                    # 要导出的模型
        dummy_input,                 # 虚拟输入张量
        ONNX_PATH,                   # 输出文件路径
        input_names=['input'],       # 输入节点名称
        output_names=['output'],     # 输出节点名称
        dynamic_axes={               # 动态轴配置
            "input": {0: "batch_size"},   # batch维度可动态变化
            "output": {0: "batch_size"},
        },
        opset_version=17,            # ONNX算子集版本
        do_constant_folding=True,    # 常量折叠优化
        dynamo=False                 # 禁用TorchDynamo
    )
    print(f"ONNX模型已导出至: {ONNX_PATH}")

# ================== 模型加载 ==================
def load_models():
    """加载PyTorch和ONNX模型"""
    print("\n正在加载模型...")
    
    # 加载PyTorch模型
    model_pt = resnet50(weights=None).to(device)  # 不重新下载权重
    model_pt.load_state_dict(torch.load(PTH_PATH, map_location=device))
    model_pt.eval()
    
    # 加载ONNX模型
    providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
    sess = ort.InferenceSession(ONNX_PATH, providers=providers)
    
    print(f"PyTorch模型已加载")
    print(f"ONNX模型已加载,使用提供商: {sess.get_providers()}")
    
    return model_pt, sess

# ================== 性能测试函数 ==================
def benchmark_pytorch(model, data, num_batches=50):
    """测试PyTorch模型推理性能"""
    print(f"\n开始PyTorch性能测试 (批次数: {num_batches})")
    
    with torch.no_grad():
        # 预热阶段
        for _ in range(10):
            _ = model(data)
        
        # 同步GPU操作
        if device == 'cuda':
            torch.cuda.synchronize()
        
        # 正式计时
        start = time.time()
        for _ in range(num_batches):
            _ = model(data)
        
        if device == 'cuda':
            torch.cuda.synchronize()
        end = time.time()
    
    avg_time = (end - start) / num_batches
    print(f"PyTorch平均推理时间: {avg_time:.6f}秒")
    return avg_time

def benchmark_onnx(sess, data, num_batches=50):
    """测试ONNX模型推理性能"""
    print(f"\n开始ONNX性能测试 (批次数: {num_batches})")
    
    # 准备ONNX输入格式
    onnx_input = {sess.get_inputs()[0].name: data.cpu().numpy()}
    
    # 预热阶段
    for _ in range(10):
        _ = sess.run(None, onnx_input)
    
    # 正式计时
    start = time.time()
    for _ in range(num_batches):
        _ = sess.run(None, onnx_input)
    end = time.time()
    
    avg_time = (end - start) / num_batches
    print(f"ONNX平均推理时间: {avg_time:.6f}秒")
    return avg_time

# ================== 精度验证 ==================
def verify_correctness(model_pt, sess):
    """验证PyTorch和ONNX模型输出一致性"""
    print("\n开始精度验证...")
    
    # 创建测试输入
    input_tensor = torch.randn(1, C, H, W, device=device)
    
    # PyTorch推理
    with torch.no_grad():
        pytorch_out = model_pt(input_tensor).detach().cpu().numpy()
    
    # ONNX推理
    onnx_input = {sess.get_inputs()[0].name: input_tensor.cpu().numpy()}
    onnx_out = sess.run(None, onnx_input)[0]
    
    # 比较结果
    is_close = np.allclose(pytorch_out, onnx_out, atol=1e-3)
    max_diff = np.max(np.abs(pytorch_out - onnx_out))
    
    print(f"输出一致性检查: {'通过' if is_close else '失败'}")
    print(f" 最大差异: {max_diff:.6f}")
    
    return is_close, max_diff

# ================== 主函数 ==================
def main():
    """主执行函数"""
    print("ONNX模型导出与性能对比教程")
    print("=" * 50)
    
    # 1. 下载并导出模型
    download_and_export()
    
    # 2. 加载模型
    model_pt, sess = load_models()
    
    # 3. 准备测试数据
    batch_size = 64
    test_data = torch.randn(batch_size, C, H, W, device=device)
    print(f"\n测试数据形状: {test_data.shape}")
    
    # 4. 性能测试
    print("\n" + "="*30 + " 性能测试 " + "="*30)
    
    num_batches = 20
    pt_time = benchmark_pytorch(model_pt, test_data, num_batches)
    onnx_time = benchmark_onnx(sess, test_data, num_batches)
    
    # 5. 结果分析
    print("\n" + "="*30 + " 测试结果 " + "="*30)
    print(f"PyTorch推理时间:  {pt_time:.6f}秒")
    print(f"ONNX推理时间:     {onnx_time:.6f}秒")
    
    if onnx_time < pt_time:
        speedup = pt_time / onnx_time
        print(f"ONNX加速比:       {speedup:.2f}x")
        print(f"性能提升:         {((pt_time - onnx_time) / pt_time * 100):.1f}%")
    else:
        slowdown = onnx_time / pt_time
        print(f"ONNX较慢:         {slowdown:.2f}x")
    
    # 6. 精度验证
    print("\n" + "="*30 + " 精度验证 " + "="*30)
    is_accurate, max_diff = verify_correctness(model_pt, sess)
    
if __name__ == "__main__":
    main()
相关推荐
库库8399 小时前
Spring AI 知识点总结
java·人工智能·spring
AndrewHZ9 小时前
【图像处理基石】通过立体视觉重建建筑高度:原理、实操与代码实现
图像处理·人工智能·计算机视觉·智慧城市·三维重建·立体视觉·1024程序员节
Theodore_10229 小时前
深度学习(3)神经网络
人工智能·深度学习·神经网络·算法·机器学习·计算机视觉
文火冰糖的硅基工坊9 小时前
[人工智能-大模型-70]:模型层技术 - 从数据中自动学习一个有用的数学函数的全过程,AI函数计算三大件:神经网络、损失函数、优化器
人工智能·深度学习·神经网络
我叫张土豆10 小时前
Neo4j 版本选型与 Java 技术栈深度解析:Spring Data Neo4j vs Java Driver,如何抉择?
java·人工智能·spring·neo4j
lang2015092810 小时前
Spring环境配置与属性管理完全指南
java·python·spring
IT_陈寒10 小时前
Vue3性能提升30%的秘密:5个90%开发者不知道的组合式API优化技巧
前端·人工智能·后端
懒惰蜗牛10 小时前
Day10:Python实现Excel自动汇总
python·numpy·pandas·pip·1024程序员节·python读写excel
我是华为OD~HR~栗栗呀10 小时前
华为od-22届考研-C++面经
java·前端·c++·python·华为od·华为·面试