4.8 构建onnx结构模型-Less

前言

构建onnx方式通常有两种:

1、通过代码转换成onnx结构,比如pytorch ---> onnx

2、通过onnx 自定义结点,图,生成onnx结构

本文主要是简单学习和使用两种不同onnx结构,

下面以 Less 结点进行分析

方式

方法一:pytorch --> onnx

暂缓,主要研究方式二

方法二: onnx

cpp 复制代码
import onnx 
from onnx import TensorProto, helper, numpy_helper
import numpy as np

def run():
    print("run start....\n")

    less = helper.make_node(
        "Less",
        name="Less_0",
        inputs=["input1", "input2"],
        outputs=["output1"],
    )
    input1_data = np.load("./tensor.npy") # 16, 397
    # input1_data = np.load("./data.npy")  # 16, 398 test
    # print(f"input1_data shape:{input1_data.shape}\n")
    # input1_data = np.zeros((16,398))
    initializer = [ 
        helper.make_tensor("input1", TensorProto.FLOAT, [16,397], input1_data)
    ]

    cast_nodel = helper.make_node(
            op_type="Cast",
            inputs=["output1"],
            outputs=["output2"],
            name="test_cast",
            to=TensorProto.FLOAT,
        )
    value_info = helper.make_tensor_value_info(
            "output2", TensorProto.BOOL, [16,397])

    graph = helper.make_graph(
        nodes=[less, cast_nodel],
        name="test_graph",
        inputs=[helper.make_tensor_value_info(
            "input2", TensorProto.FLOAT, [16,1]
        )],
        outputs=[helper.make_tensor_value_info(
            "output2",TensorProto.FLOAT, [16,397]
        )],
        initializer=initializer,
        value_info=[value_info],
    )

    op = onnx.OperatorSetIdProto()
    op.version = 11
    model = helper.make_model(graph, opset_imports=[op])
    model.ir_version = 8
    print("run done....\n")
    return model

if __name__ == "__main__":
    model = run()
    onnx.save(model, "./test_less_ori.onnx")

run

cpp 复制代码
import onnx
import onnxruntime
import numpy as np


# 检查onnx计算图
def check_onnx(mdoel):
    onnx.checker.check_model(model)
    # print(onnx.helper.printable_graph(model.graph))

def run(model):
    print(f'run start....\n')
    session = onnxruntime.InferenceSession(model,providers=['CPUExecutionProvider'])
    input_name1 = session.get_inputs()[0].name  
    input_data1= np.random.randn(16,1).astype(np.float32)
    print(f'input_data1 shape:{input_data1.shape}\n')

    output_name1 = session.get_outputs()[0].name

    pred_onx = session.run(
    [output_name1], {input_name1: input_data1})[0]

    print(f'pred_onx shape:{pred_onx.shape} \n')

    print(f'run end....\n')


if __name__ == '__main__':
    path = "./test_less_ori.onnx"
    model = onnx.load("./test_less_ori.onnx")
    check_onnx(model)
    run(path)
相关推荐
周杰伦_Jay10 小时前
【RocketMQ全面解析】架构原理、消费类型、性能优化、环境搭建
性能优化·架构·rocketmq
wx_ywyy679811 小时前
短剧APP开发性能优化专项:首屏加载提速技术拆解
性能优化·短剧app·短剧系统开发·短剧app开发·短剧app系统开发·短剧开发·短剧app开发性能优化
rengang6612 小时前
智能化的重构建议:大模型分析代码结构,提出可读性和性能优化建议
人工智能·性能优化·重构·ai编程
Wang's Blog16 小时前
MySQL: 高并发电商场景下的数据库架构演进与性能优化实践
mysql·性能优化·数据库架构
潘达斯奈基~1 天前
spark性能优化1:通过依赖关系重组优化Spark性能:宽窄依赖集中处理实践
大数据·性能优化·spark
W_chuanqi1 天前
RDEx:一种效果驱动的混合单目标优化器,自适应选择与融合多种算子与策略
人工智能·算法·机器学习·性能优化
fruge1 天前
2025前端工程化与性能优化实战指南:从构建到监控的全链路方案
前端·性能优化
武子康1 天前
Java-152 深入浅出 MongoDB 索引详解 从 MongoDB B-树 到 MySQL B+树 索引机制、数据结构与应用场景的全面对比分析
java·开发语言·数据库·sql·mongodb·性能优化·nosql
UTwelve1 天前
【UE】材质与半透明 - 00.什么是半透明材质
性能优化·材质
Mr YiRan2 天前
多线程性能优化基础
android·java·开发语言·性能优化