Spring Ai Alibaba

代码地址:https://gitee.com/CodeMao01/springai-examples

一、前期准备及概述

智谱大模型,注册并创建好ApiKey:智谱AI开放平台

可以通过ApiFox或Postmant调用下模型API使用概述 - 智谱AI开放文档

Sping Ai Alibaba是在Spring AI的基础上增加了Graph,可以用来进行多智能体编排及工作流** (Multi-Agent/Workflow)**

版本说明:

JDK:17 (官方最低要求17)

SpringAI :1.0.0

SpringBoot : 3.4.0

Spring AI Alibaba : 1.0.0.4

二、快速入门

父pom

xml 复制代码
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
  xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
  <modelVersion>4.0.0</modelVersion>
  <packaging>pom</packaging>
  <modules>
    <module>01_springai-alibaba-quick-start</module>
  </modules>
  <parent>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-parent</artifactId>
    <version>3.5.13</version>
    <relativePath/> <!-- lookup parent from repository -->
  </parent>
  <groupId>com.example</groupId>
  <artifactId>springai-examples</artifactId>
  <version>0.0.1-SNAPSHOT</version>
  <name>springai-examples</name>
  <description>springai-examples</description>
  <url/>
  <licenses>
    <license/>
  </licenses>
  <developers>
    <developer/>
  </developers>
  <scm>
    <connection/>
    <developerConnection/>
    <tag/>
    <url/>
  </scm>
  <properties>
    <java.version>17</java.version>

    <project.version>1.0-SNAPSHOT</project.version>
    <maven.compiler.source>17</maven.compiler.source>
    <maven.compiler.target>17</maven.compiler.target>
    <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    <spring-ai.version>1.0.0</spring-ai.version>
    <spring-ai-alibaba.version>1.0.0.4</spring-ai-alibaba.version>
    <!-- Spring Boot -->
    <spring-boot.version>3.4.0</spring-boot.version>
    <!-- maven plugin -->
    <maven-deploy-plugin.version>3.1.1</maven-deploy-plugin.version>
    <flatten-maven-plugin.version>1.3.0</flatten-maven-plugin.version>
    <maven-compiler-plugin.version>3.8.1</maven-compiler-plugin.version>
  </properties>
  <!--    <dependencies>-->
  <!--        <dependency>-->
  <!--            <groupId>org.springframework.boot</groupId>-->
  <!--            <artifactId>spring-boot-starter</artifactId>-->
  <!--        </dependency>-->

  <!--        <dependency>-->
  <!--            <groupId>org.springframework.boot</groupId>-->
  <!--            <artifactId>spring-boot-starter-test</artifactId>-->
  <!--            <scope>test</scope>-->
  <!--        </dependency>-->
  <!--    </dependencies>-->

  <dependencyManagement>
    <dependencies>
      <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-dependencies</artifactId>
        <version>${spring-boot.version}</version>
        <type>pom</type>
        <scope>import</scope>
      </dependency>
      <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-bom</artifactId>
        <version>${spring-ai.version}</version>
        <type>pom</type>
        <scope>import</scope>
      </dependency>
      <dependency>
        <groupId>com.alibaba.cloud.ai</groupId>
                <artifactId>spring-ai-alibaba-bom</artifactId>
                <version>${spring-ai-alibaba.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>


<!--    <build>-->
<!--        <plugins>-->
<!--            <plugin>-->
<!--                <groupId>org.springframework.boot</groupId>-->
<!--                <artifactId>spring-boot-maven-plugin</artifactId>-->
<!--            </plugin>-->
<!--        </plugins>-->
<!--    </build>-->
    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <version>${spring-boot.version}</version>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-deploy-plugin</artifactId>
                <version>${maven-deploy-plugin.version}</version>
                <configuration>
                    <skip>true</skip>
                </configuration>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>${maven-compiler-plugin.version}</version>
                <configuration>
                    <release>17</release>
                    <compilerArgs>
                        <compilerArg>-parameters</compilerArg>
                    </compilerArgs>
                </configuration>
            </plugin>
            <plugin>
                <groupId>org.codehaus.mojo</groupId>
                <artifactId>flatten-maven-plugin</artifactId>
                <version>${flatten-maven-plugin.version}</version>
                <inherited>true</inherited>
                <executions>
                    <execution>
                        <id>flatten</id>
                        <phase>process-resources</phase>
                        <goals>
                            <goal>flatten</goal>
                        </goals>
                        <configuration>
                            <updatePomFile>true</updatePomFile>
                            <flattenMode>ossrh</flattenMode>
                            <pomElements>
                                <distributionManagement>remove</distributionManagement>
                                <dependencyManagement>remove</dependencyManagement>
                                <repositories>remove</repositories>
                                <scm>keep</scm>
                                <url>keep</url>
                                <organization>resolve</organization>
                            </pomElements>
                        </configuration>
                    </execution>
                    <execution>
                        <id>flatten.clean</id>
                        <phase>clean</phase>
                        <goals>
                            <goal>clean</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>

</project>

子pom-01_springai-alibaba-quick-start

xml 复制代码
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>com.example</groupId>
        <artifactId>springai-examples</artifactId>
        <version>0.0.1-SNAPSHOT</version>
        <!-- relativePath:告诉IDEA父pom在上级目录,本地查找,不用去远程仓库拉父工程 -->
        <relativePath>../pom.xml</relativePath>
    </parent>
    <groupId>com.example</groupId>
    <artifactId>01_springai-alibaba-quick-start</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>01_springai-alibaba-quick-start</name>
    <description>01_springai-alibaba-quick-start</description>
    <url/>
    <licenses>
        <license/>
    </licenses>
    <developers>
        <developer/>
    </developers>
    <scm>
        <connection/>
        <developerConnection/>
        <tag/>
        <url/>
    </scm>
    <properties>
        <java.version>17</java.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-starter-model-zhipuai</artifactId>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <scope>provided</scope>
        </dependency>
    </dependencies>


    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>

</project>

子工程-01_springai-alibaba-quick-start的配置文件

yaml 复制代码
server:
  port: 8080
  servlet:
    encoding:
      charset: UTF-8
      enabled: true
      force: true
spring:
  application:
    name: 01_springai-alibaba-quick-start
  ai:
    zhipuai:
      api-key: ${ZHIPU_KEY} # 配置 API Key
      base-url: "https://open.bigmodel.cn/api/paas" # 配置 模型地址
      chat:
        options:
          model: glm-4.5

2.1、chatModel

call的三个参数及流式输出:

参数一:直接传用户问题

参数二:SystemMessage + UserMessage

参数三:Prompt = Message + ChatOptions

流式输出:可以理解为一个字一个字蹦出来?

java 复制代码
package com.example.controller;

import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;

import java.util.List;

@RestController
@RequestMapping("/zhipuai")
public class ZhipuController {
    private final ChatModel chatModel;

    // 通过构造器的方注入ChatModel
    public ZhipuController(ChatModel chatModel) {
        this.chatModel = chatModel;
    }

    @GetMapping("/simple")
    public String simpleChat(@RequestParam(name = "query") String query) {
        return chatModel.call(query);
    }

    @GetMapping("/message")
    public String messageChat(@RequestParam(name = "query") String query) {
        SystemMessage systemMessage = new SystemMessage("你是一个AI助手");
        UserMessage userMessage = new UserMessage(query);
        return chatModel.call(systemMessage, userMessage);
    }

    /**
     * Prompt = Messages + ChatOptions
     */
    @GetMapping("/chatOptions")
    public String chatOptions(@RequestParam(name = "query") String query) {
        SystemMessage systemMessage = new SystemMessage("你是一个AI助手");
        UserMessage userMessage = new UserMessage(query);

        ChatOptions chatOptions = ZhiPuAiChatOptions.builder()
                .maxTokens(10000)
                .model("glm-4.5")
                .temperature(0.9)

                .build();
        Prompt prompt = new Prompt(List.of(systemMessage, userMessage), chatOptions);
        ChatResponse chatResponse = chatModel.call(prompt);
        return chatResponse.getResult().getOutput().getText();
    }

    @GetMapping("/stream/chat")
    public Flux<String> streamChat(@RequestParam(name = "query") String query) {
        return chatModel.stream(query);
    }
}

2.2、chatClient

比chatModel更加常用,增加了一些功能,比如advisor等

2.2.1、基本使用

java 复制代码
package com.example.controller;

import com.example.advisor.SGCallAdvisor1;
import com.example.advisor.SGCallAdvisor2;
import com.example.advisor.SimpleMessageChatMemoryAdvisor;
import com.example.entity.Book;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;

import java.util.List;
import java.util.function.Consumer;

@RestController
@RequestMapping("/chatClient")
public class ZhipuChatClientController {
    private final ChatClient chatClient;

    public ZhipuChatClientController(ChatClient.Builder builder) {
        this.chatClient = builder.build();
    }

    @GetMapping("/simple")
    public String simpleChat(@RequestParam(name = "query") String query) {
        SystemMessage systemMessage = new SystemMessage("你是一个AI助手");
        UserMessage userMessage = new UserMessage(query);

        ChatOptions chatOptions = ChatOptions.builder()
                .maxTokens(10000)
                .model("glm-4.5")
                .temperature(0.9)
                .build();

        Prompt prompt = new Prompt(List.of(systemMessage, userMessage), chatOptions);
        return chatClient.prompt(prompt).call().content();
    }

    @GetMapping("/simple2")
    public ChatResponse simpleChat2(@RequestParam(name = "query") String query) {
        ChatOptions chatOptions = ChatOptions.builder()
                .maxTokens(10000)
                .model("glm-4.5")
                .temperature(0.9)
                .build();

        return chatClient.prompt()
                .system("你是一个AI助手")
                .user(query)
                .options(chatOptions).call().chatResponse();
    }

    @GetMapping("/response")
    public Book response() {
        return chatClient.prompt()
                .user("给我随机生成一本书,要求书名和作者都是中文")
                .call()
                .entity(Book.class);
    }

    @GetMapping("/stream")
    public Flux<String> stream() {
        return chatClient.prompt()
                .user("给我随机生成一本书,要求书名和作者都是中文")
                .stream().content();
    }

}

2.2.2、advisor

java 复制代码
@GetMapping("/advisor")
public String advisor() {
    return chatClient.prompt().user("你是谁?")
            .advisors(new SGCallAdvisor2(), new SGCallAdvisor1()).call().content();
}

SGCallAdvisor1:

java 复制代码
package com.example.advisor;

import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;

@Slf4j
public class SGCallAdvisor1 implements CallAdvisor {
    @Override
    public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
        log.info("SGCallAdvisor1-adviseCall start");
        ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);
        log.info("SGCallAdvisor1-adviseCall end");
        return chatClientResponse;
    }

    @Override
    public String getName() {
        return "SGCallAdvisor1-name";
    }

    @Override
    public int getOrder() {
        return 0;
    }
}

SGCallAdvisor2:

java 复制代码
package com.example.advisor;

import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;

@Slf4j
public class SGCallAdvisor2 implements CallAdvisor {
    @Override
    public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
        log.info("SGCallAdvisor2-adviseCall start");
        ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);
        log.info("SGCallAdvisor2-adviseCall end");
        return chatClientResponse;
    }

    @Override
    public String getName() {
        return "SGCallAdvisor2-name";
    }

    @Override
    public int getOrder() {
        return 1;
    }
}

2.2.3、chatMemory

2.2.3.1、自定义chatMemory
java 复制代码
package com.example.advisor;

import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.util.CollectionUtils;

import java.util.*;

public class SimpleMessageChatMemoryAdvisor implements BaseAdvisor {

    private static Map<String, List<Message>> chatMemory = new HashMap<>();

    @Override
    public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
        String conversationId = chatClientRequest.context().get("conversationId").toString();
        // 获取当前会话的id对应的Message
        List<Message> hisMessages = chatMemory.get(conversationId);
        if (CollectionUtils.isEmpty(hisMessages)) {
            hisMessages = new ArrayList<>();
        }

        // 获取当前会话的message并添加到chatMemory中
        hisMessages.addAll(chatClientRequest.prompt().getInstructions());

        // 把最新的chatMessage放到当前的request中
        // 注意request是只读的 只能新建
        Prompt newPrompt = chatClientRequest.prompt().mutate()
                .messages(hisMessages)
                .build();
        ChatClientRequest newRequest = chatClientRequest.mutate()
                .prompt(newPrompt)
                .build();

        // 把消息存储到内存中
        chatMemory.put(conversationId, hisMessages);
        return newRequest;
    }

    @Override
    public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
        // 通过获取id获取记忆的会话内容
        String conversationId = chatClientResponse.context().get("conversationId").toString();
        List<Message> hisMessages = chatMemory.get(conversationId);
        if (CollectionUtils.isEmpty(hisMessages)) {
            hisMessages = new ArrayList<>();
        }

        // 获取LLM响应结果并添加到chatMemory中
        if (Objects.isNull(chatClientResponse)) {
            return chatClientResponse;
        }
        AssistantMessage output = chatClientResponse.chatResponse().getResult().getOutput();
        hisMessages.add(output);
        chatMemory.put(conversationId, hisMessages);

        return chatClientResponse;
    }

    @Override
    public int getOrder() {
        return 0;
    }
}
java 复制代码
@GetMapping("/simpleMessageChatMemoryAdvisor")
public String simpleMessageChatMemoryAdvisor(@RequestParam(name = "query") String query,
                                             @RequestParam(name = "conversationId") String conversationId) {
    return chatClient.prompt().user(query)
            .advisors(new Consumer<ChatClient.AdvisorSpec>() {
                @Override
                public void accept(ChatClient.AdvisorSpec advisorSpec) {
                    advisorSpec.param("conversationId", conversationId);
                }
            })
            .advisors(new SimpleMessageChatMemoryAdvisor()).call().content();
}
2.2.3.2、使用SpringAi的chatMemory
java 复制代码
package com.example.controller;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.Map;
import java.util.function.Consumer;

@RestController
@RequestMapping("/chatMemory")
public class ZhipuChatMemoryController {
    private final ChatClient chatClient;

    public ZhipuChatMemoryController(ChatClient.Builder builder) {
        MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder().build();

        // 构造聊天记忆拦截器
        MessageChatMemoryAdvisor chatMemoryAdvisor = MessageChatMemoryAdvisor
                .builder(chatMemory)
                .build();
        this.chatClient = builder.defaultAdvisors(chatMemoryAdvisor).build();
    }

    @GetMapping("/chatMemoryAdivsor")
    public String simpleMessageChatMemoryAdvisor(@RequestParam(name = "query") String query,
                                                 @RequestParam(name = "conversationId") String conversationId) {
        return chatClient.prompt().user(query)
                .advisors(new Consumer<ChatClient.AdvisorSpec>() {
                    @Override
                    public void accept(ChatClient.AdvisorSpec advisorSpec) {
                        advisorSpec.param(ChatMemory.CONVERSATION_ID, conversationId);
                    }
                })
                .call().content();
    }

    public static void main(String[] args) {
        PromptTemplate userPromptTemplate = new PromptTemplate("你是一个AI人工助手, 你的名字是:{name}, 你的语言风格是:{voice}, 用户的问题是:{query}");
        Message message = userPromptTemplate.createMessage(Map.of("name", "小a", "voice", "幽默", "query", "你叫什么名字?"));
        System.out.println(message);

        PromptTemplate systemPromptTemplate = new PromptTemplate("你是一个AI人工助手, 你的名字是:{name}, 你的语言风格是:{voice}");
        Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "小a", "voice", "幽默"));
        System.out.println(systemMessage);

    }
}

2.3、RAG

2.3.1、RAG介绍、功能以及流程

RAG:检索增强生成。通过外部数据库检索出相关的知识 ,然后把问题和相关知识一起给到大模型来让模型生成回答。

作用:当询问AI不知道的知识的时候会出现幻觉,这样的话可以减少幻觉并且可以做垂直领域,相当于连接了一个外部知识库

流程:

存储:文档切片→进行embedding→存储到向量数据库

查询:用户查询问题→转换成向量→检索向量数据库→添加到prompt→返回给大模型→用户

2.3.2、数据嵌入及相似度搜索

这里选用的是redis-stack自己开发使用

pom:

xml 复制代码
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>com.example</groupId>
        <artifactId>springai-examples</artifactId>
        <version>0.0.1-SNAPSHOT</version>
        <!-- relativePath:告诉IDEA父pom在上级目录,本地查找,不用去远程仓库拉父工程 -->
        <relativePath>../pom.xml</relativePath>
    </parent>
    <groupId>com.example</groupId>
    <artifactId>02_springai-alibaba-rag</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>02_springai-alibaba-rag</name>
    <description>02_springai-alibaba-rag</description>
    <url/>
    <licenses>
        <license/>
    </licenses>
    <developers>
        <developer/>
    </developers>
    <scm>
        <connection/>
        <developerConnection/>
        <tag/>
        <url/>
    </scm>
    <properties>
        <java.version>17</java.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-starter-vector-store-redis</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-starter-model-zhipuai</artifactId>
        </dependency>

    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>

</project>

yaml:

yaml 复制代码
# 服务器配置
server:
  port: 8080 # 应用服务监听端口,默认8080
spring:
  # 应用基本信息配置
  application:
  name: 02_springai-alibaba-rag # 应用名称
  # Spring AI 配置
  ai:
    # 智谱AI大模型配置
    zhipuai:
      api-key: ${ZHIPU_KEY} # 智谱API密钥(从环境变量ZHIPU_KEY获取)
    chat:
      options:
        model: glm-4.6 # 使用的聊天模型名称(GLM-4.6)
      embedding:
        options:
          model: embedding-3 # 使用的嵌入模型名称(embedding-3)
          dimensions: 256 # 嵌入向量的维度(256维)
    # 向量存储配置
    vectorstore:
      redis:
        initialize-schema: true # 启动时自动创建Redis向量索引结构(首次部署需开启)
        prefix: sangeng_rag_prefix # Redis键名前缀,用于区分不同应用的向量数据
        index: sangeng_rag_index # Redis向量索引名称
  # 数据源配置
  data:
    redis:
      host: localhost # Redis服务器地址
      port: 6379 # Redis服务器连接端口(默认6379)

java:

java 复制代码
package com.example.controller;

import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.web.bind.annotation.*;

import java.util.List;

@RestController
@RequestMapping("/redis")
public class RedisController {
    private final VectorStore vectorStore;

    public RedisController(VectorStore vectorStore) {
        this.vectorStore = vectorStore;
    }

    @GetMapping("/import")
    public void importData(@RequestParam(name = "content") String content) {
        List<Document> documents = List.of(new Document(content));
        vectorStore.add(documents);
    }

    @PostMapping("/query")
    public List<Document> query(@RequestParam(name = "query") String query) {
        SearchRequest searchRequest = SearchRequest.builder()
                .query(query)
                .topK(3)
//               // 阈值
                .similarityThreshold(0.5)
                .build();
        return vectorStore.similaritySearch(searchRequest);
    }
}

2.3.3、咖啡店客服实战

  1. 引入rag以及csv的pom
xml 复制代码
<dependency>
    <groupId>org.apache.commons</groupId>
    <artifactId>commons-csv</artifactId>
    <version>1.10.0</version>
</dependency>

<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-rag</artifactId>
</dependency>
  1. 导入咖啡店智能知识库
  2. 使用RetrievalAugmentationAdvisor进行检索

RetrievalAugmentationAdvisor原理其实就是通过VectorStore查询数据并组装到userMessage中

xml 复制代码
package com.example.controller;

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.ClassPathResource;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;

@RestController
@RequestMapping("/coffee")
public class CoffeeController {
    private final VectorStore vectorStore;
    private final ChatClient chatClient;

    public CoffeeController(VectorStore vectorStore, ChatClient.Builder chatClient) {
        this.vectorStore = vectorStore;

        VectorStoreDocumentRetriever documentRetriever = VectorStoreDocumentRetriever.builder()
                .vectorStore(vectorStore)
                .topK(3)
                .similarityThreshold(0.5)
                .build();
        RetrievalAugmentationAdvisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
                .documentRetriever(documentRetriever)
                .build();
        this.chatClient = chatClient.defaultAdvisors(retrievalAugmentationAdvisor).build();
    }

    @GetMapping("/import")
    public String importCsv() {

        try {
            // 读取QA文件
            ClassPathResource resource = new ClassPathResource("QA.csv");
            InputStreamReader reader = new InputStreamReader(resource.getInputStream());// 解析csv文件
            CSVParser csvParser = CSVFormat.DEFAULT.builder()
                    .setHeader()// 第一行作为标题
                    .setSkipHeaderRecord(true)// 跳过标题行
                    .build()
                    .parse(reader);

            List<Document> documents = new ArrayList<>();

            // 遍历每一行数据
            for (CSVRecord record : csvParser) {
                // 获取问题和回答字段
                String question = record.get("question");
                String answer = record.get("answer");

                // 将问题和回答组合成文档内容
                String content = "问题:" + question + "\n答案:" + answer;

                // 创建Document对象
                Document document = new Document(content);

                // 添加到文档列表
                documents.add(document);

            }

            // 关闭解析器
            csvParser.close();

            // 将文档存入向量数据库
            vectorStore.add(documents);

            return "成功导入" + documents.size() + "条记录到向量数据库";

        } catch (IOException e) {
            e.printStackTrace();
            return "导入失败:" + e.getMessage();
        }
    }

    /**
     * 新增RAG问答接口, 明确展示查询向量数据库的过程
     *
     * @param question 用户的问题
     * @return AI基于检索到的信息生成的回答
     */
    @GetMapping("/rag-ask")
    public String ragAskQuestion(@RequestParam("question") String question) {
        // 先从向量数据库中检索相关的信息
        // 这里会使用RetrievalAugmentationAdvisor自动检索相关文档

        // 将问题和检索的上下文一起发送给AI模型生成回答
        return chatClient.prompt()
                .system("你是三更咖啡的服务员, 你需要回答用户的问题")
                .user(question)
                .call()
                .content();
    }
}

2.4、tool Calling

含义:把工具信息发给大模型,让大模型判断是否需要调用工具,让大模型具有调用工具的能力,比如获取当前时间

流程:用户问题 + tool信息 → LLM → tool calling → LLM → 用户

2.4.1、定义工具

java 复制代码
package com.example.tools;

import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;

import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;

public class TimeTools {

    @Tool(description = "通过时区id获取当前时间")
    public String getTimeByZoneID(@ToolParam(description = "时区id, 比如Asia/Shanghai") String zoneID) {
        ZoneId zid = ZoneId.of(zoneID);
        ZonedDateTime zonedDateTime = ZonedDateTime.now(zid);
        DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss z");
        return zonedDateTime.format(formatter);
    }

}

2.4.2、把工具信息传递给大模型

java 复制代码
this.chatClient = chatClient
        .defaultAdvisors(retrievalAugmentationAdvisor)
        .defaultTools(new TimeTools())
        .build();

2.5、MCP

2.5.1、MCP概念

MCP(Model Context Protocol,模型上下文协议)是Anthropic公司推出的开放协议,旨在标准化大

语言模型与外部工具、数据源的交互方式。

官方文档:https://modelcontextprotocol.io/docs/learn/architecture

MCP-host 管理和协调mcp-client的AI应用程序,比如claude-code
MCP-client 与 MCP-service连接并从 MCP 服务器获取上下文,供 MCP 主机使用的组件
MCP-service 给AI模型提供工具的服务

大白话:MCP就是满足AI调用的一种协议标准,这样的话大家都可以写mcp-service提供各种各样的工具,比如操作浏览器,操作飞书等等供AI调用,从而达到百花齐放的效果

2.5.2、实战-提供查询当前时间的工具

2.5.2.1、mcp-service
  1. pom依赖
xml 复制代码
<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-starter-mcp-server-webflux</artifactId>
</dependency>
  1. 定义工具
java 复制代码
package com.example.tools;

import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.stereotype.Component;

import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;

@Component
public class TimeTools {
    @Tool(description = "通过时区id获取当前时间")
    public String getTimeByZoneID(@ToolParam(description = "时区id, 比如Asia/Shanghai") String zoneId) {
        ZoneId zid = ZoneId.of(zoneId);
        ZonedDateTime zonedDateTime = ZonedDateTime.now(zid);
        DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss z");
        return zonedDateTime.format(formatter);
    }
}
  1. 配置工具对外提供(ToolCallbackProvider
java 复制代码
package com.example.config;

import com.example.tools.TimeTools;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class McpConfig {
    @Bean
    public ToolCallbackProvider toolCallbackProvider(TimeTools timeTools) {
        return MethodToolCallbackProvider.builder().toolObjects(timeTools).build();
    }
}
2.5.2.2、mcp-client
  1. pom依赖
xml 复制代码
<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-starter-mcp-client-webflux</artifactId>
</dependency>
  1. application.yaml配置
yaml 复制代码
spring:
  ai:
    mcp:
      client:
        sse:
          connections:
            server1:
              url: http://localhost:8080 # mcp服务url
  1. 使用工具
java 复制代码
this.chatClient = chatClient
        .defaultAdvisors(retrievalAugmentationAdvisor)
//                .defaultTools(new TimeTools())
        .defaultToolCallbacks(toolCallbackProvider.getToolCallbacks())
        .build();

<!-- 这是一张图片,ocr 内容为:REDISVECTORSTOREAPPLICATION X 田 线程和变量 控制台 环境 映射 运行状况 BEAN 对表达式求值(ENTER)或添加监视(CTRL+SHIFT+ENTER) THISZHIPUAIAPI@11019 API(O MESSAGES (ARRAYLIST@11212) SIZE 4... 探索元素 -1- (ZHIPURISCHACOMPLETIONMESSAGEQ11219)'CHATCOMPLETIONMESSAGELAWCONTENTECONTECONTECONTECONININININI ----------------------------------------------------------------------------------------------------- 3 3 "GLM-4-AIR MODEL STOP STREAM - BOOLEAN@11024) FALSE TEMPERATURE{DOUBLE@11214)0.7 支持依赖项'JAVALIO.PROJECTORREACTORREACTORE, SPRING'的插件 REACTIVE STREAMS,SPRING 当前尚未安装. TOPP NULL TOOLS :(ARRAYLIST@11215) SIZE ; 探索元素 不再建议 配置插件... TOOLCHOICENULL -->

2.5.3、网页爬虫MCP使用

2.5.3.1、docker部署mcp
yaml 复制代码
version: "3.8"
services:
  fetcher-mcp:
    image: ghcr.io/jae-jae/fetcher-mcp:latest
    container_name: fetcher-mcp
    restart: unless-stopped
    ports:
      - "3000:3000"
    environment:
      - NODE_ENV=production
    # Using host network mode on Linux hosts can improve browser access efficiency
    # network_mode: "host"
    volumes:
      # For Playwright, may need to share certain system paths
      - /tmp:/tmp
    # Health check
    healthcheck:
      test: ["CMD", "wget", "--spider", "-q", "http://localhost:3000"]
      interval: 30s
      timeout: 10s
      retries: 3

启动docker-compose:docker-compose up -d

2.5.3.2、添加MCP服务相关配置
yaml 复制代码
    mcp:
      client:
        sse:
          connections:
            server1:
              url: http://localhost:8080 # mcp服务url
            fetcher-mcp:
              url: http://localhost:3000
java 复制代码
package com.example.controller;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

@RestController
@RequestMapping("/external")
public class ExternalMcpServiceCall {
    private final ChatClient chatClient;
    public ExternalMcpServiceCall(ChatClient.Builder chatClient, ToolCallbackProvider toolCallbackProvider) {
        this.chatClient = chatClient
                .defaultToolCallbacks(toolCallbackProvider.getToolCallbacks())
                .build();
    }

    @GetMapping("/fetcher")
    public String fetcher(@RequestParam("question") String question) {
        return chatClient.prompt()
                .system("你是一个网页爬虫专家, 你可以调用工具爬取指定网页的内容并进行总结")
                .user(question)
                .call()
                .content();
    }
}

2.6、Graph

2.6.1、概述

假设你需要处理一个真实的商业需求:比如自动化审核一份商业合同。

拆分:

  • 理解内容:先请AI通读合同,提炼摘要和关键条款。
  • 合规检查:再让AI根据公司政策库,判断合同是否有合规风险。
  • 风险评估:接着,要求AI从法律、财务等多维度给出风险评分。
  • 人工介入:最后,必须将AI的分析结果提交给法务专家,由专家做出"批准"、"拒绝"或"要求修改"的最终决定。
  • 后续动作:系统根据专家的决策,自动进入不同的处理流程(如生成公文、起草拒绝函等)。

你会发现它变成了一个多步骤、有状态、且需要协调AI自动化和人类决策的复杂流程。这就是我们常说

的 AI工作流(AI Workflow)或智能体(AI Agent) 要处理的核心问题。

在没有Graph之前,实现这样的流程会非常麻烦,需要我们写很多硬编码去编排实现。

而 Spring AI Alibaba Graph 就是为了优雅地解决这些问题而生的。它是一个强大的工作流编排引擎,让

你能像画流程图一样,直观地定义和执行复杂的AI应用流程

State (状态) :负责在不同步骤间安全地传递和共享数据的容器

Node (节点) :代表一个执行单元,可以是一个AI调用、一个数据库操作,或一段业务逻辑

Edge (边):定义节点之间的连接关系和流转方向。可以是固定的简单边,也可以是能根据State内

容决定下一步的条件边。

2.6.2、快速入门

  1. pom依赖
xml 复制代码
<dependencies>
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>
    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-starter-model-zhipuai</artifactId>
    </dependency>
    <dependency>
        <groupId>com.alibaba.cloud.ai</groupId>
        <artifactId>spring-ai-alibaba-graph-core</artifactId>
    </dependency>
</dependencies>
  1. application.yml配置
yaml 复制代码
# 服务器配置
server:
  port: 8889 # 应用服务监听端口,默认8080

spring:
# 应用基本信息配置
  application:
    name: sangeng-graph # 应用名称
  # Spring AI 配置
  ai: # 智谱AI大模型配置
    zhipuai:
      api-key: ${ZHIPU_KEY} # 智谱API密钥(从环境变量ZHIPU_KEY获取)
      chat:
        options:
          model: glm-4.6 # 使用的聊天模型名称(GLM-4.6)
  1. 定义状态图
java 复制代码
package com.example.config;

import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.Map;


@Configuration
public class GraphConfiguration {
    private static final Logger log = LoggerFactory.getLogger(GraphConfiguration.class);

    @Bean("quickStartGraph")
    public CompiledGraph quickStartGraph() throws GraphStateException {
        // 定义状态
        KeyStrategyFactory keyStrategyFactory = new KeyStrategyFactory() {
            @Override
            public Map<String, KeyStrategy> apply() {
                return Map.of("input1", new ReplaceStrategy(), "input2", new ReplaceStrategy());
            }
        };
        StateGraph stateGraph = new StateGraph("quickStartGraph", keyStrategyFactory);

        // 定义节点
        stateGraph.addNode("node1", AsyncNodeAction.node_async(new NodeAction() {
            @Override
            public Map<String, Object> apply(OverAllState state) throws Exception {
                log.info("node1 state:{}", state);
                return Map.of("input1", 1, "input2", 1);
            }
        }));

        stateGraph.addNode("node2", AsyncNodeAction.node_async(new NodeAction() {
            @Override
            public Map<String, Object> apply(OverAllState state) throws Exception {
                log.info("node2 state:{}", state);
                return Map.of("input1", 2, "input2", 2);
            }
        }));

        // 定义边
        stateGraph.addEdge(StateGraph.START, "node1");
        stateGraph.addEdge("node1", "node2");
        stateGraph.addEdge("node2", StateGraph.END);
        return stateGraph.compile();
    }

}
  1. 测试调用
java 复制代码
package com.example.controller;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import java.util.Map;
import java.util.Optional;


@RestController
@RequestMapping("/graph")
public class GraphController {
    private static final Logger log = LoggerFactory.getLogger(GraphController.class);
    private final CompiledGraph graph;

    public GraphController(CompiledGraph graph) {
        this.graph = graph;
    }

    @GetMapping("/test")
    public Optional<OverAllState> test() {
        Optional<OverAllState> call = graph.call(Map.of());
        log.info("test call:{}", call);
        return call;
    }
}

2.6.3、API概述

2.6.3.1、KeyStrategyFactory

用来定义图中的状态有哪些数据,并且定义这些数据的更新策略是什么。

有三种策略分别是:替换(ReplaceStrategy),合并(MergeStrategy),追加(AppendStrategy)

java 复制代码
KeyStrategyFactory keyStrategyFactory = () -> { 
    HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
    keyStrategyHashMap.put("input1", new ReplaceStrategy());
    keyStrategyHashMap.put("input2", new MergeStrategy());
    keyStrategyHashMap.put("input3", new AppendStrategy()); 
    return keyStrategyHashMap; 
};

替换(ReplaceStrategy): 新值替换掉老值 (常用,更加灵活)

合并(MergeStrategy):适合Map类型的数据,新老Map的数据和合并

追加(AppendStrategy):适合List类型的数据,新List的数据追加到老List中

2.6.3.2、AsyncNodeAction&NodeAction

NodeAction 是Graph中对节点的抽象。我们只需要实现NodeAction接口,在apply方法中定义节点的

执行逻辑即可。

java 复制代码
@FunctionalInterface
public interface NodeAction {
    Map<String, Object> apply(OverAllState state) throws Exception; 
}

AsyncNodeAction 异步节点,提供了一个静态方法可以NodeAction转化成 AsyncNodeAction

java 复制代码
/*
 * Copyright 2024-2025 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.alibaba.cloud.ai.graph.action;

import com.alibaba.cloud.ai.graph.OverAllState;
import io.opentelemetry.context.Context;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;

/**
 * Represents an asynchronous node action that operates on an agent state and returns
 * state update.
 *
 */
@FunctionalInterface
public interface AsyncNodeAction extends Function<OverAllState, CompletableFuture<Map<String, Object>>> {

	/**
	 * Applies this action to the given agent state.
	 * @param state the agent state
	 * @return a CompletableFuture representing the result of the action
	 */
	CompletableFuture<Map<String, Object>> apply(OverAllState state);

	/**
	 * Creates an asynchronous node action from a synchronous node action.
	 * @param syncAction the synchronous node action
	 * @return an asynchronous node action
	 */
	static AsyncNodeAction node_async(NodeAction syncAction) {
		return state -> {
			Context context = Context.current();
			CompletableFuture<Map<String, Object>> result = new CompletableFuture<>();
			try {
				result.complete(syncAction.apply(state));
			}
			catch (Exception e) {
				result.completeExceptionally(e);
			}
			return result;
		};
	}

}
2.6.3.4、StateGraph

状态图的抽象,需要配置状态(通过KeyStrategyFactory ),节点,边。

配置好后通过compile方法编译成CompiledGraph后才可以供调用。

2.6.3.5、CompiledGraph

CompiledGraph是StateGraph编译后的结果,CompiledGraph才能用了执行。

一般我们是把StateGraph定义好后调用其compile方法得到一个CompiledGraph放入Spring容器中。

然后在需要的时候从容器中注入然后再调用。

2.6.4、英语造句翻译小助手实战

  1. 定义造句node节点
java 复制代码
package com.example.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;

import java.util.Map;

public class SentenceConstructionNode implements NodeAction {

    private final ChatClient chatClient;

    public SentenceConstructionNode(ChatClient.Builder builder) {
        this.chatClient = builder.build();
    }

    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        // 获取状态图中的word
        String word = state.value("word", "");

        // 调用模型进行翻译
        PromptTemplate promptTemplate = new PromptTemplate("你是一个英语造句专家,能 够基于给定的单词进行造句。"
                + "要求只返回最终造好的句子,不要返回其他信息。 给定的单词:{word}");
        promptTemplate.add("word", word);
        String render = promptTemplate.render();
        String sentence = chatClient.prompt().user(render).call().content();


        // 把翻译结果放到状态中
        return Map.of("sentence", sentence);
    }
}
  1. 定义翻译node节点
java 复制代码
package com.example.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;

import java.util.Map;

public class TranslationNode implements NodeAction {

    private final ChatClient chatClient;

    public TranslationNode(ChatClient.Builder builder) {
        this.chatClient = builder.build();
    }

    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        // 从状态中获取sentence
        String sentence = state.value("sentence", "");

        // 调用 大模型 翻译
        PromptTemplate promptTemplate = new PromptTemplate("你是一个英语翻译专家,能 够对句子进行翻译。" +
                "要求只返回翻译的结果不要返回其他信息。要翻译的句子:{sentence}");
        promptTemplate.add("sentence", sentence);
        String render = promptTemplate.render();
        String translation = chatClient.prompt().user(render).call().content();

        // 返回翻译结果
        return Map.of("translation", translation);
    }
}
  1. 定义Graph
java 复制代码
package com.example.config;

import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.example.node.SentenceConstructionNode;
import com.example.node.TranslationNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.Map;


@Configuration
public class GraphConfiguration {
    private static final Logger log = LoggerFactory.getLogger(GraphConfiguration.class);

    @Bean("quickStartGraph")
    public CompiledGraph quickStartGraph() throws GraphStateException {
        // 定义状态
        KeyStrategyFactory keyStrategyFactory = new KeyStrategyFactory() {
            @Override
            public Map<String, KeyStrategy> apply() {
                return Map.of("input1", new ReplaceStrategy(), "input2", new ReplaceStrategy());
            }
        };
        StateGraph stateGraph = new StateGraph("quickStartGraph", keyStrategyFactory);

        // 定义节点
        stateGraph.addNode("node1", AsyncNodeAction.node_async(new NodeAction() {
            @Override
            public Map<String, Object> apply(OverAllState state) throws Exception {
                log.info("node1 state:{}", state);
                return Map.of("input1", 1, "input2", 1);
            }
        }));

        stateGraph.addNode("node2", AsyncNodeAction.node_async(new NodeAction() {
            @Override
            public Map<String, Object> apply(OverAllState state) throws Exception {
                log.info("node2 state:{}", state);
                return Map.of("input1", 2, "input2", 2);
            }
        }));

        // 定义边
        stateGraph.addEdge(StateGraph.START, "node1");
        stateGraph.addEdge("node1", "node2");
        stateGraph.addEdge("node2", StateGraph.END);
        return stateGraph.compile();
    }

    @Bean("simpleGraph")
    public CompiledGraph simpleGraph(ChatClient.Builder chatClient) throws GraphStateException {
        // 定义状态
        KeyStrategyFactory keyStrategyFactory = new KeyStrategyFactory() {
            @Override
            public Map<String, KeyStrategy> apply() {
                return Map.of("word", new ReplaceStrategy());
            }
        };

        StateGraph stateGraph = new StateGraph("simpleGraph", keyStrategyFactory);

        // 定义节点
        stateGraph.addNode("SentenceConstructionNode", AsyncNodeAction.node_async(new SentenceConstructionNode(chatClient)));
        stateGraph.addNode("TranslationNode", AsyncNodeAction.node_async(new TranslationNode(chatClient)));

        // 定义边
        stateGraph.addEdge(StateGraph.START, "SentenceConstructionNode");
        stateGraph.addEdge("SentenceConstructionNode", "TranslationNode");
        stateGraph.addEdge("TranslationNode", StateGraph.END);

        return stateGraph.compile();
    }

}
  1. controller测试调用
java 复制代码
package com.example.controller;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.Map;
import java.util.Optional;


@RestController
@RequestMapping("/graph")
public class GraphController {
    private static final Logger log = LoggerFactory.getLogger(GraphController.class);

    private final CompiledGraph quickStartGraph;
    private final CompiledGraph simpleGraph;

    public GraphController(@Qualifier("quickStartGraph") CompiledGraph quickStartGraph,
                           @Qualifier("simpleGraph") CompiledGraph simpleGraph) {
        this.quickStartGraph = quickStartGraph;
        this.simpleGraph = simpleGraph;
    }

    @GetMapping("/test")
    public Optional<OverAllState> test() {
        Optional<OverAllState> call = quickStartGraph.call(Map.of());
        log.info("test call:{}", call);
        return call;
    }

    @GetMapping("simpleGraph")
    public Map<String, Object> simpleGraph(@RequestParam("word") String word) {
        Optional<OverAllState> call = simpleGraph.call(Map.of("word", word));
        return call.map(OverAllState::data).orElse(Map.of());
    }
}

2.6.5、条件边-笑话生成实战

功能:生成笑话→笑话评估→优秀则结束,不够优秀则优化笑话在结束

  1. 生成笑话node
java 复制代码
package com.example.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;

import java.util.Map;

public class GenerateJokeNode implements NodeAction {

    private final ChatClient chatClient;

    public GenerateJokeNode(ChatClient.Builder builder) {
        this.chatClient = builder.build();
    }


    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        // 通过状态获取主题
        String topic = state.value("topic", "");

        // 通过大模型生成笑话
        PromptTemplate promptTemplate = new PromptTemplate("你需要写一个关于指定主题 的短笑话。要求返回的结果中只能包含笑话的内容" + "主题:{topic}");
        promptTemplate.add("topic", topic);
        String render = promptTemplate.render();
        String joke = chatClient.prompt().user(render).call().content();

        // 返回笑话
        return Map.of("joke", joke);
    }
}
  1. 评估笑话node
java 复制代码
package com.example.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;

import java.util.Map;

public class EvaluateJokesNode implements NodeAction {

    private final ChatClient chatClient;

    public EvaluateJokesNode(ChatClient.Builder builder) {
        this.chatClient = builder.build();
    }


    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        // 从状态中获取joke
        String joke = state.value("joke", "");

        // 调用大模型评估joke
        PromptTemplate promptTemplate = new PromptTemplate("你是一个笑话评分专家,能 够对笑话进行评分,基于效果的搞笑程度给出0到10分的打分。然后基于评分结果进行评价。如果大于等于3分 评价:优秀 否则评价:不够优秀\n"
                + "要求结果只返回最后的评价,不要其他内容。"
                + "要求只返回翻译的结果不要返回其他信息。要评分的笑话::{joke}");
        promptTemplate.add("joke", joke);
        String render = promptTemplate.render();
        String content = chatClient.prompt().user(render).call().content();

        // 返回评估结果 包含\n需要去掉
        return Map.of("result", content.trim());
    }
}
  1. 优化笑话node
java 复制代码
package com.example.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;

import java.util.Map;

public class EnhanceJokeQualityNode implements NodeAction {

    private final ChatClient chatClient;

    public EnhanceJokeQualityNode(ChatClient.Builder builder) {
        this.chatClient = builder.build();
    }


    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        // 通过状态获取笑话
        String joke = state.value("joke", "");

        // 调用大模型进行优化笑话
        PromptTemplate promptTemplate = new PromptTemplate("你是一个笑话优化专家,你 能够优化笑话,让它更加搞笑"
                + "要求只返回翻译的结果不要返回其他信息。要优化的笑话:{joke}");
        promptTemplate.add("joke", joke);
        String render = promptTemplate.render();
        String content = chatClient.prompt().user(render).call().content();

        // 返回重新生成笑话的结果
        return Map.of("newJoke", content);
    }
}
  1. graph配置
java 复制代码
package com.example.config;

import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.example.node.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.Map;


@Configuration
public class GraphConfiguration {
    private static final Logger log = LoggerFactory.getLogger(GraphConfiguration.class);

    // 生成笑话→评估笑话→优秀则结束, 不够优秀则优化笑话然后在结束
    @Bean("conditionalGraph")
    public CompiledGraph conditionalGraph(ChatClient.Builder chatClient) throws GraphStateException {
        // 定义状态图
        KeyStrategyFactory keyStrategyFactory = new KeyStrategyFactory() {
            @Override
            public Map<String, KeyStrategy> apply() {
                return Map.of("topic", new ReplaceStrategy());
            }
        };
        StateGraph stateGraph = new StateGraph("conditionalGraph", keyStrategyFactory);

        // 定义节点
        stateGraph.addNode("生成笑话", AsyncNodeAction.node_async(new GenerateJokeNode(chatClient)));
        stateGraph.addNode("评估笑话", AsyncNodeAction.node_async(new EvaluateJokesNode(chatClient)));
        stateGraph.addNode("优化笑话", AsyncNodeAction.node_async(new EnhanceJokeQualityNode(chatClient)));

        // 定义边
        stateGraph.addEdge(StateGraph.START, "生成笑话");
        stateGraph.addEdge("生成笑话", "评估笑话");
        stateGraph.addConditionalEdges("评估笑话", AsyncEdgeAction.edge_async(new EdgeAction() {

            // 获取评估结果
            @Override
            public String apply(OverAllState state) throws Exception {
                return state.value("result", "优秀");
            }
        }), Map.of("优秀", StateGraph.END, "不够优秀", "优化笑话"));
        stateGraph.addEdge("优化笑话", StateGraph.END);

        // 编译返回
        return stateGraph.compile();
    }

}
  1. controller测试
java 复制代码
package com.example.controller;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.Map;
import java.util.Optional;


@RestController
@RequestMapping("/graph")
public class GraphController {
    private static final Logger log = LoggerFactory.getLogger(GraphController.class);

    private final CompiledGraph quickStartGraph;
    private final CompiledGraph simpleGraph;
    private final CompiledGraph conditionalGraph;

    public GraphController(@Qualifier("quickStartGraph") CompiledGraph quickStartGraph,
                           @Qualifier("simpleGraph") CompiledGraph simpleGraph,
                           @Qualifier("conditionalGraph") CompiledGraph conditionalGraph) {
        this.quickStartGraph = quickStartGraph;
        this.simpleGraph = simpleGraph;
        this.conditionalGraph = conditionalGraph;
    }


    @GetMapping("conditionalGraph")
    public Map<String, Object> conditionalGraph(@RequestParam("topic") String topic) {
        Optional<OverAllState> call = conditionalGraph.call(Map.of("topic", topic));
        return call.map(OverAllState::data).orElse(Map.of());
    }
}

2.6.6、循环-优化笑话直到达到固定分数或循环次数

功能:生成笑话→评估笑话→如果大于6分或者循环3次则结束,否则优化笑话 、优化笑话→评估笑话

  1. 生成笑话还用上一个
  2. 循环评估笑话(只需要生成提示词以及判断修改)
java 复制代码
package com.example.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.example.controller.GraphController;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;

import java.util.Map;

public class LoopEvaluateJokesNode implements NodeAction {

    private static final Logger log = LoggerFactory.getLogger(LoopEvaluateJokesNode.class);

    private final ChatClient chatClient;
    private final Integer targetScore;
    private final Integer maxLoopScore;

    public LoopEvaluateJokesNode(ChatClient.Builder builder, Integer targetScore, Integer maxLoopScore) {
        this.chatClient = builder.build();
        this.targetScore = targetScore;
        this.maxLoopScore = maxLoopScore;
    }


    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        // 从状态中获取joke
        String joke = state.value("joke", "");
        Integer loopCount = state.value("loopCount", 1);

        // 调用大模型评估joke
        PromptTemplate promptTemplate = new PromptTemplate("你是一个笑话评分专家,能 够对笑话进行评分,基于效果的搞笑程度给出0到10分的打分。要求打分只能是整数\n"
                + "要求结果只返回最后的打分,不要其他内容。"
                + "要评分的笑话::{joke}");
        promptTemplate.add("joke", joke);
        String render = promptTemplate.render();
        String content = chatClient.prompt().user(render).call().content();

        // 返回评估结果 包含\n需要去掉
        Integer score = Integer.valueOf(content.trim());
        log.info("score: {}, loopCount:{}, joke:{}", score, loopCount, joke);

        String result = "loop";
        // 分数大于6分 或 循环次数 > 3次
        if (score > targetScore || loopCount >= maxLoopScore) {
            result = "break";
        }
        loopCount++;
        return Map.of("result", result, "lookCount", loopCount, "score", score);
    }
}
  1. 循环优化笑话
java 复制代码
package com.example.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;

import java.util.Map;

public class LoopEnhanceJokeQualityNode implements NodeAction {

    private final ChatClient chatClient;

    public LoopEnhanceJokeQualityNode(ChatClient.Builder builder) {
        this.chatClient = builder.build();
    }


    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        // 通过状态获取笑话
        String joke = state.value("joke", "");

        // 调用大模型进行优化笑话
        PromptTemplate promptTemplate = new PromptTemplate("你是一个笑话优化专家,你 能够优化笑话,让它更加搞笑"
                + "要求只返回翻译的结果不要返回其他信息。要优化的笑话:{joke}");
        promptTemplate.add("joke", joke);
        String render = promptTemplate.render();
        String content = chatClient.prompt().user(render).call().content();

        // 返回重新生成笑话的结果
        return Map.of("joke", content);
    }
}
  1. graph
java 复制代码
package com.example.config;

import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.example.node.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.Map;


@Configuration
public class GraphConfiguration {
    private static final Logger log = LoggerFactory.getLogger(GraphConfiguration.class);

    @Bean("loopGraph")
    public CompiledGraph loopGraph(ChatClient.Builder chatClient) throws GraphStateException {
        // 定义状态
        KeyStrategyFactory keyStrategyFactory = new KeyStrategyFactory() {
            @Override
            public Map<String, KeyStrategy> apply() {
                return Map.of("topic", new ReplaceStrategy());
            }
        };
        StateGraph stateGraph = new StateGraph("loopGraph", keyStrategyFactory);

        // 定义节点
        stateGraph.addNode("生成笑话", AsyncNodeAction.node_async(new GenerateJokeNode(chatClient)));
        stateGraph.addNode("评估笑话", AsyncNodeAction.node_async(new LoopEvaluateJokesNode(chatClient, 6, 3)));
        stateGraph.addNode("优化笑话", AsyncNodeAction.node_async(new LoopEnhanceJokeQualityNode(chatClient)));

        // 定义边
        stateGraph.addEdge(StateGraph.START, "生成笑话");
        stateGraph.addEdge("生成笑话", "评估笑话");
        stateGraph.addConditionalEdges("评估笑话", AsyncEdgeAction.edge_async(new EdgeAction() {
            @Override
            public String apply(OverAllState state) throws Exception {
                return state.value("result", "loop");
            }
        }), Map.of("loop", "优化笑话", "break", StateGraph.END));
        stateGraph.addEdge("优化笑话", "评估笑话");
        return stateGraph.compile();
    }
}
  1. controller测试
java 复制代码
package com.example.controller;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.Map;
import java.util.Optional;


@RestController
@RequestMapping("/graph")
public class GraphController {
    private static final Logger log = LoggerFactory.getLogger(GraphController.class);

    private final CompiledGraph quickStartGraph;
    private final CompiledGraph simpleGraph;
    private final CompiledGraph conditionalGraph;
    private final CompiledGraph loopGraph;

    public GraphController(@Qualifier("quickStartGraph") CompiledGraph quickStartGraph,
                           @Qualifier("simpleGraph") CompiledGraph simpleGraph,
                           @Qualifier("conditionalGraph") CompiledGraph conditionalGraph,
                           @Qualifier("loopGraph") CompiledGraph loopGraph) {
        this.quickStartGraph = quickStartGraph;
        this.simpleGraph = simpleGraph;
        this.conditionalGraph = conditionalGraph;
        this.loopGraph = loopGraph;
    }

    @GetMapping("loopGraph")
    public Map<String, Object> loopGraph(@RequestParam("topic") String topic) {
        Optional<OverAllState> call = loopGraph.call(Map.of("topic", topic));
        return call.map(OverAllState::data).orElse(Map.of());
    }
}

2.6.7、保存状态

  1. graph
java 复制代码
package com.example.config;

import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.example.node.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.ArrayList;
import java.util.Map;


@Configuration
public class GraphConfiguration {
    private static final Logger log = LoggerFactory.getLogger(GraphConfiguration.class);

    @Bean("saveGraph")
    public CompiledGraph saveGraph(ChatClient.Builder chatClient) throws GraphStateException {
        // 定义状态图
        KeyStrategyFactory keyStrategyFactory = Map::of;
        StateGraph stateGraph = new StateGraph("saveGraph", keyStrategyFactory);
        stateGraph.addNode("对话存储", AsyncNodeAction.node_async(new NodeAction() {
            @Override
            public Map<String, Object> apply(OverAllState state) throws Exception {
                String msg = state.value("msg", "");
                ArrayList<Object> historyMsg = state.value("historyMsg", new ArrayList<>());
                historyMsg.add(msg);
                return Map.of("historyMsg", historyMsg);
            }
        }));
        // 定义边
        stateGraph.addEdge(StateGraph.START, "对话存储");
        stateGraph.addEdge("对话存储", StateGraph.END);
        // 添加PlantUML打印
        GraphRepresentation graphRepresentation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML, "saveGraph");
        log.info("\n===expander UML FLOW ===");
        log.info(graphRepresentation.content());
        log.info("==========================\n");
        return stateGraph.compile();
    }
}
  1. controller测试
java 复制代码
package com.example.controller;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.Map;
import java.util.Optional;


@RestController
@RequestMapping("/graph")
public class GraphController {
    private static final Logger log = LoggerFactory.getLogger(GraphController.class);

    private final CompiledGraph quickStartGraph;
    private final CompiledGraph simpleGraph;
    private final CompiledGraph conditionalGraph;
    private final CompiledGraph loopGraph;
    private final CompiledGraph saveGraph;

    public GraphController(@Qualifier("quickStartGraph") CompiledGraph quickStartGraph,
                           @Qualifier("simpleGraph") CompiledGraph simpleGraph,
                           @Qualifier("conditionalGraph") CompiledGraph conditionalGraph,
                           @Qualifier("loopGraph") CompiledGraph loopGraph,
                           @Qualifier("saveGraph") CompiledGraph saveGraph) {
        this.quickStartGraph = quickStartGraph;
        this.simpleGraph = simpleGraph;
        this.conditionalGraph = conditionalGraph;
        this.loopGraph = loopGraph;
        this.saveGraph = saveGraph;
    }

    @GetMapping("saveGraph")
    public Map<String, Object> saveGraph(@RequestParam("msg") String msg,
                                         @RequestParam("conversationId") String conversationId) {
        RunnableConfig runnableConfig = RunnableConfig.builder().threadId(conversationId).build();
        Optional<OverAllState> call = saveGraph.call(Map.of("msg", msg), runnableConfig);
        return call.map(OverAllState::data).orElse(Map.of());
    }
}

可视化图地址:http://www.plantuml.com/plantuml/uml/SyfFKj2rKt3CoKnELR1Io4ZDoSa70000

相关推荐
迅利科技1 小时前
CATIA:高端制造的“数字母体”
人工智能·科技·制造
计算机安禾1 小时前
【c++面向对象编程】第5篇:类与对象(四):赋值运算符重载
java·前端·c++
Honey Ro1 小时前
pytorch中的损失函数使用
人工智能·pytorch·深度学习
weixin_435208161 小时前
大模型 Agent 面试高频100题——基础篇
人工智能·深度学习·自然语言处理·面试·职场和发展·aigc
青稞社区.1 小时前
OpenAI 翁家翌:“启发式学习”的强化学习新范式
人工智能·经验分享·学习·agi
QYR-分析1 小时前
全球及中国固定翼无人机光电吊舱行业发展现状与前景分析
人工智能·无人机
AI人工智能+电脑小能手1 小时前
【大白话说Java面试题 第45题】【JVM篇】第5题:JVM中,对象何时会进入老年代?
java·开发语言·jvm·后端·面试
扬帆破浪1 小时前
免费开源AI软件.桌面单机版,可移动的AI知识库,察元 AI桌面版:公司只允许装签名应用 给察元AI打企业内部分发包
人工智能·windows·电脑·知识图谱
luck_bor1 小时前
使用接口定义规范,实现类完成具体逻辑
java·开发语言