使用ONNX模型(java)

一、目标

本次探索的目标是探索一种将ONNX模型集成到Java中的方法,以便后期可以在联合仿真环境中加载和执行ONNX模型。

二、研究

什么是ONNX

在进行技术探索之前,我们需要了解ONNX的相关知识。

ONNX(Open Neural Network Exchange)是一种用于表示机器学习模型的开放式格式,可以将模型从一个框架转移到另一个框架。ONNX模型可以使用不同的工具和库进行加载和执行,例如TensorFlow、PyTorch、Caffe2等。在机器学习和人工智能领域,ONNX已成为一个流行的标准格式。由于其开放式和跨平台的特性,ONNX模型可以在不同的环境和设备上使用,例如移动设备、嵌入式系统、云计算平台等。

如下的图来自官方,可以看到有提供了Java的API:

由于ONNX Runtime是跨平台的高性能推理引擎,可以使用ONNX Runtime Java库可以方便地加载和执行ONNX模型。

下面是一个简单的代码示例,展示如何在Java系统中使用ONNX Runtime Java库加载和执行ONNX模型:

java 复制代码
import ai.onnxruntime.*;

// Load the model and create InferenceSession
String modelPath = "path/to/your/onnx/model";
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession(modelPath);

// Load and preprocess the input image inputTensor
...

// Run inference
OrtSession.Result outputs = session.run(inputTensor);
System.out.println(outputs.get(0).getTensor().getFloatBuffer().get(0));

实现

实现步骤如下:

  • 配置ONNX Runtime Java库
  • 将ONNX模型加载到系统中
  • 设置输入,并验证输出

配置POM

xml 复制代码
<dependency>
    <groupId>com.microsoft.onnxruntime</groupId>
    <artifactId>onnxruntime</artifactId>
    <version>1.15.1</version>
</dependency>

相关代码

scss 复制代码
public class LoadACloopOnnx {
    private static final String DEFAULT_MODEL= "onnx/simple/simple_model.onnx";

    public static void main(String[] args) {
        try (OrtEnvironment env = OrtEnvironment.getEnvironment();
             OrtSession session = env.createSession(getResource(DEFAULT_MODEL).toString(),new OrtSession.SessionOptions())){
            for (String name: session.getInputNames()) {
                System.out.println("输入: " + session.getInputInfo().get(name));
            }
            for (String name: session.getOutputNames()) {
                System.out.println("输出: " + session.getOutputInfo().get(name));
            }
        
            Optional.ofNullable(session.getInputInfo().keySet())
                    .orElse(Collections.emptySet())
                    .stream()
                    .findFirst()
                    .ifPresent(key->{
                        try {
                            NodeInfo nodeInfo = session.getInputInfo().get(key);
                            if (nodeInfo.getInfo() instanceof  TensorInfo) {
                                Map<String,OnnxTensor> stringOnnxTensorMap = new HashMap<>();
                                stringOnnxTensorMap.put("input1",OnnxTensor.createTensor(env,new float[]{1}));
                                stringOnnxTensorMap.put("input2",OnnxTensor.createTensor(env,new float[]{2}));
                                try (OrtSession.Result result = session.run(stringOnnxTensorMap)){
                                    for (Map.Entry<String, OnnxValue> entry : result) {
                                        System.out.println(String.format("结果项[%s]",entry.getKey()));
                                        System.out.println("信息:"+entry.getValue().getInfo());
                                        System.out.println("类型:"+entry.getValue().getType());
                                        printMultiArrayHelper(entry.getValue().getValue(),"");
                                    }
                                }
                            }
                        } catch (OrtException e) {
                            e.printStackTrace();
                        }
                    });
        } catch (OrtException e) {
            e.printStackTrace();
        }
    }



    private static void printMultiArrayHelper(Object array, String indent) {
        if (array == null) {
            System.out.println("null");
            return;
        }

        Class<?> componentType = array.getClass().getComponentType();
        if (!componentType.isArray()) {
            System.out.print(indent);
            System.out.print("[ ");
            for (int i = 0; i < Array.getLength(array); i++) {
                if (i > 0) {
                    System.out.print(", ");
                }
                System.out.print(Array.get(array, i));
            }
            System.out.println(" ]");
        } else {
            System.out.println(indent + "[");
            for (int i = 0; i < Array.getLength(array); i++) {
                Object subArray = Array.get(array, i);
                printMultiArrayHelper(subArray, indent + "  ");
            }
            System.out.println(indent + "]");
        }
    }


    private static Path getResource(String name) {
        return Paths.get("src/main/resources").toAbsolutePath().resolve(name);
    }
}

输出结果

ini 复制代码
输入: NodeInfo(name=input1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1]))
输入: NodeInfo(name=input2,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1]))
输出: NodeInfo(name=output,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1]))
结果项[output]
信息:TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1])
类型:ONNX_TYPE_TENSOR
[ 3.0 ]

三、结论

使用ONNX Runtime Java库可以方便地将ONNX模型集成到Java环境中,并与其他子系统进行交互,使用ONNX模型可以方便地在不同的环境和设备上共享和使用。Java系统可以尝试使用该库来加载和执行ONNX模型,并进行集成。

最后,异构模型的支持是大势所趋,联合仿真系统应该积极探索和尝试新的技术和方法,以不断提升系统的性能和功能。ONNX作为一种先进的机器学习模型表示格式,将为联合仿真系统的发展带来新的机遇和挑战,可拓展联合仿真系统的适用范围。

四、参考文档

Get Started with ORT for Java

相关推荐
红尘散仙6 小时前
我把终端小说阅读器接上了 AI Agent:TRNovel 现在能用 skill 生成书源了
人工智能·后端·rust
卷毛的技术笔记7 小时前
告别硬编码!Spring AI Alibaba 实现 AI Agent 智能工具调用(Tool Calling)
java·人工智能·后端·python·spring·ai编程
会编程的土豆7 小时前
Go 语言反射(Reflection)详解
开发语言·后端·golang
喵个咪8 小时前
GoWind Toolkit Go后端代码生成 完整全流程实战
后端·go·orm
basketball6168 小时前
Go 语言从入门到进阶:4. 数组和MAP使用方法总结
开发语言·后端·golang
qq_2518364578 小时前
SpringBoot+Vue 共享电池柜管理系统 完整实现 前后端分离项目实战 完整代码
vue.js·spring boot·后端
zhangxingchao9 小时前
AI 大模型核心六:量化、Workflow 与 Agent、多轮 RAG
前端·人工智能·后端
IT_陈寒10 小时前
Vite打包时遇到的坑,原来问题出在这里
前端·人工智能·后端
ayqy贾杰11 小时前
基层管理的三板斧,在AI时代行不通了
前端·后端·团队管理
Apifox11 小时前
Apifox 5 月更新|Postman 导入优化、Runner 支持非 root 运行、请求代码自动带鉴权
前端·后端·安全