代码地址:https://gitee.com/CodeMao01/springai-examples
一、前期准备及概述
智谱大模型,注册并创建好ApiKey:智谱AI开放平台
可以通过ApiFox或Postmant调用下模型API使用概述 - 智谱AI开放文档
Sping Ai Alibaba是在Spring AI的基础上增加了Graph,可以用来进行多智能体编排及工作流** (Multi-Agent/Workflow)**
版本说明:
JDK:17 (官方最低要求17)
SpringAI :1.0.0
SpringBoot : 3.4.0
Spring AI Alibaba : 1.0.0.4
二、快速入门
父pom
xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<packaging>pom</packaging>
<modules>
<module>01_springai-alibaba-quick-start</module>
</modules>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.5.13</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.example</groupId>
<artifactId>springai-examples</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>springai-examples</name>
<description>springai-examples</description>
<url/>
<licenses>
<license/>
</licenses>
<developers>
<developer/>
</developers>
<scm>
<connection/>
<developerConnection/>
<tag/>
<url/>
</scm>
<properties>
<java.version>17</java.version>
<project.version>1.0-SNAPSHOT</project.version>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<spring-ai.version>1.0.0</spring-ai.version>
<spring-ai-alibaba.version>1.0.0.4</spring-ai-alibaba.version>
<!-- Spring Boot -->
<spring-boot.version>3.4.0</spring-boot.version>
<!-- maven plugin -->
<maven-deploy-plugin.version>3.1.1</maven-deploy-plugin.version>
<flatten-maven-plugin.version>1.3.0</flatten-maven-plugin.version>
<maven-compiler-plugin.version>3.8.1</maven-compiler-plugin.version>
</properties>
<!-- <dependencies>-->
<!-- <dependency>-->
<!-- <groupId>org.springframework.boot</groupId>-->
<!-- <artifactId>spring-boot-starter</artifactId>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>org.springframework.boot</groupId>-->
<!-- <artifactId>spring-boot-starter-test</artifactId>-->
<!-- <scope>test</scope>-->
<!-- </dependency>-->
<!-- </dependencies>-->
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-dependencies</artifactId>
<version>${spring-boot.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-bom</artifactId>
<version>${spring-ai.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-bom</artifactId>
<version>${spring-ai-alibaba.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<!-- <build>-->
<!-- <plugins>-->
<!-- <plugin>-->
<!-- <groupId>org.springframework.boot</groupId>-->
<!-- <artifactId>spring-boot-maven-plugin</artifactId>-->
<!-- </plugin>-->
<!-- </plugins>-->
<!-- </build>-->
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<version>${spring-boot.version}</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-deploy-plugin</artifactId>
<version>${maven-deploy-plugin.version}</version>
<configuration>
<skip>true</skip>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>${maven-compiler-plugin.version}</version>
<configuration>
<release>17</release>
<compilerArgs>
<compilerArg>-parameters</compilerArg>
</compilerArgs>
</configuration>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>flatten-maven-plugin</artifactId>
<version>${flatten-maven-plugin.version}</version>
<inherited>true</inherited>
<executions>
<execution>
<id>flatten</id>
<phase>process-resources</phase>
<goals>
<goal>flatten</goal>
</goals>
<configuration>
<updatePomFile>true</updatePomFile>
<flattenMode>ossrh</flattenMode>
<pomElements>
<distributionManagement>remove</distributionManagement>
<dependencyManagement>remove</dependencyManagement>
<repositories>remove</repositories>
<scm>keep</scm>
<url>keep</url>
<organization>resolve</organization>
</pomElements>
</configuration>
</execution>
<execution>
<id>flatten.clean</id>
<phase>clean</phase>
<goals>
<goal>clean</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
子pom-01_springai-alibaba-quick-start
xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>com.example</groupId>
<artifactId>springai-examples</artifactId>
<version>0.0.1-SNAPSHOT</version>
<!-- relativePath:告诉IDEA父pom在上级目录,本地查找,不用去远程仓库拉父工程 -->
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>com.example</groupId>
<artifactId>01_springai-alibaba-quick-start</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>01_springai-alibaba-quick-start</name>
<description>01_springai-alibaba-quick-start</description>
<url/>
<licenses>
<license/>
</licenses>
<developers>
<developer/>
</developers>
<scm>
<connection/>
<developerConnection/>
<tag/>
<url/>
</scm>
<properties>
<java.version>17</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-zhipuai</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
子工程-01_springai-alibaba-quick-start的配置文件
yaml
server:
port: 8080
servlet:
encoding:
charset: UTF-8
enabled: true
force: true
spring:
application:
name: 01_springai-alibaba-quick-start
ai:
zhipuai:
api-key: ${ZHIPU_KEY} # 配置 API Key
base-url: "https://open.bigmodel.cn/api/paas" # 配置 模型地址
chat:
options:
model: glm-4.5
2.1、chatModel
call的三个参数及流式输出:
参数一:直接传用户问题
参数二:SystemMessage + UserMessage
参数三:Prompt = Message + ChatOptions
流式输出:可以理解为一个字一个字蹦出来?
java
package com.example.controller;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;
import java.util.List;
@RestController
@RequestMapping("/zhipuai")
public class ZhipuController {
private final ChatModel chatModel;
// 通过构造器的方注入ChatModel
public ZhipuController(ChatModel chatModel) {
this.chatModel = chatModel;
}
@GetMapping("/simple")
public String simpleChat(@RequestParam(name = "query") String query) {
return chatModel.call(query);
}
@GetMapping("/message")
public String messageChat(@RequestParam(name = "query") String query) {
SystemMessage systemMessage = new SystemMessage("你是一个AI助手");
UserMessage userMessage = new UserMessage(query);
return chatModel.call(systemMessage, userMessage);
}
/**
* Prompt = Messages + ChatOptions
*/
@GetMapping("/chatOptions")
public String chatOptions(@RequestParam(name = "query") String query) {
SystemMessage systemMessage = new SystemMessage("你是一个AI助手");
UserMessage userMessage = new UserMessage(query);
ChatOptions chatOptions = ZhiPuAiChatOptions.builder()
.maxTokens(10000)
.model("glm-4.5")
.temperature(0.9)
.build();
Prompt prompt = new Prompt(List.of(systemMessage, userMessage), chatOptions);
ChatResponse chatResponse = chatModel.call(prompt);
return chatResponse.getResult().getOutput().getText();
}
@GetMapping("/stream/chat")
public Flux<String> streamChat(@RequestParam(name = "query") String query) {
return chatModel.stream(query);
}
}
2.2、chatClient
比chatModel更加常用,增加了一些功能,比如advisor等
2.2.1、基本使用
java
package com.example.controller;
import com.example.advisor.SGCallAdvisor1;
import com.example.advisor.SGCallAdvisor2;
import com.example.advisor.SimpleMessageChatMemoryAdvisor;
import com.example.entity.Book;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;
import java.util.List;
import java.util.function.Consumer;
@RestController
@RequestMapping("/chatClient")
public class ZhipuChatClientController {
private final ChatClient chatClient;
public ZhipuChatClientController(ChatClient.Builder builder) {
this.chatClient = builder.build();
}
@GetMapping("/simple")
public String simpleChat(@RequestParam(name = "query") String query) {
SystemMessage systemMessage = new SystemMessage("你是一个AI助手");
UserMessage userMessage = new UserMessage(query);
ChatOptions chatOptions = ChatOptions.builder()
.maxTokens(10000)
.model("glm-4.5")
.temperature(0.9)
.build();
Prompt prompt = new Prompt(List.of(systemMessage, userMessage), chatOptions);
return chatClient.prompt(prompt).call().content();
}
@GetMapping("/simple2")
public ChatResponse simpleChat2(@RequestParam(name = "query") String query) {
ChatOptions chatOptions = ChatOptions.builder()
.maxTokens(10000)
.model("glm-4.5")
.temperature(0.9)
.build();
return chatClient.prompt()
.system("你是一个AI助手")
.user(query)
.options(chatOptions).call().chatResponse();
}
@GetMapping("/response")
public Book response() {
return chatClient.prompt()
.user("给我随机生成一本书,要求书名和作者都是中文")
.call()
.entity(Book.class);
}
@GetMapping("/stream")
public Flux<String> stream() {
return chatClient.prompt()
.user("给我随机生成一本书,要求书名和作者都是中文")
.stream().content();
}
}
2.2.2、advisor
java
@GetMapping("/advisor")
public String advisor() {
return chatClient.prompt().user("你是谁?")
.advisors(new SGCallAdvisor2(), new SGCallAdvisor1()).call().content();
}
SGCallAdvisor1:
java
package com.example.advisor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
@Slf4j
public class SGCallAdvisor1 implements CallAdvisor {
@Override
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
log.info("SGCallAdvisor1-adviseCall start");
ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);
log.info("SGCallAdvisor1-adviseCall end");
return chatClientResponse;
}
@Override
public String getName() {
return "SGCallAdvisor1-name";
}
@Override
public int getOrder() {
return 0;
}
}
SGCallAdvisor2:
java
package com.example.advisor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
@Slf4j
public class SGCallAdvisor2 implements CallAdvisor {
@Override
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
log.info("SGCallAdvisor2-adviseCall start");
ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);
log.info("SGCallAdvisor2-adviseCall end");
return chatClientResponse;
}
@Override
public String getName() {
return "SGCallAdvisor2-name";
}
@Override
public int getOrder() {
return 1;
}
}
2.2.3、chatMemory
2.2.3.1、自定义chatMemory
java
package com.example.advisor;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.util.CollectionUtils;
import java.util.*;
public class SimpleMessageChatMemoryAdvisor implements BaseAdvisor {
private static Map<String, List<Message>> chatMemory = new HashMap<>();
@Override
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
String conversationId = chatClientRequest.context().get("conversationId").toString();
// 获取当前会话的id对应的Message
List<Message> hisMessages = chatMemory.get(conversationId);
if (CollectionUtils.isEmpty(hisMessages)) {
hisMessages = new ArrayList<>();
}
// 获取当前会话的message并添加到chatMemory中
hisMessages.addAll(chatClientRequest.prompt().getInstructions());
// 把最新的chatMessage放到当前的request中
// 注意request是只读的 只能新建
Prompt newPrompt = chatClientRequest.prompt().mutate()
.messages(hisMessages)
.build();
ChatClientRequest newRequest = chatClientRequest.mutate()
.prompt(newPrompt)
.build();
// 把消息存储到内存中
chatMemory.put(conversationId, hisMessages);
return newRequest;
}
@Override
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
// 通过获取id获取记忆的会话内容
String conversationId = chatClientResponse.context().get("conversationId").toString();
List<Message> hisMessages = chatMemory.get(conversationId);
if (CollectionUtils.isEmpty(hisMessages)) {
hisMessages = new ArrayList<>();
}
// 获取LLM响应结果并添加到chatMemory中
if (Objects.isNull(chatClientResponse)) {
return chatClientResponse;
}
AssistantMessage output = chatClientResponse.chatResponse().getResult().getOutput();
hisMessages.add(output);
chatMemory.put(conversationId, hisMessages);
return chatClientResponse;
}
@Override
public int getOrder() {
return 0;
}
}
java
@GetMapping("/simpleMessageChatMemoryAdvisor")
public String simpleMessageChatMemoryAdvisor(@RequestParam(name = "query") String query,
@RequestParam(name = "conversationId") String conversationId) {
return chatClient.prompt().user(query)
.advisors(new Consumer<ChatClient.AdvisorSpec>() {
@Override
public void accept(ChatClient.AdvisorSpec advisorSpec) {
advisorSpec.param("conversationId", conversationId);
}
})
.advisors(new SimpleMessageChatMemoryAdvisor()).call().content();
}
2.2.3.2、使用SpringAi的chatMemory
java
package com.example.controller;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.Map;
import java.util.function.Consumer;
@RestController
@RequestMapping("/chatMemory")
public class ZhipuChatMemoryController {
private final ChatClient chatClient;
public ZhipuChatMemoryController(ChatClient.Builder builder) {
MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder().build();
// 构造聊天记忆拦截器
MessageChatMemoryAdvisor chatMemoryAdvisor = MessageChatMemoryAdvisor
.builder(chatMemory)
.build();
this.chatClient = builder.defaultAdvisors(chatMemoryAdvisor).build();
}
@GetMapping("/chatMemoryAdivsor")
public String simpleMessageChatMemoryAdvisor(@RequestParam(name = "query") String query,
@RequestParam(name = "conversationId") String conversationId) {
return chatClient.prompt().user(query)
.advisors(new Consumer<ChatClient.AdvisorSpec>() {
@Override
public void accept(ChatClient.AdvisorSpec advisorSpec) {
advisorSpec.param(ChatMemory.CONVERSATION_ID, conversationId);
}
})
.call().content();
}
public static void main(String[] args) {
PromptTemplate userPromptTemplate = new PromptTemplate("你是一个AI人工助手, 你的名字是:{name}, 你的语言风格是:{voice}, 用户的问题是:{query}");
Message message = userPromptTemplate.createMessage(Map.of("name", "小a", "voice", "幽默", "query", "你叫什么名字?"));
System.out.println(message);
PromptTemplate systemPromptTemplate = new PromptTemplate("你是一个AI人工助手, 你的名字是:{name}, 你的语言风格是:{voice}");
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "小a", "voice", "幽默"));
System.out.println(systemMessage);
}
}
2.3、RAG
2.3.1、RAG介绍、功能以及流程
RAG:检索增强生成。通过外部数据库检索出相关的知识 ,然后把问题和相关知识一起给到大模型来让模型生成回答。
作用:当询问AI不知道的知识的时候会出现幻觉,这样的话可以减少幻觉并且可以做垂直领域,相当于连接了一个外部知识库
流程:
存储:文档切片→进行embedding→存储到向量数据库
查询:用户查询问题→转换成向量→检索向量数据库→添加到prompt→返回给大模型→用户

2.3.2、数据嵌入及相似度搜索
这里选用的是redis-stack自己开发使用
pom:
xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>com.example</groupId>
<artifactId>springai-examples</artifactId>
<version>0.0.1-SNAPSHOT</version>
<!-- relativePath:告诉IDEA父pom在上级目录,本地查找,不用去远程仓库拉父工程 -->
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>com.example</groupId>
<artifactId>02_springai-alibaba-rag</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>02_springai-alibaba-rag</name>
<description>02_springai-alibaba-rag</description>
<url/>
<licenses>
<license/>
</licenses>
<developers>
<developer/>
</developers>
<scm>
<connection/>
<developerConnection/>
<tag/>
<url/>
</scm>
<properties>
<java.version>17</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-vector-store-redis</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-zhipuai</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
yaml:
yaml
# 服务器配置
server:
port: 8080 # 应用服务监听端口,默认8080
spring:
# 应用基本信息配置
application:
name: 02_springai-alibaba-rag # 应用名称
# Spring AI 配置
ai:
# 智谱AI大模型配置
zhipuai:
api-key: ${ZHIPU_KEY} # 智谱API密钥(从环境变量ZHIPU_KEY获取)
chat:
options:
model: glm-4.6 # 使用的聊天模型名称(GLM-4.6)
embedding:
options:
model: embedding-3 # 使用的嵌入模型名称(embedding-3)
dimensions: 256 # 嵌入向量的维度(256维)
# 向量存储配置
vectorstore:
redis:
initialize-schema: true # 启动时自动创建Redis向量索引结构(首次部署需开启)
prefix: sangeng_rag_prefix # Redis键名前缀,用于区分不同应用的向量数据
index: sangeng_rag_index # Redis向量索引名称
# 数据源配置
data:
redis:
host: localhost # Redis服务器地址
port: 6379 # Redis服务器连接端口(默认6379)
java:
java
package com.example.controller;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.web.bind.annotation.*;
import java.util.List;
@RestController
@RequestMapping("/redis")
public class RedisController {
private final VectorStore vectorStore;
public RedisController(VectorStore vectorStore) {
this.vectorStore = vectorStore;
}
@GetMapping("/import")
public void importData(@RequestParam(name = "content") String content) {
List<Document> documents = List.of(new Document(content));
vectorStore.add(documents);
}
@PostMapping("/query")
public List<Document> query(@RequestParam(name = "query") String query) {
SearchRequest searchRequest = SearchRequest.builder()
.query(query)
.topK(3)
// // 阈值
.similarityThreshold(0.5)
.build();
return vectorStore.similaritySearch(searchRequest);
}
}
2.3.3、咖啡店客服实战
- 引入rag以及csv的pom
xml
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-csv</artifactId>
<version>1.10.0</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-rag</artifactId>
</dependency>
- 导入咖啡店智能知识库
- 使用
RetrievalAugmentationAdvisor进行检索
RetrievalAugmentationAdvisor原理其实就是通过VectorStore查询数据并组装到userMessage中
xml
package com.example.controller;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.ClassPathResource;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
@RestController
@RequestMapping("/coffee")
public class CoffeeController {
private final VectorStore vectorStore;
private final ChatClient chatClient;
public CoffeeController(VectorStore vectorStore, ChatClient.Builder chatClient) {
this.vectorStore = vectorStore;
VectorStoreDocumentRetriever documentRetriever = VectorStoreDocumentRetriever.builder()
.vectorStore(vectorStore)
.topK(3)
.similarityThreshold(0.5)
.build();
RetrievalAugmentationAdvisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
.documentRetriever(documentRetriever)
.build();
this.chatClient = chatClient.defaultAdvisors(retrievalAugmentationAdvisor).build();
}
@GetMapping("/import")
public String importCsv() {
try {
// 读取QA文件
ClassPathResource resource = new ClassPathResource("QA.csv");
InputStreamReader reader = new InputStreamReader(resource.getInputStream());// 解析csv文件
CSVParser csvParser = CSVFormat.DEFAULT.builder()
.setHeader()// 第一行作为标题
.setSkipHeaderRecord(true)// 跳过标题行
.build()
.parse(reader);
List<Document> documents = new ArrayList<>();
// 遍历每一行数据
for (CSVRecord record : csvParser) {
// 获取问题和回答字段
String question = record.get("question");
String answer = record.get("answer");
// 将问题和回答组合成文档内容
String content = "问题:" + question + "\n答案:" + answer;
// 创建Document对象
Document document = new Document(content);
// 添加到文档列表
documents.add(document);
}
// 关闭解析器
csvParser.close();
// 将文档存入向量数据库
vectorStore.add(documents);
return "成功导入" + documents.size() + "条记录到向量数据库";
} catch (IOException e) {
e.printStackTrace();
return "导入失败:" + e.getMessage();
}
}
/**
* 新增RAG问答接口, 明确展示查询向量数据库的过程
*
* @param question 用户的问题
* @return AI基于检索到的信息生成的回答
*/
@GetMapping("/rag-ask")
public String ragAskQuestion(@RequestParam("question") String question) {
// 先从向量数据库中检索相关的信息
// 这里会使用RetrievalAugmentationAdvisor自动检索相关文档
// 将问题和检索的上下文一起发送给AI模型生成回答
return chatClient.prompt()
.system("你是三更咖啡的服务员, 你需要回答用户的问题")
.user(question)
.call()
.content();
}
}

2.4、tool Calling
含义:把工具信息发给大模型,让大模型判断是否需要调用工具,让大模型具有调用工具的能力,比如获取当前时间
流程:用户问题 + tool信息 → LLM → tool calling → LLM → 用户

2.4.1、定义工具
java
package com.example.tools;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
public class TimeTools {
@Tool(description = "通过时区id获取当前时间")
public String getTimeByZoneID(@ToolParam(description = "时区id, 比如Asia/Shanghai") String zoneID) {
ZoneId zid = ZoneId.of(zoneID);
ZonedDateTime zonedDateTime = ZonedDateTime.now(zid);
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss z");
return zonedDateTime.format(formatter);
}
}
2.4.2、把工具信息传递给大模型
java
this.chatClient = chatClient
.defaultAdvisors(retrievalAugmentationAdvisor)
.defaultTools(new TimeTools())
.build();
2.5、MCP
2.5.1、MCP概念
MCP(Model Context Protocol,模型上下文协议)是Anthropic公司推出的开放协议,旨在标准化大
语言模型与外部工具、数据源的交互方式。
官方文档:https://modelcontextprotocol.io/docs/learn/architecture
| MCP-host | 管理和协调mcp-client的AI应用程序,比如claude-code |
|---|---|
| MCP-client | 与 MCP-service连接并从 MCP 服务器获取上下文,供 MCP 主机使用的组件 |
| MCP-service | 给AI模型提供工具的服务 |
大白话:MCP就是满足AI调用的一种协议标准,这样的话大家都可以写mcp-service提供各种各样的工具,比如操作浏览器,操作飞书等等供AI调用,从而达到百花齐放的效果
2.5.2、实战-提供查询当前时间的工具
2.5.2.1、mcp-service
- pom依赖
xml
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-mcp-server-webflux</artifactId>
</dependency>
- 定义工具
java
package com.example.tools;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.stereotype.Component;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
@Component
public class TimeTools {
@Tool(description = "通过时区id获取当前时间")
public String getTimeByZoneID(@ToolParam(description = "时区id, 比如Asia/Shanghai") String zoneId) {
ZoneId zid = ZoneId.of(zoneId);
ZonedDateTime zonedDateTime = ZonedDateTime.now(zid);
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss z");
return zonedDateTime.format(formatter);
}
}
- 配置工具对外提供(
ToolCallbackProvider)
java
package com.example.config;
import com.example.tools.TimeTools;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class McpConfig {
@Bean
public ToolCallbackProvider toolCallbackProvider(TimeTools timeTools) {
return MethodToolCallbackProvider.builder().toolObjects(timeTools).build();
}
}
2.5.2.2、mcp-client
- pom依赖
xml
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-mcp-client-webflux</artifactId>
</dependency>
- application.yaml配置
yaml
spring:
ai:
mcp:
client:
sse:
connections:
server1:
url: http://localhost:8080 # mcp服务url
- 使用工具
java
this.chatClient = chatClient
.defaultAdvisors(retrievalAugmentationAdvisor)
// .defaultTools(new TimeTools())
.defaultToolCallbacks(toolCallbackProvider.getToolCallbacks())
.build();
<!-- 这是一张图片,ocr 内容为:REDISVECTORSTOREAPPLICATION X 田 线程和变量 控制台 环境 映射 运行状况 BEAN 对表达式求值(ENTER)或添加监视(CTRL+SHIFT+ENTER) THISZHIPUAIAPI@11019 API(O MESSAGES (ARRAYLIST@11212) SIZE 4... 探索元素 -1- (ZHIPURISCHACOMPLETIONMESSAGEQ11219)'CHATCOMPLETIONMESSAGELAWCONTENTECONTECONTECONTECONININININI ----------------------------------------------------------------------------------------------------- 3 3 "GLM-4-AIR MODEL STOP STREAM - BOOLEAN@11024) FALSE TEMPERATURE{DOUBLE@11214)0.7 支持依赖项'JAVALIO.PROJECTORREACTORREACTORE, SPRING'的插件 REACTIVE STREAMS,SPRING 当前尚未安装. TOPP NULL TOOLS :(ARRAYLIST@11215) SIZE ; 探索元素 不再建议 配置插件... TOOLCHOICENULL -->

2.5.3、网页爬虫MCP使用
2.5.3.1、docker部署mcp
yaml
version: "3.8"
services:
fetcher-mcp:
image: ghcr.io/jae-jae/fetcher-mcp:latest
container_name: fetcher-mcp
restart: unless-stopped
ports:
- "3000:3000"
environment:
- NODE_ENV=production
# Using host network mode on Linux hosts can improve browser access efficiency
# network_mode: "host"
volumes:
# For Playwright, may need to share certain system paths
- /tmp:/tmp
# Health check
healthcheck:
test: ["CMD", "wget", "--spider", "-q", "http://localhost:3000"]
interval: 30s
timeout: 10s
retries: 3
启动docker-compose:docker-compose up -d
2.5.3.2、添加MCP服务相关配置
yaml
mcp:
client:
sse:
connections:
server1:
url: http://localhost:8080 # mcp服务url
fetcher-mcp:
url: http://localhost:3000
java
package com.example.controller;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/external")
public class ExternalMcpServiceCall {
private final ChatClient chatClient;
public ExternalMcpServiceCall(ChatClient.Builder chatClient, ToolCallbackProvider toolCallbackProvider) {
this.chatClient = chatClient
.defaultToolCallbacks(toolCallbackProvider.getToolCallbacks())
.build();
}
@GetMapping("/fetcher")
public String fetcher(@RequestParam("question") String question) {
return chatClient.prompt()
.system("你是一个网页爬虫专家, 你可以调用工具爬取指定网页的内容并进行总结")
.user(question)
.call()
.content();
}
}
2.6、Graph
2.6.1、概述
假设你需要处理一个真实的商业需求:比如自动化审核一份商业合同。
拆分:
- 理解内容:先请AI通读合同,提炼摘要和关键条款。
- 合规检查:再让AI根据公司政策库,判断合同是否有合规风险。
- 风险评估:接着,要求AI从法律、财务等多维度给出风险评分。
- 人工介入:最后,必须将AI的分析结果提交给法务专家,由专家做出"批准"、"拒绝"或"要求修改"的最终决定。
- 后续动作:系统根据专家的决策,自动进入不同的处理流程(如生成公文、起草拒绝函等)。
你会发现它变成了一个多步骤、有状态、且需要协调AI自动化和人类决策的复杂流程。这就是我们常说
的 AI工作流(AI Workflow)或智能体(AI Agent) 要处理的核心问题。
在没有Graph之前,实现这样的流程会非常麻烦,需要我们写很多硬编码去编排实现。
而 Spring AI Alibaba Graph 就是为了优雅地解决这些问题而生的。它是一个强大的工作流编排引擎,让
你能像画流程图一样,直观地定义和执行复杂的AI应用流程
State (状态) :负责在不同步骤间安全地传递和共享数据的容器
Node (节点) :代表一个执行单元,可以是一个AI调用、一个数据库操作,或一段业务逻辑
Edge (边):定义节点之间的连接关系和流转方向。可以是固定的简单边,也可以是能根据State内
容决定下一步的条件边。
2.6.2、快速入门
- pom依赖
xml
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-zhipuai</artifactId>
</dependency>
<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-graph-core</artifactId>
</dependency>
</dependencies>
- application.yml配置
yaml
# 服务器配置
server:
port: 8889 # 应用服务监听端口,默认8080
spring:
# 应用基本信息配置
application:
name: sangeng-graph # 应用名称
# Spring AI 配置
ai: # 智谱AI大模型配置
zhipuai:
api-key: ${ZHIPU_KEY} # 智谱API密钥(从环境变量ZHIPU_KEY获取)
chat:
options:
model: glm-4.6 # 使用的聊天模型名称(GLM-4.6)
- 定义状态图
java
package com.example.config;
import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.Map;
@Configuration
public class GraphConfiguration {
private static final Logger log = LoggerFactory.getLogger(GraphConfiguration.class);
@Bean("quickStartGraph")
public CompiledGraph quickStartGraph() throws GraphStateException {
// 定义状态
KeyStrategyFactory keyStrategyFactory = new KeyStrategyFactory() {
@Override
public Map<String, KeyStrategy> apply() {
return Map.of("input1", new ReplaceStrategy(), "input2", new ReplaceStrategy());
}
};
StateGraph stateGraph = new StateGraph("quickStartGraph", keyStrategyFactory);
// 定义节点
stateGraph.addNode("node1", AsyncNodeAction.node_async(new NodeAction() {
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
log.info("node1 state:{}", state);
return Map.of("input1", 1, "input2", 1);
}
}));
stateGraph.addNode("node2", AsyncNodeAction.node_async(new NodeAction() {
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
log.info("node2 state:{}", state);
return Map.of("input1", 2, "input2", 2);
}
}));
// 定义边
stateGraph.addEdge(StateGraph.START, "node1");
stateGraph.addEdge("node1", "node2");
stateGraph.addEdge("node2", StateGraph.END);
return stateGraph.compile();
}
}
- 测试调用
java
package com.example.controller;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.Map;
import java.util.Optional;
@RestController
@RequestMapping("/graph")
public class GraphController {
private static final Logger log = LoggerFactory.getLogger(GraphController.class);
private final CompiledGraph graph;
public GraphController(CompiledGraph graph) {
this.graph = graph;
}
@GetMapping("/test")
public Optional<OverAllState> test() {
Optional<OverAllState> call = graph.call(Map.of());
log.info("test call:{}", call);
return call;
}
}
2.6.3、API概述
2.6.3.1、KeyStrategyFactory
用来定义图中的状态有哪些数据,并且定义这些数据的更新策略是什么。
有三种策略分别是:替换(ReplaceStrategy),合并(MergeStrategy),追加(AppendStrategy)
java
KeyStrategyFactory keyStrategyFactory = () -> {
HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
keyStrategyHashMap.put("input1", new ReplaceStrategy());
keyStrategyHashMap.put("input2", new MergeStrategy());
keyStrategyHashMap.put("input3", new AppendStrategy());
return keyStrategyHashMap;
};
替换(ReplaceStrategy): 新值替换掉老值 (常用,更加灵活)
合并(MergeStrategy):适合Map类型的数据,新老Map的数据和合并
追加(AppendStrategy):适合List类型的数据,新List的数据追加到老List中
2.6.3.2、AsyncNodeAction&NodeAction
NodeAction 是Graph中对节点的抽象。我们只需要实现NodeAction接口,在apply方法中定义节点的
执行逻辑即可。
java
@FunctionalInterface
public interface NodeAction {
Map<String, Object> apply(OverAllState state) throws Exception;
}
AsyncNodeAction 异步节点,提供了一个静态方法可以NodeAction转化成 AsyncNodeAction
java
/*
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.graph.action;
import com.alibaba.cloud.ai.graph.OverAllState;
import io.opentelemetry.context.Context;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
/**
* Represents an asynchronous node action that operates on an agent state and returns
* state update.
*
*/
@FunctionalInterface
public interface AsyncNodeAction extends Function<OverAllState, CompletableFuture<Map<String, Object>>> {
/**
* Applies this action to the given agent state.
* @param state the agent state
* @return a CompletableFuture representing the result of the action
*/
CompletableFuture<Map<String, Object>> apply(OverAllState state);
/**
* Creates an asynchronous node action from a synchronous node action.
* @param syncAction the synchronous node action
* @return an asynchronous node action
*/
static AsyncNodeAction node_async(NodeAction syncAction) {
return state -> {
Context context = Context.current();
CompletableFuture<Map<String, Object>> result = new CompletableFuture<>();
try {
result.complete(syncAction.apply(state));
}
catch (Exception e) {
result.completeExceptionally(e);
}
return result;
};
}
}
2.6.3.4、StateGraph
状态图的抽象,需要配置状态(通过KeyStrategyFactory ),节点,边。
配置好后通过compile方法编译成CompiledGraph后才可以供调用。
2.6.3.5、CompiledGraph
CompiledGraph是StateGraph编译后的结果,CompiledGraph才能用了执行。
一般我们是把StateGraph定义好后调用其compile方法得到一个CompiledGraph放入Spring容器中。
然后在需要的时候从容器中注入然后再调用。
2.6.4、英语造句翻译小助手实战
- 定义造句node节点
java
package com.example.node;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import java.util.Map;
public class SentenceConstructionNode implements NodeAction {
private final ChatClient chatClient;
public SentenceConstructionNode(ChatClient.Builder builder) {
this.chatClient = builder.build();
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
// 获取状态图中的word
String word = state.value("word", "");
// 调用模型进行翻译
PromptTemplate promptTemplate = new PromptTemplate("你是一个英语造句专家,能 够基于给定的单词进行造句。"
+ "要求只返回最终造好的句子,不要返回其他信息。 给定的单词:{word}");
promptTemplate.add("word", word);
String render = promptTemplate.render();
String sentence = chatClient.prompt().user(render).call().content();
// 把翻译结果放到状态中
return Map.of("sentence", sentence);
}
}
- 定义翻译node节点
java
package com.example.node;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import java.util.Map;
public class TranslationNode implements NodeAction {
private final ChatClient chatClient;
public TranslationNode(ChatClient.Builder builder) {
this.chatClient = builder.build();
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
// 从状态中获取sentence
String sentence = state.value("sentence", "");
// 调用 大模型 翻译
PromptTemplate promptTemplate = new PromptTemplate("你是一个英语翻译专家,能 够对句子进行翻译。" +
"要求只返回翻译的结果不要返回其他信息。要翻译的句子:{sentence}");
promptTemplate.add("sentence", sentence);
String render = promptTemplate.render();
String translation = chatClient.prompt().user(render).call().content();
// 返回翻译结果
return Map.of("translation", translation);
}
}
- 定义Graph
java
package com.example.config;
import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.example.node.SentenceConstructionNode;
import com.example.node.TranslationNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.Map;
@Configuration
public class GraphConfiguration {
private static final Logger log = LoggerFactory.getLogger(GraphConfiguration.class);
@Bean("quickStartGraph")
public CompiledGraph quickStartGraph() throws GraphStateException {
// 定义状态
KeyStrategyFactory keyStrategyFactory = new KeyStrategyFactory() {
@Override
public Map<String, KeyStrategy> apply() {
return Map.of("input1", new ReplaceStrategy(), "input2", new ReplaceStrategy());
}
};
StateGraph stateGraph = new StateGraph("quickStartGraph", keyStrategyFactory);
// 定义节点
stateGraph.addNode("node1", AsyncNodeAction.node_async(new NodeAction() {
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
log.info("node1 state:{}", state);
return Map.of("input1", 1, "input2", 1);
}
}));
stateGraph.addNode("node2", AsyncNodeAction.node_async(new NodeAction() {
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
log.info("node2 state:{}", state);
return Map.of("input1", 2, "input2", 2);
}
}));
// 定义边
stateGraph.addEdge(StateGraph.START, "node1");
stateGraph.addEdge("node1", "node2");
stateGraph.addEdge("node2", StateGraph.END);
return stateGraph.compile();
}
@Bean("simpleGraph")
public CompiledGraph simpleGraph(ChatClient.Builder chatClient) throws GraphStateException {
// 定义状态
KeyStrategyFactory keyStrategyFactory = new KeyStrategyFactory() {
@Override
public Map<String, KeyStrategy> apply() {
return Map.of("word", new ReplaceStrategy());
}
};
StateGraph stateGraph = new StateGraph("simpleGraph", keyStrategyFactory);
// 定义节点
stateGraph.addNode("SentenceConstructionNode", AsyncNodeAction.node_async(new SentenceConstructionNode(chatClient)));
stateGraph.addNode("TranslationNode", AsyncNodeAction.node_async(new TranslationNode(chatClient)));
// 定义边
stateGraph.addEdge(StateGraph.START, "SentenceConstructionNode");
stateGraph.addEdge("SentenceConstructionNode", "TranslationNode");
stateGraph.addEdge("TranslationNode", StateGraph.END);
return stateGraph.compile();
}
}
- controller测试调用
java
package com.example.controller;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.Map;
import java.util.Optional;
@RestController
@RequestMapping("/graph")
public class GraphController {
private static final Logger log = LoggerFactory.getLogger(GraphController.class);
private final CompiledGraph quickStartGraph;
private final CompiledGraph simpleGraph;
public GraphController(@Qualifier("quickStartGraph") CompiledGraph quickStartGraph,
@Qualifier("simpleGraph") CompiledGraph simpleGraph) {
this.quickStartGraph = quickStartGraph;
this.simpleGraph = simpleGraph;
}
@GetMapping("/test")
public Optional<OverAllState> test() {
Optional<OverAllState> call = quickStartGraph.call(Map.of());
log.info("test call:{}", call);
return call;
}
@GetMapping("simpleGraph")
public Map<String, Object> simpleGraph(@RequestParam("word") String word) {
Optional<OverAllState> call = simpleGraph.call(Map.of("word", word));
return call.map(OverAllState::data).orElse(Map.of());
}
}
2.6.5、条件边-笑话生成实战
功能:生成笑话→笑话评估→优秀则结束,不够优秀则优化笑话在结束
- 生成笑话node
java
package com.example.node;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import java.util.Map;
public class GenerateJokeNode implements NodeAction {
private final ChatClient chatClient;
public GenerateJokeNode(ChatClient.Builder builder) {
this.chatClient = builder.build();
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
// 通过状态获取主题
String topic = state.value("topic", "");
// 通过大模型生成笑话
PromptTemplate promptTemplate = new PromptTemplate("你需要写一个关于指定主题 的短笑话。要求返回的结果中只能包含笑话的内容" + "主题:{topic}");
promptTemplate.add("topic", topic);
String render = promptTemplate.render();
String joke = chatClient.prompt().user(render).call().content();
// 返回笑话
return Map.of("joke", joke);
}
}
- 评估笑话node
java
package com.example.node;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import java.util.Map;
public class EvaluateJokesNode implements NodeAction {
private final ChatClient chatClient;
public EvaluateJokesNode(ChatClient.Builder builder) {
this.chatClient = builder.build();
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
// 从状态中获取joke
String joke = state.value("joke", "");
// 调用大模型评估joke
PromptTemplate promptTemplate = new PromptTemplate("你是一个笑话评分专家,能 够对笑话进行评分,基于效果的搞笑程度给出0到10分的打分。然后基于评分结果进行评价。如果大于等于3分 评价:优秀 否则评价:不够优秀\n"
+ "要求结果只返回最后的评价,不要其他内容。"
+ "要求只返回翻译的结果不要返回其他信息。要评分的笑话::{joke}");
promptTemplate.add("joke", joke);
String render = promptTemplate.render();
String content = chatClient.prompt().user(render).call().content();
// 返回评估结果 包含\n需要去掉
return Map.of("result", content.trim());
}
}
- 优化笑话node
java
package com.example.node;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import java.util.Map;
public class EnhanceJokeQualityNode implements NodeAction {
private final ChatClient chatClient;
public EnhanceJokeQualityNode(ChatClient.Builder builder) {
this.chatClient = builder.build();
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
// 通过状态获取笑话
String joke = state.value("joke", "");
// 调用大模型进行优化笑话
PromptTemplate promptTemplate = new PromptTemplate("你是一个笑话优化专家,你 能够优化笑话,让它更加搞笑"
+ "要求只返回翻译的结果不要返回其他信息。要优化的笑话:{joke}");
promptTemplate.add("joke", joke);
String render = promptTemplate.render();
String content = chatClient.prompt().user(render).call().content();
// 返回重新生成笑话的结果
return Map.of("newJoke", content);
}
}
- graph配置
java
package com.example.config;
import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.example.node.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.Map;
@Configuration
public class GraphConfiguration {
private static final Logger log = LoggerFactory.getLogger(GraphConfiguration.class);
// 生成笑话→评估笑话→优秀则结束, 不够优秀则优化笑话然后在结束
@Bean("conditionalGraph")
public CompiledGraph conditionalGraph(ChatClient.Builder chatClient) throws GraphStateException {
// 定义状态图
KeyStrategyFactory keyStrategyFactory = new KeyStrategyFactory() {
@Override
public Map<String, KeyStrategy> apply() {
return Map.of("topic", new ReplaceStrategy());
}
};
StateGraph stateGraph = new StateGraph("conditionalGraph", keyStrategyFactory);
// 定义节点
stateGraph.addNode("生成笑话", AsyncNodeAction.node_async(new GenerateJokeNode(chatClient)));
stateGraph.addNode("评估笑话", AsyncNodeAction.node_async(new EvaluateJokesNode(chatClient)));
stateGraph.addNode("优化笑话", AsyncNodeAction.node_async(new EnhanceJokeQualityNode(chatClient)));
// 定义边
stateGraph.addEdge(StateGraph.START, "生成笑话");
stateGraph.addEdge("生成笑话", "评估笑话");
stateGraph.addConditionalEdges("评估笑话", AsyncEdgeAction.edge_async(new EdgeAction() {
// 获取评估结果
@Override
public String apply(OverAllState state) throws Exception {
return state.value("result", "优秀");
}
}), Map.of("优秀", StateGraph.END, "不够优秀", "优化笑话"));
stateGraph.addEdge("优化笑话", StateGraph.END);
// 编译返回
return stateGraph.compile();
}
}
- controller测试
java
package com.example.controller;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.Map;
import java.util.Optional;
@RestController
@RequestMapping("/graph")
public class GraphController {
private static final Logger log = LoggerFactory.getLogger(GraphController.class);
private final CompiledGraph quickStartGraph;
private final CompiledGraph simpleGraph;
private final CompiledGraph conditionalGraph;
public GraphController(@Qualifier("quickStartGraph") CompiledGraph quickStartGraph,
@Qualifier("simpleGraph") CompiledGraph simpleGraph,
@Qualifier("conditionalGraph") CompiledGraph conditionalGraph) {
this.quickStartGraph = quickStartGraph;
this.simpleGraph = simpleGraph;
this.conditionalGraph = conditionalGraph;
}
@GetMapping("conditionalGraph")
public Map<String, Object> conditionalGraph(@RequestParam("topic") String topic) {
Optional<OverAllState> call = conditionalGraph.call(Map.of("topic", topic));
return call.map(OverAllState::data).orElse(Map.of());
}
}
2.6.6、循环-优化笑话直到达到固定分数或循环次数
功能:生成笑话→评估笑话→如果大于6分或者循环3次则结束,否则优化笑话 、优化笑话→评估笑话
- 生成笑话还用上一个
- 循环评估笑话(只需要生成提示词以及判断修改)
java
package com.example.node;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.example.controller.GraphController;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import java.util.Map;
public class LoopEvaluateJokesNode implements NodeAction {
private static final Logger log = LoggerFactory.getLogger(LoopEvaluateJokesNode.class);
private final ChatClient chatClient;
private final Integer targetScore;
private final Integer maxLoopScore;
public LoopEvaluateJokesNode(ChatClient.Builder builder, Integer targetScore, Integer maxLoopScore) {
this.chatClient = builder.build();
this.targetScore = targetScore;
this.maxLoopScore = maxLoopScore;
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
// 从状态中获取joke
String joke = state.value("joke", "");
Integer loopCount = state.value("loopCount", 1);
// 调用大模型评估joke
PromptTemplate promptTemplate = new PromptTemplate("你是一个笑话评分专家,能 够对笑话进行评分,基于效果的搞笑程度给出0到10分的打分。要求打分只能是整数\n"
+ "要求结果只返回最后的打分,不要其他内容。"
+ "要评分的笑话::{joke}");
promptTemplate.add("joke", joke);
String render = promptTemplate.render();
String content = chatClient.prompt().user(render).call().content();
// 返回评估结果 包含\n需要去掉
Integer score = Integer.valueOf(content.trim());
log.info("score: {}, loopCount:{}, joke:{}", score, loopCount, joke);
String result = "loop";
// 分数大于6分 或 循环次数 > 3次
if (score > targetScore || loopCount >= maxLoopScore) {
result = "break";
}
loopCount++;
return Map.of("result", result, "lookCount", loopCount, "score", score);
}
}
- 循环优化笑话
java
package com.example.node;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import java.util.Map;
public class LoopEnhanceJokeQualityNode implements NodeAction {
private final ChatClient chatClient;
public LoopEnhanceJokeQualityNode(ChatClient.Builder builder) {
this.chatClient = builder.build();
}
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
// 通过状态获取笑话
String joke = state.value("joke", "");
// 调用大模型进行优化笑话
PromptTemplate promptTemplate = new PromptTemplate("你是一个笑话优化专家,你 能够优化笑话,让它更加搞笑"
+ "要求只返回翻译的结果不要返回其他信息。要优化的笑话:{joke}");
promptTemplate.add("joke", joke);
String render = promptTemplate.render();
String content = chatClient.prompt().user(render).call().content();
// 返回重新生成笑话的结果
return Map.of("joke", content);
}
}
- graph
java
package com.example.config;
import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.example.node.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.Map;
@Configuration
public class GraphConfiguration {
private static final Logger log = LoggerFactory.getLogger(GraphConfiguration.class);
@Bean("loopGraph")
public CompiledGraph loopGraph(ChatClient.Builder chatClient) throws GraphStateException {
// 定义状态
KeyStrategyFactory keyStrategyFactory = new KeyStrategyFactory() {
@Override
public Map<String, KeyStrategy> apply() {
return Map.of("topic", new ReplaceStrategy());
}
};
StateGraph stateGraph = new StateGraph("loopGraph", keyStrategyFactory);
// 定义节点
stateGraph.addNode("生成笑话", AsyncNodeAction.node_async(new GenerateJokeNode(chatClient)));
stateGraph.addNode("评估笑话", AsyncNodeAction.node_async(new LoopEvaluateJokesNode(chatClient, 6, 3)));
stateGraph.addNode("优化笑话", AsyncNodeAction.node_async(new LoopEnhanceJokeQualityNode(chatClient)));
// 定义边
stateGraph.addEdge(StateGraph.START, "生成笑话");
stateGraph.addEdge("生成笑话", "评估笑话");
stateGraph.addConditionalEdges("评估笑话", AsyncEdgeAction.edge_async(new EdgeAction() {
@Override
public String apply(OverAllState state) throws Exception {
return state.value("result", "loop");
}
}), Map.of("loop", "优化笑话", "break", StateGraph.END));
stateGraph.addEdge("优化笑话", "评估笑话");
return stateGraph.compile();
}
}
- controller测试
java
package com.example.controller;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.Map;
import java.util.Optional;
@RestController
@RequestMapping("/graph")
public class GraphController {
private static final Logger log = LoggerFactory.getLogger(GraphController.class);
private final CompiledGraph quickStartGraph;
private final CompiledGraph simpleGraph;
private final CompiledGraph conditionalGraph;
private final CompiledGraph loopGraph;
public GraphController(@Qualifier("quickStartGraph") CompiledGraph quickStartGraph,
@Qualifier("simpleGraph") CompiledGraph simpleGraph,
@Qualifier("conditionalGraph") CompiledGraph conditionalGraph,
@Qualifier("loopGraph") CompiledGraph loopGraph) {
this.quickStartGraph = quickStartGraph;
this.simpleGraph = simpleGraph;
this.conditionalGraph = conditionalGraph;
this.loopGraph = loopGraph;
}
@GetMapping("loopGraph")
public Map<String, Object> loopGraph(@RequestParam("topic") String topic) {
Optional<OverAllState> call = loopGraph.call(Map.of("topic", topic));
return call.map(OverAllState::data).orElse(Map.of());
}
}
2.6.7、保存状态
- graph
java
package com.example.config;
import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.example.node.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.ArrayList;
import java.util.Map;
@Configuration
public class GraphConfiguration {
private static final Logger log = LoggerFactory.getLogger(GraphConfiguration.class);
@Bean("saveGraph")
public CompiledGraph saveGraph(ChatClient.Builder chatClient) throws GraphStateException {
// 定义状态图
KeyStrategyFactory keyStrategyFactory = Map::of;
StateGraph stateGraph = new StateGraph("saveGraph", keyStrategyFactory);
stateGraph.addNode("对话存储", AsyncNodeAction.node_async(new NodeAction() {
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
String msg = state.value("msg", "");
ArrayList<Object> historyMsg = state.value("historyMsg", new ArrayList<>());
historyMsg.add(msg);
return Map.of("historyMsg", historyMsg);
}
}));
// 定义边
stateGraph.addEdge(StateGraph.START, "对话存储");
stateGraph.addEdge("对话存储", StateGraph.END);
// 添加PlantUML打印
GraphRepresentation graphRepresentation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML, "saveGraph");
log.info("\n===expander UML FLOW ===");
log.info(graphRepresentation.content());
log.info("==========================\n");
return stateGraph.compile();
}
}
- controller测试
java
package com.example.controller;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.Map;
import java.util.Optional;
@RestController
@RequestMapping("/graph")
public class GraphController {
private static final Logger log = LoggerFactory.getLogger(GraphController.class);
private final CompiledGraph quickStartGraph;
private final CompiledGraph simpleGraph;
private final CompiledGraph conditionalGraph;
private final CompiledGraph loopGraph;
private final CompiledGraph saveGraph;
public GraphController(@Qualifier("quickStartGraph") CompiledGraph quickStartGraph,
@Qualifier("simpleGraph") CompiledGraph simpleGraph,
@Qualifier("conditionalGraph") CompiledGraph conditionalGraph,
@Qualifier("loopGraph") CompiledGraph loopGraph,
@Qualifier("saveGraph") CompiledGraph saveGraph) {
this.quickStartGraph = quickStartGraph;
this.simpleGraph = simpleGraph;
this.conditionalGraph = conditionalGraph;
this.loopGraph = loopGraph;
this.saveGraph = saveGraph;
}
@GetMapping("saveGraph")
public Map<String, Object> saveGraph(@RequestParam("msg") String msg,
@RequestParam("conversationId") String conversationId) {
RunnableConfig runnableConfig = RunnableConfig.builder().threadId(conversationId).build();
Optional<OverAllState> call = saveGraph.call(Map.of("msg", msg), runnableConfig);
return call.map(OverAllState::data).orElse(Map.of());
}
}
可视化图地址:http://www.plantuml.com/plantuml/uml/SyfFKj2rKt3CoKnELR1Io4ZDoSa70000
