4.4 获取onnx每个节点的输出结果

前言

获取onnx每个节点的结果,进行输出显示、保存

Code

cpp 复制代码
import os
import onnx
import onnx.helper as helper
import onnxruntime
from collections import OrderedDict
import  numpy as np

def get_onnx_node_out(onnx_file, save_onnx):
    model = onnx.load(onnx_file)
    out_names=[]
    for i, node in enumerate(model.graph.node):
        out_names.append(node.output[0])
    for out_name in out_names:
        intermediate_layer_value_info = helper.ValueInfoProto()
        intermediate_layer_value_info.name = out_name
        model.graph.output.append(intermediate_layer_value_info)
    onnx.save(model, save_onnx)

def onnxruntime_infer(onnx_path, input_data, output_name="output"):
 
    session = onnxruntime.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
    input_name = session.get_inputs()[0].name
    outputs = [x.name for x in session.get_outputs()]
    print("onnx input_name:", input_name)
    print("onnx outputs:", outputs)
    ort_outs = session.run(outputs, {input_name: input_data})
    ort_outs = OrderedDict(zip(outputs, ort_outs))
 
    # For debug
    for key in ort_outs:
        val = ort_outs[key]
        file = "./onnx_output/"+ key.split("/")[-1] +".npy"
        np.save(file, val, allow_pickle=True, fix_imports=True)
     
if __name__ == '__main__':

    base_path = "./"
    onnx_file = os.path.join(base_path,"example4.onnx")
    save_onnx = os.path.join(base_path,"example4_out.onnx")
    get_onnx_node_out(onnx_file, save_onnx)

    path = "./10.npy"  # 
    input_data = np.load(path)
    print(f"input_data shape:{input_data.shape}")

    onnxruntime_infer(save_onnx, input_data)   

总结

  • 相关代码简单运用
相关推荐
程序媛徐师姐4 分钟前
Python基于深度学习的手写输入识别系统【附源码、文档说明】
python·深度学习·python深度学习·手写输入识别系统·python手写输入识别系统·python手写输入识别·深度学习手写输入识别
2301_7641505611 分钟前
c++如何读取和解析带BOM头的UTF-8与UTF-16文本流【详解】
jvm·数据库·python
qq_4240985615 分钟前
HTML函数开发用窄边框笔记本有优势吗_便携与性能权衡【指南】
jvm·数据库·python
Wyz2012102418 分钟前
CSS如何实现导航栏下划线随鼠标移动_利用-hover伪类与过渡动画控制
jvm·数据库·python
2201_7610405918 分钟前
SQL如何统计每个用户的首次行为时间_MIN聚合与分组
jvm·数据库·python
qq_1898070324 分钟前
mysql如何实现定时清理缓存数据_利用event scheduler执行
jvm·数据库·python
Polar__Star26 分钟前
golang如何实现低功耗设备唤醒机制_golang低功耗设备唤醒机制实现教程
jvm·数据库·python
a95114164228 分钟前
CSS怎么在flex布局中实现项目均分间距_设置justify-content space-evenly
jvm·数据库·python
2201_7610405934 分钟前
Golang如何做灰度发布_Golang灰度发布教程【实战】
jvm·数据库·python
baidu_3409988241 分钟前
CSS Grid布局如何实现项目在网格内填充_掌握justify-items属性
jvm·数据库·python