官方文档:Chat Memory :: Spring AI Reference
1. 引言
SpringAI 1.0.0 改动了很多地方,本文根据官方的InMemoryChatMemoryRepository实现了自定义的RedisChatMemoryRepository,并使用MessageWindowChatMemory创建ChatMemory
2. 实现
2.1. 添加依赖
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-openai</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
注意:SpringAI 1.0.0的maven依赖有所改变,artifactId变化了
2.2. 配置文件
java
server:
port: 8080
spring:
ai:
openai:
api-key: xxx # 填自己的api-key
base-url: https://api.deepseek.com
chat:
options:
model: deepseek-chat
temperature: 0.7
data:
redis:
host: localhost
port: 6379
password: 123456
正确配置redis连接即可
api-key可以填deepseek的(需要购买,1块钱能用挺久)
2.3. RedisChatMemoryRepository
RedisChatMemoryRepository用于存储会话数据
这里参考InMemoryChatMemoryRepository与【SpringAI 1.0.0】 ChatMemory 转换为 Redis 存储_springai如何将数据保存到redis-CSDN博客
java
package com.njust.repository;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.content.Media;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.util.MimeType;
import java.io.IOException;
import java.net.URL;
import java.util.*;
import java.util.stream.Collectors;
public class RedisChatMemoryRepository implements ChatMemoryRepository {
private final StringRedisTemplate stringRedisTemplate; // 用于操作 Redis
private final ObjectMapper objectMapper; // 用于序列化和反序列化
private final String PREFIX ; // 存储对话的 Redis Key 前缀
private final String CONVERSATION_IDS_SET; // 存储对话ID的 Redis Key
public RedisChatMemoryRepository(StringRedisTemplate stringRedisTemplate, ObjectMapper objectMapper) {
this(stringRedisTemplate, objectMapper, "chat:conversation:", "chat:all_conversation_ids");
}
public RedisChatMemoryRepository(StringRedisTemplate stringRedisTemplate, ObjectMapper objectMapper, String PREFIX) {
this(stringRedisTemplate, objectMapper, PREFIX, "chat:all_conversation_ids");
}
public RedisChatMemoryRepository(StringRedisTemplate stringRedisTemplate, ObjectMapper objectMapper, String PREFIX, String CONVERSATION_IDS_SET) {
this.stringRedisTemplate = stringRedisTemplate;
this.objectMapper = objectMapper;
this.PREFIX = PREFIX;
this.CONVERSATION_IDS_SET = CONVERSATION_IDS_SET;
}
// 获取所有 conversationId(KEYS 命令匹配 chat:*)
@Override
public List<String> findConversationIds() {
// 使用ZSet存储对话ID(更高效)
// 获取对话ID集合(按时间倒序排序,越晚创建的对话ID排在前面)
Set<String> conversationIds = stringRedisTemplate.opsForZSet().reverseRange(CONVERSATION_IDS_SET, 0, -1);
if (conversationIds == null || conversationIds.isEmpty()) {
return List.of();
}
return new ArrayList<>(conversationIds);
}
// 根据 conversationId 获取 Message 列表
@Override
public List<Message> findByConversationId(String conversationId) {
// 参数验证
if (conversationId == null || conversationId.isEmpty()) {
throw new IllegalArgumentException("conversationId cannot be null or empty");
}
List<String> list = stringRedisTemplate.opsForList().range(PREFIX + conversationId, 0, -1);
if (list == null || list.isEmpty()) {
return List.of();
}
return list.stream()
.map(json -> {
try {
// return objectMapper.convertValue(json, Message.class); // 直接反序列化Message会报错
return deserializeMessage(json); // 手动反序列化
} catch (IOException e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList());
}
// 保存整个 Message 列表到指定 conversationId
@Override
public void saveAll(String conversationId, List<Message> messages) {
// 参数验证
if (conversationId == null || conversationId.isEmpty()) {
throw new IllegalArgumentException("conversationId cannot be null or empty");
}
// 先清除原有的 conversation 数据
stringRedisTemplate.delete(PREFIX + conversationId);
if (messages == null || messages.isEmpty()) {
return;
}
List<String> list = messages.stream()
.map(message -> {
try {
return objectMapper.writeValueAsString(message);
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to serialize Message", e);
}
})
.collect(Collectors.toList());
stringRedisTemplate.opsForList().rightPushAll(PREFIX + conversationId, list);
// 更新对话ID集合
stringRedisTemplate.opsForZSet().add(CONVERSATION_IDS_SET, conversationId, System.currentTimeMillis());
}
// 删除指定 conversationId 的数据
@Override
public void deleteByConversationId(String conversationId) {
if (conversationId == null || conversationId.isEmpty()) {
throw new IllegalArgumentException("conversationId cannot be null or empty");
}
stringRedisTemplate.delete(PREFIX + conversationId);
stringRedisTemplate.opsForZSet().remove(CONVERSATION_IDS_SET, conversationId);
}
// 手动反序列化 Message
public Message deserializeMessage(String json) throws IOException {
// 解析 JSON 字符串为 JsonNode
JsonNode jsonNode = objectMapper.readTree(json);
// 获取 messageType 字段值
if (!jsonNode.has("messageType")) {
throw new IllegalArgumentException("Missing or invalid messageType field");
}
String messageType = jsonNode.get("messageType").asText();
// 获取 text 字段值
String text = jsonNode.has("text") ? jsonNode.get("text").asText() : "";
// 获取 metadata 字段值
Map<String, Object> metadata = getMetadata(jsonNode);
// 获取 media 字段值
List<Media> mediaList = getMediaList(jsonNode);
return switch (MessageType.valueOf(messageType)) {
case SYSTEM -> new SystemMessage(text);
case USER -> UserMessage.builder()
.text(text)
.media(mediaList)
.metadata(metadata)
.build();
case ASSISTANT -> {
List<AssistantMessage.ToolCall> toolCalls = getToolCalls(jsonNode);
yield new AssistantMessage(text, metadata, toolCalls, mediaList);
}
default -> throw new IllegalArgumentException("Unknown message type: " + messageType);
};
}
private Media deserializeMedia(ObjectMapper mapper, JsonNode mediaNode) throws IOException {
Media.Builder builder = Media.builder();
// Handle MIME type
if (mediaNode.has("mimeType")) {
JsonNode mimeNode = mediaNode.get("mimeType");
String type = mimeNode.get("type").asText();
String subtype = mimeNode.get("subtype").asText();
builder.mimeType(new MimeType(type, subtype));
}
// Handle data - could be either URL string or byte array
if (mediaNode.has("data")) {
String data = mediaNode.get("data").asText();
if (data.startsWith("http://") || data.startsWith("https://")) {
builder.data(new URL(data));
} else {
// Assume it's base64 encoded binary data
byte[] bytes = Base64.getDecoder().decode(data);
builder.data(bytes);
}
}
// Handle dataAsByteArray if present (overrides data if both exist)
if (mediaNode.has("dataAsByteArray")) {
byte[] bytes = Base64.getDecoder().decode(mediaNode.get("dataAsByteArray").asText());
builder.data(bytes);
}
// Handle optional fields
if (mediaNode.has("id")) {
builder.id(mediaNode.get("id").asText());
}
if (mediaNode.has("name")) {
builder.name(mediaNode.get("name").asText());
}
return builder.build();
}
private Map<String, Object> getMetadata(JsonNode jsonNode) {
if (jsonNode.has("metadata")) {
return objectMapper.convertValue(jsonNode.get("metadata"), new TypeReference<>() {});
}
return new HashMap<>();
}
private List<Media> getMediaList(JsonNode jsonNode) throws IOException {
List<Media> mediaList = new ArrayList<>();
if (jsonNode.has("media")) {
for (JsonNode mediaNode : jsonNode.get("media")) {
mediaList.add(deserializeMedia(objectMapper, mediaNode));
}
}
return mediaList;
}
private List<AssistantMessage.ToolCall> getToolCalls(JsonNode jsonNode) {
if (jsonNode.has("toolCalls")) {
return objectMapper.convertValue(jsonNode.get("toolCalls"), new TypeReference<>() {});
}
return Collections.emptyList();
}
}
主要的部分都写上注释了,应该比较好理解
需要注意的是反序列化Message需要手动进行
2.4. 注册Bean
java
package com.njust.config;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.njust.repository.ChatHistoryRepository;
import com.njust.repository.RedisChatHistoryRepository;
import com.njust.repository.RedisChatMemoryRepository;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.StringRedisTemplate;
@Configuration
public class CommonConfiguration {
@Bean
public ChatMemoryRepository chatMemoryRepository(StringRedisTemplate stringRedisTemplate) {
// 默认情况下,如果尚未配置其他存储库,则 Spring AI 会自动配置ChatMemoryRepository类型的 beanInMemoryChatMemoryRepository可以直接在应用程序中使用。
// 这里手动创建内存聊天记忆存储库
return new RedisChatMemoryRepository(stringRedisTemplate, new ObjectMapper());
}
@Bean
public ChatMemory chatMemory(ChatMemoryRepository chatMemoryRepository) {
// 注册聊天上下文记忆机制
return MessageWindowChatMemory
.builder()
.chatMemoryRepository(chatMemoryRepository)
.maxMessages(20) // 聊天记忆条数
.build();
}
@Bean
// 通过OpenAI平台注入deepseek模型
public ChatClient deepseekChatClient(OpenAiChatModel openAiChatModel, ChatMemory chatMemory) {
return ChatClient
.builder(openAiChatModel)
.defaultSystem("你是南京理工大学计算机科学与工程学院的一名研究生,你的名字叫小兰")
.defaultAdvisors(
new SimpleLoggerAdvisor(), // 配置日志Advisor
MessageChatMemoryAdvisor.builder(chatMemory).build() // 绑定上下文记忆
)
.build();
}
}
这里用MessageWindowChatMemory创建ChatMemory,用于限制上下文记忆条数
2.5. Controller
java
package com.njust.controller;
import com.njust.repository.ChatHistoryRepository;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.content.Media;
import org.springframework.util.MimeType;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import reactor.core.publisher.Flux;
import java.util.List;
import java.util.Objects;
import static org.springframework.ai.chat.memory.ChatMemory.CONVERSATION_ID;
// @RequiredArgsConstructor 的作用是:为所有 final 字段或带有 @NotNull 注解的字段自动生成构造函数,实现简洁、安全的依赖注入
@RequiredArgsConstructor
@RestController
@RequestMapping("/ai")
public class ChatController {
private final ChatClient deepseekChatClient;
@RequestMapping(value = "/chat", produces = "text/html;charset=utf-8")
public Flux<String> chat(
@RequestParam("prompt") String prompt,
@RequestParam("chatId") String chatId) {
return deepseekChatClient.prompt()
.user(prompt)
.advisors(a -> a.param(CONVERSATION_ID, chatId))
.stream()
.content();
}
}
3. 效果
redis中存储的内容为:

如果想存储某个用户的会话id,也可以存储在redis中,自己写个方法,用zset存,在controller中的chat方法内,将chatId作为member存入zset,score为时间戳,这样就能根据创建时间排序了。每次会话都更新下score,这样最新发送消息订单会话就能排在最上面,和deepseek一样。