若要实现自定义Redis持久化ChatMemory,可以参考这篇文章 Spring AI 自定义Redis持久化的ChatMemory
自定义实现ChatMemory
Spring AI的对话记忆实现非常巧妙,解耦了"储存"和"记忆算法",
-
存储:ChatMemory:我们可以单独修改ChatMemory存储来改变对话记忆的保存位置,而无需修改保存对话记忆的流程。
-
记忆算法:ChatMemory Advisor,advisor可以理解为拦截器,在调用大模型时的前或后执行一些操作
- MessageChatMemoryAdvisor: 从记忆中(ChatMemory)检索历史对话,并将其作为消息集合添加到提示词中。常用。能更好的保持上下文连贯性。
- PromptChatMemoryAdvisor: 从记忆中检索历史对话,并将其添加到提示词的系统文本中。可以理解为没有结构性的纯文本。
- VectorStoreChatMemoryAdvisor: 可以用向量数据库来存储检索历史对话。
我们可以单独修改ChatMemory储存来改变对话记忆的保存位置,而无需修改保存对话记忆的流程.
虽然官方文档没有给我们自定义ChatMemory实现的示例,但是我们可以直接去阅读默认实现类 InMemoryChatMemory 的源码

其本质是实现了ChatMemory
的增删查接口

所以我们想实现自己的持久化,修改对应的储存实现就行了.
参考 InMemoryChatMemory 的源码,其实就是通过 ConcurrentHashMap 来维护对话信息,key 是对话 id(相当于房间号),value 是该对话 id 对应的消息列表。
自定义MYSQL持久化ChatMemory
本质是将数据储存到MySQL中,同样的,由于List<Message>
中Message是一个接口,虽然需要实现的接口不多,但是实现起来还是有一定复杂度的,一个最主要的问题是 消息和文本的转换。我们在保存消息时,要将消息从 Message 对象转为文件内的文本;读取消息时,要将文件内的文本转换为 Message 对象。也就是对象的序列化和反序列化。
我们本能地会想到通过 JSON 进行序列化,但实际操作中,我们发现这并不容易。原因是:
- 要持久化的 Message 是一个接口,有很多种不同的子类实现(比如 UserMessage、SystemMessage 等)
- 每种子类所拥有的字段都不一样,结构不统一
- 子类没有无参构造函数,而且没有实现 Serializable 序列化接口
在这里有两个方案:
- 自己结构化数据库,使用结构化的数据库然后自己手动创建
Message
的实现对象来序列化.这里参考这篇文章:
java
@Override
public List<Message> get(String conversationId, int lastN) {
// 分页查询最近N条记录
Page<AiChatMemory> page = new Page<>(1, lastN);
QueryWrapper<AiChatMemory> wrapper = new QueryWrapper<>();
wrapper.eq("conversation_id", conversationId)
.orderByDesc("create_time");
List<AiChatMemory> aiChatMemories = mapper.selectList(wrapper);
// 反转列表,使得最新的消息在最后
Collections.reverse(aiChatMemories);
//------------------------------------------------------
// 转换为Message对象
List<Message> messages = new ArrayList<>();
for (AiChatMemory aiChatMemory : aiChatMemories) {
String type = aiChatMemory.getType();//数据库储存的Message实现类类型
switch (type) {//根据类型手动创建出实现类
case "user" -> messages.add(new UserMessage(aiChatMemory.getContent()));
case "assistant" -> messages.add(new AssistantMessage(aiChatMemory.getContent()));
case "system" -> messages.add(new SystemMessage(aiChatMemory.getContent()));
default -> throw new IllegalArgumentException("Unknown message type: " + type);
}
}
//------------------------------------------------------
return messages;
}
其原理如下: 只储存文本和Message
的实现类的Type,并不将整个对象信息存入,在取出时手动创建相应的Message
实现类.
- 方案二: 使用序列化库来实现,这里我分别尝试了jackson和Kryo序列化库,最终选择了Kryo序列化库,其可以动态注册,减少代码量.
1)先创建数据库表:
sql
create table logger
(
id varchar(255) not null,
userId bigint not null,
message text not null,
time datetime default CURRENT_TIMESTAMP not null
);
create table request
(
id varchar(255) not null,
userId bigint not null,
name varchar(255) not null
);#会话
create table user
(
id bigint not null
primary key,
name varchar(255) not null,
status tinyint not null comment '用户身份
0 - 无ai权限
1 - 有ai权限'
);
2)引入相关依赖:
(这里我使用的是Spring Boot 3.4.4 和 Java 21)
xml
<!-- 自定义持久化的序列化库-->
<dependency>
<groupId>com.esotericsoftware</groupId>
<artifactId>kryo</artifactId>
<version>5.6.2</version>
</dependency>
<!-- mybatis plus-->
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-spring-boot3-starter</artifactId>
<version>3.5.12</version>
</dependency>
<!-- 3.5.9及以上版本想使用mybatis plus分页配置需要单独引入-->
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-jsqlparser</artifactId>
<version>3.5.12</version> <!-- 确保版本和 MyBatis Plus 主包一致 -->
</dependency>
3)在application.yml
配置:
yaml
spring:
#数据库连接
datasource:
driver-class-name: com.mysql.cj.jdbc.Driver
username: root
password: 123456
url: jdbc:mysql://127.0.0.1:3306/aiapp
ai:
dashscope:
api-key: {API - KEY}
chat:
options:
model: qwq-plus
mybatis-plus:
configuration:
map-underscore-to-camel-case: false
4)配置mybatisPlu分页插件:
java
@Configuration
@MapperScan("你的mapper包")
public class MybatisPlusConfig {
/**
* 添加分页插件
*/
@Bean
public MybatisPlusInterceptor mybatisPlusInterceptor() {
MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL)); // 如果配置多个插件, 切记分页最后添加
// 如果有多数据源可以不配具体类型, 否则都建议配上具体的 DbType
return interceptor;
}
}
5)创建序列化工具类:
java
@Component
public class MessageSerializer {
// ⚠️ 静态 Kryo 实例(线程不安全,建议改用局部实例)
private static final Kryo kryo = new Kryo();
static {
kryo.setRegistrationRequired(false);
// 设置实例化策略(需确保兼容所有 Message 实现类)
kryo.setInstantiatorStrategy(new StdInstantiatorStrategy());
}
/**
* 使用 Kryo 将 Message 序列化为 Base64 字符串
*/
public static String serialize(Message message) {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
Output output = new Output(baos)) {
kryo.writeClassAndObject(output, message); // ⚠️ 依赖动态注册和实例化策略
output.flush();
return Base64.getEncoder().encodeToString(baos.toByteArray());
} catch (IOException e) {
throw new RuntimeException("序列化失败", e);
}
}
/**
* 使用 Kryo 将 Base64 字符串反序列化为 Message 对象
*/
public static Message deserialize(String base64) {
try (ByteArrayInputStream bais = new ByteArrayInputStream(Base64.getDecoder().decode(base64));
Input input = new Input(bais)) {
return (Message) kryo.readClassAndObject(input); // ⚠️ 依赖动态注册和实例化策略
} catch (IOException e) {
throw new RuntimeException("反序列化失败", e);
}
}
}
6)创建自定义的数据库持久化ChatMemory:
java
/**
* 自定义数据库持久化
*/
@Service
@Slf4j
public class MySQLChatMemory implements ChatMemory {
private final LoggerMapper loggerMapper;
public MySQLChatMemory(LoggerMapper mapper){
this.loggerMapper = mapper;
}
/**
* 添加一个数据到数据库中
* @param conversationId
* @param message
*/
@Override
public void add(String conversationId, Message message) {
Long userId = parseUserId(conversationId);
Logger logger = new Logger();
logger.setId(conversationId);
logger.setUserId(userId);
logger.setTime(new Date());
logger.setMessage(MessageSerializer.serialize(message));
loggerMapper.insert(logger);
}
/**
* 添加多条数据到数据库中
* @param conversationId
* @param messages
*/
@Override
public void add(String conversationId, List<Message> messages) {
Long userId = parseUserId(conversationId);
List<Logger> loggerList = new ArrayList<>();
for (Message message : messages) {
Logger logger = new Logger();
logger.setId(conversationId);
logger.setUserId(userId);
logger.setTime(new Date());
logger.setMessage(MessageSerializer.serialize(message));
loggerList.add(logger);
}
loggerMapper.insert(loggerList);
}
/**
* 从数据库中获取数据
* 从数据库中获取倒数lastN条数据
* @param conversationId
* @param lastN
* @return
*/
@Override
public List<Message> get(String conversationId, int lastN) {
Long userId = parseUserId(conversationId);
Page<Logger> page = new Page<>(1, lastN);
QueryWrapper<Logger> wrapper = new QueryWrapper<>();
wrapper.eq("id", conversationId)
.eq("userId", userId) // 添加用户 ID 过滤
.orderByDesc("time"); // 按时间倒序
// 使用 selectPage 而非 selectList
List<Logger> loggerList = loggerMapper.selectPage(page, wrapper).getRecords();
List<Message> messages = new ArrayList<>();
for (Logger logger : loggerList) {
messages.add(MessageSerializer.deserialize(logger.getMessage()));
}
return messages;
}
/**
* 清空数据
* @param conversationId
*/
@Override
public void clear(String conversationId) {
Long userId = parseUserId(conversationId);
QueryWrapper<Logger> loggerQueryWrapper = new QueryWrapper<>();
loggerQueryWrapper.eq("id",conversationId);
loggerMapper.deleteById(loggerQueryWrapper);
}
// 从 conversationId 解析用户 ID(格式:chat-{userId})
private long parseUserId(String conversationId) {
String[] parts = conversationId.split("-");
if (parts.length == 2 && "chat".equals(parts[0])) {
return Long.parseLong(parts[1]);
}
throw new IllegalArgumentException("无效的 conversationId 格式: " + conversationId);
}
}
7)创建MyApp:
java
/**
* AI应用程序
* 提供应用程序的调用功能
*/
@Component
@Slf4j
public class MyApp {
private final ChatClient chatClient;
private final MySQLChatMemory mySQLChatMemory;
private static final String SYSTEM_PROMPT= "你是抖音电商 \"厨意生活旗舰店\" 的智能客服小厨," +
"专注于为顾客提供专业、贴心的厨具购物咨询服务。" +
"同时要通过温暖亲切的语言传递品牌温度。";
public MyApp(ChatModel dashscopeChatModel,MySQLChatMemory mySQLChatMemory){
//初始化基于文件持久化的记忆(自定义实现)
// String fileDir = System.getProperty("user.dir")+"/chat-memory";
// ChatMemory chatMemory = new FileBasedChatMemory(fileDir);
//初始化基于内存的对话记忆
// ChatMemory chatMemory = new InMemoryChatMemory();
//初始化基于数据库持久化的记忆(自定义实现)
this.mySQLChatMemory = mySQLChatMemory;
this.chatClient = ChatClient.builder(dashscopeChatModel)
.defaultSystem(SYSTEM_PROMPT)
.defaultAdvisors(
new MessageChatMemoryAdvisor(mySQLChatMemory),
//自定义日志Advisor,可按需开启(自定议)
// new MyLoggerAdvisor(),
//自定义权限校验,校验登录用户是否可以使用AI(自定义)
// new AuthAdvisor(),
//违禁词校验,校验是否存在违禁词(自定义)
// new BannedWordsAdvisor()
)
.build();
}
/**
* AI 基础对话
* @param message
* @param chatId
* @return
*/
public String doChat(String message,String chatId){
ChatResponse chatResponse = chatClient
.prompt()
.user(message)
.advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
.call()
.chatResponse();
String content = chatResponse.getResult().getOutput().getText();
// log.info("content:{}",content);
return content;
}
}
代码测试
现在可以测试了,创建测试类,也可以使用接口测试,这里我是用请求测试的
创建AIController
java
@RestController
@RequestMapping("/ai")
public class AIController {
@Resource
private MyApp myApp;
@Resource
private RequestService requestService;
/**
* AI 对话
* @param message 对话信息
* @param name 会话名字(方便操作,没用ID)
* @param req HttpServletRequest
* @return 对话的结果
*/
@GetMapping("/doChat")
public String doChat(String message, String name,HttpServletRequest req){
User user = (User)req.getSession().getAttribute(UserConstant.USER_LOGIN_STATE);
if(user == null){
throw new RuntimeException("未登录");
}
QueryWrapper<Request> requestQueryWrapper = new QueryWrapper<>();
requestQueryWrapper.eq("name",name).eq("userId",user.getId());
Request request = requestService.getOne(requestQueryWrapper);
return myApp.doChat(message, request.getId());
}
}
先进行对话

查看数据库

关闭项目重启,再提问

成功!!!
在AI操作时可以手动记录到数据库中,自定义持久化到RedisChatMemory,这样可以更快的响应,同时数据库也保存了长期的对话信息.