spring-ai-alibaba 简化版NL2SQL

在前面的文章中分析过spring-ai-alibaba-starter-nl2sql(spring-ai-alibaba 1.0.0.2 学习(十五)------自然语言生成sql-CSDN博客)也提到过nl2sql虽然效果好,但是使用起来确实并不太方便,并不能即插即用,所以想能不能设计一个简化版的nl2sql

设计方案

最初是想做成一个advisor,但是后来觉得行不通,advisor会对所有对话进行拦截,而对话中可能只有部分是希望转换为sql的需求

那就只能采用工具的方式进行注册,由大模型按需调用

虽然说是设计一个简化版,但是太简单了也就没有实用性了

所以还是需要保留相当一部分扩展性,最终添加了许多可插拔的组件接口,这里参考了spring-ai的RetrievalAugmentaionAdvisor的实现,具体组件接口后面会详细说明

大概的流程就是根据用户查询来检索相关表结构,然后调用大模型生成sql,整个流程中每个环节前后会预留用户可介入的接口

简单调用方式

java 复制代码
    @GetMapping("/tool")
    public String tool(String input) {
        return chatClient.prompt()
                .tools(new SimpleNl2SqlTool(schemaRetriever, chatClient))
                .user(input)
                .call()
                .content();
    }

其中,入参schemaRetriever负责获取schema信息,chatClient负责调用大模型生成sql。

复杂调用方式

java 复制代码
SimpleNl2SqlTool(DocumentRetriever schemaRetriever, 
        PromptTemplate promptTemplate, 
        DocumentRetriever infoRetriever, 
        QueryAugmenter queryAugmenter, 
        DocumentPostProcessor documentPostProcessor, 
        QueryTransformer queryTransformer, 
        ChatClient chatClient)

其实调用方式是一样的,只是在创建工具的时候需要添加许多额外的组件(由用户自己按需实现):

复制代码
schemaRetriever:数据库表结构检索器,如果表比较少,甚至可以不使用向量库检索而是全量返回
复制代码
queryTransformer:查询转换器,流程中的第一步,可以对用户输入进行增强,例如补充数据库dialect信息,添加sql设定(如关联表不超过3张,不选取无关字段,join代替子查询等),将其中的本月本年等词转换为具体日期等
复制代码
infoRetriever:补充信息检索器,补充信息可以是相关表未在数据库中进行设置的外键,也可以是相关业务术语的解释,这里不一定是检索向量库,内部实现甚至可以写死返回固定的Document
复制代码
queryAugmenter:根据检索到的补充信息对用户输入再次进行增强
复制代码
documentPostProcessor:对检索到的schema信息进行增强,比如重排序等
复制代码
promptTemplate:自定义的系统消息提示词模板,需预留schema占位符

以上就是预留的各环节前后,用户可自行介入的接口,也可以参照spring-ai-alibaba-starter-nl2sql中的各种增强措施,工具类具体代码实现如下

java 复制代码
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class SimpleNl2SqlTool {

    private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
            你是一名数据库专家,参考给出的schema信息,为需求编写一条可执行的sql语句。
            注意:请只返回sql语句。如果无法根据给出的schema信息做出回答,则输出"不支持回答该问题"。
            schema 如下,以----------环绕
            ----------
            {schema}
            ----------
            """);

    private static final Logger log = LoggerFactory.getLogger(SimpleNl2SqlTool.class);

    /**
     * 获取数据库的schema信息
     */
    private DocumentRetriever schemaRetriever;

    private PromptTemplate promptTemplate;

    /**
     * 获取补充信息,例如相关表未设置的外键,甚至相关业务术语的解释
     */
    private DocumentRetriever infoRetriever;

    /**
     * 根据补充信息对用户需求进行补充增强
     */
    private QueryAugmenter queryAugmenter;

    /**
     * 对数据库的schema信息进行后处理, 例如重排序等
     */
    private DocumentPostProcessor documentPostProcessor;

    /**
     * 对请求进行补充增强,例如添加默认的数据库类型dialect
     */
    private QueryTransformer queryTransformer;

    private ChatClient chatClient;

    public SimpleNl2SqlTool(DocumentRetriever schemaRetriever, PromptTemplate promptTemplate, DocumentRetriever infoRetriever, QueryAugmenter queryAugmenter, DocumentPostProcessor documentPostProcessor, QueryTransformer queryTransformer, ChatClient chatClient) {
        Assert.notNull(schemaRetriever, "schemaRetriever can not be null");
        Assert.notNull(promptTemplate, "promptTemplate can not be null");
        Assert.notNull(chatClient, "chatClient can not be null");
        this.schemaRetriever = schemaRetriever;
        this.promptTemplate = promptTemplate;
        this.infoRetriever = infoRetriever;
        this.queryAugmenter = queryAugmenter;
        this.documentPostProcessor = documentPostProcessor;
        this.queryTransformer = queryTransformer;
        this.chatClient = chatClient;
    }

    public SimpleNl2SqlTool(DocumentRetriever schemaRetriever, ChatClient chatClient) {
        this.schemaRetriever = schemaRetriever;
        promptTemplate = DEFAULT_PROMPT_TEMPLATE;
        this.chatClient = chatClient;
    }

    @Tool(description = "将需求转换为sql")
    public String nl2sql(@ToolParam(description = "希望转换为sql的需求描述") String request) {
        Query query = new Query(request);

        query = queryTransformer == null ? query : queryTransformer.transform(query);

        List<Document> infoDocuments = infoRetriever == null ? List.of() : infoRetriever.retrieve(query);
        query = queryAugmenter == null ? query : queryAugmenter.augment(query, infoDocuments);

        List<Document> schemaDocuments = schemaRetriever.retrieve(query);
        schemaDocuments = documentPostProcessor == null ? schemaDocuments : documentPostProcessor.process(query, schemaDocuments);
        String schema = joinDocument(schemaDocuments);

        String systemMessage = promptTemplate.render(Map.of("schema", schema));

        return chatClient.prompt().system(systemMessage).user(query.text()).call().content();
    }

    private static String joinDocument(List<Document> documents) {
        return documents.stream().map(Document::getText)
                .collect(Collectors.joining(System.lineSeparator()));
    }


}
相关推荐
老神在在00115 分钟前
SpringMVC1
java·前端·学习·spring
程序猿小D1 小时前
[附源码+数据库+毕业论文+开题报告]基于Spring+MyBatis+MySQL+Maven+jsp实现的车辆运输管理系统,推荐!
java·数据库·mysql·spring·毕业设计·开题报告·车辆运输管理系统
Bug退退退1232 小时前
RabbitMQ 高级特性之消息分发
java·分布式·spring·rabbitmq
Jack_hrx3 小时前
基于 Drools 的规则引擎性能调优实践:架构、缓存与编译优化全解析
java·性能优化·规则引擎·drools·规则编译
二进制person4 小时前
数据结构--准备知识
java·开发语言·数据结构
半梦半醒*4 小时前
H3CNE综合实验之机器人
java·开发语言·网络
消失的旧时光-19434 小时前
Android模块化架构:基于依赖注入和服务定位器的解耦方案
android·java·架构·kotlin
@ chen5 小时前
Spring Boot 解决跨域问题
java·spring boot·后端
洛_尘5 小时前
Java EE进阶2:前端 HTML+CSS+JavaScript
java·前端·java-ee
转转技术团队6 小时前
转转上门隐私号系统的演进
java·后端