SpringAi基于PgSQL数据库存储扩展ChatMemory

一、环境准备

SpringAI入门学习

XML 复制代码
        <!-- SpringAI-->
        <dependency>
            <groupId>com.alibaba.cloud.ai</groupId>
            <artifactId>spring-ai-alibaba-starter</artifactId>
            <version>1.0.0-M6.1</version>
        </dependency>
        <dependency>
            <groupId>org.postgresql</groupId>
            <artifactId>postgresql</artifactId>
            <version>42.3.1</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-pgvector-store-spring-boot-starter</artifactId>
            <version>1.0.0-M6</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-jdbc</artifactId>
        </dependency>

上文中实现了基于内存的会话保持功能,现在要基于数据库扩展实现

JdbcTemplate+PgSQL数据库实现扩展ChatMemory实现。

二、扩展实现

java 复制代码
package org.springframework.ai.chat.memory;

import java.util.List;
import org.springframework.ai.chat.messages.Message;

public interface ChatMemory {
    default void add(String conversationId, Message message) {
        this.add(conversationId, List.of(message));
    }

    void add(String conversationId, List<Message> messages);

    List<Message> get(String conversationId, int lastN);

    void clear(String conversationId);
}
  • 表结构准备
sql 复制代码
CREATE TABLE chat_messages (
                               id BIGSERIAL PRIMARY KEY,
                               conversation_id VARCHAR(255) NOT NULL,
                               message_type VARCHAR(50) NOT NULL,
                               content TEXT NOT NULL,
                               created_at TIMESTAMP NOT NULL
);
-- 创建索引
CREATE INDEX idx_conversation_id ON chat_messages (conversation_id);
  • ChatDao接口定义
java 复制代码
package org.spring.springaiprojet.dao;

import org.spring.springaiprojet.entity.ChatMessageEntity;

import java.util.List;

public interface ChatDao {
    /**
     * 保存表
     * @param messages
     */
    void insertMessages(List<ChatMessageEntity> messages);

    /**
     * 查询最近的N条消息
     * @param conversationId
     * @param lastN
     * @return
     */
    List<ChatMessageEntity> findLastNMessages(String conversationId, int lastN);


    /**
     * 删除会话
     * @param conversationId
     */
    void deleteByConversationId(String conversationId);
}
  • ChatImpl接口实现
java 复制代码
import org.spring.springaiprojet.dao.ChatDao;
import org.spring.springaiprojet.entity.ChatMessageEntity;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Repository;

import java.time.LocalDateTime;
import java.util.List;

@Repository
public class ChatDaoImpl implements ChatDao {

    @Autowired
    private JdbcTemplate jdbcTemplate;

    @Override
    public void insertMessages(List<ChatMessageEntity> messages) {
         jdbcTemplate.batchUpdate("insert into chat_messages (conversation_id, message_type, content, created_at) values (?, ?, ?, ?)", messages, messages.size(), (ps, message) -> {
            ps.setString(1, message.getConversationId());
            ps.setString(2, message.getMessageType());
            ps.setString(3, message.getContent());
            ps.setObject(4, message.getCreatedAt());
        });
    }

    @Override
    public List<ChatMessageEntity> findLastNMessages(String conversationId, int lastN) {
        return jdbcTemplate.query(
                "select * from chat_messages where conversation_id = ? order by created_at desc limit ?",
                (rs, rowNum) -> {
                    ChatMessageEntity message = new ChatMessageEntity();
                    message.setConversationId(rs.getString("conversation_id"));
                    message.setMessageType(rs.getString("message_type"));
                    message.setContent(rs.getString("content"));
                    message.setCreatedAt(rs.getObject("created_at", LocalDateTime.class));
                    return message;
                },
                conversationId,
                lastN
        );
    }

    @Override
    public void deleteByConversationId(String conversationId) {
        jdbcTemplate.update("delete from chat_messages where conversation_id = ?", conversationId);
    }
}
复制代码
ChatMessageEntity
java 复制代码
package org.spring.springaiprojet.entity;

import java.time.LocalDateTime;

public class ChatMessageEntity {
    private Long id;
    private String conversationId;
    private String messageType;
    private String content;
    private LocalDateTime createdAt;

    public Long getId() {
        return id;
    }

    public void setId(Long id) {
        this.id = id;
    }

    public String getConversationId() {
        return conversationId;
    }

    public void setConversationId(String conversationId) {
        this.conversationId = conversationId;
    }

    public String getMessageType() {
        return messageType;
    }

    public void setMessageType(String messageType) {
        this.messageType = messageType;
    }

    public String getContent() {
        return content;
    }

    public void setContent(String content) {
        this.content = content;
    }

    public LocalDateTime getCreatedAt() {
        return createdAt;
    }

    public void setCreatedAt(LocalDateTime createdAt) {
        this.createdAt = createdAt;
    }
}
复制代码
PgSQLChatMemory
java 复制代码
package org.spring.springaiprojet.config.chat;

import org.spring.springaiprojet.dao.ChatDao;
import org.spring.springaiprojet.entity.ChatMessageEntity;
import org.spring.springaiprojet.entity.MessageEnum;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.time.LocalDateTime;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

@Component
public class PgSQLChatMemory implements ChatMemory {

    @Autowired
    private ChatDao chatDao;
    @Override
    public void add(String conversationId, List<Message> messages) {
        if (messages == null || messages.isEmpty()) {
            return;
        }
        List<ChatMessageEntity> entities = messages.stream()
                .map(msg -> {
                    ChatMessageEntity entity = new ChatMessageEntity();
                    entity.setConversationId(conversationId);
                    entity.setContent(msg.getText());

                    if (msg instanceof UserMessage) {
                        entity.setMessageType(MessageEnum.USER.getValue());
                    } else if (msg instanceof AssistantMessage) {
                        entity.setMessageType(MessageEnum.ASSISTANT.getValue());
                    }
                    entity.setCreatedAt(LocalDateTime.now());
                    return entity;
                })
                .collect(Collectors.toList());
       chatDao.insertMessages(entities);
    }

    @Override
    public List<Message> get(String conversationId, int lastN) {
        List<ChatMessageEntity> entities = chatDao.findLastNMessages(conversationId, lastN);
        if (entities == null || entities.isEmpty()) {
            return Collections.emptyList();
        }
        // 倒序
        Collections.reverse(entities);
        return entities.stream()
                .map(entity -> {
                    switch (entity.getMessageType()) {
                        case "user":
                            return new UserMessage(entity.getContent());
                        case "assistant":
                            return new AssistantMessage(entity.getContent());
                        default:
                            throw new IllegalArgumentException("未知的消息类型!");
                    }
                })
                .collect(Collectors.toList());
    }

    @Override
    public void clear(String conversationId) {
        if (conversationId == null){
            return;
        }
        chatDao.deleteByConversationId(conversationId);
    }

}
  • AiConfig实现
java 复制代码
    /**
     * 基于PgSQL实现会话记忆
     */
    @Autowired
    private PgSQLChatMemory pgSQLChatMemory;

    @Bean
    public ChatClient chatClient(ChatClient.Builder builder) {
        // defaultSystem,默认系统角色,带有对话身份
        return builder
//                .defaultSystem("请以通俗开发者角度介绍")
                // 增加会话记忆,基于
                .defaultAdvisors(new PromptChatMemoryAdvisor(pgSQLChatMemory)).build();
//                .defaultAdvisors(new SimpleLoggerAdvisor()).build();
    }
  • 控制器实现
java 复制代码
    /**
     * 普通对话模式
     * @param question
     * @return
     */
    @RequestMapping("/qwen/chat/api")
    public String chat(String question) {
        return chatClient.prompt().user(question).call().content();
    }

三、测试调用

  • 第一次调用如下接口
XML 复制代码
GET http://localhost:8088/boot/ai/qwen/chat/api?question=SpringBoot框架各个模块作用

已经写入数据表中了

  • 第二次基于上文会话内容,写入表中
XML 复制代码
GET http://localhost:8088/boot/ai/qwen/chat/api?question=基于上述回答内容,SpringBoot哪个模块学习难度较大,评估一下

再次查看数据库表,发现已经基于会话保存回答问题了

  • 第三次提问,会话记忆
java 复制代码
GET http://localhost:8088/boot/ai/qwen/chat/api?question=基于上述回答内容,SpringBoot哪个模块学习难度相对较小,比较容易上手,评估一下

再次查看数据库,以及基本保存进来,基于数据库内容实现上下文会话回答了。

相关推荐
剩下了什么7 小时前
MySQL JSON_SET() 函数
数据库·mysql·json
山峰哥8 小时前
数据库工程与SQL调优——从索引策略到查询优化的深度实践
数据库·sql·性能优化·编辑器
较劲男子汉8 小时前
CANN Runtime零拷贝传输技术源码实战 彻底打通Host与Device的数据传输壁垒
运维·服务器·数据库·cann
java搬砖工-苤-初心不变8 小时前
MySQL 主从复制配置完全指南:从原理到实践
数据库·mysql
山岚的运维笔记10 小时前
SQL Server笔记 -- 第18章:Views
数据库·笔记·sql·microsoft·sqlserver
roman_日积跬步-终至千里11 小时前
【LangGraph4j】LangGraph4j 核心概念与图编排原理
java·服务器·数据库
汇智信科11 小时前
打破信息孤岛,重构企业效率:汇智信科企业信息系统一体化运营平台
数据库·重构
野犬寒鸦11 小时前
从零起步学习并发编程 || 第六章:ReentrantLock与synchronized 的辨析及运用
java·服务器·数据库·后端·学习·算法
晚霞的不甘12 小时前
揭秘 CANN 内存管理:如何让大模型在小设备上“轻装上阵”?
前端·数据库·经验分享·flutter·3d
市场部需要一个软件开发岗位13 小时前
JAVA开发常见安全问题:纵向越权
java·数据库·安全