spring-ai-alibaba 简化版NL2SQL

在前面的文章中分析过spring-ai-alibaba-starter-nl2sql(spring-ai-alibaba 1.0.0.2 学习(十五)------自然语言生成sql-CSDN博客)也提到过nl2sql虽然效果好,但是使用起来确实并不太方便,并不能即插即用,所以想能不能设计一个简化版的nl2sql

设计方案

最初是想做成一个advisor,但是后来觉得行不通,advisor会对所有对话进行拦截,而对话中可能只有部分是希望转换为sql的需求

那就只能采用工具的方式进行注册,由大模型按需调用

虽然说是设计一个简化版,但是太简单了也就没有实用性了

所以还是需要保留相当一部分扩展性,最终添加了许多可插拔的组件接口,这里参考了spring-ai的RetrievalAugmentaionAdvisor的实现,具体组件接口后面会详细说明

大概的流程就是根据用户查询来检索相关表结构,然后调用大模型生成sql,整个流程中每个环节前后会预留用户可介入的接口

简单调用方式

java 复制代码
    @GetMapping("/tool")
    public String tool(String input) {
        return chatClient.prompt()
                .tools(new SimpleNl2SqlTool(schemaRetriever, chatClient))
                .user(input)
                .call()
                .content();
    }

其中,入参schemaRetriever负责获取schema信息,chatClient负责调用大模型生成sql。

复杂调用方式

java 复制代码
SimpleNl2SqlTool(DocumentRetriever schemaRetriever, 
        PromptTemplate promptTemplate, 
        DocumentRetriever infoRetriever, 
        QueryAugmenter queryAugmenter, 
        DocumentPostProcessor documentPostProcessor, 
        QueryTransformer queryTransformer, 
        ChatClient chatClient)

其实调用方式是一样的,只是在创建工具的时候需要添加许多额外的组件(由用户自己按需实现):

复制代码
schemaRetriever:数据库表结构检索器,如果表比较少,甚至可以不使用向量库检索而是全量返回
复制代码
queryTransformer:查询转换器,流程中的第一步,可以对用户输入进行增强,例如补充数据库dialect信息,添加sql设定(如关联表不超过3张,不选取无关字段,join代替子查询等),将其中的本月本年等词转换为具体日期等
复制代码
infoRetriever:补充信息检索器,补充信息可以是相关表未在数据库中进行设置的外键,也可以是相关业务术语的解释,这里不一定是检索向量库,内部实现甚至可以写死返回固定的Document
复制代码
queryAugmenter:根据检索到的补充信息对用户输入再次进行增强
复制代码
documentPostProcessor:对检索到的schema信息进行增强,比如重排序等
复制代码
promptTemplate:自定义的系统消息提示词模板,需预留schema占位符

以上就是预留的各环节前后,用户可自行介入的接口,也可以参照spring-ai-alibaba-starter-nl2sql中的各种增强措施,工具类具体代码实现如下

java 复制代码
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class SimpleNl2SqlTool {

    private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
            你是一名数据库专家,参考给出的schema信息,为需求编写一条可执行的sql语句。
            注意:请只返回sql语句。如果无法根据给出的schema信息做出回答,则输出"不支持回答该问题"。
            schema 如下,以----------环绕
            ----------
            {schema}
            ----------
            """);

    private static final Logger log = LoggerFactory.getLogger(SimpleNl2SqlTool.class);

    /**
     * 获取数据库的schema信息
     */
    private DocumentRetriever schemaRetriever;

    private PromptTemplate promptTemplate;

    /**
     * 获取补充信息,例如相关表未设置的外键,甚至相关业务术语的解释
     */
    private DocumentRetriever infoRetriever;

    /**
     * 根据补充信息对用户需求进行补充增强
     */
    private QueryAugmenter queryAugmenter;

    /**
     * 对数据库的schema信息进行后处理, 例如重排序等
     */
    private DocumentPostProcessor documentPostProcessor;

    /**
     * 对请求进行补充增强,例如添加默认的数据库类型dialect
     */
    private QueryTransformer queryTransformer;

    private ChatClient chatClient;

    public SimpleNl2SqlTool(DocumentRetriever schemaRetriever, PromptTemplate promptTemplate, DocumentRetriever infoRetriever, QueryAugmenter queryAugmenter, DocumentPostProcessor documentPostProcessor, QueryTransformer queryTransformer, ChatClient chatClient) {
        Assert.notNull(schemaRetriever, "schemaRetriever can not be null");
        Assert.notNull(promptTemplate, "promptTemplate can not be null");
        Assert.notNull(chatClient, "chatClient can not be null");
        this.schemaRetriever = schemaRetriever;
        this.promptTemplate = promptTemplate;
        this.infoRetriever = infoRetriever;
        this.queryAugmenter = queryAugmenter;
        this.documentPostProcessor = documentPostProcessor;
        this.queryTransformer = queryTransformer;
        this.chatClient = chatClient;
    }

    public SimpleNl2SqlTool(DocumentRetriever schemaRetriever, ChatClient chatClient) {
        this.schemaRetriever = schemaRetriever;
        promptTemplate = DEFAULT_PROMPT_TEMPLATE;
        this.chatClient = chatClient;
    }

    @Tool(description = "将需求转换为sql")
    public String nl2sql(@ToolParam(description = "希望转换为sql的需求描述") String request) {
        Query query = new Query(request);

        query = queryTransformer == null ? query : queryTransformer.transform(query);

        List<Document> infoDocuments = infoRetriever == null ? List.of() : infoRetriever.retrieve(query);
        query = queryAugmenter == null ? query : queryAugmenter.augment(query, infoDocuments);

        List<Document> schemaDocuments = schemaRetriever.retrieve(query);
        schemaDocuments = documentPostProcessor == null ? schemaDocuments : documentPostProcessor.process(query, schemaDocuments);
        String schema = joinDocument(schemaDocuments);

        String systemMessage = promptTemplate.render(Map.of("schema", schema));

        return chatClient.prompt().system(systemMessage).user(query.text()).call().content();
    }

    private static String joinDocument(List<Document> documents) {
        return documents.stream().map(Document::getText)
                .collect(Collectors.joining(System.lineSeparator()));
    }


}
相关推荐
沙子迷了蜗牛眼9 分钟前
当展示列表使用 URL.createObjectURL 的创建临时图片、视频无法加载问题
java·前端·javascript·vue.js
ganshenml11 分钟前
【Android】 开发四角版本全解析:AS、AGP、Gradle 与 JDK 的配套关系
android·java·开发语言
我命由我1234512 分钟前
Kotlin 运算符 - == 运算符与 === 运算符
android·java·开发语言·java-ee·kotlin·android studio·android-studio
小途软件18 分钟前
ssm327校园二手交易平台的设计与实现+vue
java·人工智能·pytorch·python·深度学习·语言模型
alonewolf_9922 分钟前
Java类加载机制深度解析:从双亲委派到热加载实战
java·开发语言
追梦者12323 分钟前
springboot整合minio
java·spring boot·后端
云游26 分钟前
Jaspersoft Studio community edition 7.0.3的应用
java·报表
帅气的你31 分钟前
Spring Boot 集成 AOP 实现日志记录与接口权限校验
java·spring boot
zhglhy1 小时前
Spring Data Slice使用指南
java·spring
win x1 小时前
Redis 主从复制
java·数据库·redis