在前面的文章中分析过spring-ai-alibaba-starter-nl2sql(spring-ai-alibaba 1.0.0.2 学习(十五)------自然语言生成sql-CSDN博客)也提到过nl2sql虽然效果好,但是使用起来确实并不太方便,并不能即插即用,所以想能不能设计一个简化版的nl2sql
设计方案
最初是想做成一个advisor,但是后来觉得行不通,advisor会对所有对话进行拦截,而对话中可能只有部分是希望转换为sql的需求
那就只能采用工具的方式进行注册,由大模型按需调用
虽然说是设计一个简化版,但是太简单了也就没有实用性了
所以还是需要保留相当一部分扩展性,最终添加了许多可插拔的组件接口,这里参考了spring-ai的RetrievalAugmentaionAdvisor的实现,具体组件接口后面会详细说明
大概的流程就是根据用户查询来检索相关表结构,然后调用大模型生成sql,整个流程中每个环节前后会预留用户可介入的接口
简单调用方式
java
@GetMapping("/tool")
public String tool(String input) {
return chatClient.prompt()
.tools(new SimpleNl2SqlTool(schemaRetriever, chatClient))
.user(input)
.call()
.content();
}
其中,入参schemaRetriever负责获取schema信息,chatClient负责调用大模型生成sql。
复杂调用方式
java
SimpleNl2SqlTool(DocumentRetriever schemaRetriever,
PromptTemplate promptTemplate,
DocumentRetriever infoRetriever,
QueryAugmenter queryAugmenter,
DocumentPostProcessor documentPostProcessor,
QueryTransformer queryTransformer,
ChatClient chatClient)
其实调用方式是一样的,只是在创建工具的时候需要添加许多额外的组件(由用户自己按需实现):
schemaRetriever:数据库表结构检索器,如果表比较少,甚至可以不使用向量库检索而是全量返回
queryTransformer:查询转换器,流程中的第一步,可以对用户输入进行增强,例如补充数据库dialect信息,添加sql设定(如关联表不超过3张,不选取无关字段,join代替子查询等),将其中的本月本年等词转换为具体日期等
infoRetriever:补充信息检索器,补充信息可以是相关表未在数据库中进行设置的外键,也可以是相关业务术语的解释,这里不一定是检索向量库,内部实现甚至可以写死返回固定的Document
queryAugmenter:根据检索到的补充信息对用户输入再次进行增强
documentPostProcessor:对检索到的schema信息进行增强,比如重排序等
promptTemplate:自定义的系统消息提示词模板,需预留schema占位符
以上就是预留的各环节前后,用户可自行介入的接口,也可以参照spring-ai-alibaba-starter-nl2sql中的各种增强措施,工具类具体代码实现如下
java
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class SimpleNl2SqlTool {
private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
你是一名数据库专家,参考给出的schema信息,为需求编写一条可执行的sql语句。
注意:请只返回sql语句。如果无法根据给出的schema信息做出回答,则输出"不支持回答该问题"。
schema 如下,以----------环绕
----------
{schema}
----------
""");
private static final Logger log = LoggerFactory.getLogger(SimpleNl2SqlTool.class);
/**
* 获取数据库的schema信息
*/
private DocumentRetriever schemaRetriever;
private PromptTemplate promptTemplate;
/**
* 获取补充信息,例如相关表未设置的外键,甚至相关业务术语的解释
*/
private DocumentRetriever infoRetriever;
/**
* 根据补充信息对用户需求进行补充增强
*/
private QueryAugmenter queryAugmenter;
/**
* 对数据库的schema信息进行后处理, 例如重排序等
*/
private DocumentPostProcessor documentPostProcessor;
/**
* 对请求进行补充增强,例如添加默认的数据库类型dialect
*/
private QueryTransformer queryTransformer;
private ChatClient chatClient;
public SimpleNl2SqlTool(DocumentRetriever schemaRetriever, PromptTemplate promptTemplate, DocumentRetriever infoRetriever, QueryAugmenter queryAugmenter, DocumentPostProcessor documentPostProcessor, QueryTransformer queryTransformer, ChatClient chatClient) {
Assert.notNull(schemaRetriever, "schemaRetriever can not be null");
Assert.notNull(promptTemplate, "promptTemplate can not be null");
Assert.notNull(chatClient, "chatClient can not be null");
this.schemaRetriever = schemaRetriever;
this.promptTemplate = promptTemplate;
this.infoRetriever = infoRetriever;
this.queryAugmenter = queryAugmenter;
this.documentPostProcessor = documentPostProcessor;
this.queryTransformer = queryTransformer;
this.chatClient = chatClient;
}
public SimpleNl2SqlTool(DocumentRetriever schemaRetriever, ChatClient chatClient) {
this.schemaRetriever = schemaRetriever;
promptTemplate = DEFAULT_PROMPT_TEMPLATE;
this.chatClient = chatClient;
}
@Tool(description = "将需求转换为sql")
public String nl2sql(@ToolParam(description = "希望转换为sql的需求描述") String request) {
Query query = new Query(request);
query = queryTransformer == null ? query : queryTransformer.transform(query);
List<Document> infoDocuments = infoRetriever == null ? List.of() : infoRetriever.retrieve(query);
query = queryAugmenter == null ? query : queryAugmenter.augment(query, infoDocuments);
List<Document> schemaDocuments = schemaRetriever.retrieve(query);
schemaDocuments = documentPostProcessor == null ? schemaDocuments : documentPostProcessor.process(query, schemaDocuments);
String schema = joinDocument(schemaDocuments);
String systemMessage = promptTemplate.render(Map.of("schema", schema));
return chatClient.prompt().system(systemMessage).user(query.text()).call().content();
}
private static String joinDocument(List<Document> documents) {
return documents.stream().map(Document::getText)
.collect(Collectors.joining(System.lineSeparator()));
}
}