SpringAi基于PgSQL数据库存储扩展ChatMemory

一、环境准备

SpringAI入门学习

XML 复制代码
        <!-- SpringAI-->
        <dependency>
            <groupId>com.alibaba.cloud.ai</groupId>
            <artifactId>spring-ai-alibaba-starter</artifactId>
            <version>1.0.0-M6.1</version>
        </dependency>
        <dependency>
            <groupId>org.postgresql</groupId>
            <artifactId>postgresql</artifactId>
            <version>42.3.1</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-pgvector-store-spring-boot-starter</artifactId>
            <version>1.0.0-M6</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-jdbc</artifactId>
        </dependency>

上文中实现了基于内存的会话保持功能,现在要基于数据库扩展实现

JdbcTemplate+PgSQL数据库实现扩展ChatMemory实现。

二、扩展实现

java 复制代码
package org.springframework.ai.chat.memory;

import java.util.List;
import org.springframework.ai.chat.messages.Message;

public interface ChatMemory {
    default void add(String conversationId, Message message) {
        this.add(conversationId, List.of(message));
    }

    void add(String conversationId, List<Message> messages);

    List<Message> get(String conversationId, int lastN);

    void clear(String conversationId);
}
  • 表结构准备
sql 复制代码
CREATE TABLE chat_messages (
                               id BIGSERIAL PRIMARY KEY,
                               conversation_id VARCHAR(255) NOT NULL,
                               message_type VARCHAR(50) NOT NULL,
                               content TEXT NOT NULL,
                               created_at TIMESTAMP NOT NULL
);
-- 创建索引
CREATE INDEX idx_conversation_id ON chat_messages (conversation_id);
  • ChatDao接口定义
java 复制代码
package org.spring.springaiprojet.dao;

import org.spring.springaiprojet.entity.ChatMessageEntity;

import java.util.List;

public interface ChatDao {
    /**
     * 保存表
     * @param messages
     */
    void insertMessages(List<ChatMessageEntity> messages);

    /**
     * 查询最近的N条消息
     * @param conversationId
     * @param lastN
     * @return
     */
    List<ChatMessageEntity> findLastNMessages(String conversationId, int lastN);


    /**
     * 删除会话
     * @param conversationId
     */
    void deleteByConversationId(String conversationId);
}
  • ChatImpl接口实现
java 复制代码
import org.spring.springaiprojet.dao.ChatDao;
import org.spring.springaiprojet.entity.ChatMessageEntity;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Repository;

import java.time.LocalDateTime;
import java.util.List;

@Repository
public class ChatDaoImpl implements ChatDao {

    @Autowired
    private JdbcTemplate jdbcTemplate;

    @Override
    public void insertMessages(List<ChatMessageEntity> messages) {
         jdbcTemplate.batchUpdate("insert into chat_messages (conversation_id, message_type, content, created_at) values (?, ?, ?, ?)", messages, messages.size(), (ps, message) -> {
            ps.setString(1, message.getConversationId());
            ps.setString(2, message.getMessageType());
            ps.setString(3, message.getContent());
            ps.setObject(4, message.getCreatedAt());
        });
    }

    @Override
    public List<ChatMessageEntity> findLastNMessages(String conversationId, int lastN) {
        return jdbcTemplate.query(
                "select * from chat_messages where conversation_id = ? order by created_at desc limit ?",
                (rs, rowNum) -> {
                    ChatMessageEntity message = new ChatMessageEntity();
                    message.setConversationId(rs.getString("conversation_id"));
                    message.setMessageType(rs.getString("message_type"));
                    message.setContent(rs.getString("content"));
                    message.setCreatedAt(rs.getObject("created_at", LocalDateTime.class));
                    return message;
                },
                conversationId,
                lastN
        );
    }

    @Override
    public void deleteByConversationId(String conversationId) {
        jdbcTemplate.update("delete from chat_messages where conversation_id = ?", conversationId);
    }
}
复制代码
ChatMessageEntity
java 复制代码
package org.spring.springaiprojet.entity;

import java.time.LocalDateTime;

public class ChatMessageEntity {
    private Long id;
    private String conversationId;
    private String messageType;
    private String content;
    private LocalDateTime createdAt;

    public Long getId() {
        return id;
    }

    public void setId(Long id) {
        this.id = id;
    }

    public String getConversationId() {
        return conversationId;
    }

    public void setConversationId(String conversationId) {
        this.conversationId = conversationId;
    }

    public String getMessageType() {
        return messageType;
    }

    public void setMessageType(String messageType) {
        this.messageType = messageType;
    }

    public String getContent() {
        return content;
    }

    public void setContent(String content) {
        this.content = content;
    }

    public LocalDateTime getCreatedAt() {
        return createdAt;
    }

    public void setCreatedAt(LocalDateTime createdAt) {
        this.createdAt = createdAt;
    }
}
复制代码
PgSQLChatMemory
java 复制代码
package org.spring.springaiprojet.config.chat;

import org.spring.springaiprojet.dao.ChatDao;
import org.spring.springaiprojet.entity.ChatMessageEntity;
import org.spring.springaiprojet.entity.MessageEnum;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.time.LocalDateTime;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

@Component
public class PgSQLChatMemory implements ChatMemory {

    @Autowired
    private ChatDao chatDao;
    @Override
    public void add(String conversationId, List<Message> messages) {
        if (messages == null || messages.isEmpty()) {
            return;
        }
        List<ChatMessageEntity> entities = messages.stream()
                .map(msg -> {
                    ChatMessageEntity entity = new ChatMessageEntity();
                    entity.setConversationId(conversationId);
                    entity.setContent(msg.getText());

                    if (msg instanceof UserMessage) {
                        entity.setMessageType(MessageEnum.USER.getValue());
                    } else if (msg instanceof AssistantMessage) {
                        entity.setMessageType(MessageEnum.ASSISTANT.getValue());
                    }
                    entity.setCreatedAt(LocalDateTime.now());
                    return entity;
                })
                .collect(Collectors.toList());
       chatDao.insertMessages(entities);
    }

    @Override
    public List<Message> get(String conversationId, int lastN) {
        List<ChatMessageEntity> entities = chatDao.findLastNMessages(conversationId, lastN);
        if (entities == null || entities.isEmpty()) {
            return Collections.emptyList();
        }
        // 倒序
        Collections.reverse(entities);
        return entities.stream()
                .map(entity -> {
                    switch (entity.getMessageType()) {
                        case "user":
                            return new UserMessage(entity.getContent());
                        case "assistant":
                            return new AssistantMessage(entity.getContent());
                        default:
                            throw new IllegalArgumentException("未知的消息类型!");
                    }
                })
                .collect(Collectors.toList());
    }

    @Override
    public void clear(String conversationId) {
        if (conversationId == null){
            return;
        }
        chatDao.deleteByConversationId(conversationId);
    }

}
  • AiConfig实现
java 复制代码
    /**
     * 基于PgSQL实现会话记忆
     */
    @Autowired
    private PgSQLChatMemory pgSQLChatMemory;

    @Bean
    public ChatClient chatClient(ChatClient.Builder builder) {
        // defaultSystem,默认系统角色,带有对话身份
        return builder
//                .defaultSystem("请以通俗开发者角度介绍")
                // 增加会话记忆,基于
                .defaultAdvisors(new PromptChatMemoryAdvisor(pgSQLChatMemory)).build();
//                .defaultAdvisors(new SimpleLoggerAdvisor()).build();
    }
  • 控制器实现
java 复制代码
    /**
     * 普通对话模式
     * @param question
     * @return
     */
    @RequestMapping("/qwen/chat/api")
    public String chat(String question) {
        return chatClient.prompt().user(question).call().content();
    }

三、测试调用

  • 第一次调用如下接口
XML 复制代码
GET http://localhost:8088/boot/ai/qwen/chat/api?question=SpringBoot框架各个模块作用

已经写入数据表中了

  • 第二次基于上文会话内容,写入表中
XML 复制代码
GET http://localhost:8088/boot/ai/qwen/chat/api?question=基于上述回答内容,SpringBoot哪个模块学习难度较大,评估一下

再次查看数据库表,发现已经基于会话保存回答问题了

  • 第三次提问,会话记忆
java 复制代码
GET http://localhost:8088/boot/ai/qwen/chat/api?question=基于上述回答内容,SpringBoot哪个模块学习难度相对较小,比较容易上手,评估一下

再次查看数据库,以及基本保存进来,基于数据库内容实现上下文会话回答了。

相关推荐
猫豆~2 小时前
ceph分布式存储——1day
java·linux·数据库·sql·云计算
有想法的py工程师2 小时前
PostgreSQL 分区表 + Debezium CDC:为什么 REPLICA IDENTITY FULL 不生效?
数据库·postgresql
倔强的石头1062 小时前
金仓数据库(KingbaseES) 开发实战:常见迁移挑战与技术解析
数据库·kingbase
TDengine (老段)2 小时前
TDengine IDMP 地图展示数据功能快速上手
大数据·数据库·物联网·时序数据库·tdengine·涛思数据
档案宝档案管理2 小时前
电子会计档案管理系统:档案宝如何发挥会计档案的价值?
大数据·数据库·人工智能·档案·档案管理
° 安如少年初如梦6622 小时前
DataGrip/DBeaver/官方工具 连接瀚高数据库教程
数据库·瀚高·highgo
云和恩墨2 小时前
告别“头痛医头”:SQL性能优化的闭环构建,从被动修复到主动掌控
数据库·oracle
独自归家的兔2 小时前
通义千问3-VL-Plus - 界面交互(坐标改进)
数据库·microsoft·交互
p&f°2 小时前
PostgreSQL 执行计划控制参数详解
数据库·postgresql·oracle