这篇文章,用于记录将TransReID的pytorch模型转换为onnx的学习过程,期间参考和学习了许多大佬编写的博客,在参考文章这一章节中都已列出,非常感谢。
1. 在pytorch下使用ONNX主要步骤
1.1. 环境准备
安装onnxruntime包
安装教程可参考:
onnx模型预测环境安装笔记
onnxruntime配置
CPU版本:
直接pip安装
python
pip install onnxruntime
GPU版本:
先查看自己CUDA版本然后在下面的链接去找对应的onnxruntime的版本
CUDA版本的查询,可参考这个
onnxruntime版本查询
查询到对应版本,直接pip安装即可,例如
python
pip install onnxruntime-gpu==1.13.1
安装onnxsim包
python
pip install onnx-simplifier
1.2. 搭建 PyTorch 模型(TransReID)
python
def get_net(model_path,opt_=False):
if opt_:
cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")
#cfg.freeze()
train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)
else:
cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")
train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)
#state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
model_state_dict=net.state_dict()
for key in list(state_dict.keys()):
if key[7:] in model_state_dict.keys():
model_state_dict[key[7:]]=state_dict[key]
net.load_state_dict(model_state_dict)
return net
1.3. pytorch模型转换为 ONNX 模型
这个提供了静态转换(静态转换支持静态输入)和动态转换(动态转换支持动态输入)两个函数,可根据需要选择。
python
def convert_onnx_dynamic(model,save_path,simp=False):
x = torch.randn(4, 3, 256,128)
input_name = 'input'
output_name = 'class'
torch.onnx.export(model,x,save_path,input_names = [input_name],
output_names = [output_name],dynamic_axes= {
input_name: {0: 'B'},
output_name: {0: 'B'}}
)
if simp:
onnx_model = onnx.load(save_path)
model_simp, check = simplify(onnx_model,input_shapes={'input':(4,3,256,128)},dynamic_input_shape=True)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, save_path)
print('simplify onnx done')
def convert_onnx(model,save_path,batch=1,simp=False):
input_names = ['input']
output_names=['class']
x = torch.randn(batch, 3, 256, 128)
for para in model.parameters():
para.requires_grad = False
# model_script = torch.jit.script(model)
# model_trace = torch.jit.trace(model, x)
torch.onnx.export(model, x, save_path,input_names =input_names,output_names=output_names, opset_version=12)
if simp:
onnx_model = onnx.load(save_path)
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, save_path)
print('simplify onnx done')
pytorch 转 onnx 仅仅需要一个函数 torch.onnx.export
,来看看该函数的参数和用法。
python
torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)
参数 | 用法 |
---|---|
model | 需要导出的pytorch模型 |
args | 模型的任意一组输入(模拟实际输入数据的大小,比如三通道的512*512大小的图片,就可以设置为torch.randn(1, 3 , 512, 512) |
path | 输出的onnx模型的位置,例如yolov5.onnx |
export_params | 输出模型是否可训练。default=True,表示导出trained model, 否则untrained。 |
verbose | 是否打印模型转换信息,default=None |
opset_version | onnx算子集的版本 |
input_names | 模型的输入节点名称(自己定义的),如果不写,默认输出数字类型 的名称 |
output_name | 模型的输出节点名称(自己定义的), 如果不写,默认输出数字类型的名称 |
do_constant_folding | 是否使用常量折叠,默认即可。default=True。 |
dynamic_axes | 设置动态输入输出,用法:"输入输出名:[支持动态的维度",如"支持动态的维度设置为[0, 2, 3]"则表示第0维,第2维,第3维支持动态输入输出。 |
模型的输入输出有时是可变的,如rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b, 3, h, w), 其中batch、height、width是可变的,但是chancel是固定三通道。 | 格式如下:1)仅list(int)dynamic_axes={'input':[0, 2, 3], 'output':{0:'batch', 1:'c'}} 2)仅dict<int, string> dynamic_axes={'input':{'input':{0:'batch', 2:'height', 3:'width'}, 'output':{0:'batch', 1:'c'}} 3)mixed dynamic_axes={'input':{0:'batch', 2:'height', 3:'width'}, 'output':[0,1]} |
注意onnx不支持结构中带有if语句的模型,如:
当我们在网络中嵌入一些if选择性的语句时,不好意思,模型不会考虑这些, 它只会记录下运行时走过的节点,不会根据if的实际情况来选择走哪条路, 所以势必会丢弃一部分节点,而丢弃哪些则是根据我们转模型时的输入来定的,一旦指定了,后面运行onnx模型都会如此。另一个问题就是,我们在代码中有一些循环或者迭代的操作时,要注意,尤其是我们的迭代次数是根据输入不同 会有变化时,也会因为这些操作导致后面的推理出现意外错误,正像前面说的,模型转换不喜欢不确定的东西,它会把这些变量dump成常量,所以会导致推理 错误。
对于实际部署的需求,很多时候pytorch是不满足的,所以需要转成其他模型格式来加快推理。常用的就是onnx,onnx天然支持很多框架模型的转换,如Pytorch,tf,darknet,caffe等。而pytorch也给我们提供了对应的接口,就是torch.onnx.export。下面具体到每一步。
原文来自: Windows下使用ONNX+pytorch记录
首先,环境和依赖:onnx包,cuda和cudnn,我用的版本号分别是1.7.0, 10.1, 7.5.4。我们需要提供一个pytorch的模型,然后调用torch.onnx.export,同时还需要提供另外一些参数。我们一个个来分析,一是我们要给一个dummy input, 就是随便指定一个和我们实际输入时尺寸相同的一个随机数,是Tensor类型的,然后我们要指定转换的device,即是在gpu还是cpu。 然后我们要给一个input_names和output_names,这是绑定输入和输出,当然输入和输出可能不止一个,那就根据实际的输入和输出个数来给出name列表,
如果我们指定的输入和输出名和实际的网络结构不一致的话,onnx会自动给我们设置一个名字。一般是数字字符串。
输入和输出的绑定之后,我我们们可以看到还有一个参数叫做dynamic_axes,这是做什么的呢?哦,这是指定动态输入的,为了满足我们实际推理过程中,可能每张图片的分辨率不一样,所以允许我们给每个维度设置动态输入,这样是不是灵活多了?然后,设置完这些参数和输入,我们就可以开始转换模型了,如果不报错就是成功了,会在当前目录下生成一个.onnx文件。
原文来自: 一文掌握Pytorch-onnx-tensorrt模型转换
1.4 onnx-simplifier简化onnx模型
python
model_simp, check = simplify(onnx_model,input_shapes={'input':(4,3,256,128)},dynamic_input_shape=True)
Pytorch转换为ONNX的完整代码 :pytorch_to_onnx.py
python
import json
import os
import onnx
import torch
import argparse
import torch.nn as nn
from onnxsim import simplify
from collections import OrderedDict
import torch.nn.functional as F
# TransReID的模型构建需要的包
from model.make_model import *
from config import cfg
from datasets import make_dataloader
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
def convert_onnx_dynamic(model,save_path,simp=False):
x = torch.randn(4, 3, 256,128)
input_name = 'input'
output_name = 'class'
torch.onnx.export(model,x,save_path,input_names = [input_name],
output_names = [output_name],dynamic_axes= {
input_name: {0: 'B'},
output_name: {0: 'B'}}
)
if simp:
onnx_model = onnx.load(save_path)
model_simp, check = simplify(onnx_model,input_shapes={'input':(4,3,256,128)},dynamic_input_shape=True)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, save_path)
print('simplify onnx done')
def convert_onnx(model,save_path,batch=1,simp=False):
input_names = ['input']
output_names=['class']
x = torch.randn(batch, 3, 256, 128)
for para in model.parameters():
para.requires_grad = False
# model_script = torch.jit.script(model)
# model_trace = torch.jit.trace(model, x)
torch.onnx.export(model, x, save_path,input_names =input_names,output_names=output_names, opset_version=12)
if simp:
onnx_model = onnx.load(save_path)
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, save_path)
print('simplify onnx done')
def get_net(model_path,opt_=False):
if opt_:
cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")
#cfg.freeze()
train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)
else:
cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")
train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)
#state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
model_state_dict=net.state_dict()
for key in list(state_dict.keys()):
if key[7:] in model_state_dict.keys():
model_state_dict[key[7:]]=state_dict[key]
net.load_state_dict(model_state_dict)
return net
if __name__=="__main__":
parser = argparse.ArgumentParser(description='torch to onnx describe.')
parser.add_argument(
"--model_path",
type = str,
default="/home/TransReID-main/weights/vit_transreid_occ_duke.pth",
help="torch weight path, default is MobileViT_Pytorch/weights-file/model_best.pth.tar.")
parser.add_argument(
"--save_path",
type=str,
default="/home/TransReID-main/weights/vit_transreid_occ_duke_v2.onnx",
help="save direction of onnx models,default is ./target/MobileViT.onnx.")
parser.add_argument(
"--batch",
type=int,
default=1,
help="batchsize of onnx models, default is 1.")
parser.add_argument(
"--opt",
default=False, action='store_true',
help="model optmization , default is False.")
parser.add_argument(
"--dynamic",
default=False, action='store_true',
help="export dynamic onnx model , default is False.")
args = parser.parse_args()
# print(args)
#net = get_net(args.model_path,opt_=args.opt)
net = get_net(args.model_path)
if args.dynamic:
convert_onnx_dynamic(net,args.save_path,simp=True)
else:
with torch.no_grad():
convert_onnx(net,args.save_path,simp=True,batch=args.batch)
1.5 查看onnx模型
当将pytorch模型保存为 ONNX 之后,可以使用一款名为 Netron 的软件打开 .onnx 文件,查看模型结构。
2. 参考文章
[2] pytorch-onnx-tensorrt全链路简单教程(支持动态输入)
[3] PyTorch语义分割模型转ONNX以及对比转换后的效果(PyTorch2ONNX、Torch2ONNX、pth2onnx、pt2onnx、修改名称、转换、测试、加载ONNX、运行ONNX)