SpringAi基于PgSQL数据库存储扩展ChatMemory

一、环境准备

SpringAI入门学习

XML 复制代码
        <!-- SpringAI-->
        <dependency>
            <groupId>com.alibaba.cloud.ai</groupId>
            <artifactId>spring-ai-alibaba-starter</artifactId>
            <version>1.0.0-M6.1</version>
        </dependency>
        <dependency>
            <groupId>org.postgresql</groupId>
            <artifactId>postgresql</artifactId>
            <version>42.3.1</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-pgvector-store-spring-boot-starter</artifactId>
            <version>1.0.0-M6</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-jdbc</artifactId>
        </dependency>

上文中实现了基于内存的会话保持功能,现在要基于数据库扩展实现

JdbcTemplate+PgSQL数据库实现扩展ChatMemory实现。

二、扩展实现

java 复制代码
package org.springframework.ai.chat.memory;

import java.util.List;
import org.springframework.ai.chat.messages.Message;

public interface ChatMemory {
    default void add(String conversationId, Message message) {
        this.add(conversationId, List.of(message));
    }

    void add(String conversationId, List<Message> messages);

    List<Message> get(String conversationId, int lastN);

    void clear(String conversationId);
}
  • 表结构准备
sql 复制代码
CREATE TABLE chat_messages (
                               id BIGSERIAL PRIMARY KEY,
                               conversation_id VARCHAR(255) NOT NULL,
                               message_type VARCHAR(50) NOT NULL,
                               content TEXT NOT NULL,
                               created_at TIMESTAMP NOT NULL
);
-- 创建索引
CREATE INDEX idx_conversation_id ON chat_messages (conversation_id);
  • ChatDao接口定义
java 复制代码
package org.spring.springaiprojet.dao;

import org.spring.springaiprojet.entity.ChatMessageEntity;

import java.util.List;

public interface ChatDao {
    /**
     * 保存表
     * @param messages
     */
    void insertMessages(List<ChatMessageEntity> messages);

    /**
     * 查询最近的N条消息
     * @param conversationId
     * @param lastN
     * @return
     */
    List<ChatMessageEntity> findLastNMessages(String conversationId, int lastN);


    /**
     * 删除会话
     * @param conversationId
     */
    void deleteByConversationId(String conversationId);
}
  • ChatImpl接口实现
java 复制代码
import org.spring.springaiprojet.dao.ChatDao;
import org.spring.springaiprojet.entity.ChatMessageEntity;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Repository;

import java.time.LocalDateTime;
import java.util.List;

@Repository
public class ChatDaoImpl implements ChatDao {

    @Autowired
    private JdbcTemplate jdbcTemplate;

    @Override
    public void insertMessages(List<ChatMessageEntity> messages) {
         jdbcTemplate.batchUpdate("insert into chat_messages (conversation_id, message_type, content, created_at) values (?, ?, ?, ?)", messages, messages.size(), (ps, message) -> {
            ps.setString(1, message.getConversationId());
            ps.setString(2, message.getMessageType());
            ps.setString(3, message.getContent());
            ps.setObject(4, message.getCreatedAt());
        });
    }

    @Override
    public List<ChatMessageEntity> findLastNMessages(String conversationId, int lastN) {
        return jdbcTemplate.query(
                "select * from chat_messages where conversation_id = ? order by created_at desc limit ?",
                (rs, rowNum) -> {
                    ChatMessageEntity message = new ChatMessageEntity();
                    message.setConversationId(rs.getString("conversation_id"));
                    message.setMessageType(rs.getString("message_type"));
                    message.setContent(rs.getString("content"));
                    message.setCreatedAt(rs.getObject("created_at", LocalDateTime.class));
                    return message;
                },
                conversationId,
                lastN
        );
    }

    @Override
    public void deleteByConversationId(String conversationId) {
        jdbcTemplate.update("delete from chat_messages where conversation_id = ?", conversationId);
    }
}
复制代码
ChatMessageEntity
java 复制代码
package org.spring.springaiprojet.entity;

import java.time.LocalDateTime;

public class ChatMessageEntity {
    private Long id;
    private String conversationId;
    private String messageType;
    private String content;
    private LocalDateTime createdAt;

    public Long getId() {
        return id;
    }

    public void setId(Long id) {
        this.id = id;
    }

    public String getConversationId() {
        return conversationId;
    }

    public void setConversationId(String conversationId) {
        this.conversationId = conversationId;
    }

    public String getMessageType() {
        return messageType;
    }

    public void setMessageType(String messageType) {
        this.messageType = messageType;
    }

    public String getContent() {
        return content;
    }

    public void setContent(String content) {
        this.content = content;
    }

    public LocalDateTime getCreatedAt() {
        return createdAt;
    }

    public void setCreatedAt(LocalDateTime createdAt) {
        this.createdAt = createdAt;
    }
}
复制代码
PgSQLChatMemory
java 复制代码
package org.spring.springaiprojet.config.chat;

import org.spring.springaiprojet.dao.ChatDao;
import org.spring.springaiprojet.entity.ChatMessageEntity;
import org.spring.springaiprojet.entity.MessageEnum;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.time.LocalDateTime;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

@Component
public class PgSQLChatMemory implements ChatMemory {

    @Autowired
    private ChatDao chatDao;
    @Override
    public void add(String conversationId, List<Message> messages) {
        if (messages == null || messages.isEmpty()) {
            return;
        }
        List<ChatMessageEntity> entities = messages.stream()
                .map(msg -> {
                    ChatMessageEntity entity = new ChatMessageEntity();
                    entity.setConversationId(conversationId);
                    entity.setContent(msg.getText());

                    if (msg instanceof UserMessage) {
                        entity.setMessageType(MessageEnum.USER.getValue());
                    } else if (msg instanceof AssistantMessage) {
                        entity.setMessageType(MessageEnum.ASSISTANT.getValue());
                    }
                    entity.setCreatedAt(LocalDateTime.now());
                    return entity;
                })
                .collect(Collectors.toList());
       chatDao.insertMessages(entities);
    }

    @Override
    public List<Message> get(String conversationId, int lastN) {
        List<ChatMessageEntity> entities = chatDao.findLastNMessages(conversationId, lastN);
        if (entities == null || entities.isEmpty()) {
            return Collections.emptyList();
        }
        // 倒序
        Collections.reverse(entities);
        return entities.stream()
                .map(entity -> {
                    switch (entity.getMessageType()) {
                        case "user":
                            return new UserMessage(entity.getContent());
                        case "assistant":
                            return new AssistantMessage(entity.getContent());
                        default:
                            throw new IllegalArgumentException("未知的消息类型!");
                    }
                })
                .collect(Collectors.toList());
    }

    @Override
    public void clear(String conversationId) {
        if (conversationId == null){
            return;
        }
        chatDao.deleteByConversationId(conversationId);
    }

}
  • AiConfig实现
java 复制代码
    /**
     * 基于PgSQL实现会话记忆
     */
    @Autowired
    private PgSQLChatMemory pgSQLChatMemory;

    @Bean
    public ChatClient chatClient(ChatClient.Builder builder) {
        // defaultSystem,默认系统角色,带有对话身份
        return builder
//                .defaultSystem("请以通俗开发者角度介绍")
                // 增加会话记忆,基于
                .defaultAdvisors(new PromptChatMemoryAdvisor(pgSQLChatMemory)).build();
//                .defaultAdvisors(new SimpleLoggerAdvisor()).build();
    }
  • 控制器实现
java 复制代码
    /**
     * 普通对话模式
     * @param question
     * @return
     */
    @RequestMapping("/qwen/chat/api")
    public String chat(String question) {
        return chatClient.prompt().user(question).call().content();
    }

三、测试调用

  • 第一次调用如下接口
XML 复制代码
GET http://localhost:8088/boot/ai/qwen/chat/api?question=SpringBoot框架各个模块作用

已经写入数据表中了

  • 第二次基于上文会话内容,写入表中
XML 复制代码
GET http://localhost:8088/boot/ai/qwen/chat/api?question=基于上述回答内容,SpringBoot哪个模块学习难度较大,评估一下

再次查看数据库表,发现已经基于会话保存回答问题了

  • 第三次提问,会话记忆
java 复制代码
GET http://localhost:8088/boot/ai/qwen/chat/api?question=基于上述回答内容,SpringBoot哪个模块学习难度相对较小,比较容易上手,评估一下

再次查看数据库,以及基本保存进来,基于数据库内容实现上下文会话回答了。

相关推荐
TDengine (老段)9 小时前
金融风控系统中的实时数据库技术实践
大数据·数据库·物联网·时序数据库·tdengine·涛思数据
看我干嘛!9 小时前
第三次python作业
服务器·数据库·python
2501_936960369 小时前
ROS快速入门教程
数据库·mongodb
知识分享小能手9 小时前
Oracle 19c入门学习教程,从入门到精通,Oracle 的闪回技术 — 语法知识点与使用方法详解(19)
数据库·学习·oracle
踢足球09299 小时前
寒假打卡:2026-01-31
数据库·sql
是小崔啊9 小时前
PostgreSQL快速入门
数据库·postgresql
xxxmine9 小时前
Redis 持久化详解:RDB、AOF 与混合模式
数据库·redis·缓存
yufuu989 小时前
使用Scikit-learn进行机器学习模型评估
jvm·数据库·python
MMME~9 小时前
Ansible模块速查指南:高效定位与实战技巧
大数据·运维·数据库
甘露s10 小时前
深入理解 Redis:事务、持久化与过期策略全解析
数据库·redis