今天正在路上刷手机,突然看到群里有人发了一个链接,内容是Spring Cloud Alibaba AI 的使用,spring cloud AI的使用,于是就想着玩一玩试试,难度不大,就是有些文档的坑,这里做一个记录,后续会继续更新这个系列的,毕竟AI时代,我们springer也得搭一下顺风车。
一、文档阅读
我们看到文档其实描述了他的能力如下图,但是我们这里只尝试一下文生文的能力,其实说人话就是对话。
二、快速开始
我们基于文档描述开始搭建这个环境。
1、基于文档搭建环境
- 1、 引入依赖
我们按照文档的内容引入如下依赖。
xml
<dependencyManagement>
<dependencies>
<dependency>
<groupId>com.alibaba.cloud</groupId>
<artifactId>spring-cloud-alibaba-dependencies</artifactId>
<version>2023.0.1.0</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>com.alibaba.cloud</groupId>
<artifactId>spring-cloud-starter-alibaba-ai</artifactId>
</dependency>
</dependencies>
- 2、在 application.yml 配置文件中加入以下配置:
这个API-KEY需要你通过阿里云去获取,具体操作步骤:获取通义千问api-key
yml
spring:
cloud:
ai:
tongyi:
chat:
options:
# Replace the following key with a valid API-KEY.
api-key: sk-a3d73b1709bf4a178c28ed7c8b3b5axx
- 3、编写聊天服务实现类,由 Spring AI 自动注入 ChatClient、StreamingChatClient,ChatClient 屏蔽底层通义大模型交互细节。
java
@Service
public class TongYiSimpleServiceImpl extends AbstractTongYiServiceImpl {
private final ChatClient chatClient;
private final StreamingChatClient streamingChatClient;
@Autowired
public TongYiSimpleServiceImpl(ChatClient chatClient, StreamingChatClient streamingChatClient) {
this.chatClient = chatClient;
this.streamingChatClient = streamingChatClient;
}
}
提供具体聊天逻辑实现
java
@Service
public class TongYiSimpleServiceImpl extends AbstractTongYiServiceImpl {
// ......
@Override
public String completion(String message) {
Prompt prompt = new Prompt(new UserMessage(message));
return chatClient.call(prompt).getResult().getOutput().getContent();
}
@Override
public Map<String, String> streamCompletion(String message) {
StringBuilder fullContent = new StringBuilder();
streamingChatClient.stream(new Prompt(message))
.flatMap(chatResponse -> Flux.fromIterable(chatResponse.getResults()))
.map(content -> content.getOutput().getContent())
.doOnNext(fullContent::append)
.last()
.map(lastContent -> Map.of(message, fullContent.toString()))
.block();
log.info(fullContent.toString());
return Map.of(message, fullContent.toString());
}
}
编写 Spring 入口类并启动应用
java
@SpringBootApplication
public class TongYiApplication {
public static void main(String[] args) {
SpringApplication.run(TongYiApplication.class);
}
}
2、踩坑
2.1、maven导入失败
我在按照如上环境搭建好之后,reload maven发现报错。
从阿里云的maven镜像库里面拉不下来这些依赖。我用的maven库如下:
xml
<mirror>
<id>alimaven</id>
<name>aliyun maven</name>
<url>http://maven.aliyun.com/nexus/content/groups/public/</url>
<mirrorOf>central</mirrorOf>
</mirror>
既然踩了坑,那就去文档找吧,估计不止我一个人踩坑。翻了一下找到了一个阿里cloud答疑问题的网址。spring cloud alibaba答疑区
最后翻到了一个这个spring ai maven无法引入问题解决
我们就按照他这个加一个maven的仓库配置,
xml
<repositories>
<repository>
<id>spring-milestones</id>
<name>Spring Milestones</name>
<url>https://repo.spring.io/milestone</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
<repository>
<id>spring-snapshots</id>
<name>Spring Snapshots</name>
<url>https://repo.spring.io/snapshot</url>
<releases>
<enabled>false</enabled>
</releases>
</repository>
</repositories>
2.2、yml配置问题
我按照上面的配置启动服务,发现报错。找不到配置文件的这个api-key.
应该是配置的路径不对,但是我是按照官方配置的。所以我们来看下源码吧还是。
我们看到他的前缀spring.cloud.ai.tongyi和我们的文档的不一样层级,所以我们还是改一下吧。
改为如下图所示:
OK,此时就已经可以启动了。
3、完整代码
3.1、controller
java
package com.test.controller;
import com.test.utils.SseEmitterUtils;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*;
import com.test.service.TongYiSimpleServiceImpl;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import reactor.core.publisher.Flux;
@RestController
@RequestMapping("/chat")
public class TongYiSimpleController {
@Resource
StreamingChatClient streamingChatClient;
@GetMapping("/connect")
@ApiResponse(description = "用户创建连接")
public SseEmitter connect(@RequestParam(name = "username") String username) {
return SseEmitterUtils.getConnection(username);
}
@PostMapping(value = "/send")
@ApiResponse(description = "用户发送消息")
public void sendMessage(@RequestParam(name = "username") String username, @RequestParam(name = "message") String message) {
try {
TongYiSimpleServiceImpl tongYiSimpleService = new TongYiSimpleServiceImpl(streamingChatClient);
tongYiSimpleService.sendMsg(username,message);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
3.2、service
java
package com.test.service;
import com.test.utils.SseEmitterUtils;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.io.IOException;
import java.util.Map;
@Service
public class TongYiSimpleServiceImpl{
private final StreamingChatClient streamingChatClient;
@Autowired
public TongYiSimpleServiceImpl(StreamingChatClient streamingChatClient) {
this.streamingChatClient = streamingChatClient;
}
public Map<String, String> sendMsg(String userName,String message) {
StringBuilder fullContent = new StringBuilder();
// 流式调用大模型,以响应式的方式返回结果
streamingChatClient.stream(new Prompt(message))
.flatMap(chatResponse -> Flux.fromIterable(chatResponse.getResults()))
.map(content -> content.getOutput().getContent())
.doOnNext(fullContent::append)
.subscribe(item ->{
try {
SseEmitterUtils.sendMsg(userName,item);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
return Map.of(message, fullContent.toString());
}
public Mono<Map<String, String>> streamCompletion(String message) {
return streamingChatClient.stream(new Prompt(message))
.flatMapIterable(chatResponse -> chatResponse.getResults())
.map(item->item.getOutput())
.map(item->item.getContent())
.reduce(new StringBuilder(), (builder, content) -> builder.append(content))
.map(builder -> Map.of(message, builder.toString()));
}
}
3.3、sse工具类
java
package com.test.utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
public class SseEmitterUtils {
private static final Logger logger = LoggerFactory.getLogger(SseEmitterUtils.class);
private static final ThreadPoolExecutor ssePool = new ThreadPoolExecutor(
20,
200,
30,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1000),
runnable -> new Thread(runnable, "sse-sendMsg-pool"),
new ThreadPoolExecutor.AbortPolicy()
);
// SSE连接关闭延迟时间
private static final Integer EMITTER_COMPLETE_DELAY_MILLISECONDS = 500;
// SSE连接初始化超时时间
private static final Long EMITTER_TIME_OUT_MILLISECONDS = 600_000L;
// 缓存 SSE连接
private static final Map<String, SseEmitter> SSE_CACHE = new ConcurrentHashMap<>();
/**
* 获取 SSE连接 默认超时时间EMITTER_TIME_OUT_MILLISECONDS 毫秒
*
* @param clientId 客户端 ID
* @return 连接对象
*/
public static SseEmitter getConnection(String clientId) {
return getConnection(clientId,EMITTER_TIME_OUT_MILLISECONDS);
}
/**
* 获取 SSE连接
*
* @param clientId 客户端 ID
* @param timeout 连接超时时间,单位毫秒
* @return 连接对象
*/
public static SseEmitter getConnection(String clientId,Long timeout) {
final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
if (Objects.nonNull(sseEmitter)) {
return sseEmitter;
} else {
final SseEmitter emitter = new SseEmitter(timeout);
// 初始化emitter回调
initSseEmitter(emitter, clientId);
// 连接建立后,将连接放入缓存
SSE_CACHE.put(clientId, emitter);
logger.info("[SseEmitter] 连接已建立,clientId = {}", clientId);
return emitter;
}
}
/**
* 关闭指定的流连接
*
* @param clientId 客户端 ID
*/
public static void closeConnection(String clientId) {
final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
logger.info("[流式响应-停止生成] 收到客户端关闭连接指令,Emitter is {},clientId = {}", null == sseEmitter ? "NOT-Exist" : "Exist", clientId);
if (Objects.nonNull(sseEmitter)) {
SSE_CACHE.remove(clientId);
sseEmitter.complete();
}
try {
TimeUnit.MILLISECONDS.sleep(EMITTER_COMPLETE_DELAY_MILLISECONDS);
} catch (InterruptedException ex) {
logger.error("流式响应异常", ex);
Thread.currentThread().interrupt();
}
}
/**
* 推送消息
*
* @param clientId 客户端 ID
* @param msg 消息
* @return 连接是否存在
* @throws IOException IO异常
*/
public static boolean sendMsg(String clientId, String msg) throws IOException {
final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
if (Objects.nonNull(sseEmitter)) {
try {
sseEmitter.send(msg);
} catch (Exception e) {
logger.error("[流式响应-停止生成] ");
return true;
}
return false;
} else {
return true;
}
}
/**
* 异步推送消息 TODO 目前未实现提供回调
*
* @param clientId 客户端 ID
* @param msg 消息
* @return 连接是否存在
* @throws IOException IO异常
*/
public static boolean sendMsgAsync(String clientId, String msg){
final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
if (Objects.nonNull(sseEmitter)) {
try {
ssePool.submit(()->{
try {
sseEmitter.send(msg);
} catch (IOException e) {
logger.error("[流式响应-停止生成] ");
}
});
} catch (Exception e) {
logger.error("[流式响应-停止生成] ");
return true;
}
return false;
} else {
return true;
}
}
/**
* 立即关闭SseEmitter,可能存在推流不完全的情况,谨慎使用
*
* @param clientId
*/
public static void complete(String clientId) {
completeDelay(clientId,0);
}
/**
* 延迟关闭 SseEmitter,延迟一定时长时为了尽量保证最后一次推送数据被前端完整接收
*
* @param clientId 客户端ID
*/
public static void completeDelay(String clientId) {
completeDelay(clientId,EMITTER_COMPLETE_DELAY_MILLISECONDS);
}
/**
* 延迟关闭 SseEmitter,延迟指定时长时为了尽量保证最后一次推送数据被前端完整接收
*
* @param clientId 客户端ID
*/
public static void completeDelay(String clientId,Integer delayMilliSeconds) {
final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
if (Objects.nonNull(sseEmitter)) {
try {
TimeUnit.MILLISECONDS.sleep(delayMilliSeconds);
sseEmitter.complete();
} catch (InterruptedException ex) {
logger.error("流式响应异常", ex);
Thread.currentThread().interrupt();
}
}
}
/**
* 初始化 SSE连接 设置一些属性和回调之类的
*
* @param emitter 连接对象
* @param clientId 客户端 ID
*/
private static void initSseEmitter(SseEmitter emitter, String clientId){
// 设置SSE的超时回调
emitter.onTimeout(() -> {
logger.info("[SseEmitter] 连接已超时,正准备关闭,clientId = {}", clientId);
SSE_CACHE.remove(clientId);
});
// 设置SSE的结束回调
emitter.onCompletion(() -> {
logger.info("[SseEmitter] 连接已释放,clientId = {}", clientId);
SSE_CACHE.remove(clientId);
});
// 设置SSE的异常回调
emitter.onError(throwable -> {
logger.error("[SseEmitter] 连接已异常,正准备关闭,clientId = {}", clientId);
SSE_CACHE.remove(clientId);
});
}
}
3.4、配置文件
yml
spring:
cloud:
ai:
tongyi:
# Replace the following key with a valid API-KEY.
api-key: 替换你的
model: qwen-max
3.5、maven配置
xml
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>test</artifactId>
<version>1.0-SNAPSHOT</version>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>com.alibaba.cloud</groupId>
<artifactId>spring-cloud-alibaba-dependencies</artifactId>
<version>2023.0.1.0</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>com.alibaba.cloud</groupId>
<artifactId>spring-cloud-starter-alibaba-ai</artifactId>
</dependency>
</dependencies>
<repositories>
<repository>
<id>spring-milestones</id>
<name>Spring Milestones</name>
<url>https://repo.spring.io/milestone</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
<repository>
<id>spring-snapshots</id>
<name>Spring Snapshots</name>
<url>https://repo.spring.io/snapshot</url>
<releases>
<enabled>false</enabled>
</releases>
</repository>
</repositories>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>9</source>
<target>9</target>
</configuration>
</plugin>
</plugins>
</build>
</project>
3.6、简单前端页面
html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>SSE Chat</title>
</head>
<body>
<h1>橘子-Chat</h1>
<div id="chat-messages"></div>
<form id="message-form">
<input type="text" id="message-input" placeholder="输入消息">
<button type="submit">发送</button>
</form>
<script>
const chatMessages = document.getElementById('chat-messages');
const messageForm = document.getElementById('message-form');
const messageInput = document.getElementById('message-input');
// 连接到聊天室
const connectToChat = () => {
const username = prompt('Enter your username:');
const eventSource = new EventSource(`/chat/connect?username=${encodeURIComponent(username)}`);
// 接收来自服务器的消息
eventSource.onmessage = function(event) {
const message = event.data;
displayMessage(message);
};
// 处理连接错误
eventSource.onerror = function(event) {
console.error('EventSource error:', event);
eventSource.close();
};
// 提交消息表单
messageForm.addEventListener('submit', function(event) {
event.preventDefault();
const message = messageInput.value.trim();
if (message !== '') {
sendMessage(username, message);
messageInput.value = '';
}
});
};
// 发送消息到服务器
const sendMessage = (username, message) => {
fetch(`/chat/send?username=${encodeURIComponent(username)}&message=${encodeURIComponent(message)}`, {
method: 'POST'
})
.catch(error => console.error('Error sending message:', error));
};
// 在界面上显示消息
const displayMessage = (message) => {
const messageElement = document.createElement('div');
messageElement.textContent = message;
chatMessages.appendChild(messageElement);
};
// 发起连接
connectToChat();
</script>
</body>
</html>
至此就完成了对接。