使用ONNX模型(java)

一、目标

本次探索的目标是探索一种将ONNX模型集成到Java中的方法,以便后期可以在联合仿真环境中加载和执行ONNX模型。

二、研究

什么是ONNX

在进行技术探索之前,我们需要了解ONNX的相关知识。

ONNX(Open Neural Network Exchange)是一种用于表示机器学习模型的开放式格式,可以将模型从一个框架转移到另一个框架。ONNX模型可以使用不同的工具和库进行加载和执行,例如TensorFlow、PyTorch、Caffe2等。在机器学习和人工智能领域,ONNX已成为一个流行的标准格式。由于其开放式和跨平台的特性,ONNX模型可以在不同的环境和设备上使用,例如移动设备、嵌入式系统、云计算平台等。

如下的图来自官方,可以看到有提供了Java的API:

由于ONNX Runtime是跨平台的高性能推理引擎,可以使用ONNX Runtime Java库可以方便地加载和执行ONNX模型。

下面是一个简单的代码示例,展示如何在Java系统中使用ONNX Runtime Java库加载和执行ONNX模型:

java 复制代码
import ai.onnxruntime.*;

// Load the model and create InferenceSession
String modelPath = "path/to/your/onnx/model";
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession(modelPath);

// Load and preprocess the input image inputTensor
...

// Run inference
OrtSession.Result outputs = session.run(inputTensor);
System.out.println(outputs.get(0).getTensor().getFloatBuffer().get(0));

实现

实现步骤如下:

  • 配置ONNX Runtime Java库
  • 将ONNX模型加载到系统中
  • 设置输入,并验证输出

配置POM

xml 复制代码
<dependency>
    <groupId>com.microsoft.onnxruntime</groupId>
    <artifactId>onnxruntime</artifactId>
    <version>1.15.1</version>
</dependency>

相关代码

scss 复制代码
public class LoadACloopOnnx {
    private static final String DEFAULT_MODEL= "onnx/simple/simple_model.onnx";

    public static void main(String[] args) {
        try (OrtEnvironment env = OrtEnvironment.getEnvironment();
             OrtSession session = env.createSession(getResource(DEFAULT_MODEL).toString(),new OrtSession.SessionOptions())){
            for (String name: session.getInputNames()) {
                System.out.println("输入: " + session.getInputInfo().get(name));
            }
            for (String name: session.getOutputNames()) {
                System.out.println("输出: " + session.getOutputInfo().get(name));
            }
        
            Optional.ofNullable(session.getInputInfo().keySet())
                    .orElse(Collections.emptySet())
                    .stream()
                    .findFirst()
                    .ifPresent(key->{
                        try {
                            NodeInfo nodeInfo = session.getInputInfo().get(key);
                            if (nodeInfo.getInfo() instanceof  TensorInfo) {
                                Map<String,OnnxTensor> stringOnnxTensorMap = new HashMap<>();
                                stringOnnxTensorMap.put("input1",OnnxTensor.createTensor(env,new float[]{1}));
                                stringOnnxTensorMap.put("input2",OnnxTensor.createTensor(env,new float[]{2}));
                                try (OrtSession.Result result = session.run(stringOnnxTensorMap)){
                                    for (Map.Entry<String, OnnxValue> entry : result) {
                                        System.out.println(String.format("结果项[%s]",entry.getKey()));
                                        System.out.println("信息:"+entry.getValue().getInfo());
                                        System.out.println("类型:"+entry.getValue().getType());
                                        printMultiArrayHelper(entry.getValue().getValue(),"");
                                    }
                                }
                            }
                        } catch (OrtException e) {
                            e.printStackTrace();
                        }
                    });
        } catch (OrtException e) {
            e.printStackTrace();
        }
    }



    private static void printMultiArrayHelper(Object array, String indent) {
        if (array == null) {
            System.out.println("null");
            return;
        }

        Class<?> componentType = array.getClass().getComponentType();
        if (!componentType.isArray()) {
            System.out.print(indent);
            System.out.print("[ ");
            for (int i = 0; i < Array.getLength(array); i++) {
                if (i > 0) {
                    System.out.print(", ");
                }
                System.out.print(Array.get(array, i));
            }
            System.out.println(" ]");
        } else {
            System.out.println(indent + "[");
            for (int i = 0; i < Array.getLength(array); i++) {
                Object subArray = Array.get(array, i);
                printMultiArrayHelper(subArray, indent + "  ");
            }
            System.out.println(indent + "]");
        }
    }


    private static Path getResource(String name) {
        return Paths.get("src/main/resources").toAbsolutePath().resolve(name);
    }
}

输出结果

ini 复制代码
输入: NodeInfo(name=input1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1]))
输入: NodeInfo(name=input2,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1]))
输出: NodeInfo(name=output,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1]))
结果项[output]
信息:TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1])
类型:ONNX_TYPE_TENSOR
[ 3.0 ]

三、结论

使用ONNX Runtime Java库可以方便地将ONNX模型集成到Java环境中,并与其他子系统进行交互,使用ONNX模型可以方便地在不同的环境和设备上共享和使用。Java系统可以尝试使用该库来加载和执行ONNX模型,并进行集成。

最后,异构模型的支持是大势所趋,联合仿真系统应该积极探索和尝试新的技术和方法,以不断提升系统的性能和功能。ONNX作为一种先进的机器学习模型表示格式,将为联合仿真系统的发展带来新的机遇和挑战,可拓展联合仿真系统的适用范围。

四、参考文档

Get Started with ORT for Java

相关推荐
LIUAWEIO18 小时前
Unix 时间戳换算
前端·后端·unix·database
whinc1 天前
Rust技术周刊 2026年第17周
后端·rust
whinc1 天前
Rust技术周刊 2026年第18周
后端·rust
whinc1 天前
Rust技术周刊 2026年第16周
后端·rust
jieyucx1 天前
Go语言深度解剖:Map扩容机制全解析(增量扩容+等量扩容+渐进式迁移)
开发语言·后端·golang·map·扩容策略
王码码20351 天前
Go语言的内存管理:原理与实战
后端·golang·go·接口
Lee川1 天前
打字机是怎么炼成的:Chat 流式输出深度解析
前端·后端·面试
Lee川1 天前
Token 无感刷新与 Logout:前端安全会话管理实战
前端·后端·react.js
舒一笑1 天前
零后端、零数据库——我做了一个让 10000+ 人成功告白的开源工具
后端·产品·设计师
Java技术小馆1 天前
如何零成本将各种 AI 编程工具接入免费大模型?
后端