spring-ai-alibaba 简化版NL2SQL

在前面的文章中分析过spring-ai-alibaba-starter-nl2sql(spring-ai-alibaba 1.0.0.2 学习(十五)------自然语言生成sql-CSDN博客)也提到过nl2sql虽然效果好,但是使用起来确实并不太方便,并不能即插即用,所以想能不能设计一个简化版的nl2sql

设计方案

最初是想做成一个advisor,但是后来觉得行不通,advisor会对所有对话进行拦截,而对话中可能只有部分是希望转换为sql的需求

那就只能采用工具的方式进行注册,由大模型按需调用

虽然说是设计一个简化版,但是太简单了也就没有实用性了

所以还是需要保留相当一部分扩展性,最终添加了许多可插拔的组件接口,这里参考了spring-ai的RetrievalAugmentaionAdvisor的实现,具体组件接口后面会详细说明

大概的流程就是根据用户查询来检索相关表结构,然后调用大模型生成sql,整个流程中每个环节前后会预留用户可介入的接口

简单调用方式

java 复制代码
    @GetMapping("/tool")
    public String tool(String input) {
        return chatClient.prompt()
                .tools(new SimpleNl2SqlTool(schemaRetriever, chatClient))
                .user(input)
                .call()
                .content();
    }

其中,入参schemaRetriever负责获取schema信息,chatClient负责调用大模型生成sql。

复杂调用方式

java 复制代码
SimpleNl2SqlTool(DocumentRetriever schemaRetriever, 
        PromptTemplate promptTemplate, 
        DocumentRetriever infoRetriever, 
        QueryAugmenter queryAugmenter, 
        DocumentPostProcessor documentPostProcessor, 
        QueryTransformer queryTransformer, 
        ChatClient chatClient)

其实调用方式是一样的,只是在创建工具的时候需要添加许多额外的组件(由用户自己按需实现):

复制代码
schemaRetriever:数据库表结构检索器,如果表比较少,甚至可以不使用向量库检索而是全量返回
复制代码
queryTransformer:查询转换器,流程中的第一步,可以对用户输入进行增强,例如补充数据库dialect信息,添加sql设定(如关联表不超过3张,不选取无关字段,join代替子查询等),将其中的本月本年等词转换为具体日期等
复制代码
infoRetriever:补充信息检索器,补充信息可以是相关表未在数据库中进行设置的外键,也可以是相关业务术语的解释,这里不一定是检索向量库,内部实现甚至可以写死返回固定的Document
复制代码
queryAugmenter:根据检索到的补充信息对用户输入再次进行增强
复制代码
documentPostProcessor:对检索到的schema信息进行增强,比如重排序等
复制代码
promptTemplate:自定义的系统消息提示词模板,需预留schema占位符

以上就是预留的各环节前后,用户可自行介入的接口,也可以参照spring-ai-alibaba-starter-nl2sql中的各种增强措施,工具类具体代码实现如下

java 复制代码
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class SimpleNl2SqlTool {

    private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
            你是一名数据库专家,参考给出的schema信息,为需求编写一条可执行的sql语句。
            注意:请只返回sql语句。如果无法根据给出的schema信息做出回答,则输出"不支持回答该问题"。
            schema 如下,以----------环绕
            ----------
            {schema}
            ----------
            """);

    private static final Logger log = LoggerFactory.getLogger(SimpleNl2SqlTool.class);

    /**
     * 获取数据库的schema信息
     */
    private DocumentRetriever schemaRetriever;

    private PromptTemplate promptTemplate;

    /**
     * 获取补充信息,例如相关表未设置的外键,甚至相关业务术语的解释
     */
    private DocumentRetriever infoRetriever;

    /**
     * 根据补充信息对用户需求进行补充增强
     */
    private QueryAugmenter queryAugmenter;

    /**
     * 对数据库的schema信息进行后处理, 例如重排序等
     */
    private DocumentPostProcessor documentPostProcessor;

    /**
     * 对请求进行补充增强,例如添加默认的数据库类型dialect
     */
    private QueryTransformer queryTransformer;

    private ChatClient chatClient;

    public SimpleNl2SqlTool(DocumentRetriever schemaRetriever, PromptTemplate promptTemplate, DocumentRetriever infoRetriever, QueryAugmenter queryAugmenter, DocumentPostProcessor documentPostProcessor, QueryTransformer queryTransformer, ChatClient chatClient) {
        Assert.notNull(schemaRetriever, "schemaRetriever can not be null");
        Assert.notNull(promptTemplate, "promptTemplate can not be null");
        Assert.notNull(chatClient, "chatClient can not be null");
        this.schemaRetriever = schemaRetriever;
        this.promptTemplate = promptTemplate;
        this.infoRetriever = infoRetriever;
        this.queryAugmenter = queryAugmenter;
        this.documentPostProcessor = documentPostProcessor;
        this.queryTransformer = queryTransformer;
        this.chatClient = chatClient;
    }

    public SimpleNl2SqlTool(DocumentRetriever schemaRetriever, ChatClient chatClient) {
        this.schemaRetriever = schemaRetriever;
        promptTemplate = DEFAULT_PROMPT_TEMPLATE;
        this.chatClient = chatClient;
    }

    @Tool(description = "将需求转换为sql")
    public String nl2sql(@ToolParam(description = "希望转换为sql的需求描述") String request) {
        Query query = new Query(request);

        query = queryTransformer == null ? query : queryTransformer.transform(query);

        List<Document> infoDocuments = infoRetriever == null ? List.of() : infoRetriever.retrieve(query);
        query = queryAugmenter == null ? query : queryAugmenter.augment(query, infoDocuments);

        List<Document> schemaDocuments = schemaRetriever.retrieve(query);
        schemaDocuments = documentPostProcessor == null ? schemaDocuments : documentPostProcessor.process(query, schemaDocuments);
        String schema = joinDocument(schemaDocuments);

        String systemMessage = promptTemplate.render(Map.of("schema", schema));

        return chatClient.prompt().system(systemMessage).user(query.text()).call().content();
    }

    private static String joinDocument(List<Document> documents) {
        return documents.stream().map(Document::getText)
                .collect(Collectors.joining(System.lineSeparator()));
    }


}
相关推荐
HuiSoul2003 小时前
Spring MVC
java·后端·spring mvc
摇滚侠5 小时前
面试实战 问题二十四 Spring 框架中循环依赖问题的解决方法
java·后端·spring
三木水7 小时前
Spring-rabbit使用实战七
java·分布式·后端·spring·消息队列·java-rabbitmq·java-activemq
别来无恙1497 小时前
Spring Boot文件下载功能实现详解
java·spring boot·后端·数据导出
optimistic_chen7 小时前
【Java EE初阶 --- 网络原理】JVM
java·jvm·笔记·网络协议·java-ee
weixin_456904277 小时前
Java泛型与委托
java·spring boot·spring
悟能不能悟8 小时前
能刷java题的网站
java·开发语言
程序员陆通9 小时前
Java高并发场景下的缓存穿透问题定位与解决方案
java·开发语言·缓存
北执南念9 小时前
Java多线程基础总结
java
David爱编程10 小时前
JDK vs JRE:到底有什么本质区别?99% 的人都答不上来
java·后端