前言
获取onnx每个节点的结果,进行输出显示、保存
Code
cpp
import os
import onnx
import onnx.helper as helper
import onnxruntime
from collections import OrderedDict
import numpy as np
def get_onnx_node_out(onnx_file, save_onnx):
model = onnx.load(onnx_file)
out_names=[]
for i, node in enumerate(model.graph.node):
out_names.append(node.output[0])
for out_name in out_names:
intermediate_layer_value_info = helper.ValueInfoProto()
intermediate_layer_value_info.name = out_name
model.graph.output.append(intermediate_layer_value_info)
onnx.save(model, save_onnx)
def onnxruntime_infer(onnx_path, input_data, output_name="output"):
session = onnxruntime.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
input_name = session.get_inputs()[0].name
outputs = [x.name for x in session.get_outputs()]
print("onnx input_name:", input_name)
print("onnx outputs:", outputs)
ort_outs = session.run(outputs, {input_name: input_data})
ort_outs = OrderedDict(zip(outputs, ort_outs))
# For debug
for key in ort_outs:
val = ort_outs[key]
file = "./onnx_output/"+ key.split("/")[-1] +".npy"
np.save(file, val, allow_pickle=True, fix_imports=True)
if __name__ == '__main__':
base_path = "./"
onnx_file = os.path.join(base_path,"example4.onnx")
save_onnx = os.path.join(base_path,"example4_out.onnx")
get_onnx_node_out(onnx_file, save_onnx)
path = "./10.npy" #
input_data = np.load(path)
print(f"input_data shape:{input_data.shape}")
onnxruntime_infer(save_onnx, input_data)
总结
- 相关代码简单运用