30分钟吃掉 Pytorch 转 onnx

节前,我们星球组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学.

针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。

汇总合集:

《大模型面试宝典》(2024版) 发布!

圈粉无数!《PyTorch 实战宝典》火了!!!


PyTorch 是一个用于机器学习的开源深度学习框架,而ONNX(Open Neural Network Exchange)是一个用于表示深度学习模型的开放式格式。

将 PyTorch 模型转换为ONNX格式有几个原因和优势:

  1. 跨平台部署: ONNX是一个跨平台的格式,支持多种深度学习框架,包括PyTorch、TensorFlow等。将模型转换为ONNX格式可以使模型在不同框架和设备上进行部署和运行。

  2. 性能优化: ONNX格式可以在不同框架之间实现性能优化。例如,可以在PyTorch中训练模型,然后转换为ONNX格式,并在性能更高的框架(如TensorRT)中进行推理。

  3. 模型压缩: ONNX格式可以实现模型的压缩和优化,从而减小模型的体积并提高推理速度。这对于在资源受限的设备上部署模型尤为重要。

pytorch 模型线上部署最常见的方式是转换成onnx,然后再转成tensorRT 在cuda上进行部署推理。

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了Pytorch 技术与面试交流群, 想要获取最新面试题、了解最新面试动态的、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2040。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群

方式②、添加微信号:mlc2040,备注:技术交流

本文介绍将pytorch模型转换成onnx模型并进行推理的方法。

bash 复制代码
#!pip install onnx 
#!pip install onnxruntime
#!pip install torchvision

一,准备pytorch模型

我们先导入torchvision中的resnet18模型,演示它的推理效果。

以便和onnx的结果进行对比。

python 复制代码
import torch
import torchvision.models as models
import numpy as np
import torchvision
import torchvision.transforms as T

from PIL import Image

def create_net():
    net = models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
    return net 

net = create_net()

torch.save(net.state_dict(),'resnet18.pt')
net.eval();
python 复制代码
def get_test_transform():
    return T.Compose([
        T.Resize([320, 320]),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

image = Image.open("dog.png") # 289
img = get_test_transform()(image)
img = img.unsqueeze_(0) 
output = net(img)
score, indice = torch.max(torch.softmax(output,axis=-1),1)
info = {'score':score.tolist()[0],'indice':indice.tolist()[0]}

def show_image(image, title):
    import matplotlib.pyplot as plt 
    ax=plt.subplot()
    ax.imshow(image)
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([]) 
    plt.show()

show_image(image, title = info)

二,pytorch模型转换成onnx模型

1, 简化版本

python 复制代码
import onnxruntime
import onnx

batch_size = 1  
input_shape = (3, 320, 320)   

x = torch.randn(batch_size, *input_shape)
onnx_file = "resnet18.onnx"
torch.onnx.export(net,x,onnx_file,
                opset_version=10,
                do_constant_folding=True,  # 是否执行常量折叠优化
                input_names=["input"],
                output_names=["output"],
                dynamic_axes={
                    "input":{0:"batch_size"},  
                     "output":{0:"batch_size"}})
bash 复制代码
!du -s -h resnet18.pt
bash 复制代码
 45M	resnet18.pt
bash 复制代码
!du -s -h resnet18.onnx 
bash 复制代码
 45M	resnet18.onnx

可以在 https://netron.app/ 中拖入 resnet18.onnx 文件查看模型结构

2,全面版本

下面的代码包括了设置输入输出尺寸,以及动态可以变batch等等。

python 复制代码
import argparse
from argparse import Namespace
import time
import sys
import os
import torch
import torch.nn as nn
import torchvision.models as models
import onnx
import onnxruntime

from io import BytesIO


ROOT = os.getcwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

params = Namespace(weights='resnet18.pt',
                   img_size=[320,320],
                   batch_size=1,
                   half=False,
                   dynamic_batch=True
                  )

parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='checkpoint.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[320, 320], help='image size')  # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--inplace', action='store_true', help='set Detect() inplace=True')
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
parser.add_argument('--dynamic-batch', action='store_true', help='export dynamic batch onnx model')
parser.add_argument('--trt-version', type=int, default=8, help='tensorrt version')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')

args = parser.parse_args(args='',namespace=params)


args.img_size *= 2 if len(args.img_size) == 1 else 1  # expand
print(args)

t = time.time()

# Check device
cuda = args.device != 'cpu' and torch.cuda.is_available()
device = torch.device(f'cuda:{args.device}' if cuda else 'cpu')
assert not (device.type == 'cpu' and args.half), '--half only compatible with GPU export, i.e. use --device 0'

# Load PyTorch model
model = create_net()
model.to(device)
model.load_state_dict(torch.load(args.weights)) # pytorch模型加载

# Input
img = torch.zeros(args.batch_size, 3, *args.img_size).to(device)  # image size(1,3,320,192) iDetection

# Update model
if args.half:
    img, model = img.half(), model.half()  # to FP16
model.eval()

prediction = model(img)  # dry run

# ONNX export
print('\nStarting to export ONNX...')
export_file = args.weights.replace('.pt', '.onnx')  # filename
with BytesIO() as f:
    dynamic_axes = {"input":{0:"batch_size"}, "output":{0:"batch_size"} } if args.dynamic_batch else None
    torch.onnx.export(model, img, f, verbose=False, opset_version=13,
                      training=torch.onnx.TrainingMode.EVAL,
                      do_constant_folding=True,
                      input_names=['input'],
                      output_names=['output'],
                      dynamic_axes=dynamic_axes)
    f.seek(0)
    # Checks
    onnx_model = onnx.load(f)  # load onnx model
    onnx.checker.check_model(onnx_model)  # check onnx model
    
if args.simplify:
    try:
        import onnxsim
        print('\nStarting to simplify ONNX...')
        onnx_model, check = onnxsim.simplify(onnx_model)
        assert check, 'assert check failed'
    except Exception as e:
        print(f'Simplifier failure: {e}')

onnx.save(onnx_model, export_file)

print(f'ONNX export success, saved as {export_file}')

# Finish
print('\nExport complete (%.2fs)' % (time.time() - t))
bash 复制代码
Namespace(weights='resnet18.pt', img_size=[320, 320], batch_size=1, half=False, dynamic_batch=True, inplace=False, simplify=False, trt_version=8, device='cpu')

Starting to export ONNX...
ONNX export success, saved as resnet18.onnx

Export complete (0.57s)

三,使用onnx模型进行推理

1,函数风格

bash 复制代码
onnx_sesstion = onnxruntime.InferenceSession(export_file)
python 复制代码
def pipe(img_path,
         onnx_sesstion = onnx_sesstion):
    image = Image.open(img_path) 
    img = get_test_transform()(image)
    img = img.unsqueeze_(0) 

    to_numpy = lambda tensor: tensor.data.cpu().numpy()
    
    inputs = {onnx_sesstion.get_inputs()[0].name: to_numpy(img)}
    outs = onnx_sesstion.run(None, inputs)[0]

    score, indice = torch.max(torch.softmax(torch.as_tensor(outs),axis=-1),1)
    info = {'score':score.tolist()[0],'indice':indice.tolist()[0]}
    return info
bash 复制代码
img_path = 'dog.png'image = Image.open(img_path)info = pipe(img_path)show_image(image,info)

2,对象风格

python 复制代码
import os, sys

import onnxruntime
import onnx
    
class ONNXModel():
    def __init__(self, onnx_path):
        self.onnx_session = onnxruntime.InferenceSession(onnx_path)
        self.input_names = [node.name for node in self.onnx_session.get_inputs()]
        self.output_names = [node.name for node in self.onnx_session.get_outputs()]
        print("input_name:{}".format(self.input_names))
        print("output_name:{}".format(self.output_names))
 
    def forward(self, x):
        if isinstance(x,np.ndarray):
            assert len(self.input_names)==1
            input_feed = {self.input_names[0]:x}
        elif isinstance(x,(tuple,list)):
            assert len(self.input_names)==len(x)
            input_feed = {k:v for k,v in zip(self.input_names,x)}
        else:
            assert isinstance(x,dict)
            input_feed = x
        outs = self.onnx_session.run(self.output_names, input_feed=input_feed)
        return outs
    
    def predict(self,img_path):
        image = Image.open(img_path) 
        img = get_test_transform()(image)
        img = img.unsqueeze_(0) 
        to_numpy = lambda tensor: tensor.data.cpu().numpy()
        outs = self.forward(to_numpy(img))[0]
        score, indice = torch.max(torch.softmax(torch.as_tensor(outs),axis=-1),1)
        return {'score':score[0].data.numpy().tolist(),
            'indice':indice[0].data.numpy().tolist()}
bash 复制代码
onnx_model = ONNXModel(export_file)
info = onnx_model.predict(img_path)
show_image(image, title = info)
bash 复制代码
input_name:['input']
output_name:['output']
相关推荐
Chef_Chen4 分钟前
从0开始机器学习--Day17--神经网络反向传播作业
python·神经网络·机器学习
千澜空24 分钟前
celery在django项目中实现并发任务和定时任务
python·django·celery·定时任务·异步任务
学习前端的小z27 分钟前
【AIGC】如何通过ChatGPT轻松制作个性化GPTs应用
人工智能·chatgpt·aigc
斯凯利.瑞恩31 分钟前
Python决策树、随机森林、朴素贝叶斯、KNN(K-最近邻居)分类分析银行拉新活动挖掘潜在贷款客户附数据代码
python·决策树·随机森林
yannan201903131 小时前
【算法】(Python)动态规划
python·算法·动态规划
埃菲尔铁塔_CV算法1 小时前
人工智能图像算法:开启视觉新时代的钥匙
人工智能·算法
EasyCVR1 小时前
EHOME视频平台EasyCVR视频融合平台使用OBS进行RTMP推流,WebRTC播放出现抖动、卡顿如何解决?
人工智能·算法·ffmpeg·音视频·webrtc·监控视频接入
打羽毛球吗️1 小时前
机器学习中的两种主要思路:数据驱动与模型驱动
人工智能·机器学习
蒙娜丽宁1 小时前
《Python OpenCV从菜鸟到高手》——零基础进阶,开启图像处理与计算机视觉的大门!
python·opencv·计算机视觉
光芒再现dev1 小时前
已解决,部署GPTSoVITS报错‘AsyncRequest‘ object has no attribute ‘_json_response_data‘
运维·python·gpt·语言模型·自然语言处理