TensorRT加速推理入门-1:Pytorch转ONNX

这篇文章,用于记录将TransReID的pytorch模型转换为onnx的学习过程,期间参考和学习了许多大佬编写的博客,在参考文章这一章节中都已列出,非常感谢。

1. 在pytorch下使用ONNX主要步骤

1.1. 环境准备

安装onnxruntime包

安装教程可参考:
onnx模型预测环境安装笔记
onnxruntime配置

CPU版本:

直接pip安装

python 复制代码
pip install onnxruntime

GPU版本:

先查看自己CUDA版本然后在下面的链接去找对应的onnxruntime的版本
CUDA版本的查询,可参考这个
onnxruntime版本查询

查询到对应版本,直接pip安装即可,例如

python 复制代码
pip install onnxruntime-gpu==1.13.1

安装onnxsim包

python 复制代码
pip install onnx-simplifier

1.2. 搭建 PyTorch 模型(TransReID)

python 复制代码
def get_net(model_path,opt_=False):
    if opt_:
        cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")
    #cfg.freeze()
        train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
        net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)
    
    else:
        cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")
        train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
        net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)

    
  #state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    model_state_dict=net.state_dict()
    for key in list(state_dict.keys()):
        if key[7:] in model_state_dict.keys():
            model_state_dict[key[7:]]=state_dict[key]
    
    net.load_state_dict(model_state_dict)
    return net

1.3. pytorch模型转换为 ONNX 模型

这个提供了静态转换(静态转换支持静态输入)和动态转换(动态转换支持动态输入)两个函数,可根据需要选择。

python 复制代码
def convert_onnx_dynamic(model,save_path,simp=False):
    x = torch.randn(4, 3, 256,128)
    input_name = 'input'
    output_name = 'class'
    torch.onnx.export(model,x,save_path,input_names = [input_name],
                    output_names = [output_name],dynamic_axes= {
                        input_name: {0: 'B'},
                        output_name: {0: 'B'}}
                   )
    if simp:
        onnx_model = onnx.load(save_path) 
        model_simp, check = simplify(onnx_model,input_shapes={'input':(4,3,256,128)},dynamic_input_shape=True)
        assert check, "Simplified ONNX model could not be validated"
        onnx.save(model_simp, save_path)
        print('simplify onnx done')

def convert_onnx(model,save_path,batch=1,simp=False):
    
    input_names = ['input']
    output_names=['class']
    x = torch.randn(batch, 3, 256, 128)
    for para in model.parameters():
        para.requires_grad = False
    
    # model_script = torch.jit.script(model)
    # model_trace = torch.jit.trace(model, x)
    torch.onnx.export(model, x, save_path,input_names =input_names,output_names=output_names, opset_version=12)
    if simp:
        onnx_model = onnx.load(save_path) 
        model_simp, check = simplify(onnx_model)
        assert check, "Simplified ONNX model could not be validated"
        onnx.save(model_simp, save_path)
        print('simplify onnx done')

pytorch 转 onnx 仅仅需要一个函数 torch.onnx.export,来看看该函数的参数和用法。

python 复制代码
torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)
参数 用法
model 需要导出的pytorch模型
args 模型的任意一组输入(模拟实际输入数据的大小,比如三通道的512*512大小的图片,就可以设置为torch.randn(1, 3 , 512, 512)
path 输出的onnx模型的位置,例如yolov5.onnx
export_params 输出模型是否可训练。default=True,表示导出trained model, 否则untrained。
verbose 是否打印模型转换信息,default=None
opset_version onnx算子集的版本
input_names 模型的输入节点名称(自己定义的),如果不写,默认输出数字类型 的名称
output_name 模型的输出节点名称(自己定义的), 如果不写,默认输出数字类型的名称
do_constant_folding 是否使用常量折叠,默认即可。default=True。
dynamic_axes 设置动态输入输出,用法:"输入输出名:[支持动态的维度",如"支持动态的维度设置为[0, 2, 3]"则表示第0维,第2维,第3维支持动态输入输出。
模型的输入输出有时是可变的,如rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b, 3, h, w), 其中batch、height、width是可变的,但是chancel是固定三通道。 格式如下:1)仅list(int)dynamic_axes={'input':[0, 2, 3], 'output':{0:'batch', 1:'c'}} 2)仅dict<int, string> dynamic_axes={'input':{'input':{0:'batch', 2:'height', 3:'width'}, 'output':{0:'batch', 1:'c'}} 3)mixed dynamic_axes={'input':{0:'batch', 2:'height', 3:'width'}, 'output':[0,1]}

注意onnx不支持结构中带有if语句的模型,如:

当我们在网络中嵌入一些if选择性的语句时,不好意思,模型不会考虑这些, 它只会记录下运行时走过的节点,不会根据if的实际情况来选择走哪条路, 所以势必会丢弃一部分节点,而丢弃哪些则是根据我们转模型时的输入来定的,一旦指定了,后面运行onnx模型都会如此。另一个问题就是,我们在代码中有一些循环或者迭代的操作时,要注意,尤其是我们的迭代次数是根据输入不同 会有变化时,也会因为这些操作导致后面的推理出现意外错误,正像前面说的,模型转换不喜欢不确定的东西,它会把这些变量dump成常量,所以会导致推理 错误。

对于实际部署的需求,很多时候pytorch是不满足的,所以需要转成其他模型格式来加快推理。常用的就是onnx,onnx天然支持很多框架模型的转换,如Pytorch,tf,darknet,caffe等。而pytorch也给我们提供了对应的接口,就是torch.onnx.export。下面具体到每一步。
原文来自: Windows下使用ONNX+pytorch记录
首先,环境和依赖:onnx包,cuda和cudnn,我用的版本号分别是1.7.0, 10.1, 7.5.4。

我们需要提供一个pytorch的模型,然后调用torch.onnx.export,同时还需要提供另外一些参数。我们一个个来分析,一是我们要给一个dummy input, 就是随便指定一个和我们实际输入时尺寸相同的一个随机数,是Tensor类型的,然后我们要指定转换的device,即是在gpu还是cpu。 然后我们要给一个input_names和output_names,这是绑定输入和输出,当然输入和输出可能不止一个,那就根据实际的输入和输出个数来给出name列表,

如果我们指定的输入和输出名和实际的网络结构不一致的话,onnx会自动给我们设置一个名字。一般是数字字符串。

输入和输出的绑定之后,我我们们可以看到还有一个参数叫做dynamic_axes,这是做什么的呢?哦,这是指定动态输入的,为了满足我们实际推理过程中,可能每张图片的分辨率不一样,所以允许我们给每个维度设置动态输入,这样是不是灵活多了?然后,设置完这些参数和输入,我们就可以开始转换模型了,如果不报错就是成功了,会在当前目录下生成一个.onnx文件。
原文来自: 一文掌握Pytorch-onnx-tensorrt模型转换

1.4 onnx-simplifier简化onnx模型

python 复制代码
model_simp, check = simplify(onnx_model,input_shapes={'input':(4,3,256,128)},dynamic_input_shape=True)

Pytorch转换为ONNX的完整代码pytorch_to_onnx.py

python 复制代码
import json
import os
import onnx
import torch
import argparse
import torch.nn as nn
from onnxsim import simplify
from collections import OrderedDict
import torch.nn.functional as F

# TransReID的模型构建需要的包
from model.make_model import *
from config import cfg
from datasets import make_dataloader 

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

def convert_onnx_dynamic(model,save_path,simp=False):
    x = torch.randn(4, 3, 256,128)
    input_name = 'input'
    output_name = 'class'
    torch.onnx.export(model,x,save_path,input_names = [input_name],
                    output_names = [output_name],dynamic_axes= {
                        input_name: {0: 'B'},
                        output_name: {0: 'B'}}
                   )
    if simp:
        onnx_model = onnx.load(save_path) 
        model_simp, check = simplify(onnx_model,input_shapes={'input':(4,3,256,128)},dynamic_input_shape=True)
        assert check, "Simplified ONNX model could not be validated"
        onnx.save(model_simp, save_path)
        print('simplify onnx done')

def convert_onnx(model,save_path,batch=1,simp=False):
    
    input_names = ['input']
    output_names=['class']
    x = torch.randn(batch, 3, 256, 128)
    for para in model.parameters():
        para.requires_grad = False
    
    # model_script = torch.jit.script(model)
    # model_trace = torch.jit.trace(model, x)
    torch.onnx.export(model, x, save_path,input_names =input_names,output_names=output_names, opset_version=12)
    if simp:
        onnx_model = onnx.load(save_path) 
        model_simp, check = simplify(onnx_model)
        assert check, "Simplified ONNX model could not be validated"
        onnx.save(model_simp, save_path)
        print('simplify onnx done')


def get_net(model_path,opt_=False):
    if opt_:
        cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")
    #cfg.freeze()
        train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
        net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)
    
    else:
        cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")
        train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
        net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)

    
  #state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    model_state_dict=net.state_dict()
    for key in list(state_dict.keys()):
        if key[7:] in model_state_dict.keys():
            model_state_dict[key[7:]]=state_dict[key]
    
    net.load_state_dict(model_state_dict)
    return net
    
if __name__=="__main__":
    parser = argparse.ArgumentParser(description='torch to onnx describe.')
    parser.add_argument(
        "--model_path",
        type = str,
        default="/home/TransReID-main/weights/vit_transreid_occ_duke.pth",
        help="torch weight path, default is MobileViT_Pytorch/weights-file/model_best.pth.tar.")

    parser.add_argument(
        "--save_path",
        type=str,
        default="/home/TransReID-main/weights/vit_transreid_occ_duke_v2.onnx",
        help="save direction of onnx models,default is ./target/MobileViT.onnx.")

    parser.add_argument(
        "--batch",
        type=int,
        default=1,
        help="batchsize of onnx models, default is 1.")

    parser.add_argument(
        "--opt",
        default=False, action='store_true',
        help="model optmization , default is False.")
    parser.add_argument(
        "--dynamic",
        default=False, action='store_true',
        help="export  dynamic onnx model , default is False.")

    args = parser.parse_args()
#   print(args)
  #net = get_net(args.model_path,opt_=args.opt)
    net = get_net(args.model_path)
    if args.dynamic:
        convert_onnx_dynamic(net,args.save_path,simp=True)
    else:
        with torch.no_grad():
            convert_onnx(net,args.save_path,simp=True,batch=args.batch)

1.5 查看onnx模型

当将pytorch模型保存为 ONNX 之后,可以使用一款名为 Netron 的软件打开 .onnx 文件,查看模型结构。

2. 参考文章

[1] Windows下使用ONNX+pytorch记录

[2] pytorch-onnx-tensorrt全链路简单教程(支持动态输入)

[3] PyTorch语义分割模型转ONNX以及对比转换后的效果(PyTorch2ONNX、Torch2ONNX、pth2onnx、pt2onnx、修改名称、转换、测试、加载ONNX、运行ONNX)

[4] ONNX系列一:ONNX的使用,从转化到推理

相关推荐
GISer_Jing5 分钟前
神经网络初学总结(一)
人工智能·深度学习·神经网络
szxinmai主板定制专家14 分钟前
【国产NI替代】基于A7 FPGA+AI的16振动(16bits)终端PCIE数据采集板卡
人工智能·fpga开发
数据分析能量站1 小时前
神经网络-AlexNet
人工智能·深度学习·神经网络
Ven%1 小时前
如何修改pip全局缓存位置和全局安装包存放路径
人工智能·python·深度学习·缓存·自然语言处理·pip
szxinmai主板定制专家1 小时前
【NI国产替代】基于国产FPGA+全志T3的全国产16振动+2转速(24bits)高精度终端采集板卡
人工智能·fpga开发
YangJZ_ByteMaster1 小时前
EndtoEnd Object Detection with Transformers
人工智能·深度学习·目标检测·计算机视觉
Anlici1 小时前
模型训练与数据分析
人工智能·机器学习
余~~185381628002 小时前
NFC 碰一碰发视频源码搭建技术详解,支持OEM
开发语言·人工智能·python·音视频
唔皇万睡万万睡2 小时前
五子棋小游戏设计(Matlab)
人工智能·matlab·游戏程序
视觉语言导航2 小时前
AAAI-2024 | 大语言模型赋能导航决策!NavGPT:基于大模型显式推理的视觉语言导航
人工智能·具身智能