Spring AI 自定义数据库持久化的ChatMemory

若要实现自定义Redis持久化ChatMemory,可以参考这篇文章 Spring AI 自定义Redis持久化的ChatMemory

自定义实现ChatMemory

Spring AI的对话记忆实现非常巧妙,解耦了"储存"和"记忆算法",

  • 存储:ChatMemory:我们可以单独修改ChatMemory存储来改变对话记忆的保存位置,而无需修改保存对话记忆的流程。

  • 记忆算法:ChatMemory Advisor,advisor可以理解为拦截器,在调用大模型时的前或后执行一些操作

    • MessageChatMemoryAdvisor: 从记忆中(ChatMemory)检索历史对话,并将其作为消息集合添加到提示词中。常用。能更好的保持上下文连贯性。
    • PromptChatMemoryAdvisor: 从记忆中检索历史对话,并将其添加到提示词的系统文本中。可以理解为没有结构性的纯文本。
    • VectorStoreChatMemoryAdvisor: 可以用向量数据库来存储检索历史对话。

我们可以单独修改ChatMemory储存来改变对话记忆的保存位置,而无需修改保存对话记忆的流程.

虽然官方文档没有给我们自定义ChatMemory实现的示例,但是我们可以直接去阅读默认实现类 InMemoryChatMemory 的源码

其本质是实现了ChatMemory的增删查接口

所以我们想实现自己的持久化,修改对应的储存实现就行了.

参考 InMemoryChatMemory 的源码,其实就是通过 ConcurrentHashMap 来维护对话信息,key 是对话 id(相当于房间号),value 是该对话 id 对应的消息列表。

自定义MYSQL持久化ChatMemory

本质是将数据储存到MySQL中,同样的,由于List<Message>中Message是一个接口,虽然需要实现的接口不多,但是实现起来还是有一定复杂度的,一个最主要的问题是 消息和文本的转换。我们在保存消息时,要将消息从 Message 对象转为文件内的文本;读取消息时,要将文件内的文本转换为 Message 对象。也就是对象的序列化和反序列化。

我们本能地会想到通过 JSON 进行序列化,但实际操作中,我们发现这并不容易。原因是:

  1. 要持久化的 Message 是一个接口,有很多种不同的子类实现(比如 UserMessage、SystemMessage 等)
  2. 每种子类所拥有的字段都不一样,结构不统一
  3. 子类没有无参构造函数,而且没有实现 Serializable 序列化接口

在这里有两个方案:

  1. 自己结构化数据库,使用结构化的数据库然后自己手动创建Message的实现对象来序列化.这里参考这篇文章:

SpringAI--基于MySQL的持久化对话记忆实现

java 复制代码
@Override
    public List<Message> get(String conversationId, int lastN) {
        // 分页查询最近N条记录
        Page<AiChatMemory> page = new Page<>(1, lastN);
        QueryWrapper<AiChatMemory> wrapper = new QueryWrapper<>();
        wrapper.eq("conversation_id", conversationId)
                .orderByDesc("create_time");

        List<AiChatMemory> aiChatMemories = mapper.selectList(wrapper);
        // 反转列表,使得最新的消息在最后
        Collections.reverse(aiChatMemories);
//------------------------------------------------------
		// 转换为Message对象
    	List<Message> messages = new ArrayList<>();
    	for (AiChatMemory aiChatMemory : aiChatMemories) {
        	String type = aiChatMemory.getType();//数据库储存的Message实现类类型
        	switch (type) {//根据类型手动创建出实现类
            	case "user" -> messages.add(new UserMessage(aiChatMemory.getContent()));
            	case "assistant" -> messages.add(new AssistantMessage(aiChatMemory.getContent()));
            	case "system" -> messages.add(new SystemMessage(aiChatMemory.getContent()));
            	default -> throw new IllegalArgumentException("Unknown message type: " + type);
        	}
    	}
//------------------------------------------------------
    	return messages;
    }
  

其原理如下: 只储存文本和Message的实现类的Type,并不将整个对象信息存入,在取出时手动创建相应的Message实现类.

  1. 方案二: 使用序列化库来实现,这里我分别尝试了jackson和Kryo序列化库,最终选择了Kryo序列化库,其可以动态注册,减少代码量.

1)先创建数据库表:

sql 复制代码
create table logger
(
    id      varchar(255)                       not null,
    userId  bigint                             not null,
    message text                               not null,
    time    datetime default CURRENT_TIMESTAMP not null
);

create table request
(
    id     varchar(255) not null,
    userId bigint       not null,
    name   varchar(255) not null
);#会话

create table user
(
    id     bigint       not null
        primary key,
    name   varchar(255) not null,
    status tinyint      not null comment '用户身份
0 - 无ai权限
1 - 有ai权限'
);
2)引入相关依赖:

(这里我使用的是Spring Boot 3.4.4 和 Java 21)

xml 复制代码
		<!-- 自定义持久化的序列化库-->
		<dependency>
    		<groupId>com.esotericsoftware</groupId>
    		<artifactId>kryo</artifactId>
    		<version>5.6.2</version>
		</dependency>
		<!-- mybatis plus-->
        <dependency>
            <groupId>com.baomidou</groupId>
            <artifactId>mybatis-plus-spring-boot3-starter</artifactId>
            <version>3.5.12</version>
        </dependency>
        <!-- 3.5.9及以上版本想使用mybatis plus分页配置需要单独引入-->
        <dependency>
            <groupId>com.baomidou</groupId>
            <artifactId>mybatis-plus-jsqlparser</artifactId>
            <version>3.5.12</version> <!-- 确保版本和 MyBatis Plus 主包一致 -->
        </dependency>

3)在application.yml配置:

yaml 复制代码
spring:
  #数据库连接
  datasource:
    driver-class-name: com.mysql.cj.jdbc.Driver
    username: root
    password: 123456
    url: jdbc:mysql://127.0.0.1:3306/aiapp
  ai:
    dashscope:
      api-key: {API - KEY}
    chat:
      options:
      model: qwq-plus
mybatis-plus:
  configuration:
    map-underscore-to-camel-case: false

4)配置mybatisPlu分页插件:

java 复制代码
@Configuration
@MapperScan("你的mapper包")
public class MybatisPlusConfig {

    /**
     * 添加分页插件
     */
    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL)); // 如果配置多个插件, 切记分页最后添加
        // 如果有多数据源可以不配具体类型, 否则都建议配上具体的 DbType
        return interceptor;
    }
}

5)创建序列化工具类:

java 复制代码
@Component
public class MessageSerializer {

    // ⚠️ 静态 Kryo 实例(线程不安全,建议改用局部实例)
    private static final Kryo kryo = new Kryo();

    static {
        kryo.setRegistrationRequired(false);
        // 设置实例化策略(需确保兼容所有 Message 实现类)
        kryo.setInstantiatorStrategy(new StdInstantiatorStrategy());
    }

    /**
     * 使用 Kryo 将 Message 序列化为 Base64 字符串
     */
    public static String serialize(Message message) {
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
             Output output = new Output(baos)) {
            kryo.writeClassAndObject(output, message);  // ⚠️ 依赖动态注册和实例化策略
            output.flush();
            return Base64.getEncoder().encodeToString(baos.toByteArray());
        } catch (IOException e) {
            throw new RuntimeException("序列化失败", e);
        }
    }

    /**
     * 使用 Kryo 将 Base64 字符串反序列化为 Message 对象
     */
    public static Message deserialize(String base64) {
        try (ByteArrayInputStream bais = new ByteArrayInputStream(Base64.getDecoder().decode(base64));
             Input input = new Input(bais)) {
            return (Message) kryo.readClassAndObject(input);  // ⚠️ 依赖动态注册和实例化策略
        } catch (IOException e) {
            throw new RuntimeException("反序列化失败", e);
        }
    }
}

6)创建自定义的数据库持久化ChatMemory:

java 复制代码
/**
 * 自定义数据库持久化
 */
@Service
@Slf4j
public class MySQLChatMemory implements ChatMemory {


    private final LoggerMapper loggerMapper;

    public MySQLChatMemory(LoggerMapper mapper){
        this.loggerMapper = mapper;
    }
    

            

    /**
     * 添加一个数据到数据库中
     * @param conversationId
     * @param message
     */
    @Override
    public void add(String conversationId, Message message) {
        Long userId = parseUserId(conversationId);
        Logger logger = new Logger();
        logger.setId(conversationId);
        logger.setUserId(userId);
        logger.setTime(new Date());
        logger.setMessage(MessageSerializer.serialize(message));
        loggerMapper.insert(logger);


    }

    /**
     * 添加多条数据到数据库中
     * @param conversationId
     * @param messages
     */
    @Override
    public void add(String conversationId, List<Message> messages) {
        Long userId = parseUserId(conversationId);
        List<Logger> loggerList = new ArrayList<>();
        for (Message message : messages) {
            Logger logger = new Logger();
            logger.setId(conversationId);
            logger.setUserId(userId);
            logger.setTime(new Date());
            logger.setMessage(MessageSerializer.serialize(message));
            loggerList.add(logger);
        }
        loggerMapper.insert(loggerList);


    }

    /**
     * 从数据库中获取数据
     * 从数据库中获取倒数lastN条数据
     * @param conversationId
     * @param lastN
     * @return
     */
    @Override
    public List<Message> get(String conversationId, int lastN) {
        Long userId = parseUserId(conversationId);
        Page<Logger> page = new Page<>(1, lastN);
        QueryWrapper<Logger> wrapper = new QueryWrapper<>();
        wrapper.eq("id", conversationId)
                .eq("userId", userId) // 添加用户 ID 过滤
                .orderByDesc("time"); // 按时间倒序

        // 使用 selectPage 而非 selectList
        List<Logger> loggerList = loggerMapper.selectPage(page, wrapper).getRecords();

        List<Message> messages = new ArrayList<>();
        for (Logger logger : loggerList) {
            messages.add(MessageSerializer.deserialize(logger.getMessage()));
        }
        return messages;
    }

    /**
     * 清空数据
     * @param conversationId
     */
    @Override
    public void clear(String conversationId) {
        Long userId = parseUserId(conversationId);
        QueryWrapper<Logger> loggerQueryWrapper = new QueryWrapper<>();
        loggerQueryWrapper.eq("id",conversationId);
        loggerMapper.deleteById(loggerQueryWrapper);

    }


    // 从 conversationId 解析用户 ID(格式:chat-{userId})
    private long parseUserId(String conversationId) {
        String[] parts = conversationId.split("-");
        if (parts.length == 2 && "chat".equals(parts[0])) {
            return Long.parseLong(parts[1]);
        }
        throw new IllegalArgumentException("无效的 conversationId 格式: " + conversationId);
    }



}

7)创建MyApp:

java 复制代码
/**
 * AI应用程序
 * 提供应用程序的调用功能
 */
@Component
@Slf4j
public class MyApp {

    private final ChatClient chatClient;

    private final MySQLChatMemory mySQLChatMemory;

    private static final String SYSTEM_PROMPT= "你是抖音电商 \"厨意生活旗舰店\" 的智能客服小厨," +
            "专注于为顾客提供专业、贴心的厨具购物咨询服务。" +
            "同时要通过温暖亲切的语言传递品牌温度。";

    public MyApp(ChatModel dashscopeChatModel,MySQLChatMemory mySQLChatMemory){
        //初始化基于文件持久化的记忆(自定义实现)
//        String fileDir = System.getProperty("user.dir")+"/chat-memory";
//        ChatMemory chatMemory = new FileBasedChatMemory(fileDir);

        //初始化基于内存的对话记忆
//        ChatMemory chatMemory = new InMemoryChatMemory();

        //初始化基于数据库持久化的记忆(自定义实现)
        this.mySQLChatMemory = mySQLChatMemory;

        this.chatClient = ChatClient.builder(dashscopeChatModel)
                .defaultSystem(SYSTEM_PROMPT)
                .defaultAdvisors(
                        new MessageChatMemoryAdvisor(mySQLChatMemory),
                        //自定义日志Advisor,可按需开启(自定议)
//                        new MyLoggerAdvisor(),
                        //自定义权限校验,校验登录用户是否可以使用AI(自定义)
//                        new AuthAdvisor(),
                        //违禁词校验,校验是否存在违禁词(自定义)
//                        new BannedWordsAdvisor()
                )
                .build();
    }


    /**
     * AI 基础对话
     * @param message
     * @param chatId
     * @return
     */
    public String doChat(String message,String chatId){
        ChatResponse chatResponse = chatClient
                .prompt()
                .user(message)
                .advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
                        .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10))
                .call()
                .chatResponse();
        String content = chatResponse.getResult().getOutput().getText();
//        log.info("content:{}",content);
        return content;
    }


}

代码测试

现在可以测试了,创建测试类,也可以使用接口测试,这里我是用请求测试的

创建AIController

java 复制代码
@RestController
@RequestMapping("/ai")
public class AIController {

    @Resource
    private MyApp myApp;

    @Resource
    private RequestService requestService;

    
     /**
     * AI 对话
     * @param message 对话信息
     * @param name 会话名字(方便操作,没用ID)
     * @param req HttpServletRequest
     * @return 对话的结果
     */
    @GetMapping("/doChat")
    public String doChat(String message, String name,HttpServletRequest req){
        User user = (User)req.getSession().getAttribute(UserConstant.USER_LOGIN_STATE);
        if(user == null){
            throw new RuntimeException("未登录");
        }
        QueryWrapper<Request> requestQueryWrapper = new QueryWrapper<>();

        requestQueryWrapper.eq("name",name).eq("userId",user.getId());
        Request request = requestService.getOne(requestQueryWrapper);
        return myApp.doChat(message, request.getId());
    }


}

先进行对话

查看数据库

关闭项目重启,再提问

成功!!!

在AI操作时可以手动记录到数据库中,自定义持久化到RedisChatMemory,这样可以更快的响应,同时数据库也保存了长期的对话信息.

相关推荐
考虑考虑29 分钟前
Springboot3.5.x结构化日志新属性
spring boot·后端·spring
涡能增压发动积31 分钟前
一起来学 Langgraph [第三节]
后端
sky_ph44 分钟前
JAVA-GC浅析(二)G1(Garbage First)回收器
java·后端
涡能增压发动积1 小时前
一起来学 Langgraph [第二节]
后端
hello早上好1 小时前
Spring不同类型的ApplicationContext的创建方式
java·后端·架构
roman_日积跬步-终至千里1 小时前
【Go语言基础【20】】Go的包与工程
开发语言·后端·golang
00后程序员2 小时前
提升移动端网页调试效率:WebDebugX 与常见工具组合实践
后端
HyggeBest3 小时前
Mysql的数据存储结构
后端·架构
TobyMint3 小时前
golang 实现雪花算法
后端
G探险者3 小时前
【案例解析】一次 TIME_WAIT 导致 TPS 断崖式下降的排查与优化
后端