spring-ai-alibaba之Rag 增强问答质量

1、RAG(Retrieval-Augmented Generation,检索增强生成) ,该技术通过从外部知识库中检索相关信息,并将其作为提示(Prompt)输入给大型语言模型(LLMs),以增强模型处理知识密集型任务的能力

2、RAG 模块化案例

RAG 可以由一组模块化组件构成 《Rag 模块化》,结构化的工作流程保障 AI 模型生成质量

DocumentSelectFirst
java 复制代码
package com.spring.ai.tutorial.rag.service;

import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;

import java.util.Collections;
import java.util.List;

public class DocumentSelectFirst implements DocumentPostProcessor {

    @Override
    public List<Document> process(Query query, List<Document> documents) {
        return Collections.singletonList(documents.get(0));
    }
}

实现 DocumentPostProcessor 接口,从文档中挑选第一个

RagModuleController
java 复制代码
package com.spring.ai.tutorial.rag.controller;

import com.spring.ai.tutorial.rag.service.DocumentSelectFirst;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

@RestController
@RequestMapping("/rag/module")
public class RagModuleController {

    private static final Logger logger = LoggerFactory.getLogger(RagSimpleController.class);
    private final SimpleVectorStore simpleVectorStore;
    private final ChatClient.Builder chatClientBuilder;

    public RagModuleController(EmbeddingModel embeddingModel, ChatClient.Builder builder) {
        this.simpleVectorStore = SimpleVectorStore
                .builder(embeddingModel).build();
        this.chatClientBuilder = builder;
    }

    @GetMapping("/add")
    public void add() {
        logger.info("start add data");
        HashMap<String, Object> map = new HashMap<>();
        map.put("year", 2025);
        map.put("name", "yingzi");
        List<Document> documents = List.of(
                new Document("你的姓名是影子,湖南邵阳人,25年硕士毕业于北京科技大学,曾先后在百度、理想、快手实习,曾发表过一篇自然语言处理的sci,现在是一名AI研发工程师"),
                new Document("你的姓名是影子,专业领域包含的数学、前后端、大数据、自然语言处理", Map.of("year", 2024)),
                new Document("你姓名是影子,爱好是发呆、思考、运动", map));
        simpleVectorStore.add(documents);
    }

    @GetMapping("/chat-rag-advisor")
    public String chatRagAdvisor(@RequestParam(value = "query", defaultValue = "你好,请告诉我影子这个人的身份信息") String query) {
        logger.info("start chat with rag-advisor");

        // 1. Pre-Retrieval
            // 1.1 MultiQueryExpander
        MultiQueryExpander multiQueryExpander = MultiQueryExpander.builder()
                .chatClientBuilder(this.chatClientBuilder)
                .build();
            // 1.2 TranslationQueryTransformer
        TranslationQueryTransformer translationQueryTransformer = TranslationQueryTransformer.builder()
                .chatClientBuilder(this.chatClientBuilder)
                .targetLanguage("English")
                .build();

        // 2. Retrieval
            // 2.1 VectorStoreDocumentRetriever
        VectorStoreDocumentRetriever vectorStoreDocumentRetriever = VectorStoreDocumentRetriever.builder()
                .vectorStore(simpleVectorStore)
                .build();
        // 2.2 ConcatenationDocumentJoiner
        ConcatenationDocumentJoiner concatenationDocumentJoiner = new ConcatenationDocumentJoiner();

        // 3. Post-Retrieval
            // 3.1 DocumentSelectFirst
        DocumentSelectFirst documentSelectFirst = new DocumentSelectFirst();

        // 4. Generation
            // 4.1 ContextualQueryAugmenter
        ContextualQueryAugmenter contextualQueryAugmenter = ContextualQueryAugmenter.builder()
                .allowEmptyContext(true)
                .build();

        RetrievalAugmentationAdvisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
                // 扩充为原来的3倍
                .queryExpander(multiQueryExpander)
                // 转为英文
                .queryTransformers(translationQueryTransformer)
                // 丛向向量存储中检索文档
                .documentRetriever(vectorStoreDocumentRetriever)
                // 将检索到的文档进行拼接
                .documentJoiner(concatenationDocumentJoiner)
                // 对检索到的文档进行处理,选择第一个
                .documentPostProcessors(documentSelectFirst)
                // 对生成的查询进行上下文增强
                .queryAugmenter(contextualQueryAugmenter)
                .build();

        return this.chatClientBuilder.build().prompt(query)
                .advisors(retrievalAugmentationAdvisor)
                .call().content();
    }
}

在这个例子中,我们使用了所有的 RAG 模块组件

Pre-Retrieval

  1. 扩充问题:MultiQueryExpander
  2. 翻译为英文:TranslationQueryTransformer

Retrieval

  1. 从向量存储中检索文档:VectorStoreDocumentRetriever
  2. 将检索到的文档进行拼接:ConcatenationDocumentJoiner

Post-Retrieval

  1. 选择第一个文档:DocumentSelectFirst

Generation

  1. 对生成的查询进行上下文增强:ContextualQueryAugmenter
3、效果

1、首先,进来的 originalQuery 的原始文本为"你好,请告诉我影子这个人的身份信息"

2、经过 TranslationQueryTransformer 翻译为英文,并扩展增加 3 个,且保留原来的 1 个

3、从向量存储中检索文档

4、将检索到的文档进行拼接

5、选择第一个文档

6、增加的上下文信息,生成回答

4、RetrievalAugmentationAdvisor

RAG 增强器,利用模块化 RAG 组件(Query、Pre-Retrieval、Retrieval、Post-Retrieval、Generation)为用户文本添加额外信息

核心方法是 before、after

before:

  1. 创建原始查询(originalQuery):从用户输入的文本、参数和对话历史中构建一个 Query 对象,作为后续处理的基础
  2. 查询转换(transformedQuery):依次通过 queryTransformers 列表中的每个 QueryTransformer,对原始查询进行转换。每个转换器可以对查询内容进行修改(如规范化、重写等),形成最终的 transformedQuery
  3. 查询扩展(expandedQueries):若配置了 queryExpander,则用它将转换后的查询扩展为一个或多个查询(如同义词扩展、多轮问答等),否则只用转换后的查询本身
  4. 检索相关文档(documentsForQuery):对每个扩展后的查询,异步调用 getDocumentsForQuery 方法,通过 documentRetriever 检索与查询相关的文档。所有结果以 Map<Query, List<List>> 形式收集
  5. 文档合并(documents):使用 documentJoiner 将所有查询检索到的文档合并成一个文档列表,便于后续处理
  6. 文档后处理(Post-process):依次通过 documentPostProcessors 列表中的每个处理器,对合并后的文档进行进一步处理(如去重、排序、摘要等)。处理结果存入上下文 context
  7. 查询增强(Augment):用 queryAugmenter 将原始查询和检索到的文档结合,生成带有文档上下文信息的增强查询(如将文档内容拼接到用户问题后)
  8. 更新请求(Update Request):用增强后的查询内容更新 ChatClientRequest,并将文档上下文写入请求上下文,返回新的请求对象用于后续流程

after:

  1. 将 RAG 过程中检索到的文档添加到元数据中,键为"ragdocumentcontext"
java 复制代码
/*
 * Copyright 2023-2025 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.ai.rag.advisor;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

import reactor.core.scheduler.Scheduler;

import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner;
import org.springframework.ai.rag.retrieval.join.DocumentJoiner;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.core.task.TaskExecutor;
import org.springframework.core.task.support.ContextPropagatingTaskDecorator;
import org.springframework.lang.Nullable;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.Assert;

/**
 * Advisor that implements common Retrieval Augmented Generation (RAG) flows using the
 * building blocks defined in the {@link org.springframework.ai.rag} package and following
 * the Modular RAG Architecture.
 *
 * @author Christian Tzolov
 * @author Thomas Vitale
 * @since 1.0.0
 * @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
 * @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
 * @see <a href="https://export.arxiv.org/abs/2410.20878">arXiv:2410.20878</a>
 */
public final class RetrievalAugmentationAdvisor implements BaseAdvisor {

	public static final String DOCUMENT_CONTEXT = "rag_document_context";

	private final List<QueryTransformer> queryTransformers;

	@Nullable
	private final QueryExpander queryExpander;

	private final DocumentRetriever documentRetriever;

	private final DocumentJoiner documentJoiner;

	private final List<DocumentPostProcessor> documentPostProcessors;

	private final QueryAugmenter queryAugmenter;

	private final TaskExecutor taskExecutor;

	private final Scheduler scheduler;

	private final int order;

	private RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers,
			@Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever,
			@Nullable DocumentJoiner documentJoiner, @Nullable List<DocumentPostProcessor> documentPostProcessors,
			@Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler,
			@Nullable Integer order) {
		Assert.notNull(documentRetriever, "documentRetriever cannot be null");
		Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
		this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
		this.queryExpander = queryExpander;
		this.documentRetriever = documentRetriever;
		this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner();
		this.documentPostProcessors = documentPostProcessors != null ? documentPostProcessors : List.of();
		this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build();
		this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
		this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER;
		this.order = order != null ? order : 0;
	}

	public static Builder builder() {
		return new Builder();
	}

	@Override
	public ChatClientRequest before(ChatClientRequest chatClientRequest, @Nullable AdvisorChain advisorChain) {
		Map<String, Object> context = new HashMap<>(chatClientRequest.context());

		// 0. Create a query from the user text, parameters, and conversation history.
		Query originalQuery = Query.builder()
			.text(chatClientRequest.prompt().getUserMessage().getText())
			.history(chatClientRequest.prompt().getInstructions())
			.context(context)
			.build();

		// 1. Transform original user query based on a chain of query transformers.
		Query transformedQuery = originalQuery;
		for (var queryTransformer : this.queryTransformers) {
			transformedQuery = queryTransformer.apply(transformedQuery);
		}

		// 2. Expand query into one or multiple queries.
		List<Query> expandedQueries = this.queryExpander != null ? this.queryExpander.expand(transformedQuery)
				: List.of(transformedQuery);

		// 3. Get similar documents for each query.
		Map<Query, List<List<Document>>> documentsForQuery = expandedQueries.stream()
			.map(query -> CompletableFuture.supplyAsync(() -> getDocumentsForQuery(query), this.taskExecutor))
			.toList()
			.stream()
			.map(CompletableFuture::join)
			.collect(Collectors.toMap(Map.Entry::getKey, entry -> List.of(entry.getValue())));

		// 4. Combine documents retrieved based on multiple queries and from multiple data
		// sources.
		List<Document> documents = this.documentJoiner.join(documentsForQuery);

		// 5. Post-process the documents.
		for (var documentPostProcessor : this.documentPostProcessors) {
			documents = documentPostProcessor.process(originalQuery, documents);
		}
		context.put(DOCUMENT_CONTEXT, documents);

		// 5. Augment user query with the document contextual data.
		Query augmentedQuery = this.queryAugmenter.augment(originalQuery, documents);

		// 6. Update ChatClientRequest with augmented prompt.
		return chatClientRequest.mutate()
			.prompt(chatClientRequest.prompt().augmentUserMessage(augmentedQuery.text()))
			.context(context)
			.build();
	}

	/**
	 * Processes a single query by routing it to document retrievers and collecting
	 * documents.
	 */
	private Map.Entry<Query, List<Document>> getDocumentsForQuery(Query query) {
		List<Document> documents = this.documentRetriever.retrieve(query);
		return Map.entry(query, documents);
	}

	@Override
	public ChatClientResponse after(ChatClientResponse chatClientResponse, @Nullable AdvisorChain advisorChain) {
		ChatResponse.Builder chatResponseBuilder;
		if (chatClientResponse.chatResponse() == null) {
			chatResponseBuilder = ChatResponse.builder();
		}
		else {
			chatResponseBuilder = ChatResponse.builder().from(chatClientResponse.chatResponse());
		}
		chatResponseBuilder.metadata(DOCUMENT_CONTEXT, chatClientResponse.context().get(DOCUMENT_CONTEXT));
		return ChatClientResponse.builder()
			.chatResponse(chatResponseBuilder.build())
			.context(chatClientResponse.context())
			.build();
	}

	@Override
	public Scheduler getScheduler() {
		return this.scheduler;
	}

	@Override
	public int getOrder() {
		return this.order;
	}

	private static TaskExecutor buildDefaultTaskExecutor() {
		ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
		taskExecutor.setThreadNamePrefix("ai-advisor-");
		taskExecutor.setCorePoolSize(4);
		taskExecutor.setMaxPoolSize(16);
		taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
		taskExecutor.initialize();
		return taskExecutor;
	}

	public static final class Builder {

		private List<QueryTransformer> queryTransformers;

		private QueryExpander queryExpander;

		private DocumentRetriever documentRetriever;

		private DocumentJoiner documentJoiner;

		private List<DocumentPostProcessor> documentPostProcessors;

		private QueryAugmenter queryAugmenter;

		private TaskExecutor taskExecutor;

		private Scheduler scheduler;

		private Integer order;

		private Builder() {
		}

		public Builder queryTransformers(List<QueryTransformer> queryTransformers) {
			Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
			this.queryTransformers = queryTransformers;
			return this;
		}

		public Builder queryTransformers(QueryTransformer... queryTransformers) {
			Assert.notNull(queryTransformers, "queryTransformers cannot be null");
			Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
			this.queryTransformers = Arrays.asList(queryTransformers);
			return this;
		}

		public Builder queryExpander(QueryExpander queryExpander) {
			this.queryExpander = queryExpander;
			return this;
		}

		public Builder documentRetriever(DocumentRetriever documentRetriever) {
			this.documentRetriever = documentRetriever;
			return this;
		}

		public Builder documentJoiner(DocumentJoiner documentJoiner) {
			this.documentJoiner = documentJoiner;
			return this;
		}

		public Builder documentPostProcessors(List<DocumentPostProcessor> documentPostProcessors) {
			Assert.noNullElements(documentPostProcessors, "documentPostProcessors cannot contain null elements");
			this.documentPostProcessors = documentPostProcessors;
			return this;
		}

		public Builder documentPostProcessors(DocumentPostProcessor... documentPostProcessors) {
			Assert.notNull(documentPostProcessors, "documentPostProcessors cannot be null");
			Assert.noNullElements(documentPostProcessors, "documentPostProcessors cannot contain null elements");
			this.documentPostProcessors = Arrays.asList(documentPostProcessors);
			return this;
		}

		public Builder queryAugmenter(QueryAugmenter queryAugmenter) {
			this.queryAugmenter = queryAugmenter;
			return this;
		}

		public Builder taskExecutor(TaskExecutor taskExecutor) {
			this.taskExecutor = taskExecutor;
			return this;
		}

		public Builder scheduler(Scheduler scheduler) {
			this.scheduler = scheduler;
			return this;
		}

		public Builder order(Integer order) {
			this.order = order;
			return this;
		}

		public RetrievalAugmentationAdvisor build() {
			return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.documentRetriever,
					this.documentJoiner, this.documentPostProcessors, this.queryAugmenter, this.taskExecutor,
					this.scheduler, this.order);
		}

	}

}

Query

用于在 RAG 流程中表示查询的类

  • String text:查询的文本内容,用户输入的核心查询语句
  • List<Message> history:当前查询相关的对话历史记录
  • Map<String, Object> context:查询的上下文信息,键值对集合,用于存储与查询相关的额外数据
java 复制代码
/*
 * Copyright 2023-2024 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.ai.rag;

import java.util.List;
import java.util.Map;

import org.springframework.ai.chat.messages.Message;
import org.springframework.util.Assert;

/**
 * Represents a query in the context of a Retrieval Augmented Generation (RAG) flow.
 *
 * @param text the text of the query
 * @param history the messages in the conversation history
 * @param context the context of the query
 * @author Thomas Vitale
 * @since 1.0.0
 */
public record Query(String text, List<Message> history, Map<String, Object> context) {

	public Query {
		Assert.hasText(text, "text cannot be null or empty");
		Assert.notNull(history, "history cannot be null");
		Assert.noNullElements(history, "history elements cannot be null");
		Assert.notNull(context, "context cannot be null");
		Assert.noNullElements(context.keySet(), "context keys cannot be null");
	}

	public Query(String text) {
		this(text, List.of(), Map.of());
	}

	public Builder mutate() {
		return new Builder().text(this.text).history(this.history).context(this.context);
	}

	public static Builder builder() {
		return new Builder();
	}

	public static final class Builder {

		private String text;

		private List<Message> history = List.of();

		private Map<String, Object> context = Map.of();

		private Builder() {
		}

		public Builder text(String text) {
			this.text = text;
			return this;
		}

		public Builder history(List<Message> history) {
			this.history = history;
			return this;
		}

		public Builder history(Message... history) {
			this.history = List.of(history);
			return this;
		}

		public Builder context(Map<String, Object> context) {
			this.context = context;
			return this;
		}

		public Query build() {
			return new Query(this.text, this.history, this.context);
		}

	}

}

Pre-Retrieval

QueryExpander(查询扩展接口类)

作用:

  • 处理不规范的查询:通过提供替代的查询表达式,帮助改善查询质量
  • 分解复杂问题:将复杂的查询拆分为更简单的子查询,便于后续处理
MultiQueryExpander

扩展查询的类,通过使用 LLM 将单个查询扩展为多个语义上多样化的变体,这些变体能从不同角度或方面覆盖原始查询的主题,从而增加检索到相关结果的机会

字段的含义

  • ChatClient chatClient:用于与大语言模型进行交互,生成查询的变体
  • PromptTemplate promptTemplate:定义生成查询变体的提示模版。默认模板要求生成指定数量的查询变体,每个变体需覆盖不同的视角或方面。
  • boolean includeOriginal:是否在生成的查询列表中包含原始查询,默认为 true
  • int numberOfQueries:指定生成的查询变体的数量
java 复制代码
/*
 * Copyright 2023-2024 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.ai.rag.preretrieval.query.expansion;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/**
 * Uses a large language model to expand a query into multiple semantically diverse
 * variations to capture different perspectives, useful for retrieving additional
 * contextual information and increasing the chances of finding relevant results.
 *
 * <p>
 * Example usage: <pre>{@code
 * MultiQueryExpander expander = MultiQueryExpander.builder()
 *    .chatClientBuilder(chatClientBuilder)
 *    .numberOfQueries(3)
 *    .build();
 * List<Query> queries = expander.expand(new Query("How to run a Spring Boot app?"));
 * }</pre>
 *
 * @author Thomas Vitale
 * @since 1.0.0
 */
public final class MultiQueryExpander implements QueryExpander {

	private static final Logger logger = LoggerFactory.getLogger(MultiQueryExpander.class);

	private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
			You are an expert at information retrieval and search optimization.
			Your task is to generate {number} different versions of the given query.

			Each variant must cover different perspectives or aspects of the topic,
			while maintaining the core intent of the original query. The goal is to
			expand the search space and improve the chances of finding relevant information.

			Do not explain your choices or add any other text.
			Provide the query variants separated by newlines.

			Original query: {query}

			Query variants:
			""");

	private static final Boolean DEFAULT_INCLUDE_ORIGINAL = true;

	private static final Integer DEFAULT_NUMBER_OF_QUERIES = 3;

	private final ChatClient chatClient;

	private final PromptTemplate promptTemplate;

	private final boolean includeOriginal;

	private final int numberOfQueries;

	public MultiQueryExpander(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate,
			@Nullable Boolean includeOriginal, @Nullable Integer numberOfQueries) {
		Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");

		this.chatClient = chatClientBuilder.build();
		this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
		this.includeOriginal = includeOriginal != null ? includeOriginal : DEFAULT_INCLUDE_ORIGINAL;
		this.numberOfQueries = numberOfQueries != null ? numberOfQueries : DEFAULT_NUMBER_OF_QUERIES;

		PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "number", "query");
	}

	@Override
	public List<Query> expand(Query query) {
		Assert.notNull(query, "query cannot be null");

		logger.debug("Generating {} query variants", this.numberOfQueries);

		var response = this.chatClient.prompt()
			.user(user -> user.text(this.promptTemplate.getTemplate())
				.param("number", this.numberOfQueries)
				.param("query", query.text()))
			.call()
			.content();

		if (response == null) {
			logger.warn("Query expansion result is null. Returning the input query unchanged.");
			return List.of(query);
		}

		var queryVariants = Arrays.asList(response.split("\n"));

		if (CollectionUtils.isEmpty(queryVariants) || this.numberOfQueries != queryVariants.size()) {
			logger.warn(
					"Query expansion result does not contain the requested {} variants. Returning the input query unchanged.",
					this.numberOfQueries);
			return List.of(query);
		}

		var queries = queryVariants.stream()
			.filter(StringUtils::hasText)
			.map(queryText -> query.mutate().text(queryText).build())
			.collect(Collectors.toList());

		if (this.includeOriginal) {
			logger.debug("Including the original query in the result");
			queries.add(0, query);
		}

		return queries;
	}

	public static Builder builder() {
		return new Builder();
	}

	public static final class Builder {

		private ChatClient.Builder chatClientBuilder;

		private PromptTemplate promptTemplate;

		private Boolean includeOriginal;

		private Integer numberOfQueries;

		private Builder() {
		}

		public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
			this.chatClientBuilder = chatClientBuilder;
			return this;
		}

		public Builder promptTemplate(PromptTemplate promptTemplate) {
			this.promptTemplate = promptTemplate;
			return this;
		}

		public Builder includeOriginal(Boolean includeOriginal) {
			this.includeOriginal = includeOriginal;
			return this;
		}

		public Builder numberOfQueries(Integer numberOfQueries) {
			this.numberOfQueries = numberOfQueries;
			return this;
		}

		public MultiQueryExpander build() {
			return new MultiQueryExpander(this.chatClientBuilder, this.promptTemplate, this.includeOriginal,
					this.numberOfQueries);
		}

	}

}
QueryTransformer(查询转换接口类)

作用:

  1. 查询结构不完整或格式不佳
  2. 查询中的术语存在歧义
  3. 查询中使用了复杂或难以理解的词汇
  4. 查询使用了不受支持的语言
CompressionQueryTransformer

用于压缩对话历史和后续查询的类

作用:将对话上下文和后续查询合并为一个独立的查询,以捕获对话的核心内容。

适用场景:对话历史较长、后续查询与对话上下文相关

各字段含义:

  • ChatClient chatClient:用于与 LLM 交互,生成压缩后的查询
  • PromptTemplate promptTemplate:自定义用于生产压缩查询的提示文本
java 复制代码
/*
 * Copyright 2023-2025 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.ai.rag.preretrieval.query.transformation;

import java.util.List;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
 * Uses a large language model to compress a conversation history and a follow-up query
 * into a standalone query that captures the essence of the conversation.
 * <p>
 * This transformer is useful when the conversation history is long and the follow-up
 * query is related to the conversation context.
 *
 * @author Thomas Vitale
 * @since 1.0.0
 */
public class CompressionQueryTransformer implements QueryTransformer {

	private static final Logger logger = LoggerFactory.getLogger(CompressionQueryTransformer.class);

	private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
			Given the following conversation history and a follow-up query, your task is to synthesize
			a concise, standalone query that incorporates the context from the history.
			Ensure the standalone query is clear, specific, and maintains the user's intent.

			Conversation history:
			{history}

			Follow-up query:
			{query}

			Standalone query:
			""");

	private final ChatClient chatClient;

	private final PromptTemplate promptTemplate;

	public CompressionQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate) {
		Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");

		this.chatClient = chatClientBuilder.build();
		this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;

		PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "history", "query");
	}

	@Override
	public Query transform(Query query) {
		Assert.notNull(query, "query cannot be null");

		logger.debug("Compressing conversation history and follow-up query into a standalone query");

		var compressedQueryText = this.chatClient.prompt()
			.user(user -> user.text(this.promptTemplate.getTemplate())
				.param("history", formatConversationHistory(query.history()))
				.param("query", query.text()))
			.call()
			.content();

		if (!StringUtils.hasText(compressedQueryText)) {
			logger.warn("Query compression result is null/empty. Returning the input query unchanged.");
			return query;
		}

		return query.mutate().text(compressedQueryText).build();
	}

	private String formatConversationHistory(List<Message> history) {
		if (history.isEmpty()) {
			return "";
		}

		return history.stream()
			.filter(message -> message.getMessageType().equals(MessageType.USER)
					|| message.getMessageType().equals(MessageType.ASSISTANT))
			.map(message -> "%s: %s".formatted(message.getMessageType(), message.getText()))
			.collect(Collectors.joining("\n"));
	}

	public static Builder builder() {
		return new Builder();
	}

	public static final class Builder {

		private ChatClient.Builder chatClientBuilder;

		@Nullable
		private PromptTemplate promptTemplate;

		private Builder() {
		}

		public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
			this.chatClientBuilder = chatClientBuilder;
			return this;
		}

		public Builder promptTemplate(PromptTemplate promptTemplate) {
			this.promptTemplate = promptTemplate;
			return this;
		}

		public CompressionQueryTransformer build() {
			return new CompressionQueryTransformer(this.chatClientBuilder, this.promptTemplate);
		}

	}

}
RewriteQueryTransformer

重写用户查询的类

作用:通过 LLM 优化查询,以便在查询目标系统时获得更好的结果

适用场景:用户查询冗长、模糊、不包含相关信息

各字段含义

  • PromptTemplate promptTemplate:自定义重写模版
  • ChatClient chatClient:用于与 LLM 进行交互,重写查询
  • String targetSearchSystem:目标系统的名称,用于在提示模板中指定查询的目标系统,默认为"vector store"
java 复制代码
/*
 * Copyright 2023-2025 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.ai.rag.preretrieval.query.transformation;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
 * Uses a large language model to rewrite a user query to provide better results when
 * querying a target system, such as a vector store or a web search engine.
 * <p>
 * This transformer is useful when the user query is verbose, ambiguous, or contains
 * irrelevant information that may affect the quality of the search results.
 *
 * @author Thomas Vitale
 * @since 1.0.0
 * @see <a href="https://arxiv.org/pdf/2305.14283">arXiv:2305.14283</a>
 */
public class RewriteQueryTransformer implements QueryTransformer {

	private static final Logger logger = LoggerFactory.getLogger(RewriteQueryTransformer.class);

	private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
			Given a user query, rewrite it to provide better results when querying a {target}.
			Remove any irrelevant information, and ensure the query is concise and specific.

			Original query:
			{query}

			Rewritten query:
			""");

	private static final String DEFAULT_TARGET = "vector store";

	private final ChatClient chatClient;

	private final PromptTemplate promptTemplate;

	private final String targetSearchSystem;

	public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate,
			@Nullable String targetSearchSystem) {
		Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");

		this.chatClient = chatClientBuilder.build();
		this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
		this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : DEFAULT_TARGET;

		PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query");
	}

	@Override
	public Query transform(Query query) {
		Assert.notNull(query, "query cannot be null");

		logger.debug("Rewriting query to optimize for querying a {}.", this.targetSearchSystem);

		var rewrittenQueryText = this.chatClient.prompt()
			.user(user -> user.text(this.promptTemplate.getTemplate())
				.param("target", this.targetSearchSystem)
				.param("query", query.text()))
			.call()
			.content();

		if (!StringUtils.hasText(rewrittenQueryText)) {
			logger.warn("Query rewrite result is null/empty. Returning the input query unchanged.");
			return query;
		}

		return query.mutate().text(rewrittenQueryText).build();
	}

	public static Builder builder() {
		return new Builder();
	}

	public static final class Builder {

		private ChatClient.Builder chatClientBuilder;

		@Nullable
		private PromptTemplate promptTemplate;

		@Nullable
		private String targetSearchSystem;

		private Builder() {
		}

		public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
			this.chatClientBuilder = chatClientBuilder;
			return this;
		}

		public Builder promptTemplate(PromptTemplate promptTemplate) {
			this.promptTemplate = promptTemplate;
			return this;
		}

		public Builder targetSearchSystem(String targetSearchSystem) {
			this.targetSearchSystem = targetSearchSystem;
			return this;
		}

		public RewriteQueryTransformer build() {
			return new RewriteQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetSearchSystem);
		}

	}

}
TranslationQueryTransformer

将用户查询翻译为目标语言的工具类

作用:使用 LLM 将用户查询翻译为目标语言

适用场景:当嵌入模型仅支持特定语言,而用户查询使用不同语言时

各字段含义

  • ChatClient chatClient:与 LLM 交互,翻译为目标语言
  • PromptTemplate promptTemplate:自定义翻译请求的提示模版
  • String targetLanguage:目标语言
java 复制代码
/*
 * Copyright 2023-2025 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.ai.rag.preretrieval.query.transformation;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
 * Uses a large language model to translate a query to a target language that is supported
 * by the embedding model used to generate the document embeddings. If the query is
 * already in the target language, it is returned unchanged. If the language of the query
 * is unknown, it is also returned unchanged.
 * <p>
 * This transformer is useful when the embedding model is trained on a specific language
 * and the user query is in a different language.
 * <p>
 * Example usage: <pre>{@code
 * QueryTransformer transformer = TranslationQueryTransformer.builder()
 *    .chatClientBuilder(chatClientBuilder)
 *    .targetLanguage("english")
 *    .build();
 * Query transformedQuery = transformer.transform(new Query("Hvad er Danmarks hovedstad?"));
 * }</pre>
 *
 * @author Thomas Vitale
 * @since 1.0.0
 */
public final class TranslationQueryTransformer implements QueryTransformer {

	private static final Logger logger = LoggerFactory.getLogger(TranslationQueryTransformer.class);

	private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
			Given a user query, translate it to {targetLanguage}.
			If the query is already in {targetLanguage}, return it unchanged.
			If you don't know the language of the query, return it unchanged.
			Do not add explanations nor any other text.

			Original query: {query}

			Translated query:
			""");

	private final ChatClient chatClient;

	private final PromptTemplate promptTemplate;

	private final String targetLanguage;

	public TranslationQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate,
			String targetLanguage) {
		Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");
		Assert.hasText(targetLanguage, "targetLanguage cannot be null or empty");

		this.chatClient = chatClientBuilder.build();
		this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
		this.targetLanguage = targetLanguage;

		PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "targetLanguage", "query");
	}

	@Override
	public Query transform(Query query) {
		Assert.notNull(query, "query cannot be null");

		logger.debug("Translating query to target language: {}", this.targetLanguage);

		var translatedQueryText = this.chatClient.prompt()
			.user(user -> user.text(this.promptTemplate.getTemplate())
				.param("targetLanguage", this.targetLanguage)
				.param("query", query.text()))
			.call()
			.content();

		if (!StringUtils.hasText(translatedQueryText)) {
			logger.warn("Query translation result is null/empty. Returning the input query unchanged.");
			return query;
		}

		return query.mutate().text(translatedQueryText).build();
	}

	public static Builder builder() {
		return new Builder();
	}

	public static final class Builder {

		private ChatClient.Builder chatClientBuilder;

		@Nullable
		private PromptTemplate promptTemplate;

		private String targetLanguage;

		private Builder() {
		}

		public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
			this.chatClientBuilder = chatClientBuilder;
			return this;
		}

		public Builder promptTemplate(PromptTemplate promptTemplate) {
			this.promptTemplate = promptTemplate;
			return this;
		}

		public Builder targetLanguage(String targetLanguage) {
			this.targetLanguage = targetLanguage;
			return this;
		}

		public TranslationQueryTransformer build() {
			return new TranslationQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetLanguage);
		}

	}

}

Retrieval

DocumentRetriever(文档检索通用接口)
java 复制代码
/*
 * Copyright 2023-2024 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.ai.rag.retrieval.search;

import java.util.List;
import java.util.function.Function;

import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;

/**
 * Component responsible for retrieving {@link Document}s from an underlying data source,
 * such as a search engine, a vector store, a database, or a knowledge graph.
 *
 * @author Christian Tzolov
 * @author Thomas Vitale
 * @since 1.0.0
 */
public interface DocumentRetriever extends Function<Query, List<Document>> {

	/**
	 * Retrieves relevant documents from an underlying data source based on the given
	 * query.
	 * @param query The query to use for retrieving documents
	 * @return The list of relevant documents
	 */
	List<Document> retrieve(Query query);

	default List<Document> apply(Query query) {
		return retrieve(query);
	}

}
VectorStoreDocumentRetriever

用于从 VectorStore 中检索与输入查询语义相似的文档

各字段含义

  • VectorStore vectorStore:存储和检索文档的向量存储实例
  • Double similarityThreshold:相似度阈值,过滤相似度低于该值的文档
  • Integer topK:返回文档的上限
  • Supplier<Filter.Expression> filterExpression:运行时根据上下文动态生成过滤条件
java 复制代码
/*
 * Copyright 2023-2024 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.ai.rag.retrieval.search;

import java.util.List;
import java.util.function.Supplier;

import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
 * Retrieves documents from a vector store that are semantically similar to the input
 * query. It supports filtering based on metadata, similarity threshold, and top-k
 * results.
 *
 * <p>
 * Example usage: <pre>{@code
 * VectorStoreDocumentRetriever retriever = VectorStoreDocumentRetriever.builder()
 *     .vectorStore(vectorStore)
 *     .similarityThreshold(0.73)
 *     .topK(5)
 *     .filterExpression(filterExpression)
 *     .build();
 * List<Document> documents = retriever.retrieve(new Query("example query"));
 * }</pre>
 *
 * <p>
 * The {@link #FILTER_EXPRESSION} context key can be used to provide a filter expression
 * for a specific query. This key accepts either a string representation of a filter
 * expression or a {@link Filter.Expression} object directly.
 *
 * @author Thomas Vitale
 * @since 1.0.0
 */
public final class VectorStoreDocumentRetriever implements DocumentRetriever {

	public static final String FILTER_EXPRESSION = "vector_store_filter_expression";

	private final VectorStore vectorStore;

	private final Double similarityThreshold;

	private final Integer topK;

	// Supplier to allow for lazy evaluation of the filter expression,
	// which may depend on the execution content. For example, you may want to
	// filter dynamically based on the current user's identity or tenant ID.
	private final Supplier<Filter.Expression> filterExpression;

	public VectorStoreDocumentRetriever(VectorStore vectorStore, @Nullable Double similarityThreshold,
			@Nullable Integer topK, @Nullable Supplier<Filter.Expression> filterExpression) {
		Assert.notNull(vectorStore, "vectorStore cannot be null");
		Assert.isTrue(similarityThreshold == null || similarityThreshold >= 0.0,
				"similarityThreshold must be equal to or greater than 0.0");
		Assert.isTrue(topK == null || topK > 0, "topK must be greater than 0");
		this.vectorStore = vectorStore;
		this.similarityThreshold = similarityThreshold != null ? similarityThreshold
				: SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL;
		this.topK = topK != null ? topK : SearchRequest.DEFAULT_TOP_K;
		this.filterExpression = filterExpression != null ? filterExpression : () -> null;
	}

	@Override
	public List<Document> retrieve(Query query) {
		Assert.notNull(query, "query cannot be null");
		var requestFilterExpression = computeRequestFilterExpression(query);
		var searchRequest = SearchRequest.builder()
			.query(query.text())
			.filterExpression(requestFilterExpression)
			.similarityThreshold(this.similarityThreshold)
			.topK(this.topK)
			.build();
		return this.vectorStore.similaritySearch(searchRequest);
	}

	/**
	 * Computes the filter expression to use for the current request.
	 * <p>
	 * The filter expression can be provided in the query context using the
	 * {@link #FILTER_EXPRESSION} key. This key accepts either a string representation of
	 * a filter expression or a {@link Filter.Expression} object directly.
	 * <p>
	 * If no filter expression is provided in the context, the default filter expression
	 * configured for this retriever is used.
	 * @param query the query containing potential context with filter expression
	 * @return the filter expression to use for the request
	 */
	private Filter.Expression computeRequestFilterExpression(Query query) {
		var contextFilterExpression = query.context().get(FILTER_EXPRESSION);
		if (contextFilterExpression != null) {
			if (contextFilterExpression instanceof Filter.Expression) {
				return (Filter.Expression) contextFilterExpression;
			}
			else if (StringUtils.hasText(contextFilterExpression.toString())) {
				return new FilterExpressionTextParser().parse(contextFilterExpression.toString());
			}
		}
		return this.filterExpression.get();
	}

	public static Builder builder() {
		return new Builder();
	}

	/**
	 * Builder for {@link VectorStoreDocumentRetriever}.
	 */
	public static final class Builder {

		private VectorStore vectorStore;

		private Double similarityThreshold;

		private Integer topK;

		private Supplier<Filter.Expression> filterExpression;

		private Builder() {
		}

		public Builder vectorStore(VectorStore vectorStore) {
			this.vectorStore = vectorStore;
			return this;
		}

		public Builder similarityThreshold(Double similarityThreshold) {
			this.similarityThreshold = similarityThreshold;
			return this;
		}

		public Builder topK(Integer topK) {
			this.topK = topK;
			return this;
		}

		public Builder filterExpression(Filter.Expression filterExpression) {
			this.filterExpression = () -> filterExpression;
			return this;
		}

		public Builder filterExpression(Supplier<Filter.Expression> filterExpression) {
			this.filterExpression = filterExpression;
			return this;
		}

		public VectorStoreDocumentRetriever build() {
			return new VectorStoreDocumentRetriever(this.vectorStore, this.similarityThreshold, this.topK,
					this.filterExpression);
		}

	}

}
DocumentJoiner(文档统一接口类)

将基于多个查询和多个数据源检索的文档合并为一个单一的文档集合

作用:文档合并(将不同数据源检索的文档合并为一个);去重处理(合并过程中,处理重复文档);排名策略(支持对合并后的文档进行排名处理)

适用场景:从多个查询或多个数据源检索文档,并将结果合并为一个统一集合的场景

java 复制代码
/*
 * Copyright 2023-2024 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.ai.rag.retrieval.join;

import java.util.List;
import java.util.Map;
import java.util.function.Function;

import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;

/**
 * A component for combining documents retrieved based on multiple queries and from
 * multiple data sources into a single collection of documents. As part of the joining
 * process, it can also handle duplicate documents and reciprocal ranking strategies.
 *
 * @author Thomas Vitale
 * @since 1.0.0
 */
public interface DocumentJoiner extends Function<Map<Query, List<List<Document>>>, List<Document>> {

	/**
	 * Joins documents retrieved across multiple queries and daa sources.
	 * @param documentsForQuery a map of queries and the corresponding list of documents
	 * retrieved
	 * @return a single collection of documents
	 */
	List<Document> join(Map<Query, List<List<Document>>> documentsForQuery);

	default List<Document> apply(Map<Query, List<List<Document>>> documentsForQuery) {
		return join(documentsForQuery);
	}

}
ConcatenationDocumentJoiner

合并基于多个查询和多个数据源检索到的文档

java 复制代码
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.alibaba.cloud.ai.application.modulerag.join;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.retrieval.join.DocumentJoiner;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
 * @author yuluo
 * @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
 * Retrieval 步骤 负责查询向量存储等数据系统并检索和用户 query 相关性最高的 Document。
 * DocumentJoiner:将从多个 query 和从多个数据源检索到的 Document 合并为一个 Document 集合;
 */

public class ConcatenationDocumentJoiner implements DocumentJoiner {

	private static final Logger logger = LoggerFactory.getLogger(ConcatenationDocumentJoiner.class);

	@NotNull
	@Override
	public List<Document> join(
			@Nullable Map<Query, List<List<Document>>> documentsForQuery
	) {

		Assert.notNull(documentsForQuery, "documentsForQuery cannot be null");
		Assert.noNullElements(documentsForQuery.keySet(), "documentsForQuery cannot contain null keys");
		Assert.noNullElements(documentsForQuery.values(), "documentsForQuery cannot contain null values");
		logger.debug("Joining documents by concatenation");

		Map<Query, List<List<Document>>> selectDocuments = selectDocuments(documentsForQuery, 10);

		Set<String> seen = new HashSet<>();

		return selectDocuments.values().stream()
				// Flatten List<List<Documents>> to Stream<List<Documents>.
				.flatMap(List::stream)
				// Flatten Stream<List<Documents> to Stream<Documents>.
				.flatMap(List::stream)
				.filter(doc -> {
					List<String> keys = extractKeys(doc);
					for (String key : keys) {
						if (!seen.add(key)) {
							logger.info("Duplicate document metadata: {}",doc.getMetadata());
							// Duplicate keys found.
							return false;
						}
					}
					// All keys are unique.
					return true;
				})
				.collect(Collectors.toList());
	}

	private Map<Query, List<List<Document>>> selectDocuments(
			Map<Query, List<List<Document>>> documentsForQuery,
			int totalDocuments
	) {

		Map<Query, List<List<Document>>> selectDocumentsForQuery = new HashMap<>();

		int numberOfQueries = documentsForQuery.size();

		if (Objects.equals(0, numberOfQueries)) {

			return selectDocumentsForQuery;
		}

		int baseCount = totalDocuments / numberOfQueries;
		int remainder = totalDocuments % numberOfQueries;

		// To ensure consistent distribution. sort the keys (optional)
		List<Query> sortedQueries = new ArrayList<>(documentsForQuery.keySet());
		// Other sort
		// sortedQueries.sort(Comparator.comparing(Query::getSomeProperty));
		Iterator<Query> iterator = sortedQueries.iterator();

		for (int i = 0; i < numberOfQueries; i ++) {
			Query query = sortedQueries.get(i);
			int documentToSelect = baseCount + (i < remainder ? 1 : 0);
			List<List<Document>> originalDocuments = documentsForQuery.get(query);
			List<List<Document>> selectedNestLists = new ArrayList<>();

			int remainingDocuments = documentToSelect;
			for (List<Document> documentList : originalDocuments) {
				if (remainingDocuments <= 0) {
					break;
				}

				List<Document> selectSubList = new ArrayList<>();
				for (Document docs : documentList) {
					if (remainingDocuments <= 0) {
						break;
					}

					selectSubList.add(docs);
					remainingDocuments --;
				}

				if (!selectSubList.isEmpty()) {
					selectedNestLists.add(selectSubList);
				}
			}

			selectDocumentsForQuery.put(query, selectedNestLists);
		}

		return selectDocumentsForQuery;
	}

	private List<String> extractKeys(Document document) {

		List<String> keys = new ArrayList<>();

		if (Objects.nonNull(document)) {
			keys.add(document.getId());
		}

		if (Objects.nonNull(document.getMetadata())) {
			Object src = document.getMetadata().get("source");
			if (src instanceof String) {
				keys.add("SOURCE:" + src);
			}

			Object fn = document.getMetadata().get("file_name");
			if (fn instanceof String) {
				keys.add("FILE_NAME:" + fn);
			}
		}

		return keys;
	}

}

Post-Retrieval

DocumentPostProcessor

检索后,对文档进行逻辑出现,如压缩、排名、选择部分等,通过实现该接口

java 复制代码
/*
 * Copyright 2023-2025 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.ai.rag.postretrieval.document;

import java.util.List;
import java.util.function.BiFunction;

import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;

/**
 * A component for post-processing retrieved documents based on a query, addressing
 * challenges such as "lost-in-the-middle", context length restrictions from the model,
 * and the need to reduce noise and redundancy in the retrieved information.
 * <p>
 * For example, it could rank documents based on their relevance to the query, remove
 * irrelevant or redundant documents, or compress the content of each document to reduce
 * noise and redundancy.
 *
 * @author Thomas Vitale
 * @since 1.0.0
 */
public interface DocumentPostProcessor extends BiFunction<Query, List<Document>, List<Document>> {

	List<Document> process(Query query, List<Document> documents);

	default List<Document> apply(Query query, List<Document> documents) {
		return process(query, documents);
	}

}

Generation

QueryAugmenter(查询增强接口类)

通过将用户查询与额外的上下文数据结合,从而为 LLM 提供更丰富的背景信息

java 复制代码
/*
 * Copyright 2023-2024 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.ai.rag.generation.augmentation;

import java.util.List;
import java.util.function.BiFunction;

import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;

/**
 * A component for augmenting an input query with additional data, useful to provide a
 * large language model with the necessary context to answer the user query.
 *
 * @author Thomas Vitale
 * @since 1.0.0
 */
public interface QueryAugmenter extends BiFunction<Query, List<Document>, Query> {

	/**
	 * Augments the user query with contextual data.
	 * @param query The user query to augment
	 * @param documents The contextual data to use for augmentation
	 * @return The augmented query
	 */
	Query augment(Query query, List<Document> documents);

	default Query apply(Query query, List<Document> documents) {
		return augment(query, documents);
	}

}
ContextualQueryAugmenter

增强用户查询的类,通过将用户查询与提供的文档内容结合,生成一个增强后的查询,为后续的 RAG 流程提供更丰富的背景信息

各字段的含义

  • PromptTemplate promptTemplate:用户自定义提示模版,用于生成增强查询
  • PromptTemplate emptyContextPromptTemplate:用户自定义为空时的上下文提示模版
  • boolean allowEmptyContext:是否允许空上下文
  • Function<List<Document>, String> documentFormatter:用户自定义的文档格式化函数
java 复制代码
/*
 * Copyright 2023-2024 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.ai.rag.generation.augmentation;

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
 * Augments the user query with contextual data from the content of the provided
 * documents.
 *
 * <p>
 * Example usage: <pre>{@code
 * QueryAugmenter augmenter = ContextualQueryAugmenter.builder()
 *    .allowEmptyContext(false)
 *    .build();
 * Query augmentedQuery = augmenter.augment(query, documents);
 * }</pre>
 *
 * @author Thomas Vitale
 * @since 1.0.0
 */
public final class ContextualQueryAugmenter implements QueryAugmenter {

	private static final Logger logger = LoggerFactory.getLogger(ContextualQueryAugmenter.class);

	private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
			Context information is below.

			---------------------
			{context}
			---------------------

			Given the context information and no prior knowledge, answer the query.

			Follow these rules:

			1. If the answer is not in the context, just say that you don't know.
			2. Avoid statements like "Based on the context..." or "The provided information...".

			Query: {query}

			Answer:
			""");

	private static final PromptTemplate DEFAULT_EMPTY_CONTEXT_PROMPT_TEMPLATE = new PromptTemplate("""
			The user query is outside your knowledge base.
			Politely inform the user that you can't answer it.
			""");

	private static final boolean DEFAULT_ALLOW_EMPTY_CONTEXT = false;

	/**
	 * Default document formatter that just joins document text with newlines
	 */
	private static final Function<List<Document>, String> DEFAULT_DOCUMENT_FORMATTER = documents -> documents.stream()
		.map(Document::getText)
		.collect(Collectors.joining(System.lineSeparator()));

	private final PromptTemplate promptTemplate;

	private final PromptTemplate emptyContextPromptTemplate;

	private final boolean allowEmptyContext;

	private final Function<List<Document>, String> documentFormatter;

	public ContextualQueryAugmenter(@Nullable PromptTemplate promptTemplate,
			@Nullable PromptTemplate emptyContextPromptTemplate, @Nullable Boolean allowEmptyContext,
			@Nullable Function<List<Document>, String> documentFormatter) {
		this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
		this.emptyContextPromptTemplate = emptyContextPromptTemplate != null ? emptyContextPromptTemplate
				: DEFAULT_EMPTY_CONTEXT_PROMPT_TEMPLATE;
		this.allowEmptyContext = allowEmptyContext != null ? allowEmptyContext : DEFAULT_ALLOW_EMPTY_CONTEXT;
		this.documentFormatter = documentFormatter != null ? documentFormatter : DEFAULT_DOCUMENT_FORMATTER;
		PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "query", "context");
	}

	@Override
	public Query augment(Query query, List<Document> documents) {
		Assert.notNull(query, "query cannot be null");
		Assert.notNull(documents, "documents cannot be null");

		logger.debug("Augmenting query with contextual data");

		if (documents.isEmpty()) {
			return augmentQueryWhenEmptyContext(query);
		}

		// 1. Collect content from documents.
		String documentContext = this.documentFormatter.apply(documents);

		// 2. Define prompt parameters.
		Map<String, Object> promptParameters = Map.of("query", query.text(), "context", documentContext);

		// 3. Augment user prompt with document context.
		return new Query(this.promptTemplate.render(promptParameters));
	}

	private Query augmentQueryWhenEmptyContext(Query query) {
		if (this.allowEmptyContext) {
			logger.debug("Empty context is allowed. Returning the original query.");
			return query;
		}
		logger.debug("Empty context is not allowed. Returning a specific query for empty context.");
		return new Query(this.emptyContextPromptTemplate.render());
	}

	public static Builder builder() {
		return new Builder();
	}

	public static class Builder {

		private PromptTemplate promptTemplate;

		private PromptTemplate emptyContextPromptTemplate;

		private Boolean allowEmptyContext;

		private Function<List<Document>, String> documentFormatter;

		public Builder promptTemplate(PromptTemplate promptTemplate) {
			this.promptTemplate = promptTemplate;
			return this;
		}

		public Builder emptyContextPromptTemplate(PromptTemplate emptyContextPromptTemplate) {
			this.emptyContextPromptTemplate = emptyContextPromptTemplate;
			return this;
		}

		public Builder allowEmptyContext(Boolean allowEmptyContext) {
			this.allowEmptyContext = allowEmptyContext;
			return this;
		}

		public Builder documentFormatter(Function<List<Document>, String> documentFormatter) {
			this.documentFormatter = documentFormatter;
			return this;
		}

		public ContextualQueryAugmenter build() {
			return new ContextualQueryAugmenter(this.promptTemplate, this.emptyContextPromptTemplate,
					this.allowEmptyContext, this.documentFormatter);
		}

	}

}
相关推荐
Bug退退退12342 分钟前
RabbitMQ 高级特性之消息分发
java·分布式·spring·rabbitmq
Jack_hrx2 小时前
基于 Drools 的规则引擎性能调优实践:架构、缓存与编译优化全解析
java·性能优化·规则引擎·drools·规则编译
Danceful_YJ2 小时前
15.手动实现BatchNorm(BN)
人工智能·深度学习·神经网络·batchnorm
二进制person2 小时前
数据结构--准备知识
java·开发语言·数据结构
半梦半醒*2 小时前
H3CNE综合实验之机器人
java·开发语言·网络
idolyXyz2 小时前
[spring6: SpringApplication.run]-应用启动
spring
wh_xia_jun2 小时前
医疗数据分析中标准化的作用
人工智能·机器学习
消失的旧时光-19433 小时前
Android模块化架构:基于依赖注入和服务定位器的解耦方案
android·java·架构·kotlin
@ chen3 小时前
Spring Boot 解决跨域问题
java·spring boot·后端
jndingxin4 小时前
OpenCV直线段检测算法类cv::line_descriptor::LSDDetector
人工智能·opencv·算法