完整的代码下载:基于SpringAIAlibaba的多智能体RAG应用资源-CSDN下载
已同步上传至github:1998y12/multi-agent-rag-spring: a multi-agent RAG application with Spring AI Alibaba
前言
近期,阿里巴巴正式发布了Spring AI Alibaba,一款以 Spring AI 为基础,深度集成百炼平台,支持 ChatBot、工作流、多智能体应用开发模式的 AI 框架,同时开源在Github上。
以下是官网列出来的一些特点:
-
Graph 多智能体框架。 基于 Spring AI Alibaba Graph 开发者可快速构建工作流、多智能体应用,无需关心流程编排、上下文记忆管理等底层实现。支持 Dify DSL 自动生成 Graph 代码,支持 Graph 可视化调试。
-
通过 AI 生态集成,解决企业智能体落地过程中关心的痛点问题。 Spring AI Alibaba 支持与百炼平台深度集成,提供模型接入、RAG知识库解决方案;支持 ARMS、Langfuse 等 AI 可观测产品无缝接入;支持企业级的 MCP 集成,包括 Nacos MCP Registry 分布式注册与发现、自动 Router 路由等。
-
探索具备自主规划能力的通用智能体产品与平台。 社区发布了基于 Spring AI Alibaba 框架实现的 JManus 智能体,除了对标 Manus 等通用智能体的产品能力外。社区在积极探索自主规划在智能体开发方向的应用,为开发者提供从低代码、高代码到零代码构建智能体的更灵活选择,加速智能体在企业垂直业务方向的快速落地。
实际的开发体验下来,感觉跟python框架下的Langgraph的设计理念比较类似,都是通过图/工作流(graph/workflow)来描述整个AI系统的架构和数据流动,算是Langgrraph的java实现吧。
Spring AI Alibaba中,比较重要的就是状态(state)、节点(node)和边(edge)这三个概念。通过对它们进行组合,可以搭建复杂的多智能体架构,从而完成相互协作。
- 状态:代表整个AI应用当前快照,智能体共享的数据结构。
- 节点:一个用来表示智能体具体处理逻辑的函数。接收当前状态作为输入,执行一些计算或副作用,并返回更新的状态。在Spring AI Alibaba中,就是实现了NodeAction接口的对象,通过apply方法来对状态进行处理,并返回一个Map,代表有更新的状态。
- 边:根据当前状态决定下一步执行哪个节点的函数。在Spring AI Alibaba中,就是实现了EdgeAction接口的对象,在apply方法中进行相关逻辑处理,最终返回一个字符串,表示下一个节点。
功能概述
这里我就根据Spring AI Alibaba官方文档和Langgraph的一个rag示例,实现一个优化版的多智能体自适应的RAG应用,主要功能包括:
1、对于知识库的相关问题,从知识库中进行生成回答,并做一定的增强和后处理。
2、对于非知识库相关的问题,通过外部工具来获取相关知识,再按照rag流程进行回答。
虽然算是小demo,但基本也涵盖了大部分组件,包括整个rag的模块搭建(文档切分、向量库、检索器、LLM生成)、提示词prompt模板动态渲染、外部工具Tool集成、模型记忆memory设计以及Spring预构建rag智能体的使用等等,并最终通过Spring AI Alibaba workflow(graph)框架实现多智能体协作。
完整的代码下载:基于SpringAIAlibaba的多智能体RAG应用资源-CSDN下载
已同步上传至github:1998y12/multi-agent-rag-spring: a multi-agent RAG application with Spring AI Alibaba
工作流图
整体的工作流图结构如下:

工作流程
用户输入问题,经过判断边,对用户问题进行识别:
- 如果和知识库内容相关,则调用Spring预构建的增强rag生成回答;
1.1 评估回答:
1.1.1 是否存在"幻觉",即与事实不符,如果是则重新生成,否则进入下一步;
1.1.2 是否有回应用户问题,如果没有回应,则进入重写问题节点对问题进行重写;否则进入下一步;
1.2 生成最终回答。
- 如果和知识库内容无关,则调用外部工具,获取知识文档;
2.1 获取到知识文档后,同样进入以上的rag流程。不同的是,这里进入的rag不再是spring ai预先提供的rag智能体,而是采用自己构建的rag智能体,并且会对检索到的文档进行问题相关性评估。
2.2 生成最终回答。
代码实现
1. 配置文件
pom.xml:
XML
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<!-- <parent>-->
<!-- <groupId>org.springframework.boot</groupId>-->
<!-- <artifactId>spring-boot-starter-parent</artifactId>-->
<!-- <version>3.5.3</version>-->
<!-- <relativePath/> <!– lookup parent from repository –>-->
<!-- </parent>-->
<groupId>com.ai</groupId>
<artifactId>spring-ai-alibaba-demo</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>spring-ai-alibaba-demo</name>
<description>spring-ai-alibaba-demo</description>
<properties>
<!-- Spring Boot -->
<spring-boot.version>3.5.3</spring-boot.version>
<!-- Spring AI -->
<spring-ai.version>1.0.0</spring-ai.version>
<!-- Spring AI Alibaba -->
<spring-ai-alibaba.version>1.0.0.2</spring-ai-alibaba.version>
<!-- Jdk -->
<java.version>21</java.version>
<!-- Maven Compiler -->
<maven.compiler.source>21</maven.compiler.source>
<maven.compiler.target>21</maven.compiler.target>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-dependencies</artifactId>
<version>${spring-boot.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-bom</artifactId>
<version>${spring-ai.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-bom</artifactId>
<version>${spring-ai-alibaba.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<!-- Spring Boot Starter Web -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- Lombok -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<!-- Spring Boot Starter Test -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<!-- Spring AI LLM -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-model-openai</artifactId>
</dependency>
<!-- Spring AI Chat Client -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-model-chat-client</artifactId>
</dependency>
<!-- Spring AI Model Tool Integration -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-model-tool</artifactId>
</dependency>
<!-- Spring AI RAG Framework -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-rag</artifactId>
</dependency>
<!-- Spring AI Document Reader for Markdown -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-markdown-document-reader</artifactId>
</dependency>
<!-- Spring AI Alibaba Graph -->
<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-graph-core</artifactId>
</dependency>
<!-- Gson for Graph serialization -->
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<annotationProcessorPaths>
<path>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</path>
</annotationProcessorPaths>
</configuration>
</plugin>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
</project>
application.yaml
server:
port: 6666
spring:
application:
name: ai-demo-application
ai:
# 这里用的是open ai模型,实际可以用阿里云百炼平台,集成到了spring ai alibaba里面
openai:
# openai api key
api-key: sk-yourapikey
# 此处填写openai接口地址,或代理商提供的地址
base-url: https://api.openai.com
# 聊天模型的配置
chat:
options:
model: gpt-4o-mini
max-tokens: 4096
completions-path: /v1/chat/completions
# 词嵌入模型的配置,用于词向量化
embedding:
options:
model: text-embedding-3-small
# 自定义的网络搜索工具api,这里用的是tavily
tavily:
api-key: tvly-dev-yourapikey
base-url: https://api.tavily.com
2. ChatClient设计
完整代码:
java
package com.ai.demo.config;
import com.ai.demo.tool.WebSearchTool;
import lombok.AllArgsConstructor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
@Configuration
@AllArgsConstructor
public class ChatClientConfig {
private final WebSearchTool webSearchTool;
/**
* 记忆类型 固定容量的消息窗口
* <p>此为 Spring AI 自动配置 ChatMemory Bean 时采用的默认消息类型(不配置也能使用)</p>
* @return ChatMemory 实例
*/
@Bean
public ChatMemory chatMemory() {
return MessageWindowChatMemory.builder()
.maxMessages(10)
.build();
}
/**
* 记忆存储 使用内存存储的 ChatMemoryRepository
* <p>默认情况下,若未配置其他 Repository,Spring AI 将自动配置 InMemoryChatMemoryRepository 类型的 ChatMemoryRepository Bean供直接使用。(不配置也能使用)</p>
* @return ChatMemoryRepository 实例
*/
@Bean
public ChatMemoryRepository chatMemoryRepository() {
return new InMemoryChatMemoryRepository();
}
/**
* 通用的 OpenAI LLM 客户端
* @param chatModel 模型配置
* @return ChatClient 实例
*/
@Bean
@Primary
public ChatClient openAiChatClient(ChatModel chatModel) {
return ChatClient.builder(chatModel).defaultOptions(ChatOptions.builder().temperature(0.8).build()).build();
}
/**
* 问题路由器 ChatClient
* 负责将用户问题路由到不同的数据源(向量数据库或网络搜索)
* @param chatModel 模型配置
* @return ChatClient 实例
*/
@Bean
public ChatClient QuestionRouterChatClient(ChatModel chatModel, ChatMemory chatMemory) {
String systemPrompt = """
你是一个指令路由专家,负责将用户的输入/问题路由到以下对应的组件:
1. 向量数据库(vectorstore)
当用户的问题与知识库中的文档内容相关时,选择 vectorstore
知识库信息:
{knowledge_base}
2. 网络搜索(web_search)
当用户的问题涉及以下情况时,选择 web_search
- 需要最新实时信息(如新闻、天气、股价等)
- 问题超出知识库范围
请做出最佳路由决策。
""";
return ChatClient.builder(chatModel).defaultSystem(systemPrompt)
.defaultUser(u -> u.text("用户问题: {question}"))
.defaultOptions(ChatOptions.builder().temperature(0.0).build())
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("QuestionRouter").build())
.build();
}
/**
* 网络搜索 ChatClient
* 使用工具进行网络搜索,并返回相关信息
* @param chatModel 模型配置
* @return ChatClient 实例
*/
@Bean
public ChatClient WebSearchChatClient(ChatModel chatModel, ChatMemory chatMemory) {
String systemPrompt = """
请根据用户的问题,使用工具进行网络搜索,并返回相关信息。
今天的日期是: {date}
""";
return ChatClient.builder(chatModel)
.defaultSystem(systemPrompt)
.defaultUser(u -> u.text("用户问题:{question}"))
// 此处要用 ToolCallingChatOptions 而不是 ChatOptions
.defaultOptions(ToolCallingChatOptions.builder().temperature(0.8).build())
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("WebSearch").build())
.defaultTools(webSearchTool)
.build();
}
/**
* 自适应 RAG ChatClient
* 结合向量数据库和网络搜索的结果来回答用户问题
* @param chatModel 模型配置
* @return ChatClient 实例
*/
@Bean
public ChatClient AdaptiveRagChatClient(ChatModel chatModel, ChatMemory chatMemory) {
String systemPrompt = """
你是一个专业的问答助手。请基于提供的上下文信息来回答用户问题。
回答指南:
1. 优先使用提供的上下文信息来回答问题
2. 可以参考文档的名称和描述来理解内容背景
3. 如果多个文档都相关,可以综合多个来源的信息
4. 保持回答准确、简洁,通常使用2-3句话
5. 在适当时候可以提及信息来源
6. 如果你不确定或不知道答案,请诚实地说明
""";
String userPrompt = """
问题:
{question}
检索到的上下文:
{context}
请基于上述信息回答问题。
""";
return ChatClient.builder(chatModel)
.defaultSystem(systemPrompt)
.defaultUser(userPrompt)
.defaultOptions(ChatOptions.builder().temperature(0.7).build())
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("AdaptiveRag").build())
.build();
}
/**
* 用于评估 LLM 生成的回答是否基于检索到的事实
* @param chatModel 模型配置
* @return ChatClient 实例
*/
@Bean
public ChatClient HallucinationChatClient(ChatModel chatModel, ChatMemory chatMemory) {
String systemPrompt = """
你是一个评分员,负责评估LLM生成的回答是否基于/支持一组检索到的事实。
给出一个二分类分数 'yes'或'no'。 'yes'表示答案是基于/支持一组事实的。
""";
String userPrompt = """
一组事实:
{documents}
LLM生成的回答:
{generation}
""";
return ChatClient.builder(chatModel)
.defaultSystem(systemPrompt)
.defaultUser(userPrompt)
.defaultOptions(ChatOptions.builder().temperature(0.0).build())
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("Hallucination").build())
.build();
}
/**
* 用于评估 LLM 生成的回答是否回应/解决了用户问题
* @param chatModel 模型配置
* @return ChatClient 实例
*/
@Bean
public ChatClient AnswerGraderChatClient(ChatModel chatModel, ChatMemory chatMemory) {
String systemPrompt = """
你是一个评分员,负责评估回答是否回应/解决了问题。
给出一个二分类分数 'yes'或'no'。 'yes'表示答案回应/解决了问题。
""";
String userPrompt = """
用户问题:
{question}
LLM生成的回答:
{generation}
""";
return ChatClient.builder(chatModel)
.defaultSystem(systemPrompt)
.defaultUser(userPrompt)
.defaultOptions(ChatOptions.builder().temperature(0.8).build())
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("AnswerGrader").build())
.build();
}
/**
* 问题重写 ChatClient
* 负责将用户的问题重写为更清晰、更具体的形式,以便于优化检索
* @param chatModel 模型配置
* @return ChatClient 实例
*/
@Bean
public ChatClient QuestionRewriterChatClient(ChatModel chatModel, ChatMemory chatMemory) {
String systemPrompt = """
你是一个专业的问题重写专家,负责将用户的问题重写为更清晰、更具体的形式,以便于优化检索。
重写规则:
1. 保持原意,但使问题更明确
2. 使用更具体、更有针对性的关键词,避免模糊或含糊的表述
3. 如果问题过于宽泛,尝试将其细化为更具体的子问题
""";
String userPrompt = """
原始问题:
{question}
请重写这个问题以提高检索效果:
""";
return ChatClient.builder(chatModel)
.defaultSystem(systemPrompt)
.defaultUser(userPrompt)
.defaultOptions(ChatOptions.builder().temperature(0.0).build())
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("QuestionRewriter").build())
.build();
}
}
(1)运行时动态渲染prompt
可以在运行时动态注入参数到prompt模板中,如:
questionRouterChatClient.prompt()
.system(s -> s.param("knowledge_base", "关于大模型的知识"))
.user(u -> u.param("question", "什么是大模型"))
.call()
.entity(RouteQueryEntity.class)
.toString();
(2)ChatMemory设计(模型记忆)
声明记忆类型和记忆的存储方式:
java
/**
* 记忆类型 固定容量的消息窗口
* <p>此为 Spring AI 自动配置 ChatMemory Bean 时采用的默认消息类型(不配置也能使用)</p>
* @return ChatMemory 实例
*/
@Bean
public ChatMemory chatMemory() {
return MessageWindowChatMemory.builder()
.maxMessages(10)
.build();
}
/**
* 记忆存储 使用内存存储的 ChatMemoryRepository
* <p>默认情况下,若未配置其他 Repository,Spring AI 将自动配置 InMemoryChatMemoryRepository 类型的 ChatMemoryRepository Bean供直接使用。(不配置也能使用)</p>
* @return ChatMemoryRepository 实例
*/
@Bean
public ChatMemoryRepository chatMemoryRepository() {
return new InMemoryChatMemoryRepository();
}
集成的方式,通过使用Spring AI 提供的advisor方法进行增强,如:
java
ChatClient.builder(chatModel).defaultSystem(systemPrompt)
.defaultUser(u -> u.text("用户问题: {question}"))
.defaultOptions(ChatOptions.builder().temperature(0.0).build())
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("QuestionRouter").build()) // 此处集成记忆,"QuestionRouter"用来标识一个唯一的只用于该agent的id
.build();
3. Tool设计
(1)网络检索Tool
使用tavily作为搜索api,先封装一个tool类
java
package com.ai.demo.tool;
import com.fasterxml.jackson.annotation.JsonClassDescription;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@Component
@Slf4j
public class WebSearchTool {
private final WebClient webClient;
public WebSearchTool(WebClient.Builder webClientBuilder,
@Value("${tavily.base-url}") String baseUrl,
@Value("${tavily.api-key}") String apiKey) {
this.webClient = webClientBuilder
.baseUrl(baseUrl)
.defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey)
.build();
}
/**
* 使用 Tavily API 进行网络搜索
*
* @param query 搜索请求对象,包含查询参数和其他可选参数。
* @return TavilyResponse 包含搜索结果和相关信息。
*/
@Tool(description = "使用 Tavily API 进行网络搜索")
public TavilyResponse search(@ToolParam(description = "search query to look up") String query) {
TavilyRequest request = TavilyRequest.builder()
.query(query)
.maxResults(3) // 默认返回 3 条结果
.searchDepth("basic") // 默认搜索深度为 basic
.includeAnswer(false) // 默认不包含答案
.includeRawContent(false) // 默认不包含原始内容
.includeImages(false) // 默认不包含图片
.build();
if (request.getQuery() == null || request.getQuery().isEmpty()) {
throw new IllegalArgumentException("Query parameter is required.");
}
log.info("Received TavilyRequest: {}", request);
// Build the request payload with all parameters, setting defaults where necessary
TavilyRequest requestWithApiKey = TavilyRequest.builder().query(request.getQuery())
.searchDepth(request.getSearchDepth() != null ? request.getSearchDepth() : "basic")
.topic(request.getTopic() != null ? request.getTopic() : "general")
.days(request.getDays() != null ? request.getDays() : 300)
.maxResults(request.getMaxResults() != 0 ? request.getMaxResults() : 10)
.includeImages(request.isIncludeImages()).includeImageDescriptions(request.isIncludeImageDescriptions())
.includeAnswer(request.isIncludeAnswer()).includeRawContent(request.isIncludeRawContent())
.includeDomains(
request.getIncludeDomains() != null ? request.getIncludeDomains() : Collections.emptyList())
.excludeDomains(
request.getExcludeDomains() != null ? request.getExcludeDomains() : Collections.emptyList())
.build();
log.debug("Sending request to Tavily API: query={}, searchDepth={}, topic={}, days={}, maxResults={}",
requestWithApiKey.getQuery(), requestWithApiKey.getSearchDepth(), requestWithApiKey.getTopic(),
requestWithApiKey.getDays(), requestWithApiKey.getMaxResults());
try {
TavilyResponse response = webClient.post()
.uri(uriBuilder -> uriBuilder.path("/search").build())
.bodyValue(requestWithApiKey)
.retrieve()
.bodyToMono(TavilyResponse.class)
.block();
log.info("Received response from Tavily API for query: {}", requestWithApiKey.getQuery());
return response;
} catch (Exception e) {
log.error("Error occurred while calling Tavily API: {}", e.getMessage(), e);
throw new RuntimeException("Failed to fetch search results from Tavily API", e);
}
}
/**
* Request object for the Tavily API.
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonClassDescription("Request object for the Tavily API")
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class TavilyRequest {
@JsonProperty("query")
@JsonPropertyDescription("The main search query.")
private String query;
@JsonProperty("api_key")
@JsonPropertyDescription("API key for authentication with Tavily.")
private String apiKey;
@JsonProperty("search_depth")
@JsonPropertyDescription("The depth of the search. Accepted values: 'basic', 'advanced'. Default is 'basic'.")
private String searchDepth;
@JsonProperty("topic")
@JsonPropertyDescription("The category of the search. Accepted values: 'general', 'news'. Default is 'general'.")
private String topic;
@JsonProperty("days")
@JsonPropertyDescription("The number of days back from the current date to include in search results. Default is 3. Only applies to 'news' topic.")
private Integer days;
@JsonProperty("time_range")
@JsonPropertyDescription("The time range for search results. Accepted values: 'day', 'week', 'month', 'year' or 'd', 'w', 'm', 'y'. Default is none.")
private String timeRange;
@JsonProperty("max_results")
@JsonPropertyDescription("The maximum number of search results to return. Default is 5.")
private int maxResults;
@JsonProperty("include_images")
@JsonPropertyDescription("Whether to include a list of query-related images in the response. Default is false.")
private boolean includeImages;
@JsonProperty("include_image_descriptions")
@JsonPropertyDescription("When 'include_images' is true, adds descriptive text for each image. Default is false.")
private boolean includeImageDescriptions;
@JsonProperty("include_answer")
@JsonPropertyDescription("Whether to include a short answer to the query, generated from search results. Default is false.")
private boolean includeAnswer;
@JsonProperty("include_raw_content")
@JsonPropertyDescription("Whether to include the cleaned and parsed HTML content of each search result. Default is false.")
private boolean includeRawContent;
@JsonProperty("include_domains")
@JsonPropertyDescription("A list of domains to specifically include in search results. Default is an empty list.")
private List<String> includeDomains;
@JsonProperty("exclude_domains")
@JsonPropertyDescription("A list of domains to specifically exclude from search results. Default is an empty list.")
private List<String> excludeDomains;
}
/**
* Response object for the Tavily API.
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
@JsonClassDescription("Response object for the Tavily API")
public static class TavilyResponse {
@JsonProperty("query")
private String query;
@JsonProperty("follow_up_questions")
private List<String> followUpQuestions;
@JsonProperty("answer")
private String answer;
@JsonDeserialize(using = ImageDeserializer.class)
@JsonProperty("images")
private List<Image> images;
@JsonProperty("results")
private List<Result> results;
@JsonProperty("response_time")
private float responseTime;
@Data
@NoArgsConstructor
@AllArgsConstructor
public static class Image {
@JsonProperty("url")
private String url;
@JsonProperty("description")
private String description;
}
@Data
@NoArgsConstructor
@AllArgsConstructor
public static class Result {
@JsonProperty("title")
private String title;
@JsonProperty("url")
private String url;
@JsonProperty("content")
private String content;
@JsonProperty("raw_content")
private String rawContent;
@JsonProperty("score")
private float score;
@JsonProperty("published_date")
private String publishedDate;
}
}
public static class ImageDeserializer extends JsonDeserializer<List<TavilyResponse.Image>> {
@Override
public List<TavilyResponse.Image> deserialize(JsonParser jsonParser, DeserializationContext context) throws IOException {
JsonNode node = jsonParser.getCodec().readTree(jsonParser);
List<TavilyResponse.Image> images = new ArrayList<>();
if (node.isArray()) {
for (JsonNode element : node) {
// If element is a string, treat it as a URL
if (element.isTextual()) {
images.add(new TavilyResponse.Image(element.asText(), null));
}
// If element is an object, map it to Image
else if (element.isObject()) {
String url = element.has("url") ? element.get("url").asText() : null;
String description = element.has("description") ? element.get("description").asText() : null;
images.add(new TavilyResponse.Image(url, description));
}
}
}
return images;
}
}
}
(2)集成到智能体中
通过配置 defaultTools 进行工具集成
java
// existing codes...
// config类已声明@AllArgsConstructor注解
// 自动注入工具
private final WebSearchTool webSearchTool;
/**
* 网络搜索 ChatClient
* 使用工具进行网络搜索,并返回相关信息
* @param chatModel 模型配置
* @return ChatClient 实例
*/
@Bean
public ChatClient WebSearchChatClient(ChatModel chatModel, ChatMemory chatMemory) {
String systemPrompt = """
请根据用户的问题,使用工具进行网络搜索,并返回相关信息。
今天的日期是: {date}
""";
return ChatClient.builder(chatModel)
.defaultSystem(systemPrompt)
.defaultUser(u -> u.text("用户问题:{question}"))
// 此处要用 ToolCallingChatOptions 而不是 ChatOptions
.defaultOptions(ToolCallingChatOptions.builder().temperature(0.8).build())
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("WebSearch").build())
.defaultTools(webSearchTool)
.build();
}
4. 节点设计(node)
就如上文所说,节点node的功能在于:调用前面定义的ChatClient,完成对整个"图"/整个"工作流"的状态更新。
注意每个节点使用的ChatClient,会在最后搭建图时传入(最后一步构建工作流)
(1)GenerationNode
该节点用于使用自定义的rag prompt模板,进行回答生成。
java
package com.ai.demo.node;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import lombok.Builder;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Builder
public class GenerationNode implements NodeAction {
private final ChatClient chatClient;
@Override
public Map<String, Object> apply(OverAllState state) {
String query = state.value("question", "");
List<Document> documents = state.value("documents", List.of());
String generation = chatClient.prompt()
.user(u -> u.param("question", query)
.param("context", documents
.stream()
.map(Document::getText)
.collect(Collectors.joining("\n\n"))))
.call()
.content();
// 更新状态
HashMap<String, Object> resultMap = new HashMap<>();
resultMap.put("generation", generation);
return resultMap;
}
}
(2)RetrieveNode
该节点用于Spring AI预构建智能体的调用,主要通过 RetrievalAugmentationAdvisor类 进行增强。
java
package com.ai.demo.node;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import lombok.Builder;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.messages.AbstractMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static org.springframework.ai.chat.memory.ChatMemory.CONVERSATION_ID;
@Builder
public class RetrieveNode implements NodeAction {
private final DocumentRetriever documentRetriever;
private final ChatClient chatClient;
private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;
private final ChatMemory chatMemory = MessageWindowChatMemory.builder()
.maxMessages(10)
.build();
private final BaseChatMemoryAdvisor chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory)
.build();
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
String query = state.value("question", "");
// 检查 advisor 是否为空
if (retrievalAugmentationAdvisor == null) {
throw new IllegalStateException("RetrievalAugmentationAdvisor 不能为空");
}
// 检索文档
// List<Document> documents = this.documentRetriever.retrieve(Query.builder()
// .text(query)
// .build());
// other:
ChatResponse response = chatClient.prompt()
.advisors(retrievalAugmentationAdvisor)
.advisors(chatMemoryAdvisor, retrievalAugmentationAdvisor)
.advisors(advisors -> advisors.param(CONVERSATION_ID,
"PrebuiltSpringRAG"))
.user(query)
.call()
.chatResponse();
// 结果
String generation = Optional.ofNullable(response).map(ChatResponse::getResult).map(Generation::getOutput).map(AbstractMessage::getText).orElse("");
System.out.println("Generation: " + generation);
// 获取检索结果
assert response != null;
List<Document> retrievedDocuments = response.getMetadata().get("rag_document_context");
System.out.println("Documents: " + retrievedDocuments);
// 更新状态
HashMap<String, Object> resultMap = new HashMap<>();
resultMap.put("question", query);
resultMap.put("documents", retrievedDocuments);
resultMap.put("generation", generation);
return resultMap;
}
}
(3)TransformQueryNode
该节点用于对用户问题进行重写。
java
package com.ai.demo.node;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import java.util.HashMap;
import java.util.Map;
@Builder
@Slf4j
public class TransformQueryNode implements NodeAction {
private final ChatClient chatClient;
@Override
public Map<String, Object> apply(OverAllState state) {
String question = state.value("question", String.class).orElse("");
String betterQuestion = chatClient.prompt()
.user(u -> u.param("question", question))
.call()
.content();
log.info("重写后的问题: {}", betterQuestion);
// 更新状态
HashMap<String, Object> resultMap = new HashMap<>();
resultMap.put("question", betterQuestion);
return resultMap;
}
}
(4)WebSearchNode
该节点用于调用网络检索Agent,并将结果整理为文档集合
java
package com.ai.demo.node;
import com.ai.demo.tool.WebSearchTool;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import java.time.LocalDate;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@Builder
@Slf4j
public class WebSearchNode implements NodeAction {
private final ChatClient chatClient;
@Override
public Map<String, Object> apply(OverAllState state) {
String query = state.value("question", "");
WebSearchTool.TavilyResponse response = chatClient.prompt()
.system(s -> s.param("date", LocalDate.now().toString())).user(u -> u.param("question", query)).call()
.entity(WebSearchTool.TavilyResponse.class);
assert response != null;
log.debug("WebSearchNode response: {}", response);
// 获取内容并转为文档对象
List<Document> documents = response.getResults().stream()
.map(result -> new Document(result.getContent(),
Map.of("origin", result.getUrl(),
"title", result.getTitle())))
.collect(Collectors.toList());
documents.addFirst(new Document(response.getAnswer(),
Map.of("origin", "Web Search Answer", "title", "Web Search Answer")));
// 更新状态
HashMap<String, Object> resultMap = new HashMap<>();
resultMap.put("question", query);
resultMap.put("documents", documents);
return resultMap;
}
}
5. 边设计(edge)
边edge,用来决定下一步进入到哪个节点。
(1)RouteQuestionEdge
路由边,决定将用户问题直接交给向量库检索,还是通过网络检索。
java
package com.ai.demo.edge;
import com.ai.demo.entity.RouteQueryEntity;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class RouteQuestionEdge implements EdgeAction {
private final ChatClient questionRouterChatClient;
public RouteQuestionEdge(@Qualifier("QuestionRouterChatClient") ChatClient questionRouterChatClient) {
this.questionRouterChatClient = questionRouterChatClient;
}
@Override
public String apply(OverAllState state) {
log.info("---------- 边:路由问题 ----------");
String question = state.value("question", String.class).orElse("");
// 决定数据源
RouteQueryEntity response = questionRouterChatClient.prompt()
.user(u -> u.param("question", question))
.system(s -> s.param("knowledge_base", "关于Spring AI Alibaba的相关知识"))
.call()
.entity(RouteQueryEntity.class);
log.info("路由到: {}", response);
assert response != null;
return response.dataSource();
}
}
(2)GradeGenerationEdge
对回答质量进行评估,决定是可以回复了,还是需要重新生成。
java
package com.ai.demo.edge;
import com.ai.demo.entity.GradeScore;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.stream.Collectors;
@Component
@Slf4j
public class GradeGenerationEdge implements EdgeAction {
private final ChatClient hallucinationGrader;
private final ChatClient answerGrader;
public GradeGenerationEdge(@Qualifier("HallucinationChatClient") ChatClient hallucinationGrader,
@Qualifier("AnswerGraderChatClient") ChatClient answerGrader) {
this.hallucinationGrader = hallucinationGrader;
this.answerGrader = answerGrader;
}
/**
* 评估生成质量
* @param state 图状态
* @return "hallucination" 如果生成的回答不符合事实,需要重试;
* "unuseful" 如果生成的回答没有回应问题,需要重写问题;
* "useful" 如果生成的回答回应了问题。
*/
@Override
public String apply(OverAllState state) {
log.info("---------- 边:检查生成的回答是否符合事实 ----------");
String question = state.value("question", String.class).orElse("");
String generation = state.value("generation", String.class).orElse("");
List<Document> documents = state.value("documents", List.of());
GradeScore hallucinationGradeScore = hallucinationGrader.prompt()
.user(u -> u.param("documents", formatDocs(documents))
.param("generation", generation))
.call()
.entity(GradeScore.class);
assert hallucinationGradeScore != null;
if (!"yes".equals(hallucinationGradeScore.binaryScore())) {
log.info("---------- 决策:生成的回答不符合事实,需要重试 ----------");
return "hallucination";
}
log.info("---------- 决策:生成的回答符合事实 ----------");
GradeScore answerGradeScore = answerGrader.prompt()
.user(u -> u.param("question", question)
.param("generation", generation))
.call()
.entity(GradeScore.class);
assert answerGradeScore != null;
if ("yes".equals(answerGradeScore.binaryScore())) {
log.info("---------- 决策:生成的回答回应了问题 ----------");
return "useful";
}
log.info("---------- 决策:生成的回答没有回应问题 ----------");
return "unuseful";
}
private String formatDocs(List<Document> documents) {
return documents.stream().map(Document::getText).collect(Collectors.joining("\n\n"));
}
}
6. 构建工作流
(1)rag配置
主要配置每个模块的具体参数
java
package com.ai.demo.config;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.transformer.splitter.TextSplitter;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class RagConfig {
/**
* 文本分割器,用于将文档切分成小片段
* @return TextSplitter 实例
*/
@Bean
TextSplitter textSplitter() {
return TokenTextSplitter.builder()
// 可进一步设置分割的最大 token 数量、切分窗口大小等参数
.build();
}
/**
* 向量存储库,用于存储文档片段的向量表示
* @return VectorStore 实例
*/
@Bean
VectorStore vectorStore(EmbeddingModel embeddingModel) {
return SimpleVectorStore.builder(embeddingModel)
.build();
}
/**
* 文档检索器,用于从向量存储中检索相关文档片段
* @param vectorStore 向量存储库
* @return DocumentRetriever 实例
*/
@Bean
DocumentRetriever documentRetriever(VectorStore vectorStore) {
return VectorStoreDocumentRetriever.builder()
.vectorStore(vectorStore)
.similarityThreshold(0.50)
.build();
}
/**
* 压缩查询转换器,将对话历史和后续查询压缩为捕获对话本质的独立查询
* <p> <em>检索前增强</em> 适用于对话历史较长且后续查询与对话上下文相关时</p>
* @param chatClient 聊天模型
* @return CompressionQueryTransformer 实例
*/
@Bean
CompressionQueryTransformer compressionQueryTransformer(ChatClient chatClient) {
return CompressionQueryTransformer.builder()
.chatClientBuilder(chatClient.mutate())
.build();
}
/**
* 查询重写转换器,重写用户查询
* <p> <em>检索前增强</em> 适用于用户查询冗长、含歧义或包含可能影响搜索结果质量的无关信息时 </p>
* @param chatClient 聊天模型
* @return RewriteQueryTransformer 实例
*/
@Bean
RewriteQueryTransformer rewriteQueryTransformer(ChatClient chatClient) {
return RewriteQueryTransformer.builder()
.chatClientBuilder(chatClient.mutate())
.build();
}
/**
* 翻译查询转换器,将用户查询翻译为目标语言
* <p> <em>检索前增强</em> 适用于用户查询使用非目标语言时</p>
* @param chatClient 聊天模型
* @return TranslationQueryTransformer 实例
*/
@Bean
TranslationQueryTransformer translationQueryTransformer(ChatClient chatClient) {
return TranslationQueryTransformer.builder()
.chatClientBuilder(chatClient.mutate())
.targetLanguage("chinese")
.build();
}
/**
* 多查询扩展器,生成多个查询以提高检索覆盖率
* <p> <em>检索前增强</em> 利用大模型从不同视角生成多语义查询语句</p>
* @param chatClient 聊天模型
* @return MultiQueryExpander 实例
*/
@Bean
MultiQueryExpander multiQueryExpander(ChatClient chatClient) {
return MultiQueryExpander.builder()
.chatClientBuilder(chatClient.mutate())
.numberOfQueries(3)
.build();
}
}
(2)搭建图结构
java
package com.ai.demo.config;
import com.ai.demo.edge.GradeGenerationEdge;
import com.ai.demo.edge.RouteQuestionEdge;
import com.ai.demo.node.GenerationNode;
import com.ai.demo.node.RetrieveNode;
import com.ai.demo.node.TransformQueryNode;
import com.ai.demo.node.WebSearchNode;
import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.Map;
@Configuration
@Slf4j
public class GraphConfig {
private final RouteQuestionEdge routeQuestionEdge;
private final GradeGenerationEdge gradeGenerationEdge;
private final ChatClient commonChatClient;
private final ChatClient webSearchClient;
private final ChatClient ragChatClient;
private final ChatClient questionRewriterChatClient;
private final DocumentRetriever documentRetriever;
private final CompressionQueryTransformer compressionQueryTransformer;
private final RewriteQueryTransformer rewriteQueryTransformer;
private final TranslationQueryTransformer translationQueryTransformer;
public GraphConfig(RouteQuestionEdge routeQuestionEdge, GradeGenerationEdge gradeGenerationEdge,
ChatClient commonChatClient,
@Qualifier("WebSearchChatClient") ChatClient webSearchClient,
@Qualifier("AdaptiveRagChatClient") ChatClient ragChatClient,
@Qualifier("QuestionRewriterChatClient") ChatClient questionRewriterChatClient,
DocumentRetriever documentRetriever,
CompressionQueryTransformer compressionQueryTransformer,
RewriteQueryTransformer rewriteQueryTransformer,
TranslationQueryTransformer translationQueryTransformer) {
this.routeQuestionEdge = routeQuestionEdge;
this.gradeGenerationEdge = gradeGenerationEdge;
this.commonChatClient = commonChatClient;
this.webSearchClient = webSearchClient;
this.ragChatClient = ragChatClient;
this.questionRewriterChatClient = questionRewriterChatClient;
this.documentRetriever = documentRetriever;
this.compressionQueryTransformer = compressionQueryTransformer;
this.rewriteQueryTransformer = rewriteQueryTransformer;
this.translationQueryTransformer = translationQueryTransformer;
}
@Bean
public StateGraph graph(ChatClient.Builder chatClientBuilder) throws GraphStateException {
OverAllStateFactory stateFactory = () -> {
OverAllState state = new OverAllState();
state.registerKeyAndStrategy("question", new ReplaceStrategy());
state.registerKeyAndStrategy("generation", new ReplaceStrategy());
state.registerKeyAndStrategy("documents", new ReplaceStrategy());
return state;
};
StateGraph stateGraph = new StateGraph("Spring AI Alibaba Graph Demo", stateFactory);
// 添加节点
stateGraph.addNode("prebuilt_rag_generation", AsyncNodeAction.node_async(RetrieveNode.builder()
.chatClient(commonChatClient)
.documentRetriever(documentRetriever)
.retrievalAugmentationAdvisor(RetrievalAugmentationAdvisor.builder()
.documentRetriever(documentRetriever)
.queryTransformers(compressionQueryTransformer, translationQueryTransformer, rewriteQueryTransformer)
.build())
.build()));
stateGraph.addNode("web_search",
AsyncNodeAction.node_async(WebSearchNode.builder().chatClient(webSearchClient).build()));
stateGraph.addNode("self_rag_generation",
AsyncNodeAction.node_async(GenerationNode.builder().chatClient(ragChatClient).build()));
stateGraph.addNode("transform_query",
AsyncNodeAction.node_async(TransformQueryNode.builder().chatClient(questionRewriterChatClient).build()));
// 决定通过向量库检索还是网络搜索
stateGraph.addConditionalEdges(StateGraph.START, AsyncEdgeAction.edge_async(routeQuestionEdge),
Map.of("vectorstore", "prebuilt_rag_generation", "web_search", "web_search"));
// 向量库chains
stateGraph.addConditionalEdges("prebuilt_rag_generation", AsyncEdgeAction.edge_async(gradeGenerationEdge),
Map.of("useful", StateGraph.END,
"unuseful", "transform_query",
"hallucination", "prebuilt_rag_generation"));
// 网络搜索chains
stateGraph.addEdge("web_search", "self_rag_generation");
stateGraph.addConditionalEdges("self_rag_generation", AsyncEdgeAction.edge_async(gradeGenerationEdge),
Map.of("useful", StateGraph.END,
"unuseful", "transform_query",
"hallucination", "self_rag_generation"));
// 重写问题
stateGraph.addEdge("transform_query", "self_rag_generation");
// 添加 Mermaid 打印
GraphRepresentation representation = stateGraph.getGraph(GraphRepresentation.Type.MERMAID,
"Adaptive rag flow");
log.info("\n=== Adaptive rag Flow ===");
log.info(representation.content());
log.info("==================================\n");
return stateGraph;
}
}
(3)提供api接口
java
package com.ai.demo.controller;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import lombok.AllArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.markdown.MarkdownDocumentReader;
import org.springframework.ai.reader.markdown.config.MarkdownDocumentReaderConfig;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.util.ResourceUtils;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
@RestController
@RequestMapping("/graph")
@Slf4j
public class GraphController {
@Value("classpath:documents/faq.md")
Resource file1;
@Value("classpath:documents/overview.md")
Resource file2;
private final CompiledGraph compiledGraph;
private final VectorStore vectorStore;
private final String SAVE_PATH = System.getProperty("user.dir") + "/src/main/resources/vectorstore/vectorstore.json";
@SneakyThrows
public GraphController(@Qualifier("graph") StateGraph stateGraph, VectorStore vectorStore) {
this.vectorStore = vectorStore;
this.compiledGraph = stateGraph.compile();
}
@GetMapping(value = "/add")
public void addDocuments() {
// 如果存在则加载
File file = new File(SAVE_PATH);
if (file.exists()) {
log.info("load vector store from {}", SAVE_PATH);
((SimpleVectorStore) vectorStore).load(file);
return;
}
log.info("start add documents");
var markdownReader1 = new MarkdownDocumentReader(file1, MarkdownDocumentReaderConfig.builder()
.withAdditionalMetadata("title", "Spring AI Alibaba FAQ")
.withAdditionalMetadata("summary", "关于Spring AI Alibaba的常见问题和解答")
.build());
List<Document> documents = new ArrayList<>(markdownReader1.get());
var markdownReader2 = new MarkdownDocumentReader(file2, MarkdownDocumentReaderConfig.builder()
.withAdditionalMetadata("title", "Spring AI Alibaba Overview")
.withAdditionalMetadata("summary", "关于Spring AI Alibaba的概述")
.build());
documents.addAll(markdownReader2.get());
// 将文档添加到向量库中
vectorStore.add(documents);
// 持久化
((SimpleVectorStore) vectorStore).save(file);
}
@GetMapping(value = "/expand")
public Map<String, Object> expand(@RequestParam(value = "query", defaultValue = "你好,我想知道一些关于大模型的知识", required = false) String query) {
RunnableConfig runnableConfig = RunnableConfig.builder().threadId("001").build();
Map<String, Object> objectMap = new HashMap<>();
objectMap.put("question", query);
Optional<OverAllState> invoke = this.compiledGraph.invoke(objectMap, runnableConfig);
return invoke.map(OverAllState::data).orElse(new HashMap<>());
}
}
效果演示
首先需要调用 /add 接口,完成知识库的加载。
(1)基本rag功能
ai识别到问题与向量库相关,路由到vectorstore节点,并进行后续处理

(2)回答自评估
ai发现回答质量不符合事实,重新进行了生成

(3)网络搜索
ai识别到问题需要进行网络检索,路由到web_search节点,并进入后续步骤
问题1:昨天G2在LOL的MSI比赛中赢了吗
AI回答:是的,G2在昨天的MSI比赛中以3比2战胜了GAM,成功晋级淘汰赛。
问题2:明天北京的天气
AI回答:明天(2025年7月2日)北京的天气预报为晴,最高气温29°C,最低气温22°C,风向为微风,风力较小。
(4)模型记忆
当第二个问题问到"它的特点"时,结合上下文能知道指的是"spring ai alibaba"

后续优化与展望
-
优化方面,后续可以进一步扩展功能,如数据库、前端实现等。由于本身自带了流式接口,其实剩下的开发,就跟传统的Springboot + vue的模式差不多。另外,使用spring ai alibaba框架的话,后续可以更好的和阿里的一些生态进行集成,比如nacos等等。
-
在Spring ai alibaba这个框架方面,实际使用下来还是不错的,相比LangGraph,结合java语言特性也做了很多抽象。不过感觉目前很多地方都需要进行spring注入,也存在硬编码的问题,看看后续阿里能否再优化下,再完善下文档或者提供更多demo。
-
同样地,期待这个框架能够实现类似langgraph中的inject注入,比如上下文,工具调用等等(与state 不同,对于llm不可见)。提供更多的可扩展的预构建智能体或工作流,以及父子图的实现,结合阿里生态进一步完善。