Python工具将本地Embedding转换成onnx格式

不想依赖于第三方远程api的方式。

比如之前写过的调用重排模型的方式,如果后端使用该模型,需要经过python 提供接口才能调用大模型,这样服务器就需要维护2套项目。为了解决这个问题,采用下面的方式。

(简单提一下:虽然目前不管是 重排、向量、模型都可以转onnx格式,但是目前我后端使用的是LangChain4j,它仅提供了向量模型使用onnx格式的方法)

好了,开始将向量模型转换成onnx吧~~~

一:使用pyCharm工具安装依赖

这里建议使用python3.8版本,尝试了高版本执行此命令不支持!!!

python 复制代码
pip install optimum onnx onnxruntime sentence-transformers

二: 下载模型到本地,可以放到当前项目目录。

三:模型下载后,转换为onnx格式

--model (指定的model) (解压后的onnx存放目录名)

python 复制代码
optimum-cli export onnx --task sentence-similarity --model Yuan-embedding-1.0 onnx-yuan-embedding-1.0

四、LangChain4j本地运行Onnx模型

1、使用转换后的ONNX模型进行文本嵌入

java 复制代码
/**
     * 测试本地ONNX嵌入模型 (验证本地ONNX格式嵌入模型数据的功能)
     */
    @Test
    public void testLocalEmbeddingInsert(){
        // 从文件系统加载文档
        Document document = FileSystemDocumentLoader.loadDocument("E:\\hotelquestion.txt");
        String saveDir = "D:/develop/workspace_python/onnx_server/onnx-yuan-embedding-1.0";
        // 构建嵌入模型
        OnnxEmbeddingModel onnxEmbeddingModel = new OnnxEmbeddingModel(saveDir + "/model.onnx", saveDir + "/tokenizer.json", PoolingMode.MEAN);
        System.out.println("Embedding dimension: " + onnxEmbeddingModel.dimension());
        // 创建文档分段器,将文档分割成最大长度为300,重叠部分为10的段落
        DocumentSplitter documentSplitter = DocumentSplitters.recursive(300, 10);
        // 执行分段
        List<TextSegment> textSegments = documentSplitter.split(document);
        textSegments.forEach(segment -> {
            Metadata metadata = segment.metadata();
            metadata.put(Constants.KB_ID, 1l);
        });
        System.out.println("文档分段数量: " + textSegments.size());

        // 创建内存中的嵌入存储(可替换为milvus等持久化存储)
        EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
        embeddingStore = MilvusEmbeddingStore.builder()
                .host("192.168.1.19")
                .port(19530)
                .databaseName("default")//默认数据库
                .collectionName("vector_store_test")//表
                .dimension(onnxEmbeddingModel.dimension())
                .idFieldName("embeddingId")// 指定存储唯一标识符的字段名
                .textFieldName("text")// 指定存储原始文本的字段名
                .vectorFieldName("embedding") // 指定存储向量数据的字段名
                .metadataFieldName("metadata")// 指定存储元数据的字段名
                .build();

        // 逐个处理文本段落,生成嵌入并存储
        for (int i = 0; i < textSegments.size(); i++) {
            TextSegment textSegment = textSegments.get(i);
            // 为单个文本段落生成嵌入向量
            Response<Embedding> embeddingResponse = onnxEmbeddingModel.embed(textSegment);
            Embedding embedding = embeddingResponse.content();
            // 将嵌入向量和文本段落添加到向量数据库中
            embeddingStore.add(embedding, textSegment);
            System.out.println("已存储第 " + (i + 1) + " 个段落");
        }

    }

效果

2、使用转换后的ONNX模型进行向量搜索

java 复制代码
/**
     * 测试本地ONNX嵌入模型 (验证本地ONNX格式嵌入模型的功能)
     */
    @Test
    public void testLocalEmbedding() {
        String pathToModel = "D:\\develop\\workspace_python\\onnx_server\\onnx-yuan-embedding-1.0\\model.onnx";
        String pathToTokenizer = "D:\\develop\\workspace_python\\onnx_server\\onnx-yuan-embedding-1.0\\tokenizer.json";
        PoolingMode poolingMode = PoolingMode.MEAN;
        EmbeddingModel embeddingModel = new OnnxEmbeddingModel(pathToModel, pathToTokenizer, poolingMode);
        // 生成文本嵌入
        Response<Embedding> response = embeddingModel.embed("酒店的标准退房时间是几点");
        Embedding embedding = response.content();
        System.out.println(response.content().dimension());
        // 创建Milvus向量数据库存储
        MilvusEmbeddingStore embeddingStore = MilvusEmbeddingStore.builder()
                .host("192.168.1.100")
                .port(19530)
                .databaseName("default")//默认数据库
                .collectionName("vector_store_test")//表
                .idFieldName("embeddingId")// 指定存储唯一标识符的字段名
                .textFieldName("text")// 指定存储原始文本的字段名
                .vectorFieldName("embedding") // 指定存储向量数据的字段名
                .metadataFieldName("metadata")// 指定存储元数据的字段名
                .build();
        // 构建搜索请求
        EmbeddingSearchRequest request = EmbeddingSearchRequest.builder().queryEmbedding(embedding).filter(new IsEqualTo("kb_id", 1l)).build();
        // 执行搜索
        EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
        System.out.println(result.matches());

    }
相关推荐
小沈同学呀3 天前
SpringAI+MCPServer实战-StreamableHTTP协议打造企业级AI工具服务
人工智能·微服务架构·springai·mcpserver·javaai·streamablehttp
莫逸风4 天前
【AgentScope】6.文件系统(Filesystem)详解
开发语言·windows·springai·agentscope·agnet
中草药z4 天前
【RAG】工程化实战:全链路原理复盘 + 方案选型 + 实战高阶玩法
java·深度学习·机器学习·阿里云·rag·springai
阿昌喜欢吃黄桃8 天前
Java优质开源AI项目
java·ai·langchain·开源·rag·springai·langchain4j
流放深圳8 天前
抓住 AI 人工智能的风口之第 5 章 —— 使用视觉大模型(Vision-Language Model)支持图片识别,完善电商智能客服项目
人工智能·视觉大模型·图片识别·springai·vision-language
莫逸风9 天前
【AgentScope】3. 工作空间(Workspace)详解
java·ai·agent·springai·agentscope
莫逸风11 天前
【AgentScope】1. HarnessAgent 总览详解
springai·agentscope·agnet
Maiko Star12 天前
理解 RAG 的“为什么”与 Spring AI 实战初体验
人工智能·rag·springai
Maiko Star12 天前
SpringAI 模型 API 调用中的错误处理、重试与熔断降级实战
错误处理·springai
装不满的克莱因瓶15 天前
SpringAI Alibaba Tool工具调用机制实战-注解注册与函数调用全流程
人工智能·ai·tools·智能体·springai·tool