NL2SQLParser.doParse()方法是对自然语言解析的封装,其通过 chatLayerService.parse(req)实现对语义的解析,解析分成了MAPPING(映射)和PARSING(解析)两个过程。在ChatWorkflowEngine.java可以看到整个NL2SQL的处理流程。
java
public void start(ChatWorkflowState initialState, ChatQueryContext queryCtx) {
ParseResp parseResult = queryCtx.getParseResp();
queryCtx.setChatWorkflowState(initialState);
while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) {
switch (queryCtx.getChatWorkflowState()) {
case MAPPING:
performMapping(queryCtx);
if (queryCtx.getMapInfo().isEmpty()) {
parseResult.setState(ParseResp.ParseState.FAILED);
parseResult.setErrorMsg(
"No semantic entities can be mapped against user question.");
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
} else {
queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING);
}
break;
case PARSING:
performParsing(queryCtx);
if (queryCtx.getCandidateQueries().isEmpty()) {
parseResult.setState(ParseResp.ParseState.FAILED);
parseResult.setErrorMsg("No semantic queries can be parsed out.");
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
} else {
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
parseResult.setSelectedParses(parseInfos);
if (queryCtx.needSQL()) {
queryCtx.setChatWorkflowState(ChatWorkflowState.S2SQL_CORRECTING);
} else {
parseResult.setState(ParseResp.ParseState.COMPLETED);
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
}
}
break;
case S2SQL_CORRECTING:
performCorrecting(queryCtx);
queryCtx.setChatWorkflowState(ChatWorkflowState.TRANSLATING);
break;
case TRANSLATING:
long start = System.currentTimeMillis();
performTranslating(queryCtx, parseResult);
parseResult.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start);
queryCtx.setChatWorkflowState(ChatWorkflowState.PHYSICAL_SQL_CORRECTING);
break;
case PHYSICAL_SQL_CORRECTING:
performPhysicalSqlCorrecting(queryCtx);
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
break;
default:
if (parseResult.getState().equals(ParseResp.ParseState.PENDING)) {
parseResult.setState(ParseResp.ParseState.COMPLETED);
}
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
break;
}
}
}
可以看出,NL2SQL主要有以下几个流程状态:
MAPPING:语义映射阶段, performMapping()
PARSING:语义解析阶段,performParsing()
S2SQL_CORRECTING:语义SQL修正阶段,performCorrecting()
TRANSLATING:SQL翻译阶段,performTranslating()
PHYSICAL_SQL_CORRECTING:物理SQL修正阶段,performPhysicalSqlCorrecting()
FINISHED:完成阶段
有三个核心组件:
SchemaMapper:负责语义实体映射
SemanticParser:负责语义解析
SemanticCorrector:负责SQL修正(包括LLMPhysicalSqlCorrector)
前面以KeywordMapper - 关键词映射器为例解读了映射过程KyewordMapper解读,现在主要解读一下PARSING过程。
1、解析器类型
实现了SemanticParser的类有5个:
AggregateTypeParser:聚合类型解析器
LLMSqlParser:大模型SQL解析器
QueryTypeParser:查询类型解析器,明细或聚合
RuleSqlParser:SQL规则解析器
TimeRangeParser:时间范围解析器
重点分析 LLMSqlParser:大模型SQL解析器
2、LLMSqlParser
实际通过 LLMSqlParser.tryParse() 进行解析,重点关注有request对象得到response对象的过程,这个过程包含了NL2SQL。
java
private void tryParse(ChatQueryContext queryCtx, Long dataSetId) {
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
int maxRetries = ContextUtils.getBean(LLMParserConfig.class).getRecallMaxRetries();
//请求对象
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId);
int currentRetry = 1;
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
ParseResult parseResult = null;
while (currentRetry <= maxRetries) {
log.info("currentRetryRound:{}, start runText2SQL", currentRetry);
try {
//通过请求对象,得到返回对象,此步骤包含解析SQL
LLMResp llmResp = requestService.runText2SQL(llmReq);
if (Objects.nonNull(llmResp)) {
// deduplicate the S2SQL result list and build parserInfo
sqlRespMap = responseService.getDeduplicationSqlResp(currentRetry, llmResp);
if (MapUtils.isNotEmpty(sqlRespMap)) {
parseResult = ParseResult.builder().dataSetId(dataSetId).llmReq(llmReq)
.llmResp(llmResp).build();
break;
}
}
} catch (Exception e) {
log.error("currentRetryRound:{}, runText2SQL failed", currentRetry, e);
}
ChatModelConfig chatModelConfig = llmReq.getChatAppConfig()
.get(OnePassSCSqlGenStrategy.APP_KEY).getChatModelConfig();
Double temperature = chatModelConfig.getTemperature();
if (temperature == 0) {
// 报错时增加随机性,减少无效重试
chatModelConfig.setTemperature(0.5);
}
currentRetry++;
}
if (MapUtils.isEmpty(sqlRespMap)) {
return;
}
for (Entry<String, LLMSqlResp> entry : sqlRespMap.entrySet()) {
String sql = entry.getKey();
double sqlWeight = entry.getValue().getSqlWeight();
responseService.addParseInfo(queryCtx, parseResult, sql, sqlWeight);
}
}
得到Response对象的过程:
LLMSqlParser.tryParse()------》 LLMSqlParser.runText2SQL()------》SqlGenStrategy.generate()------》OnePassSCSqlGenStrategy.generate()
java
public LLMResp generate(LLMReq llmReq) {
//初始化相应对象,并设置查询文本
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
// 1.recall exemplars
log.debug("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
//获取Few-shot示例
List<List<Text2SQLExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
// 2.generate sql generation prompt for each self-consistency inference
//配置大模型参数
ChatApp chatApp = llmReq.getChatAppConfig().get(APP_KEY);
ChatModelConfig chatModelConfig = chatApp.getChatModelConfig();
//获取聊天应用配置,支持JSON格式输出配置
if (!StringUtils.isBlank(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE))) {
chatModelConfig.setJsonFormat(true);
chatModelConfig
.setJsonFormatType(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE));
}
//使用LangChain4j框架构建语义SQL提取器服务
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatModelConfig);
SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
//生成多个提示词进行并行推理
Map<Prompt, List<Text2SQLExemplar>> prompt2Exemplar = new HashMap<>();
for (List<Text2SQLExemplar> exemplars : exemplarsList) {
llmReq.setDynamicExemplars(exemplars);
//生成prompt,此过程传递Schema信息
Prompt prompt = generatePrompt(llmReq, llmResp, chatApp);
prompt2Exemplar.put(prompt, exemplars);
}
// 3.perform multiple self-consistency inferences parallelly
//并行执行自一致性推理
Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>();
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
//使用并行流进行多轮推理,收集不同的SQL生成结果
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
output2Prompt.put(s2Sql.getSql(), prompt);
keyPipelineLog.info("OnePassSCSqlGenStrategy modelReq:\n{} \nmodelResp:\n{}",
prompt.text(), s2Sql);
});
// 4.format response.
//自一致性投票选择最佳SQL
Pair<String, Map<String, Double>> sqlMapPair =
ResponseHelper.selfConsistencyVote(Lists.newArrayList(output2Prompt.keySet()));
llmResp.setSqlOutput(sqlMapPair.getLeft());
//记录使用的示例和SQL置信度信息,构建最终响应
List<Text2SQLExemplar> usedExemplars =
prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft()));
llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(usedExemplars, sqlMapPair.getRight()));
return llmResp;
}
OnePassSCSqlGenStrategy.java 的作用:
1、语义理解:将自然语言查询转换为语义SQL
2、Schema约束:确保生成的SQL符合数据库Schema
3、质量保证:通过自一致性策略,并结合Few-shot Learning和提示词工程提高SQL准确性
4、可解释性:通过内部类Semanticsql,说明SQL生成思路
java
private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp, ChatApp chatApp) {
StringBuilder exemplars = new StringBuilder();
for (Text2SQLExemplar exemplar : llmReq.getDynamicExemplars()) {
String exemplarStr = String.format("\nQuestion:%s,Schema:%s,SideInfo:%s,SQL:%s",
exemplar.getQuestion(), exemplar.getDbSchema(), exemplar.getSideInfo(),
exemplar.getSql());
exemplars.append(exemplarStr);
}
String dataSemantics = promptHelper.buildSchemaStr(llmReq);
String sideInformation = promptHelper.buildSideInformation(llmReq);
llmResp.setSchema(dataSemantics);
llmResp.setSideInfo(sideInformation);
Map<String, Object> variable = new HashMap<>();
variable.put("exemplar", exemplars);
variable.put("question", llmReq.getQueryText());
variable.put("schema", dataSemantics);
variable.put("information", sideInformation);
// use custom prompt template if provided.
String promptTemplate = chatApp.getPrompt();
return PromptTemplate.from(promptTemplate).apply(variable);
}