1、Pre-Retrieval 增强和转换用户输入,使其更有效地执行检索任务,解决格式不正确的查询、query 语义不清晰、或不受支持的语言等。
1.1 查询重写:QueryTransformer 查询改写:因为用户的输入通常是片面的,关键信息较少, 不便于大模型理解和回答问题。因此需要使用 prompt 调优手段或者大模型改写用户 query
RewriteQueryTransformer
java
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.rag.preretrieval.query.transformation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
* Uses a large language model to rewrite a user query to provide better results when
* querying a target system, such as a vector store or a web search engine.
* <p>
* This transformer is useful when the user query is verbose, ambiguous, or contains
* irrelevant information that may affect the quality of the search results.
*
* @author Thomas Vitale
* @since 1.0.0
* @see <a href="https://arxiv.org/pdf/2305.14283">arXiv:2305.14283</a>
*/
public class RewriteQueryTransformer implements QueryTransformer {
private static final Logger logger = LoggerFactory.getLogger(RewriteQueryTransformer.class);
private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
Given a user query, rewrite it to provide better results when querying a {target}.
Remove any irrelevant information, and ensure the query is concise and specific.
Original query:
{query}
Rewritten query:
""");
private static final String DEFAULT_TARGET = "vector store";
private final ChatClient chatClient;
private final PromptTemplate promptTemplate;
private final String targetSearchSystem;
public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate,
@Nullable String targetSearchSystem) {
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");
this.chatClient = chatClientBuilder.build();
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : DEFAULT_TARGET;
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query");
}
@Override
public Query transform(Query query) {
Assert.notNull(query, "query cannot be null");
logger.debug("Rewriting query to optimize for querying a {}.", this.targetSearchSystem);
var rewrittenQueryText = this.chatClient.prompt()
.user(user -> user.text(this.promptTemplate.getTemplate())
.param("target", this.targetSearchSystem)
.param("query", query.text()))
.call()
.content();
if (!StringUtils.hasText(rewrittenQueryText)) {
logger.warn("Query rewrite result is null/empty. Returning the input query unchanged.");
return query;
}
return query.mutate().text(rewrittenQueryText).build();
}
public static Builder builder() {
return new Builder();
}
public static final class Builder {
private ChatClient.Builder chatClientBuilder;
@Nullable
private PromptTemplate promptTemplate;
@Nullable
private String targetSearchSystem;
private Builder() {
}
public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
this.chatClientBuilder = chatClientBuilder;
return this;
}
public Builder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}
public Builder targetSearchSystem(String targetSearchSystem) {
this.targetSearchSystem = targetSearchSystem;
return this;
}
public RewriteQueryTransformer build() {
return new RewriteQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetSearchSystem);
}
}
}
1.2 扩充问题:MultiQueryExpander 查询扩展:将用户 query 扩展为多个语义不同的变体以获得不同视角,有助于检索额外的上下文信息并增加找到相关结果的机会。
MultiQueryExpander
java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.application.modulerag.preretrieval.query.expansion;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* User Prompt Query Expander
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* Pre-Retrieval 增强和转换用户输入,使其更有效地执行检索任务,解决格式不正确的查询、query 语义不清晰、或不受支持的语言等。
* QueryExpander 查询扩展:将用户 query 扩展为多个语义不同的变体以获得不同视角,有助于检索额外的上下文信息并增加找到相关结果的机会。
*/
public class MultiQueryExpander implements QueryExpander {
private static final Logger logger = LoggerFactory.getLogger(MultiQueryExpander.class);
// 多查询扩展 提示模板
private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate(
"""
You are an expert in information retrieval and search optimization.
Generate {number} different versions of a given query.
Each variation should cover a different perspective or aspect of the topic while maintaining the core intent of
the original query. The goal is to broaden your search and improve your chances of finding relevant information.
Don't interpret the selection or add additional text.
Query variants are provided, separated by line breaks.
Original query: {query}
Query variants:
"""
);
private static final Boolean DEFAULT_INCLUDE_ORIGINAL = true;
private static final Integer DEFAULT_NUMBER_OF_QUERIES = 3;
private final ChatClient chatClient;
private final PromptTemplate promptTemplate;
private final boolean includeOriginal;
private final int numberOfQueries;
public MultiQueryExpander(
ChatClient.Builder chatClientBuilder,
@Nullable PromptTemplate promptTemplate,
@Nullable Boolean includeOriginal,
@Nullable Integer numberOfQueries
) {
Assert.notNull(chatClientBuilder, "ChatClient.Builder must not be null");
this.chatClient = chatClientBuilder.build();
this.promptTemplate = promptTemplate == null ? DEFAULT_PROMPT_TEMPLATE : promptTemplate;
this.includeOriginal = includeOriginal == null ? DEFAULT_INCLUDE_ORIGINAL : includeOriginal;
this.numberOfQueries = numberOfQueries == null ? DEFAULT_NUMBER_OF_QUERIES : numberOfQueries;
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "number", "query");
}
@NotNull
@Override
public List<Query> expand(@Nullable Query query) {
Assert.notNull(query, "Query must not be null");
logger.debug("Generating {} queries for query: {}", this.numberOfQueries, query.text());
String resp = this.chatClient.prompt()
.user(user -> user.text(this.promptTemplate.getTemplate())
.param("number", this.numberOfQueries)
.param("query", query.text()))
.call()
.content();
logger.debug("MultiQueryExpander#expand() Response from chat client: {}", resp);
if (Objects.isNull(resp)) {
logger.warn("No response from chat client for query: {}. is return.", query.text());
return List.of(query);
}
List<String> queryVariants = Arrays.stream(resp.split("\n")).filter(StringUtils::hasText).toList();
if (CollectionUtils.isEmpty(queryVariants) || this.numberOfQueries != queryVariants.size()) {
logger.warn("Query expansion result dose not contain the requested {} variants for query: {}. is return.",
this.numberOfQueries, query.text());
return List.of(query);
}
List<Query> queries = queryVariants.stream()
.filter(StringUtils::hasText)
.map(queryText -> query.mutate().text(queryText).build())
.collect(Collectors.toList());
if (this.includeOriginal) {
logger.debug("Including original query in the expanded queries for query: {}", query.text());
queries.add(0, query);
}
logger.debug("Rewrite queries: {}", queries);
return queries;
}
public static Builder builder() {
return new Builder();
}
public static final class Builder {
private ChatClient.Builder chatClientBuilder;
private PromptTemplate promptTemplate;
private Boolean includeOriginal;
private Integer numberOfQueries;
private Builder() {
}
public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
this.chatClientBuilder = chatClientBuilder;
return this;
}
public Builder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}
public Builder includeOriginal(Boolean includeOriginal) {
this.includeOriginal = includeOriginal;
return this;
}
public Builder numberOfQueries(Integer numberOfQueries) {
this.numberOfQueries = numberOfQueries;
return this;
}
public MultiQueryExpander build() {
return new MultiQueryExpander(this.chatClientBuilder, this.promptTemplate, this.includeOriginal, this.numberOfQueries);
}
}
}
实例化查询重写和查询扩展
java
package com.alibaba.cloud.ai.application.config;
//import com.alibaba.cloud.ai.application.rag.postretrieval.DashScopeDocumentRanker;
import com.alibaba.cloud.ai.application.modulerag.preretrieval.query.expansion.MultiQueryExpander;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
*/
@Configuration
public class WeSearchConfiguration {
//
// @Bean
// public DashScopeDocumentRanker dashScopeDocumentRanker(
// RerankModel rerankModel
// ) {
// return new DashScopeDocumentRanker(rerankModel);
// }
@Bean
public QueryTransformer queryTransformer(
@Qualifier("dashscopeChatModel") ChatModel chatModel,
@Qualifier("transformerPromptTemplate") PromptTemplate transformerPromptTemplate
) {
//实例化查询重写
ChatClient chatClient = ChatClient.builder(chatModel)
.defaultOptions(
DashScopeChatOptions.builder()
.withModel("qwen-plus")
.build()
).build();
// 创建查询重写转换器
// Pre-Retrieval 增强和转换用户输入,使其更有效地执行检索任务,解决格式不正确的查询、query 语义不清晰、或不受支持的语言等。
// QueryTransformer 查询改写:因为用户的输入通常是片面的,关键信息较少,
// 不便于大模型理解和回答问题。因此需要使用 prompt 调优手段或者大模型改写用户 query
return RewriteQueryTransformer.builder()
.chatClientBuilder(chatClient.mutate())
.promptTemplate(transformerPromptTemplate)
.targetSearchSystem("Web Search")
.build();
}
@Bean
public QueryExpander queryExpander(
@Qualifier("dashscopeChatModel") ChatModel chatModel
) {
//实例化查询扩展
ChatClient chatClient = ChatClient.builder(chatModel)
.defaultOptions(
DashScopeChatOptions.builder()
.withModel("qwen-plus")
.build()
).build();
//多查询扩展是提高RAG系统检索效果的关键技术。在实际应用中,
// 用户的查询往往是简短且不完整的,这可能导致检索结果不够准确或完整。
// Spring AI提供了强大的多查询扩展机制,能够自动生成多个相关的查询变体,
// 从而提高检索的准确性和召回率
// Pre-Retrieval 增强和转换用户输入,使其更有效地执行检索任务,解决格式不正确的查询、query 语义不清晰、或不受支持的语言等。
// QueryExpander 查询扩展:将用户 query 扩展为多个语义不同的变体以获得不同视角,
// 有助于检索额外的上下文信息并增加找到相关结果的机会。
return MultiQueryExpander.builder()
.chatClientBuilder(chatClient.mutate())
.numberOfQueries(2)
.build();
}
}
查询重写提示模板
java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.application.modulerag.prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* Prompt:
* 1. https://zhuanlan.zhihu.com/p/23929522431
* 2. https://cloud.tencent.com/developer/article/2509465
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
*/
@Configuration
public class PromptTemplateConfig {
@Bean
public PromptTemplate transformerPromptTemplate() {
//查询重写 提示模板
return new PromptTemplate(
"""
Given a user query, rewrite the user question to provide better results when querying {target}.
You should follow these rules:
1. Remove any irrelevant information and make sure the query is concise and specific;
2. The output must be consistent with the language of the user's query;
3. Ensure better understanding and answers from the perspective of large models.
Original query:
{query}
Query after rewrite:
"""
);
}
@Bean
public PromptTemplate queryArgumentPromptTemplate() {
// 联网搜索提示模板
return new PromptTemplate(
"""
You'll get a set of document contexts that are relevant to the issue.
Each document begins with a reference number, such as [[x]], where x is a number that can be repeated.
Documents that are not referenced will be marked as [[null]].
Use context and refer to it at the end of each sentence, if applicable.
The context information is as follows:
---------------------
{context}
---------------------
Generate structured responses to user questions given contextual information and without prior knowledge.
When you answer user questions, follow these rules:
1. If the answer is not in context, say you don't know;
2. Don't provide any information that is not relevant to the question, and don't output any duplicate content;
3. Avoid using "context-based..." or "The provided information..." said;
4. Your answers must be correct, accurate, and written in an expertly unbiased and professional tone;
5. The appropriate text structure in the answer is determined according to the characteristics of the content, please include subheadings in the output to improve readability;
6. When generating a response, provide a clear conclusion or main idea first, without a title;
7. Make sure each section has a clear subtitle so that users can better understand and refer to your output;
8. If the information is complex or contains multiple sections, make sure each section has an appropriate heading to create a hierarchical structure;
9. Please refer to the sentence or section with the reference number at the end in [[x]] format;
10. If a sentence or section comes from more than one context, list all applicable references, e.g. [[x]][[y]];
11. Your output answers must be in beautiful and rigorous markdown format.
12. Because your output is in markdown format, please include the link in the reference document in the form of a hyperlink when referencing the context, so that users can click to view it;
13. If a reference is marked as [[null]], it does not have to be cited;
14. Except for Code. Aside from the specific name and citation, your answer must be written in the same language as the question.
User Issue:
{query}
Your answer:
"""
);
}
}
2、Retrieval 步骤 负责查询向量存储等数据系统并检索和用户 query 相关性最高的 Document。
2.1从搜索引擎检索文档:WebSearchRetriever
java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.application.modulerag;
import com.alibaba.cloud.ai.application.entity.IQSSearchResponse;
import com.alibaba.cloud.ai.application.exception.SAAAppException;
import com.alibaba.cloud.ai.application.modulerag.core.IQSSearchEngine;
import com.alibaba.cloud.ai.application.modulerag.data.DataClean;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.lang.Nullable;
import java.net.URISyntaxException;
import java.util.List;
/**
* spring-ai 从 0.8.0 版本开始不支持 DocumentRanker.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* Retrieval 步骤 负责查询向量存储等数据系统并检索和用户 query 相关性最高的 Document。
* DocumentRetriever :检索器,根据 QueryExpander 使用不同的数据源进行检索
* ,例如 搜索引擎、向量存储、数据库或知识图等;
* 此处是使用搜索引擎 IQS
*/
public class WebSearchRetriever implements DocumentRetriever {
private static final Logger logger = LoggerFactory.getLogger(WebSearchRetriever.class);
private final int maxResults;
private final DataClean dataCleaner;
private final IQSSearchEngine searchEngine;
private WebSearchRetriever(Builder builder) {
//使用IQS搜索引擎
this.searchEngine = builder.searchEngine;
this.maxResults = builder.maxResults;
//数据清洗
this.dataCleaner = builder.dataCleaner;
}
@NotNull
@Override
public List<Document> retrieve(
@Nullable Query query
) {
// 搜索
IQSSearchResponse searchResp;
try {
searchResp = searchEngine.search(query.text());
} catch (JsonProcessingException e) {
throw new SAAAppException("json process error" + e.getMessage());
}
// 清洗数据
List<Document> cleanerData;
try {
cleanerData = dataCleaner.getData(searchResp);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
// logger.debug("cleaner data: {}", cleanerData);
// 返回结果
List<Document> documents = dataCleaner.limitResults(cleanerData, maxResults);
logger.debug("WebSearchRetriever#retrieve() document size: {}, raw documents: {}",
documents.size(),
documents.stream().map(Document::getId).toArray()
);
return documents;
}
public static WebSearchRetriever.Builder builder() {
return new WebSearchRetriever.Builder();
}
public static final class Builder {
private IQSSearchEngine searchEngine;
private int maxResults;
private DataClean dataCleaner;
public WebSearchRetriever.Builder searchEngine(IQSSearchEngine searchEngine) {
this.searchEngine = searchEngine;
return this;
}
public WebSearchRetriever.Builder dataCleaner(DataClean dataCleaner) {
this.dataCleaner = dataCleaner;
return this;
}
public WebSearchRetriever.Builder maxResults(int maxResults) {
this.maxResults = maxResults;
return this;
}
public WebSearchRetriever build() {
return new WebSearchRetriever(this);
}
}
}
2.2 将检索到的文档进行拼接:ConcatenationDocumentJoiner
java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.application.modulerag.join;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.retrieval.join.DocumentJoiner;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* Retrieval 步骤 负责查询向量存储等数据系统并检索和用户 query 相关性最高的 Document。
* DocumentJoiner:将从多个 query 和从多个数据源检索到的 Document 合并为一个 Document 集合;
*/
public class ConcatenationDocumentJoiner implements DocumentJoiner {
private static final Logger logger = LoggerFactory.getLogger(ConcatenationDocumentJoiner.class);
@NotNull
@Override
public List<Document> join(
@Nullable Map<Query, List<List<Document>>> documentsForQuery
) {
Assert.notNull(documentsForQuery, "documentsForQuery cannot be null");
Assert.noNullElements(documentsForQuery.keySet(), "documentsForQuery cannot contain null keys");
Assert.noNullElements(documentsForQuery.values(), "documentsForQuery cannot contain null values");
logger.debug("Joining documents by concatenation");
Map<Query, List<List<Document>>> selectDocuments = selectDocuments(documentsForQuery, 10);
Set<String> seen = new HashSet<>();
return selectDocuments.values().stream()
// Flatten List<List<Documents>> to Stream<List<Documents>.
.flatMap(List::stream)
// Flatten Stream<List<Documents> to Stream<Documents>.
.flatMap(List::stream)
.filter(doc -> {
List<String> keys = extractKeys(doc);
for (String key : keys) {
if (!seen.add(key)) {
logger.info("Duplicate document metadata: {}",doc.getMetadata());
// Duplicate keys found.
return false;
}
}
// All keys are unique.
return true;
})
.collect(Collectors.toList());
}
private Map<Query, List<List<Document>>> selectDocuments(
Map<Query, List<List<Document>>> documentsForQuery,
int totalDocuments
) {
Map<Query, List<List<Document>>> selectDocumentsForQuery = new HashMap<>();
int numberOfQueries = documentsForQuery.size();
if (Objects.equals(0, numberOfQueries)) {
return selectDocumentsForQuery;
}
int baseCount = totalDocuments / numberOfQueries;
int remainder = totalDocuments % numberOfQueries;
// To ensure consistent distribution. sort the keys (optional)
List<Query> sortedQueries = new ArrayList<>(documentsForQuery.keySet());
// Other sort
// sortedQueries.sort(Comparator.comparing(Query::getSomeProperty));
Iterator<Query> iterator = sortedQueries.iterator();
for (int i = 0; i < numberOfQueries; i ++) {
Query query = sortedQueries.get(i);
int documentToSelect = baseCount + (i < remainder ? 1 : 0);
List<List<Document>> originalDocuments = documentsForQuery.get(query);
List<List<Document>> selectedNestLists = new ArrayList<>();
int remainingDocuments = documentToSelect;
for (List<Document> documentList : originalDocuments) {
if (remainingDocuments <= 0) {
break;
}
List<Document> selectSubList = new ArrayList<>();
for (Document docs : documentList) {
if (remainingDocuments <= 0) {
break;
}
selectSubList.add(docs);
remainingDocuments --;
}
if (!selectSubList.isEmpty()) {
selectedNestLists.add(selectSubList);
}
}
selectDocumentsForQuery.put(query, selectedNestLists);
}
return selectDocumentsForQuery;
}
private List<String> extractKeys(Document document) {
List<String> keys = new ArrayList<>();
if (Objects.nonNull(document)) {
keys.add(document.getId());
}
if (Objects.nonNull(document.getMetadata())) {
Object src = document.getMetadata().get("source");
if (src instanceof String) {
keys.add("SOURCE:" + src);
}
Object fn = document.getMetadata().get("file_name");
if (fn instanceof String) {
keys.add("FILE_NAME:" + fn);
}
}
return keys;
}
}
3、Post-Retrieval 此处没有实现该步骤
4、Generation
4.1对生成的查询进行上下文增强:ContextualQueryAugmenter
java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.application.modulerag.prompt;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
*
* QueryAugmenter(查询增强接口类)
* 通过将用户查询与额外的上下文数据结合,从而为 LLM 提供更丰富的背景信息
* QueryAugmenter 查询增强:使用附加的上下文数据信息增强用户 query,提供大模型回答问题时的必要上下文信息;
*/
public class CustomContextQueryAugmenter implements QueryAugmenter {
private static final Logger logger = LoggerFactory.getLogger(CustomContextQueryAugmenter.class);
private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate(
"""
Context information is below:
{{context}}
Given the context information and no prior knowledge, answer the query.
Follow these rules:
1. If the answer is not in the context, just say that you don't know.
2. Avoid statements like "Based on the context...." or "The provided information...".
Query: {query}
Answer:
"""
);
private static final PromptTemplate DEFAULT_EMPTY_PROMPT_TEMPLATE = new PromptTemplate(
"""
The user query is outside your knowledge base.
Politely inform the user that you cannot answer the query.
"""
);
private static final boolean DEFAULT_ALLOW_EMPTY_PROMPT = false;
private final PromptTemplate promptTemplate;
private final PromptTemplate emptyPromptTemplate;
private final boolean allowEmptyContext;
public CustomContextQueryAugmenter(
@Nullable PromptTemplate promptTemplate,
@Nullable PromptTemplate emptyPromptTemplate,
@Nullable Boolean allowEmptyContext
) {
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
this.emptyPromptTemplate = emptyPromptTemplate != null ? emptyPromptTemplate : DEFAULT_EMPTY_PROMPT_TEMPLATE;
this.allowEmptyContext = allowEmptyContext != null ? allowEmptyContext : DEFAULT_ALLOW_EMPTY_PROMPT;
logger.debug("CustomContextQueryAugmenter promptTemplate: {}", promptTemplate.getTemplate());
logger.debug("CustomContextQueryAugmenter emptyPromptTemplate: {}", emptyPromptTemplate);
logger.debug("CustomContextQueryAugmenter allowEmptyContext: {}", allowEmptyContext);
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "query", "context");
}
@NotNull
@Override
public Query augment(
@Nullable Query query,
@Nullable List<Document> documents
) {
Assert.notNull(query, "Query must not be null");
Assert.notNull(documents, "Documents must not be null");
logger.debug("Augmenting query: {}", query);
if (documents.isEmpty()) {
logger.debug("No documents found. Augmenting query with empty context.");
return augmentQueryWhenEmptyContext(query);
}
logger.debug("Documents found. Augmenting query with context.");
// 1. collect content from documents.
AtomicInteger idCounter = new AtomicInteger(1);
String documentContext = documents.stream()
.map(document -> {
String text = document.getText();
return "[[" + (idCounter.getAndIncrement()) + "]]" + text;
})
.collect(Collectors.joining("\n-----------------------------------------------\n"));
// 2. Define prompt parameters.
Map<String, Object> promptParameters = Map.of(
"query", query.text(),
"context", documentContext
);
// 3. Augment user prompt with document context.
return new Query(this.promptTemplate.render(promptParameters));
}
private Query augmentQueryWhenEmptyContext(Query query) {
if (this.allowEmptyContext) {
logger.debug("Empty context is allowed. Returning the original query.");
return query;
}
logger.debug("Empty context is not allowed. Returning a specific query for empty context.");
return new Query(this.emptyPromptTemplate.render());
}
public static final class Builder {
private PromptTemplate promptTemplate;
private PromptTemplate emptyPromptTemplate;
private Boolean allowEmptyContext;
public Builder() {
}
public CustomContextQueryAugmenter.Builder withPromptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}
public CustomContextQueryAugmenter.Builder withEmptyPromptTemplate(PromptTemplate emptyPromptTemplate) {
this.emptyPromptTemplate = emptyPromptTemplate;
return this;
}
public CustomContextQueryAugmenter.Builder withAllowEmptyContext(Boolean allowEmptyContext) {
this.allowEmptyContext = allowEmptyContext;
return this;
}
public CustomContextQueryAugmenter build() {
return new CustomContextQueryAugmenter(promptTemplate, emptyPromptTemplate, allowEmptyContext);
}
}
}
5、RetrievalAugmentationAdvisor
java
private RetrievalAugmentationAdvisor createRetrievalAugmentationAdvisor() {
// 使用RetrievalAugmentationAdvisor增强查询效果
return RetrievalAugmentationAdvisor.builder()
// 配置文档检索器
.documentRetriever(webSearchRetriever)
// 查询重写
.queryTransformers(queryTransformer)
.queryAugmenter(
new CustomContextQueryAugmenter(
queryArgumentPromptTemplate,
null,
true)
)
//多查询扩展
.queryExpander(queryExpander)
.documentJoiner(new ConcatenationDocumentJoiner())
.build();
}
SAAWebSearchService
java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.application.service;
import com.alibaba.cloud.ai.application.advisor.ReasoningContentAdvisor;
import com.alibaba.cloud.ai.application.modulerag.WebSearchRetriever;
import com.alibaba.cloud.ai.application.modulerag.core.IQSSearchEngine;
import com.alibaba.cloud.ai.application.modulerag.data.DataClean;
import com.alibaba.cloud.ai.application.modulerag.join.ConcatenationDocumentJoiner;
import com.alibaba.cloud.ai.application.modulerag.prompt.CustomContextQueryAugmenter;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.Map;
import java.util.logging.Logger;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
*/
@Service
public class SAAWebSearchService {
private final DataClean dataCleaner;
private final ChatClient chatClient;
private final QueryExpander queryExpander;
private final QueryTransformer queryTransformer;
private final WebSearchRetriever webSearchRetriever;
private final SimpleLoggerAdvisor simpleLoggerAdvisor;
private final PromptTemplate queryArgumentPromptTemplate;
private final ReasoningContentAdvisor reasoningContentAdvisor;
// It works better here with DeepSeek-R1
private static final String DEFAULT_WEB_SEARCH_MODEL = "deepseek-r1";
private static final Logger log = Logger.getLogger(SAAWebSearchService.class.getName());
public SAAWebSearchService(
DataClean dataCleaner,
QueryExpander queryExpander,
IQSSearchEngine searchEngine,
QueryTransformer queryTransformer,
SimpleLoggerAdvisor simpleLoggerAdvisor,
@Qualifier("dashscopeChatModel") ChatModel chatModel,
@Qualifier("queryArgumentPromptTemplate") PromptTemplate queryArgumentPromptTemplate
) {
//联网搜索服务类
//数据清洗
this.dataCleaner = dataCleaner;
//查询重写
this.queryTransformer = queryTransformer;
//多查询扩展
this.queryExpander = queryExpander;
//联网搜索提示模板
this.queryArgumentPromptTemplate = queryArgumentPromptTemplate;
// reasoning content for DeepSeek-r1 is integrated into the output
this.reasoningContentAdvisor = new ReasoningContentAdvisor(1);
// Build chatClient
this.chatClient = ChatClient.builder(chatModel)
.defaultOptions(
DashScopeChatOptions.builder()
.withModel(DEFAULT_WEB_SEARCH_MODEL)
// stream 模式下是否开启增量输出
.withIncrementalOutput(true)
.build()
).build();
this.simpleLoggerAdvisor = simpleLoggerAdvisor;
this.webSearchRetriever = WebSearchRetriever.builder()
.searchEngine(searchEngine)
.dataCleaner(dataCleaner)
.maxResults(2)
.build();
}
public Flux<String> chat(String prompt) {
Map<Integer, String> webLink = dataCleaner.getWebLink();
return chatClient.prompt()
.advisors(
createRetrievalAugmentationAdvisor(),
reasoningContentAdvisor,
simpleLoggerAdvisor
).user(prompt)
.stream()
.content();
// .transform(contentStream -> embedLinks(contentStream, webLink));
}
// todo 效果不好,这里只是一种思路
// stream 中 [[ 可能是一个 chunk 输出,而 ]] 在另一个 stream 中。在遇到第一个 [[ 时,短暂阻塞,到 ]] 出现时,开始替换执行后续逻辑
private Flux<String> embedLinks(Flux<String> contentStream, Map<Integer, String> webLink) {
// State for managing incomplete tags
StringBuilder buffer = new StringBuilder();
return contentStream.flatMap(chunk -> {
StringBuilder output = new StringBuilder(); // Output for this chunk
int i = 0;
while (i < chunk.length()) {
char c = chunk.charAt(i);
if (c == '[' && i + 1 < chunk.length() && chunk.charAt(i + 1) == '[') {
// Start of [[...]]
buffer.append("[[");
i += 2; // Skip [[
} else if (buffer.length() > 0 && c == ']' && i + 1 < chunk.length() && chunk.charAt(i + 1) == ']') {
// End of [[...]]
buffer.append("]]");
String tag = buffer.toString(); // Complete tag
output.append(resolveLink(tag, webLink)); // Resolve and append
buffer.setLength(0); // Clear buffer
i += 2; // Skip ]]
} else if (buffer.length() > 0) {
// Inside [[...]]
buffer.append(c);
i++;
} else {
// Normal text
output.append(c);
i++;
}
}
// If buffer still contains data, leave it for the next chunk
return Flux.just(output.toString());
}).concatWith(Flux.defer(() -> {
// If there's any leftover in the buffer, append it as-is
if (buffer.length() > 0) {
return Flux.just(buffer.toString());
}
return Flux.empty();
}));
}
private String resolveLink(String tag, Map<Integer, String> webLink) {
// Extract the number inside [[...]] and resolve the URL
if (tag.startsWith("[[") && tag.endsWith("]]")) {
String keyStr = tag.substring(2, tag.length() - 2); // Remove [[ and ]]
try {
int key = Integer.parseInt(keyStr);
if (webLink.containsKey(key)) {
return "[" + key + "](" + webLink.get(key) + ")";
}
} catch (NumberFormatException e) {
// Not a valid number, return the original tag
}
}
return tag; // Return original tag if no match
}
private RetrievalAugmentationAdvisor createRetrievalAugmentationAdvisor() {
// 使用RetrievalAugmentationAdvisor增强查询效果
return RetrievalAugmentationAdvisor.builder()
// 配置文档检索器
.documentRetriever(webSearchRetriever)
// 查询重写
.queryTransformers(queryTransformer)
.queryAugmenter(
new CustomContextQueryAugmenter(
queryArgumentPromptTemplate,
null,
true)
)
//多查询扩展
.queryExpander(queryExpander)
.documentJoiner(new ConcatenationDocumentJoiner())
.build();
}
}
6、IQS搜索引擎
java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.application.modulerag.core;
import com.alibaba.cloud.ai.application.entity.IQSSearchResponse;
import com.alibaba.cloud.ai.application.exception.SAAAppException;
import com.alibaba.cloud.ai.application.modulerag.IQSSearchProperties;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
/**
* <a href="https://help.aliyun.com/document_detail/2883041.html">通晓搜索</a>
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
*/
@Component
@EnableConfigurationProperties(IQSSearchProperties.class)
public class IQSSearchEngine {
private final RestClient restClient;
private final ObjectMapper objectMapper;
private final IQSSearchProperties iqsSearchProperties;
private static final String TIME_RANGE = "OneWeek";
private static final String BASE_URL = "https://cloud-iqs.aliyuncs.com/";
public IQSSearchEngine(
ObjectMapper objectMapper,
RestClient.Builder restClientBuilder,
IQSSearchProperties iqsSearchProperties,
ResponseErrorHandler responseErrorHandler
) {
this.objectMapper = new ObjectMapper();
this.iqsSearchProperties = iqsSearchProperties;
this.restClient = restClientBuilder.baseUrl(BASE_URL)
.defaultHeaders(getHeaders())
.defaultStatusHandler(responseErrorHandler)
.build();
}
public IQSSearchResponse search(String query) throws JsonProcessingException {
Map<String, Boolean> reqDataContents = new HashMap<>();
reqDataContents.put("mainText", true);
// IQS 目前得 md 文档效果不好,所以关闭.
reqDataContents.put("markdownText", false);
reqDataContents.put("rerankScore", true);
Map<String, Object> reqData = new HashMap<>();
reqData.put("query", query);
reqData.put("timeRange", TIME_RANGE);
reqData.put("engineType", "Generic");
reqData.put("contents", reqDataContents);
String jsonReqData = objectMapper.writeValueAsString(reqData);
// String encodeQ = URLEncoder.encode(query, StandardCharsets.UTF_8);
ResponseEntity<IQSSearchResponse> response = this.restClient.post()
.uri(
"/search/unified?query={query}&timeRange={timeRange}",
query,
TIME_RANGE
).contentType(MediaType.APPLICATION_JSON)
.body(jsonReqData)
.retrieve()
.toEntity(IQSSearchResponse.class);
return genericSearchResult(response);
}
private IQSSearchResponse genericSearchResult(ResponseEntity<IQSSearchResponse> response) {
if ((Objects.equals(response.getStatusCode(), HttpStatus.OK)) && Objects.nonNull(response.getBody())) {
return response.getBody();
}
throw new SAAAppException("Failed to search" + response.getStatusCode().value());
}
private Consumer<HttpHeaders> getHeaders() {
return httpHeaders -> {
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
httpHeaders.set("user-agent", userAgent());
if (StringUtils.hasText(this.iqsSearchProperties.getApiKey())) {
httpHeaders.set("X-API-Key", this.iqsSearchProperties.getApiKey());
}
};
}
private static String userAgent() {
return String.format("%s/%s; java/%s; platform/%s; processor/%s", "SpringAiAlibabaPlayground", "1.0.0", System.getProperty("java.version"), System.getProperty("os.name"), System.getProperty("os.arch"));
}
}
数据清洗
java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.application.modulerag.data;
import com.alibaba.cloud.ai.application.entity.IQSSearchResponse;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import java.net.URISyntaxException;
import java.util.*;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
*
* Data Cleansing: Filters out useless data and converts it into Spring AI's Document objects
*/
@Component
public class DataClean {
private static final Map<Integer, String> WebLink_MAP = new HashMap<>();
public List<Document> getData(IQSSearchResponse respData) throws URISyntaxException {
List<Document> documents = new ArrayList<>();
Map<String, Object> metadata = getQueryMetadata(respData);
for (int i = 0; i < respData.pageItems().size(); i++) {
IQSSearchResponse.PageItem pageItem = respData.pageItems().get(i);
Map<String, Object> pageItemMetadata = getPageItemMetadata(pageItem);
if (!StringUtils.hasText(pageItem.mainText()) || pageItem.mainText().length() < 10) {
// Skip items with main text that is too short
continue;
}
Document document = new Document.Builder()
.metadata(metadata)
.metadata(pageItemMetadata)
.text(pageItem.mainText())
.score(pageItem.rerankScore())
.build();
if (Objects.nonNull(pageItem.link())) {
int index = i;
WebLink_MAP.put(index + 1, pageItem.link());
}
documents.add(document);
}
return documents;
}
public Map<Integer, String> getWebLink() {
return WebLink_MAP;
}
private Map<String, Object> getQueryMetadata(IQSSearchResponse respData) {
HashMap<String, Object> docsMetadata = new HashMap<>();
if (Objects.nonNull(respData.queryContext())) {
docsMetadata.put("query", respData.queryContext().originalQuery().query());
if (Objects.nonNull(respData.queryContext().originalQuery().timeRange())) {
docsMetadata.put("timeRange", respData.queryContext().originalQuery().timeRange());
}
if (Objects.nonNull(respData.queryContext().originalQuery().timeRange())) {
docsMetadata.put("filters", respData.queryContext().originalQuery().timeRange());
}
}
return docsMetadata;
}
private Map<String, Object> getPageItemMetadata(IQSSearchResponse.PageItem pageItem) {
HashMap<String, Object> pageItemMetadata = new HashMap<>();
if (Objects.nonNull(pageItem)) {
if (Objects.nonNull(pageItem.hostname())) {
pageItemMetadata.put("hostname", pageItem.hostname());
}
if (Objects.nonNull(pageItem.title())) {
pageItemMetadata.put("title", pageItem.title());
}
if (Objects.nonNull(pageItem.markdownText())) {
pageItemMetadata.put("markdownText", pageItem.markdownText());
}
if (Objects.nonNull(pageItem.link())) {
pageItemMetadata.put("link", pageItem.link());
}
if (Objects.nonNull(pageItem.mainText())) {
pageItemMetadata.put("mainText", pageItem.mainText());
}
if (Objects.nonNull(pageItem.rerankScore())) {
pageItemMetadata.put("rerankScore", pageItem.rerankScore());
}
if (Objects.nonNull(pageItem.publishedTime())) {
pageItemMetadata.put("publishedTime", pageItem.publishedTime());
}
if (Objects.nonNull(pageItem.snippet())) {
pageItemMetadata.put("snippet", pageItem.snippet());
}
}
return pageItemMetadata;
}
public List<Document> limitResults(List<Document> documents, int minResults) {
int limit = Math.min(documents.size(), minResults);
return documents.subList(0, limit);
}
}