聊聊Spring AI的RetrievalAugmentationAdvisor

本文主要研究一下Spring AI的RetrievalAugmentationAdvisor

BaseAdvisor

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisor.java

复制代码
public interface BaseAdvisor extends CallAroundAdvisor, StreamAroundAdvisor {

	Scheduler DEFAULT_SCHEDULER = Schedulers.boundedElastic();

	@Override
	default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
		Assert.notNull(advisedRequest, "advisedRequest cannot be null");
		Assert.notNull(chain, "chain cannot be null");

		AdvisedRequest processedAdvisedRequest = before(advisedRequest);
		AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest);
		return after(advisedResponse);
	}

	@Override
	default Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
		Assert.notNull(advisedRequest, "advisedRequest cannot be null");
		Assert.notNull(chain, "chain cannot be null");
		Assert.notNull(getScheduler(), "scheduler cannot be null");

		Flux<AdvisedResponse> advisedResponses = Mono.just(advisedRequest)
			.publishOn(getScheduler())
			.map(this::before)
			.flatMapMany(chain::nextAroundStream);

		return advisedResponses.map(ar -> {
			if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) {
				ar = after(ar);
			}
			return ar;
		}).onErrorResume(error -> Flux.error(new IllegalStateException("Stream processing failed", error)));
	}

	@Override
	default String getName() {
		return this.getClass().getSimpleName();
	}

	/**
	 * Logic to be executed before the rest of the advisor chain is called.
	 */
	AdvisedRequest before(AdvisedRequest request);

	/**
	 * Logic to be executed after the rest of the advisor chain is called.
	 */
	AdvisedResponse after(AdvisedResponse advisedResponse);

	/**
	 * Scheduler used for processing the advisor logic when streaming.
	 */
	default Scheduler getScheduler() {
		return DEFAULT_SCHEDULER;
	}

}

BaseAdvisor继承了CallAroundAdvisor、StreamAroundAdvisor,它提供了aroundCall、aroundStream的default,在执行chain的next之前执行before,之后执行after方法

RetrievalAugmentationAdvisor

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java

复制代码
public final class RetrievalAugmentationAdvisor implements BaseAdvisor {

	public static final String DOCUMENT_CONTEXT = "rag_document_context";

	private final List<QueryTransformer> queryTransformers;

	@Nullable
	private final QueryExpander queryExpander;

	private final DocumentRetriever documentRetriever;

	private final DocumentJoiner documentJoiner;

	private final QueryAugmenter queryAugmenter;

	private final TaskExecutor taskExecutor;

	private final Scheduler scheduler;

	private final int order;

	public RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers,
			@Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever,
			@Nullable DocumentJoiner documentJoiner, @Nullable QueryAugmenter queryAugmenter,
			@Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) {
		Assert.notNull(documentRetriever, "documentRetriever cannot be null");
		Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
		this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
		this.queryExpander = queryExpander;
		this.documentRetriever = documentRetriever;
		this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner();
		this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build();
		this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
		this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER;
		this.order = order != null ? order : 0;
	}

	public static Builder builder() {
		return new Builder();
	}

	@Override
	public AdvisedRequest before(AdvisedRequest request) {
		Map<String, Object> context = new HashMap<>(request.adviseContext());

		// 0. Create a query from the user text, parameters, and conversation history.
		Query originalQuery = Query.builder()
			.text(new PromptTemplate(request.userText(), request.userParams()).render())
			.history(request.messages())
			.context(context)
			.build();

		// 1. Transform original user query based on a chain of query transformers.
		Query transformedQuery = originalQuery;
		for (var queryTransformer : this.queryTransformers) {
			transformedQuery = queryTransformer.apply(transformedQuery);
		}

		// 2. Expand query into one or multiple queries.
		List<Query> expandedQueries = this.queryExpander != null ? this.queryExpander.expand(transformedQuery)
				: List.of(transformedQuery);

		// 3. Get similar documents for each query.
		Map<Query, List<List<Document>>> documentsForQuery = expandedQueries.stream()
			.map(query -> CompletableFuture.supplyAsync(() -> getDocumentsForQuery(query), this.taskExecutor))
			.toList()
			.stream()
			.map(CompletableFuture::join)
			.collect(Collectors.toMap(Map.Entry::getKey, entry -> List.of(entry.getValue())));

		// 4. Combine documents retrieved based on multiple queries and from multiple data
		// sources.
		List<Document> documents = this.documentJoiner.join(documentsForQuery);
		context.put(DOCUMENT_CONTEXT, documents);

		// 5. Augment user query with the document contextual data.
		Query augmentedQuery = this.queryAugmenter.augment(originalQuery, documents);

		// 6. Update advised request with augmented prompt.
		return AdvisedRequest.from(request).userText(augmentedQuery.text()).adviseContext(context).build();
	}

	/**
	 * Processes a single query by routing it to document retrievers and collecting
	 * documents.
	 */
	private Map.Entry<Query, List<Document>> getDocumentsForQuery(Query query) {
		List<Document> documents = this.documentRetriever.retrieve(query);
		return Map.entry(query, documents);
	}

	@Override
	public AdvisedResponse after(AdvisedResponse advisedResponse) {
		ChatResponse.Builder chatResponseBuilder;
		if (advisedResponse.response() == null) {
			chatResponseBuilder = ChatResponse.builder();
		}
		else {
			chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
		}
		chatResponseBuilder.metadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT));
		return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
	}

	@Override
	public Scheduler getScheduler() {
		return this.scheduler;
	}

	@Override
	public int getOrder() {
		return this.order;
	}

	private static TaskExecutor buildDefaultTaskExecutor() {
		ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
		taskExecutor.setThreadNamePrefix("ai-advisor-");
		taskExecutor.setCorePoolSize(4);
		taskExecutor.setMaxPoolSize(16);
		taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
		taskExecutor.initialize();
		return taskExecutor;
	}

	//......

}	

RetrievalAugmentationAdvisor实现了BaseAdvisor接口,其before方法先通过queryTransformers来转换原始query,之后使用queryExpander来扩展query,接着通过documentRetriever来对扩展query检索相关文档,然后通过documentJoiner.join来把检索结果结合起来,作为rag_document_context放到context中,最后通过queryAugmenter.augment(originalQuery, documents)来生成最后的augmentedQuery;其after方法将before存储下来的key为rag_document_context的检索到的文档作为metadata附加到response中。

示例

复制代码
	@Test
	void ragWithRewrite() {
		String question = "Where are the main characters going?";

		RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder()
			.queryTransformers(RewriteQueryTransformer.builder()
				.chatClientBuilder(ChatClient.builder(this.openAiChatModel))
				.targetSearchSystem("vector store")
				.build())
			.documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build())
			.build();

		ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel)
			.build()
			.prompt()
			.user(question)
			.advisors(ragAdvisor)
			.call()
			.chatResponse();

		assertThat(chatResponse).isNotNull();

		String response = chatResponse.getResult().getOutput().getText();
		System.out.println(response);
		assertThat(response).containsIgnoringCase("Loch of the Stars");

		evaluateRelevancy(question, chatResponse);
	}

这里使用了RewriteQueryTransformer、VectorStoreDocumentRetriever

小结

Spring AI定义了BaseAdvisor,它继承了CallAroundAdvisor、StreamAroundAdvisor,它提供了aroundCall、aroundStream的default,在执行chain的next之前执行before,之后执行after方法。RetrievalAugmentationAdvisor实现了BaseAdvisor接口,其before方法先通过queryTransformers来转换原始query,之后使用queryExpander来扩展query,接着通过documentRetriever来对扩展query检索相关文档,然后通过documentJoiner.join来把检索结果结合起来,作为rag_document_context放到context中,最后通过queryAugmenter.augment(originalQuery, documents)来生成最后的augmentedQuery;其after方法将before存储下来的key为rag_document_context的检索到的文档作为metadata附加到response中。

doc

相关推荐
大写-凌祁3 小时前
零基础入门深度学习:从理论到实战,GitHub+开源资源全指南(2025最新版)
人工智能·深度学习·开源·github
焦耳加热4 小时前
阿德莱德大学Nat. Commun.:盐模板策略实现废弃塑料到单原子催化剂的高值转化,推动环境与能源催化应用
人工智能·算法·机器学习·能源·材料工程
深空数字孪生4 小时前
储能调峰新实践:智慧能源平台如何保障风电消纳与电网稳定?
大数据·人工智能·物联网
wan5555cn4 小时前
多张图片生成视频模型技术深度解析
人工智能·笔记·深度学习·算法·音视频
格林威5 小时前
机器视觉检测的光源基础知识及光源选型
人工智能·深度学习·数码相机·yolo·计算机视觉·视觉检测
今天也要学习吖5 小时前
谷歌nano banana官方Prompt模板发布,解锁六大图像生成风格
人工智能·学习·ai·prompt·nano banana·谷歌ai
Hello123网站5 小时前
glean-企业级AI搜索和知识发现平台
人工智能·产品运营·ai工具
AKAMAI5 小时前
Queue-it 为数十亿用户增强在线体验
人工智能·云原生·云计算
索迪迈科技5 小时前
INDEMIND亮相2025科技创变者大会,以机器人空间智能技术解锁具身智能新边界
人工智能·机器人·扫地机器人·空间智能·陪伴机器人
栒U6 小时前
一文从零部署vLLM+qwen0.5b(mac本地版,不可以实操GPU单元)
人工智能·macos·vllm