基于Spring AI Alibaba的多智能体RAG应用

完整的代码下载:基于SpringAIAlibaba的多智能体RAG应用资源-CSDN下载

已同步上传至github:1998y12/multi-agent-rag-spring: a multi-agent RAG application with Spring AI Alibaba

前言

近期,阿里巴巴正式发布了Spring AI Alibaba,一款以 Spring AI 为基础,深度集成百炼平台,支持 ChatBot、工作流、多智能体应用开发模式的 AI 框架,同时开源在Github上。

以下是官网列出来的一些特点:

  1. Graph 多智能体框架。 基于 Spring AI Alibaba Graph 开发者可快速构建工作流、多智能体应用,无需关心流程编排、上下文记忆管理等底层实现。支持 Dify DSL 自动生成 Graph 代码,支持 Graph 可视化调试。

  2. 通过 AI 生态集成,解决企业智能体落地过程中关心的痛点问题。 Spring AI Alibaba 支持与百炼平台深度集成,提供模型接入、RAG知识库解决方案;支持 ARMS、Langfuse 等 AI 可观测产品无缝接入;支持企业级的 MCP 集成,包括 Nacos MCP Registry 分布式注册与发现、自动 Router 路由等。

  3. 探索具备自主规划能力的通用智能体产品与平台。 社区发布了基于 Spring AI Alibaba 框架实现的 JManus 智能体,除了对标 Manus 等通用智能体的产品能力外。社区在积极探索自主规划在智能体开发方向的应用,为开发者提供从低代码、高代码到零代码构建智能体的更灵活选择,加速智能体在企业垂直业务方向的快速落地。

实际的开发体验下来,感觉跟python框架下的Langgraph的设计理念比较类似,都是通过图/工作流(graph/workflow)来描述整个AI系统的架构和数据流动,算是Langgrraph的java实现吧。

Spring AI Alibaba中,比较重要的就是状态(state)、节点(node)和边(edge)这三个概念。通过对它们进行组合,可以搭建复杂的多智能体架构,从而完成相互协作。

  • 状态:代表整个AI应用当前快照,智能体共享的数据结构。
  • 节点:一个用来表示智能体具体处理逻辑的函数。接收当前状态作为输入,执行一些计算或副作用,并返回更新的状态。在Spring AI Alibaba中,就是实现了NodeAction接口的对象,通过apply方法来对状态进行处理,并返回一个Map,代表有更新的状态。
  • 边:根据当前状态决定下一步执行哪个节点的函数。在Spring AI Alibaba中,就是实现了EdgeAction接口的对象,在apply方法中进行相关逻辑处理,最终返回一个字符串,表示下一个节点。

功能概述

这里我就根据Spring AI Alibaba官方文档和Langgraph的一个rag示例,实现一个优化版的多智能体自适应的RAG应用,主要功能包括:

1、对于知识库的相关问题,从知识库中进行生成回答,并做一定的增强和后处理。

2、对于非知识库相关的问题,通过外部工具来获取相关知识,再按照rag流程进行回答。

虽然算是小demo,但基本也涵盖了大部分组件,包括整个rag的模块搭建(文档切分、向量库、检索器、LLM生成)、提示词prompt模板动态渲染、外部工具Tool集成、模型记忆memory设计以及Spring预构建rag智能体的使用等等,并最终通过Spring AI Alibaba workflow(graph)框架实现多智能体协作。

完整的代码下载:基于SpringAIAlibaba的多智能体RAG应用资源-CSDN下载

已同步上传至github:1998y12/multi-agent-rag-spring: a multi-agent RAG application with Spring AI Alibaba

工作流图

整体的工作流图结构如下:

工作流程

用户输入问题,经过判断边,对用户问题进行识别:

  1. 如果和知识库内容相关,则调用Spring预构建的增强rag生成回答;

1.1 评估回答:

1.1.1 是否存在"幻觉",即与事实不符,如果是则重新生成,否则进入下一步;

1.1.2 是否有回应用户问题,如果没有回应,则进入重写问题节点对问题进行重写;否则进入下一步;

1.2 生成最终回答。

  1. 如果和知识库内容无关,则调用外部工具,获取知识文档;

2.1 获取到知识文档后,同样进入以上的rag流程。不同的是,这里进入的rag不再是spring ai预先提供的rag智能体,而是采用自己构建的rag智能体,并且会对检索到的文档进行问题相关性评估。

2.2 生成最终回答。

代码实现

1. 配置文件

pom.xml:

XML 复制代码
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
<!--    <parent>-->
<!--        <groupId>org.springframework.boot</groupId>-->
<!--        <artifactId>spring-boot-starter-parent</artifactId>-->
<!--        <version>3.5.3</version>-->
<!--        <relativePath/> &lt;!&ndash; lookup parent from repository &ndash;&gt;-->
<!--    </parent>-->
    <groupId>com.ai</groupId>
    <artifactId>spring-ai-alibaba-demo</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>spring-ai-alibaba-demo</name>
    <description>spring-ai-alibaba-demo</description>

    <properties>
        <!-- Spring Boot -->
        <spring-boot.version>3.5.3</spring-boot.version>

        <!-- Spring AI -->
        <spring-ai.version>1.0.0</spring-ai.version>

        <!-- Spring AI Alibaba -->
        <spring-ai-alibaba.version>1.0.0.2</spring-ai-alibaba.version>

        <!-- Jdk -->
        <java.version>21</java.version>

        <!-- Maven Compiler -->
        <maven.compiler.source>21</maven.compiler.source>
        <maven.compiler.target>21</maven.compiler.target>
    </properties>

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-dependencies</artifactId>
                <version>${spring-boot.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
            <dependency>
                <groupId>org.springframework.ai</groupId>
                <artifactId>spring-ai-bom</artifactId>
                <version>${spring-ai.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
            <dependency>
                <groupId>com.alibaba.cloud.ai</groupId>
                <artifactId>spring-ai-alibaba-bom</artifactId>
                <version>${spring-ai-alibaba.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>
    <dependencies>
        <!-- Spring Boot Starter Web -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <!-- Lombok -->
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>

        <!-- Spring Boot Starter Test -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>

        <!-- Spring AI LLM -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-autoconfigure-model-openai</artifactId>
        </dependency>

        <!-- Spring AI Chat Client -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-autoconfigure-model-chat-client</artifactId>
        </dependency>

        <!-- Spring AI Model Tool Integration -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-autoconfigure-model-tool</artifactId>
        </dependency>

        <!-- Spring AI RAG Framework -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-rag</artifactId>
        </dependency>

        <!-- Spring AI Document Reader for Markdown -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-markdown-document-reader</artifactId>
        </dependency>

        <!-- Spring AI Alibaba Graph -->
        <dependency>
            <groupId>com.alibaba.cloud.ai</groupId>
            <artifactId>spring-ai-alibaba-graph-core</artifactId>
        </dependency>

        <!-- Gson for Graph serialization -->
        <dependency>
            <groupId>com.google.code.gson</groupId>
            <artifactId>gson</artifactId>
        </dependency>

    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <annotationProcessorPaths>
                        <path>
                            <groupId>org.projectlombok</groupId>
                            <artifactId>lombok</artifactId>
                        </path>
                    </annotationProcessorPaths>
                </configuration>
            </plugin>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <configuration>
                    <excludes>
                        <exclude>
                            <groupId>org.projectlombok</groupId>
                            <artifactId>lombok</artifactId>
                        </exclude>
                    </excludes>
                </configuration>
            </plugin>
        </plugins>
    </build>

</project>

application.yaml

复制代码
server:
  port: 6666

spring:
  application:
    name: ai-demo-application

  ai:
    # 这里用的是open ai模型,实际可以用阿里云百炼平台,集成到了spring ai alibaba里面
    openai:
      # openai api key
      api-key: sk-yourapikey
      # 此处填写openai接口地址,或代理商提供的地址
      base-url: https://api.openai.com
      # 聊天模型的配置
      chat:
        options:
          model: gpt-4o-mini
          max-tokens: 4096
        completions-path: /v1/chat/completions
      # 词嵌入模型的配置,用于词向量化
      embedding:
        options:
          model: text-embedding-3-small

# 自定义的网络搜索工具api,这里用的是tavily
tavily:
  api-key: tvly-dev-yourapikey
  base-url: https://api.tavily.com

2. ChatClient设计

完整代码:

java 复制代码
package com.ai.demo.config;

import com.ai.demo.tool.WebSearchTool;
import lombok.AllArgsConstructor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;

@Configuration
@AllArgsConstructor
public class ChatClientConfig {

    private final WebSearchTool webSearchTool;

    /**
     * 记忆类型 固定容量的消息窗口
     * <p>此为 Spring AI 自动配置 ChatMemory Bean 时采用的默认消息类型(不配置也能使用)</p>
     * @return ChatMemory 实例
     */
    @Bean
    public ChatMemory chatMemory() {
        return MessageWindowChatMemory.builder()
                .maxMessages(10)
                .build();
    }

    /**
     * 记忆存储 使用内存存储的 ChatMemoryRepository
     * <p>默认情况下,若未配置其他 Repository,Spring AI 将自动配置 InMemoryChatMemoryRepository 类型的 ChatMemoryRepository Bean供直接使用。(不配置也能使用)</p>
     * @return ChatMemoryRepository 实例
     */
    @Bean
    public ChatMemoryRepository chatMemoryRepository() {
        return new InMemoryChatMemoryRepository();
    }

    /**
     * 通用的 OpenAI LLM 客户端
     * @param chatModel 模型配置
     * @return ChatClient 实例
     */
    @Bean
    @Primary
    public ChatClient openAiChatClient(ChatModel chatModel) {
        return ChatClient.builder(chatModel).defaultOptions(ChatOptions.builder().temperature(0.8).build()).build();
    }

    /**
     * 问题路由器 ChatClient
     * 负责将用户问题路由到不同的数据源(向量数据库或网络搜索)
     * @param chatModel 模型配置
     * @return ChatClient 实例
     */
    @Bean
    public ChatClient QuestionRouterChatClient(ChatModel chatModel, ChatMemory chatMemory) {
        String systemPrompt = """
                    你是一个指令路由专家,负责将用户的输入/问题路由到以下对应的组件:
                
                    1. 向量数据库(vectorstore)
                    当用户的问题与知识库中的文档内容相关时,选择 vectorstore
                
                    知识库信息:
                    {knowledge_base}
                
                    2. 网络搜索(web_search)
                    当用户的问题涉及以下情况时,选择 web_search
                    - 需要最新实时信息(如新闻、天气、股价等)
                    - 问题超出知识库范围
                
                    请做出最佳路由决策。
                """;
        return ChatClient.builder(chatModel).defaultSystem(systemPrompt)
                .defaultUser(u -> u.text("用户问题: {question}"))
                .defaultOptions(ChatOptions.builder().temperature(0.0).build())
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("QuestionRouter").build())
                .build();
    }

    /**
     * 网络搜索 ChatClient
     * 使用工具进行网络搜索,并返回相关信息
     * @param chatModel 模型配置
     * @return ChatClient 实例
     */
    @Bean
    public ChatClient WebSearchChatClient(ChatModel chatModel, ChatMemory chatMemory) {

        String systemPrompt = """
               请根据用户的问题,使用工具进行网络搜索,并返回相关信息。
               
               今天的日期是: {date}
               """;

        return ChatClient.builder(chatModel)
                .defaultSystem(systemPrompt)
                .defaultUser(u -> u.text("用户问题:{question}"))
                // 此处要用 ToolCallingChatOptions 而不是 ChatOptions
                .defaultOptions(ToolCallingChatOptions.builder().temperature(0.8).build())
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("WebSearch").build())
                .defaultTools(webSearchTool)
                .build();
    }

    /**
     * 自适应 RAG ChatClient
     * 结合向量数据库和网络搜索的结果来回答用户问题
     * @param chatModel 模型配置
     * @return ChatClient 实例
     */
    @Bean
    public ChatClient AdaptiveRagChatClient(ChatModel chatModel, ChatMemory chatMemory) {
        String systemPrompt = """
               你是一个专业的问答助手。请基于提供的上下文信息来回答用户问题。
               
               回答指南:
               1. 优先使用提供的上下文信息来回答问题
               2. 可以参考文档的名称和描述来理解内容背景
               3. 如果多个文档都相关,可以综合多个来源的信息
               4. 保持回答准确、简洁,通常使用2-3句话
               5. 在适当时候可以提及信息来源
               6. 如果你不确定或不知道答案,请诚实地说明
               """;

        String userPrompt = """
                问题:
                {question}
                
                检索到的上下文:
                {context}
                
                请基于上述信息回答问题。
                """;

        return ChatClient.builder(chatModel)
                .defaultSystem(systemPrompt)
                .defaultUser(userPrompt)
                .defaultOptions(ChatOptions.builder().temperature(0.7).build())
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("AdaptiveRag").build())
                .build();
    }

    /**
     * 用于评估 LLM 生成的回答是否基于检索到的事实
     * @param chatModel 模型配置
     * @return ChatClient 实例
     */
    @Bean
    public ChatClient HallucinationChatClient(ChatModel chatModel, ChatMemory chatMemory) {

        String systemPrompt = """
                你是一个评分员,负责评估LLM生成的回答是否基于/支持一组检索到的事实。
                
                给出一个二分类分数 'yes'或'no'。 'yes'表示答案是基于/支持一组事实的。
                """;

        String userPrompt = """
                一组事实:
                {documents}
                
                LLM生成的回答:
                {generation}
                """;

        return ChatClient.builder(chatModel)
                .defaultSystem(systemPrompt)
                .defaultUser(userPrompt)
                .defaultOptions(ChatOptions.builder().temperature(0.0).build())
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("Hallucination").build())
                .build();
    }

    /**
     * 用于评估 LLM 生成的回答是否回应/解决了用户问题
     * @param chatModel 模型配置
     * @return ChatClient 实例
     */
    @Bean
    public ChatClient AnswerGraderChatClient(ChatModel chatModel, ChatMemory chatMemory) {

        String systemPrompt = """
                你是一个评分员,负责评估回答是否回应/解决了问题。
                
                给出一个二分类分数 'yes'或'no'。 'yes'表示答案回应/解决了问题。
                """;

        String userPrompt = """
                用户问题:
                {question}
                
                LLM生成的回答:
                {generation}
                """;

        return ChatClient.builder(chatModel)
                .defaultSystem(systemPrompt)
                .defaultUser(userPrompt)
                .defaultOptions(ChatOptions.builder().temperature(0.8).build())
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("AnswerGrader").build())
                .build();
    }

    /**
     * 问题重写 ChatClient
     * 负责将用户的问题重写为更清晰、更具体的形式,以便于优化检索
     * @param chatModel 模型配置
     * @return ChatClient 实例
     */
    @Bean
    public ChatClient QuestionRewriterChatClient(ChatModel chatModel, ChatMemory chatMemory) {

        String systemPrompt = """
                你是一个专业的问题重写专家,负责将用户的问题重写为更清晰、更具体的形式,以便于优化检索。
                
                重写规则:
                1. 保持原意,但使问题更明确
                2. 使用更具体、更有针对性的关键词,避免模糊或含糊的表述
                3. 如果问题过于宽泛,尝试将其细化为更具体的子问题
                """;

        String userPrompt = """
                原始问题:
                {question}
                
                请重写这个问题以提高检索效果:
                """;

        return ChatClient.builder(chatModel)
                .defaultSystem(systemPrompt)
                .defaultUser(userPrompt)
                .defaultOptions(ChatOptions.builder().temperature(0.0).build())
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("QuestionRewriter").build())
                .build();
    }
}

(1)运行时动态渲染prompt

可以在运行时动态注入参数到prompt模板中,如:

复制代码
questionRouterChatClient.prompt()
        .system(s -> s.param("knowledge_base", "关于大模型的知识"))
        .user(u -> u.param("question", "什么是大模型"))
        .call()
        .entity(RouteQueryEntity.class)
        .toString();

(2)ChatMemory设计(模型记忆)

声明记忆类型和记忆的存储方式:

java 复制代码
/**
     * 记忆类型 固定容量的消息窗口
     * <p>此为 Spring AI 自动配置 ChatMemory Bean 时采用的默认消息类型(不配置也能使用)</p>
     * @return ChatMemory 实例
     */
    @Bean
    public ChatMemory chatMemory() {
        return MessageWindowChatMemory.builder()
                .maxMessages(10)
                .build();
    }

    /**
     * 记忆存储 使用内存存储的 ChatMemoryRepository
     * <p>默认情况下,若未配置其他 Repository,Spring AI 将自动配置 InMemoryChatMemoryRepository 类型的 ChatMemoryRepository Bean供直接使用。(不配置也能使用)</p>
     * @return ChatMemoryRepository 实例
     */
    @Bean
    public ChatMemoryRepository chatMemoryRepository() {
        return new InMemoryChatMemoryRepository();
    }

集成的方式,通过使用Spring AI 提供的advisor方法进行增强,如:

java 复制代码
ChatClient.builder(chatModel).defaultSystem(systemPrompt)
                .defaultUser(u -> u.text("用户问题: {question}"))
                .defaultOptions(ChatOptions.builder().temperature(0.0).build())
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("QuestionRouter").build()) // 此处集成记忆,"QuestionRouter"用来标识一个唯一的只用于该agent的id
                .build();

3. Tool设计

(1)网络检索Tool

使用tavily作为搜索api,先封装一个tool类

java 复制代码
package com.ai.demo.tool;

import com.fasterxml.jackson.annotation.JsonClassDescription;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

@Component
@Slf4j
public class WebSearchTool {

    private final WebClient webClient;

    public WebSearchTool(WebClient.Builder webClientBuilder,
                         @Value("${tavily.base-url}") String baseUrl,
                         @Value("${tavily.api-key}") String apiKey) {
        this.webClient = webClientBuilder
                .baseUrl(baseUrl)
                .defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
                .defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey)
                .build();
    }

    /**
     * 使用 Tavily API 进行网络搜索
     *
     * @param query 搜索请求对象,包含查询参数和其他可选参数。
     * @return TavilyResponse 包含搜索结果和相关信息。
     */
    @Tool(description = "使用 Tavily API 进行网络搜索")
    public TavilyResponse search(@ToolParam(description = "search query to look up") String query) {

        TavilyRequest request = TavilyRequest.builder()
                .query(query)
                .maxResults(3) // 默认返回 3 条结果
                .searchDepth("basic") // 默认搜索深度为 basic
                .includeAnswer(false) // 默认不包含答案
                .includeRawContent(false) // 默认不包含原始内容
                .includeImages(false) // 默认不包含图片
                .build();

        if (request.getQuery() == null || request.getQuery().isEmpty()) {
            throw new IllegalArgumentException("Query parameter is required.");
        }
        log.info("Received TavilyRequest: {}", request);

        // Build the request payload with all parameters, setting defaults where necessary
        TavilyRequest requestWithApiKey = TavilyRequest.builder().query(request.getQuery())
                .searchDepth(request.getSearchDepth() != null ? request.getSearchDepth() : "basic")
                .topic(request.getTopic() != null ? request.getTopic() : "general")
                .days(request.getDays() != null ? request.getDays() : 300)
                .maxResults(request.getMaxResults() != 0 ? request.getMaxResults() : 10)
                .includeImages(request.isIncludeImages()).includeImageDescriptions(request.isIncludeImageDescriptions())
                .includeAnswer(request.isIncludeAnswer()).includeRawContent(request.isIncludeRawContent())
                .includeDomains(
                        request.getIncludeDomains() != null ? request.getIncludeDomains() : Collections.emptyList())
                .excludeDomains(
                        request.getExcludeDomains() != null ? request.getExcludeDomains() : Collections.emptyList())
                .build();

        log.debug("Sending request to Tavily API: query={}, searchDepth={}, topic={}, days={}, maxResults={}",
                requestWithApiKey.getQuery(), requestWithApiKey.getSearchDepth(), requestWithApiKey.getTopic(),
                requestWithApiKey.getDays(), requestWithApiKey.getMaxResults());

        try {
            TavilyResponse response = webClient.post()
                    .uri(uriBuilder -> uriBuilder.path("/search").build())
                    .bodyValue(requestWithApiKey)
                    .retrieve()
                    .bodyToMono(TavilyResponse.class)
                    .block();

            log.info("Received response from Tavily API for query: {}", requestWithApiKey.getQuery());
            return response;
        } catch (Exception e) {
            log.error("Error occurred while calling Tavily API: {}", e.getMessage(), e);
            throw new RuntimeException("Failed to fetch search results from Tavily API", e);
        }
    }

    /**
     * Request object for the Tavily API.
     */
    @Data
    @Builder
    @NoArgsConstructor
    @AllArgsConstructor
    @JsonClassDescription("Request object for the Tavily API")
    @JsonInclude(JsonInclude.Include.NON_NULL)
    public static class TavilyRequest {

        @JsonProperty("query")
        @JsonPropertyDescription("The main search query.")
        private String query;

        @JsonProperty("api_key")
        @JsonPropertyDescription("API key for authentication with Tavily.")
        private String apiKey;

        @JsonProperty("search_depth")
        @JsonPropertyDescription("The depth of the search. Accepted values: 'basic', 'advanced'. Default is 'basic'.")
        private String searchDepth;

        @JsonProperty("topic")
        @JsonPropertyDescription("The category of the search. Accepted values: 'general', 'news'. Default is 'general'.")
        private String topic;

        @JsonProperty("days")
        @JsonPropertyDescription("The number of days back from the current date to include in search results. Default is 3. Only applies to 'news' topic.")
        private Integer days;

        @JsonProperty("time_range")
        @JsonPropertyDescription("The time range for search results. Accepted values: 'day', 'week', 'month', 'year' or 'd', 'w', 'm', 'y'. Default is none.")
        private String timeRange;

        @JsonProperty("max_results")
        @JsonPropertyDescription("The maximum number of search results to return. Default is 5.")
        private int maxResults;

        @JsonProperty("include_images")
        @JsonPropertyDescription("Whether to include a list of query-related images in the response. Default is false.")
        private boolean includeImages;

        @JsonProperty("include_image_descriptions")
        @JsonPropertyDescription("When 'include_images' is true, adds descriptive text for each image. Default is false.")
        private boolean includeImageDescriptions;

        @JsonProperty("include_answer")
        @JsonPropertyDescription("Whether to include a short answer to the query, generated from search results. Default is false.")
        private boolean includeAnswer;

        @JsonProperty("include_raw_content")
        @JsonPropertyDescription("Whether to include the cleaned and parsed HTML content of each search result. Default is false.")
        private boolean includeRawContent;

        @JsonProperty("include_domains")
        @JsonPropertyDescription("A list of domains to specifically include in search results. Default is an empty list.")
        private List<String> includeDomains;

        @JsonProperty("exclude_domains")
        @JsonPropertyDescription("A list of domains to specifically exclude from search results. Default is an empty list.")
        private List<String> excludeDomains;
    }

    /**
     * Response object for the Tavily API.
     */
    @Data
    @NoArgsConstructor
    @AllArgsConstructor
    @JsonClassDescription("Response object for the Tavily API")
    public static class TavilyResponse {
        @JsonProperty("query")
        private String query;

        @JsonProperty("follow_up_questions")
        private List<String> followUpQuestions;

        @JsonProperty("answer")
        private String answer;

        @JsonDeserialize(using = ImageDeserializer.class)
        @JsonProperty("images")
        private List<Image> images;

        @JsonProperty("results")
        private List<Result> results;

        @JsonProperty("response_time")
        private float responseTime;

        @Data
        @NoArgsConstructor
        @AllArgsConstructor
        public static class Image {
            @JsonProperty("url")
            private String url;

            @JsonProperty("description")
            private String description;
        }

        @Data
        @NoArgsConstructor
        @AllArgsConstructor
        public static class Result {
            @JsonProperty("title")
            private String title;

            @JsonProperty("url")
            private String url;

            @JsonProperty("content")
            private String content;

            @JsonProperty("raw_content")
            private String rawContent;

            @JsonProperty("score")
            private float score;

            @JsonProperty("published_date")
            private String publishedDate;
        }
    }

    public static class ImageDeserializer extends JsonDeserializer<List<TavilyResponse.Image>> {
        @Override
        public List<TavilyResponse.Image> deserialize(JsonParser jsonParser, DeserializationContext context) throws IOException {

            JsonNode node = jsonParser.getCodec().readTree(jsonParser);
            List<TavilyResponse.Image> images = new ArrayList<>();

            if (node.isArray()) {
                for (JsonNode element : node) {
                    // If element is a string, treat it as a URL
                    if (element.isTextual()) {
                        images.add(new TavilyResponse.Image(element.asText(), null));
                    }
                    // If element is an object, map it to Image
                    else if (element.isObject()) {
                        String url = element.has("url") ? element.get("url").asText() : null;
                        String description = element.has("description") ? element.get("description").asText() : null;
                        images.add(new TavilyResponse.Image(url, description));
                    }
                }
            }

            return images;
        }
    }
}

(2)集成到智能体中

通过配置 defaultTools 进行工具集成

java 复制代码
// existing codes...


// config类已声明@AllArgsConstructor注解
// 自动注入工具
private final WebSearchTool webSearchTool;

/**
     * 网络搜索 ChatClient
     * 使用工具进行网络搜索,并返回相关信息
     * @param chatModel 模型配置
     * @return ChatClient 实例
     */
    @Bean
    public ChatClient WebSearchChatClient(ChatModel chatModel, ChatMemory chatMemory) {

        String systemPrompt = """
               请根据用户的问题,使用工具进行网络搜索,并返回相关信息。
               
               今天的日期是: {date}
               """;

        return ChatClient.builder(chatModel)
                .defaultSystem(systemPrompt)
                .defaultUser(u -> u.text("用户问题:{question}"))
                // 此处要用 ToolCallingChatOptions 而不是 ChatOptions
                .defaultOptions(ToolCallingChatOptions.builder().temperature(0.8).build())
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId("WebSearch").build())
                .defaultTools(webSearchTool)
                .build();
    }

4. 节点设计(node)

就如上文所说,节点node的功能在于:调用前面定义的ChatClient,完成对整个"图"/整个"工作流"的状态更新。

注意每个节点使用的ChatClient,会在最后搭建图时传入(最后一步构建工作流)

(1)GenerationNode

该节点用于使用自定义的rag prompt模板,进行回答生成。

java 复制代码
package com.ai.demo.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import lombok.Builder;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Builder
public class GenerationNode implements NodeAction {

    private final ChatClient chatClient;

    @Override
    public Map<String, Object> apply(OverAllState state) {
        String query = state.value("question", "");
        List<Document> documents = state.value("documents", List.of());

        String generation = chatClient.prompt()
                .user(u -> u.param("question", query)
                        .param("context", documents
                                .stream()
                                .map(Document::getText)
                                .collect(Collectors.joining("\n\n"))))
                .call()
                .content();

        // 更新状态
        HashMap<String, Object> resultMap = new HashMap<>();
        resultMap.put("generation", generation);

        return resultMap;
    }
}

(2)RetrieveNode

该节点用于Spring AI预构建智能体的调用,主要通过 RetrievalAugmentationAdvisor类 进行增强。

java 复制代码
package com.ai.demo.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import lombok.Builder;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.messages.AbstractMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static org.springframework.ai.chat.memory.ChatMemory.CONVERSATION_ID;

@Builder
public class RetrieveNode implements NodeAction {

    private final DocumentRetriever documentRetriever;

    private final ChatClient chatClient;

    private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

    private final ChatMemory chatMemory = MessageWindowChatMemory.builder()
            .maxMessages(10)
            .build();

    private final BaseChatMemoryAdvisor chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory)
            .build();

    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        String query = state.value("question", "");

        // 检查 advisor 是否为空
        if (retrievalAugmentationAdvisor == null) {
            throw new IllegalStateException("RetrievalAugmentationAdvisor 不能为空");
        }

        // 检索文档
//        List<Document> documents = this.documentRetriever.retrieve(Query.builder()
//                .text(query)
//                .build());

        // other:
        ChatResponse response = chatClient.prompt()
                .advisors(retrievalAugmentationAdvisor)
                .advisors(chatMemoryAdvisor, retrievalAugmentationAdvisor)
                .advisors(advisors -> advisors.param(CONVERSATION_ID,
                        "PrebuiltSpringRAG"))
                .user(query)
                .call()
                .chatResponse();

        // 结果
        String generation = Optional.ofNullable(response).map(ChatResponse::getResult).map(Generation::getOutput).map(AbstractMessage::getText).orElse("");
        System.out.println("Generation: " + generation);

        // 获取检索结果
        assert response != null;
        List<Document> retrievedDocuments = response.getMetadata().get("rag_document_context");

        System.out.println("Documents: " + retrievedDocuments);

        // 更新状态
        HashMap<String, Object> resultMap = new HashMap<>();
        resultMap.put("question", query);
        resultMap.put("documents", retrievedDocuments);
        resultMap.put("generation", generation);
        return resultMap;
    }
}

(3)TransformQueryNode

该节点用于对用户问题进行重写。

java 复制代码
package com.ai.demo.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;

import java.util.HashMap;
import java.util.Map;

@Builder
@Slf4j
public class TransformQueryNode implements NodeAction {

    private final ChatClient chatClient;

    @Override
    public Map<String, Object> apply(OverAllState state) {
        String question = state.value("question", String.class).orElse("");
        String betterQuestion = chatClient.prompt()
                .user(u -> u.param("question", question))
                .call()
                .content();
        log.info("重写后的问题: {}", betterQuestion);
        // 更新状态
        HashMap<String, Object> resultMap = new HashMap<>();
        resultMap.put("question", betterQuestion);
        return resultMap;
    }
}

(4)WebSearchNode

该节点用于调用网络检索Agent,并将结果整理为文档集合

java 复制代码
package com.ai.demo.node;

import com.ai.demo.tool.WebSearchTool;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;

import java.time.LocalDate;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@Builder
@Slf4j
public class WebSearchNode implements NodeAction {

    private final ChatClient chatClient;

    @Override
    public Map<String, Object> apply(OverAllState state) {
        String query = state.value("question", "");
        WebSearchTool.TavilyResponse response = chatClient.prompt()
                .system(s -> s.param("date", LocalDate.now().toString())).user(u -> u.param("question", query)).call()
                .entity(WebSearchTool.TavilyResponse.class);


        assert response != null;
        log.debug("WebSearchNode response: {}", response);

        // 获取内容并转为文档对象
        List<Document> documents = response.getResults().stream()
                .map(result -> new Document(result.getContent(),
                        Map.of("origin", result.getUrl(),
                                "title", result.getTitle())))
                .collect(Collectors.toList());
        documents.addFirst(new Document(response.getAnswer(),
                Map.of("origin", "Web Search Answer", "title", "Web Search Answer")));

        // 更新状态
        HashMap<String, Object> resultMap = new HashMap<>();
        resultMap.put("question", query);
        resultMap.put("documents", documents);
        return resultMap;
    }
}

5. 边设计(edge)

边edge,用来决定下一步进入到哪个节点。

(1)RouteQuestionEdge

路由边,决定将用户问题直接交给向量库检索,还是通过网络检索。

java 复制代码
package com.ai.demo.edge;

import com.ai.demo.entity.RouteQueryEntity;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Component;

@Component
@Slf4j
public class RouteQuestionEdge implements EdgeAction {

    private final ChatClient questionRouterChatClient;

    public RouteQuestionEdge(@Qualifier("QuestionRouterChatClient") ChatClient questionRouterChatClient) {
        this.questionRouterChatClient = questionRouterChatClient;
    }

    @Override
    public String apply(OverAllState state) {
        log.info("---------- 边:路由问题 ----------");

        String question = state.value("question", String.class).orElse("");

        // 决定数据源
        RouteQueryEntity response = questionRouterChatClient.prompt()
                .user(u -> u.param("question", question))
                .system(s -> s.param("knowledge_base", "关于Spring AI Alibaba的相关知识"))
                .call()
                .entity(RouteQueryEntity.class);

        log.info("路由到: {}", response);
        assert response != null;

        return response.dataSource();
    }
}

(2)GradeGenerationEdge

对回答质量进行评估,决定是可以回复了,还是需要重新生成。

java 复制代码
package com.ai.demo.edge;

import com.ai.demo.entity.GradeScore;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Component;

import java.util.List;
import java.util.stream.Collectors;

@Component
@Slf4j
public class GradeGenerationEdge implements EdgeAction {

    private final ChatClient hallucinationGrader;

    private final ChatClient answerGrader;

    public GradeGenerationEdge(@Qualifier("HallucinationChatClient") ChatClient hallucinationGrader,
            @Qualifier("AnswerGraderChatClient") ChatClient answerGrader) {
        this.hallucinationGrader = hallucinationGrader;
        this.answerGrader = answerGrader;
    }

    /**
     * 评估生成质量
     * @param state 图状态
     * @return "hallucination" 如果生成的回答不符合事实,需要重试;
     * "unuseful" 如果生成的回答没有回应问题,需要重写问题;
     * "useful" 如果生成的回答回应了问题。
     */
    @Override
    public String apply(OverAllState state) {
        log.info("---------- 边:检查生成的回答是否符合事实 ----------");
        String question = state.value("question", String.class).orElse("");
        String generation = state.value("generation", String.class).orElse("");
        List<Document> documents = state.value("documents", List.of());

        GradeScore hallucinationGradeScore = hallucinationGrader.prompt()
                .user(u -> u.param("documents", formatDocs(documents))
                        .param("generation", generation))
                .call()
                .entity(GradeScore.class);

        assert hallucinationGradeScore != null;
        if (!"yes".equals(hallucinationGradeScore.binaryScore())) {
            log.info("---------- 决策:生成的回答不符合事实,需要重试 ----------");
            return "hallucination";
        }

        log.info("---------- 决策:生成的回答符合事实 ----------");
        GradeScore answerGradeScore = answerGrader.prompt()
                .user(u -> u.param("question", question)
                        .param("generation", generation))
                .call()
                .entity(GradeScore.class);

        assert answerGradeScore != null;
        if ("yes".equals(answerGradeScore.binaryScore())) {
            log.info("---------- 决策:生成的回答回应了问题 ----------");
            return "useful";
        }

        log.info("---------- 决策:生成的回答没有回应问题 ----------");
        return "unuseful";
    }

    private String formatDocs(List<Document> documents) {
        return documents.stream().map(Document::getText).collect(Collectors.joining("\n\n"));
    }
}

6. 构建工作流

(1)rag配置

主要配置每个模块的具体参数

java 复制代码
package com.ai.demo.config;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.transformer.splitter.TextSplitter;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class RagConfig {

    /**
     * 文本分割器,用于将文档切分成小片段
     * @return TextSplitter 实例
     */
    @Bean
    TextSplitter textSplitter() {
        return TokenTextSplitter.builder()
                // 可进一步设置分割的最大 token 数量、切分窗口大小等参数
                .build();
    }

    /**
     * 向量存储库,用于存储文档片段的向量表示
     * @return VectorStore 实例
     */
    @Bean
    VectorStore vectorStore(EmbeddingModel embeddingModel) {
        return SimpleVectorStore.builder(embeddingModel)
                .build();
    }

    /**
     * 文档检索器,用于从向量存储中检索相关文档片段
     * @param vectorStore 向量存储库
     * @return DocumentRetriever 实例
     */
    @Bean
    DocumentRetriever documentRetriever(VectorStore vectorStore) {
        return VectorStoreDocumentRetriever.builder()
                .vectorStore(vectorStore)
                .similarityThreshold(0.50)
                .build();
    }

    /**
     * 压缩查询转换器,将对话历史和后续查询压缩为捕获对话本质的独立查询
     * <p> <em>检索前增强</em> 适用于对话历史较长且后续查询与对话上下文相关时</p>
     * @param chatClient 聊天模型
     * @return CompressionQueryTransformer 实例
     */
    @Bean
    CompressionQueryTransformer compressionQueryTransformer(ChatClient chatClient) {
        return CompressionQueryTransformer.builder()
                .chatClientBuilder(chatClient.mutate())
                .build();
    }

    /**
     * 查询重写转换器,重写用户查询
     * <p> <em>检索前增强</em> 适用于用户查询冗长、含歧义或包含可能影响搜索结果质量的无关信息时 </p>
     * @param chatClient 聊天模型
     * @return RewriteQueryTransformer 实例
     */
    @Bean
    RewriteQueryTransformer rewriteQueryTransformer(ChatClient chatClient) {
        return RewriteQueryTransformer.builder()
                .chatClientBuilder(chatClient.mutate())
                .build();
    }

    /**
     * 翻译查询转换器,将用户查询翻译为目标语言
     * <p> <em>检索前增强</em> 适用于用户查询使用非目标语言时</p>
     * @param chatClient 聊天模型
     * @return TranslationQueryTransformer 实例
     */
    @Bean
    TranslationQueryTransformer translationQueryTransformer(ChatClient chatClient) {
        return TranslationQueryTransformer.builder()
                .chatClientBuilder(chatClient.mutate())
                .targetLanguage("chinese")
                .build();
    }

    /**
     * 多查询扩展器,生成多个查询以提高检索覆盖率
     * <p> <em>检索前增强</em> 利用大模型从不同视角生成多语义查询语句</p>
     * @param chatClient 聊天模型
     * @return MultiQueryExpander 实例
     */
    @Bean
    MultiQueryExpander multiQueryExpander(ChatClient chatClient) {
        return MultiQueryExpander.builder()
                .chatClientBuilder(chatClient.mutate())
                .numberOfQueries(3)
                .build();
    }
}

(2)搭建图结构

java 复制代码
package com.ai.demo.config;

import com.ai.demo.edge.GradeGenerationEdge;
import com.ai.demo.edge.RouteQuestionEdge;
import com.ai.demo.node.GenerationNode;
import com.ai.demo.node.RetrieveNode;
import com.ai.demo.node.TransformQueryNode;
import com.ai.demo.node.WebSearchNode;
import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.Map;

@Configuration
@Slf4j
public class GraphConfig {

    private final RouteQuestionEdge routeQuestionEdge;

    private final GradeGenerationEdge gradeGenerationEdge;

    private final ChatClient commonChatClient;

    private final ChatClient webSearchClient;

    private final ChatClient ragChatClient;

    private final ChatClient questionRewriterChatClient;

    private final DocumentRetriever documentRetriever;

    private final CompressionQueryTransformer compressionQueryTransformer;

    private final RewriteQueryTransformer rewriteQueryTransformer;

    private final TranslationQueryTransformer translationQueryTransformer;

    public GraphConfig(RouteQuestionEdge routeQuestionEdge, GradeGenerationEdge gradeGenerationEdge,
            ChatClient commonChatClient,
            @Qualifier("WebSearchChatClient") ChatClient webSearchClient,
            @Qualifier("AdaptiveRagChatClient") ChatClient ragChatClient,
            @Qualifier("QuestionRewriterChatClient") ChatClient questionRewriterChatClient,
            DocumentRetriever documentRetriever,
            CompressionQueryTransformer compressionQueryTransformer,
            RewriteQueryTransformer rewriteQueryTransformer,
            TranslationQueryTransformer translationQueryTransformer) {
        this.routeQuestionEdge = routeQuestionEdge;
        this.gradeGenerationEdge = gradeGenerationEdge;
        this.commonChatClient = commonChatClient;
        this.webSearchClient = webSearchClient;
        this.ragChatClient = ragChatClient;
        this.questionRewriterChatClient = questionRewriterChatClient;
        this.documentRetriever = documentRetriever;
        this.compressionQueryTransformer = compressionQueryTransformer;
        this.rewriteQueryTransformer = rewriteQueryTransformer;
        this.translationQueryTransformer = translationQueryTransformer;
    }

    @Bean
    public StateGraph graph(ChatClient.Builder chatClientBuilder) throws GraphStateException {
        OverAllStateFactory stateFactory = () -> {
            OverAllState state = new OverAllState();
            state.registerKeyAndStrategy("question", new ReplaceStrategy());
            state.registerKeyAndStrategy("generation", new ReplaceStrategy());
            state.registerKeyAndStrategy("documents", new ReplaceStrategy());
            return state;
        };

        StateGraph stateGraph = new StateGraph("Spring AI Alibaba Graph Demo", stateFactory);

        // 添加节点
        stateGraph.addNode("prebuilt_rag_generation", AsyncNodeAction.node_async(RetrieveNode.builder()
                        .chatClient(commonChatClient)
                        .documentRetriever(documentRetriever)
                        .retrievalAugmentationAdvisor(RetrievalAugmentationAdvisor.builder()
                                .documentRetriever(documentRetriever)
                                .queryTransformers(compressionQueryTransformer, translationQueryTransformer, rewriteQueryTransformer)
                                .build())
                        .build()));
        stateGraph.addNode("web_search",
                AsyncNodeAction.node_async(WebSearchNode.builder().chatClient(webSearchClient).build()));
        stateGraph.addNode("self_rag_generation",
                AsyncNodeAction.node_async(GenerationNode.builder().chatClient(ragChatClient).build()));
        stateGraph.addNode("transform_query",
                AsyncNodeAction.node_async(TransformQueryNode.builder().chatClient(questionRewriterChatClient).build()));

        // 决定通过向量库检索还是网络搜索
        stateGraph.addConditionalEdges(StateGraph.START, AsyncEdgeAction.edge_async(routeQuestionEdge),
                Map.of("vectorstore", "prebuilt_rag_generation", "web_search", "web_search"));

        // 向量库chains
        stateGraph.addConditionalEdges("prebuilt_rag_generation", AsyncEdgeAction.edge_async(gradeGenerationEdge),
                Map.of("useful", StateGraph.END,
                        "unuseful", "transform_query",
                        "hallucination", "prebuilt_rag_generation"));

        // 网络搜索chains
        stateGraph.addEdge("web_search", "self_rag_generation");
        stateGraph.addConditionalEdges("self_rag_generation", AsyncEdgeAction.edge_async(gradeGenerationEdge),
                Map.of("useful", StateGraph.END,
                        "unuseful", "transform_query",
                        "hallucination", "self_rag_generation"));

        // 重写问题
        stateGraph.addEdge("transform_query", "self_rag_generation");

        // 添加 Mermaid 打印
        GraphRepresentation representation = stateGraph.getGraph(GraphRepresentation.Type.MERMAID,
                "Adaptive rag flow");
        log.info("\n=== Adaptive rag Flow ===");
        log.info(representation.content());
        log.info("==================================\n");

        return stateGraph;
    }
}

(3)提供api接口

java 复制代码
package com.ai.demo.controller;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import lombok.AllArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.markdown.MarkdownDocumentReader;
import org.springframework.ai.reader.markdown.config.MarkdownDocumentReaderConfig;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.util.ResourceUtils;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;

@RestController
@RequestMapping("/graph")
@Slf4j
public class GraphController {

    @Value("classpath:documents/faq.md")
    Resource file1;

    @Value("classpath:documents/overview.md")
    Resource file2;

    private final CompiledGraph compiledGraph;

    private final VectorStore vectorStore;

    private final String SAVE_PATH = System.getProperty("user.dir") + "/src/main/resources/vectorstore/vectorstore.json";

    @SneakyThrows
    public GraphController(@Qualifier("graph") StateGraph stateGraph, VectorStore vectorStore) {
        this.vectorStore = vectorStore;
        this.compiledGraph = stateGraph.compile();
    }

    @GetMapping(value = "/add")
    public void addDocuments() {
        // 如果存在则加载
        File file = new File(SAVE_PATH);
        if (file.exists()) {
            log.info("load vector store from {}", SAVE_PATH);
            ((SimpleVectorStore) vectorStore).load(file);
            return;
        }

        log.info("start add documents");
        var markdownReader1 = new MarkdownDocumentReader(file1, MarkdownDocumentReaderConfig.builder()
                .withAdditionalMetadata("title", "Spring AI Alibaba FAQ")
                .withAdditionalMetadata("summary", "关于Spring AI Alibaba的常见问题和解答")
                .build());
        List<Document> documents = new ArrayList<>(markdownReader1.get());

        var markdownReader2 = new MarkdownDocumentReader(file2, MarkdownDocumentReaderConfig.builder()
                .withAdditionalMetadata("title", "Spring AI Alibaba Overview")
                .withAdditionalMetadata("summary", "关于Spring AI Alibaba的概述")
                .build());
        documents.addAll(markdownReader2.get());

        // 将文档添加到向量库中
        vectorStore.add(documents);

        // 持久化
        ((SimpleVectorStore) vectorStore).save(file);
    }

    @GetMapping(value = "/expand")
    public Map<String, Object> expand(@RequestParam(value = "query", defaultValue = "你好,我想知道一些关于大模型的知识", required = false) String query) {
        RunnableConfig runnableConfig = RunnableConfig.builder().threadId("001").build();
        Map<String, Object> objectMap = new HashMap<>();
        objectMap.put("question", query);
        Optional<OverAllState> invoke = this.compiledGraph.invoke(objectMap, runnableConfig);
        return invoke.map(OverAllState::data).orElse(new HashMap<>());
    }
}

效果演示

首先需要调用 /add 接口,完成知识库的加载。

(1)基本rag功能

ai识别到问题与向量库相关,路由到vectorstore节点,并进行后续处理

(2)回答自评估

ai发现回答质量不符合事实,重新进行了生成

(3)网络搜索

ai识别到问题需要进行网络检索,路由到web_search节点,并进入后续步骤

问题1:昨天G2在LOL的MSI比赛中赢了吗

AI回答:是的,G2在昨天的MSI比赛中以3比2战胜了GAM,成功晋级淘汰赛。

问题2:明天北京的天气

AI回答:明天(2025年7月2日)北京的天气预报为晴,最高气温29°C,最低气温22°C,风向为微风,风力较小。

(4)模型记忆

当第二个问题问到"它的特点"时,结合上下文能知道指的是"spring ai alibaba"

后续优化与展望

  1. 优化方面,后续可以进一步扩展功能,如数据库、前端实现等。由于本身自带了流式接口,其实剩下的开发,就跟传统的Springboot + vue的模式差不多。另外,使用spring ai alibaba框架的话,后续可以更好的和阿里的一些生态进行集成,比如nacos等等。

  2. 在Spring ai alibaba这个框架方面,实际使用下来还是不错的,相比LangGraph,结合java语言特性也做了很多抽象。不过感觉目前很多地方都需要进行spring注入,也存在硬编码的问题,看看后续阿里能否再优化下,再完善下文档或者提供更多demo。

  3. 同样地,期待这个框架能够实现类似langgraph中的inject注入,比如上下文,工具调用等等(与state 不同,对于llm不可见)。提供更多的可扩展的预构建智能体或工作流,以及父子图的实现,结合阿里生态进一步完善。

参考

  1. https://github.com/alibaba/spring-ai-alibaba/tree/main

  2. https://github.com/langchain-ai/langgraph

  3. https://github.com/apappascs/spring-ai-examples/

相关推荐
缘来是庄13 分钟前
设计模式之访问者模式
java·设计模式·访问者模式
jndingxin17 分钟前
OpenCV CUDA模块设备层-----高效地计算两个 uint 类型值的带权重平均值
人工智能·opencv·计算机视觉
Sweet锦28 分钟前
零基础保姆级本地化部署文心大模型4.5开源系列
人工智能·语言模型·文心一言
Bug退退退12338 分钟前
RabbitMQ 高级特性之死信队列
java·分布式·spring·rabbitmq
梵高的代码色盘1 小时前
后端树形结构
java
代码的奴隶(艾伦·耶格尔)1 小时前
后端快捷代码
java·开发语言
虾条_花吹雪1 小时前
Chat Model API
java
双力臂4041 小时前
MyBatis动态SQL进阶:复杂查询与性能优化实战
java·sql·性能优化·mybatis
hie988941 小时前
MATLAB锂离子电池伪二维(P2D)模型实现
人工智能·算法·matlab
晨同学03271 小时前
opencv的颜色通道问题 & rgb & bgr
人工智能·opencv·计算机视觉