SpringAi基于PgSQL数据库存储扩展ChatMemory

一、环境准备

SpringAI入门学习

XML 复制代码
        <!-- SpringAI-->
        <dependency>
            <groupId>com.alibaba.cloud.ai</groupId>
            <artifactId>spring-ai-alibaba-starter</artifactId>
            <version>1.0.0-M6.1</version>
        </dependency>
        <dependency>
            <groupId>org.postgresql</groupId>
            <artifactId>postgresql</artifactId>
            <version>42.3.1</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-pgvector-store-spring-boot-starter</artifactId>
            <version>1.0.0-M6</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-jdbc</artifactId>
        </dependency>

上文中实现了基于内存的会话保持功能,现在要基于数据库扩展实现

JdbcTemplate+PgSQL数据库实现扩展ChatMemory实现。

二、扩展实现

java 复制代码
package org.springframework.ai.chat.memory;

import java.util.List;
import org.springframework.ai.chat.messages.Message;

public interface ChatMemory {
    default void add(String conversationId, Message message) {
        this.add(conversationId, List.of(message));
    }

    void add(String conversationId, List<Message> messages);

    List<Message> get(String conversationId, int lastN);

    void clear(String conversationId);
}
  • 表结构准备
sql 复制代码
CREATE TABLE chat_messages (
                               id BIGSERIAL PRIMARY KEY,
                               conversation_id VARCHAR(255) NOT NULL,
                               message_type VARCHAR(50) NOT NULL,
                               content TEXT NOT NULL,
                               created_at TIMESTAMP NOT NULL
);
-- 创建索引
CREATE INDEX idx_conversation_id ON chat_messages (conversation_id);
  • ChatDao接口定义
java 复制代码
package org.spring.springaiprojet.dao;

import org.spring.springaiprojet.entity.ChatMessageEntity;

import java.util.List;

public interface ChatDao {
    /**
     * 保存表
     * @param messages
     */
    void insertMessages(List<ChatMessageEntity> messages);

    /**
     * 查询最近的N条消息
     * @param conversationId
     * @param lastN
     * @return
     */
    List<ChatMessageEntity> findLastNMessages(String conversationId, int lastN);


    /**
     * 删除会话
     * @param conversationId
     */
    void deleteByConversationId(String conversationId);
}
  • ChatImpl接口实现
java 复制代码
import org.spring.springaiprojet.dao.ChatDao;
import org.spring.springaiprojet.entity.ChatMessageEntity;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Repository;

import java.time.LocalDateTime;
import java.util.List;

@Repository
public class ChatDaoImpl implements ChatDao {

    @Autowired
    private JdbcTemplate jdbcTemplate;

    @Override
    public void insertMessages(List<ChatMessageEntity> messages) {
         jdbcTemplate.batchUpdate("insert into chat_messages (conversation_id, message_type, content, created_at) values (?, ?, ?, ?)", messages, messages.size(), (ps, message) -> {
            ps.setString(1, message.getConversationId());
            ps.setString(2, message.getMessageType());
            ps.setString(3, message.getContent());
            ps.setObject(4, message.getCreatedAt());
        });
    }

    @Override
    public List<ChatMessageEntity> findLastNMessages(String conversationId, int lastN) {
        return jdbcTemplate.query(
                "select * from chat_messages where conversation_id = ? order by created_at desc limit ?",
                (rs, rowNum) -> {
                    ChatMessageEntity message = new ChatMessageEntity();
                    message.setConversationId(rs.getString("conversation_id"));
                    message.setMessageType(rs.getString("message_type"));
                    message.setContent(rs.getString("content"));
                    message.setCreatedAt(rs.getObject("created_at", LocalDateTime.class));
                    return message;
                },
                conversationId,
                lastN
        );
    }

    @Override
    public void deleteByConversationId(String conversationId) {
        jdbcTemplate.update("delete from chat_messages where conversation_id = ?", conversationId);
    }
}
复制代码
ChatMessageEntity
java 复制代码
package org.spring.springaiprojet.entity;

import java.time.LocalDateTime;

public class ChatMessageEntity {
    private Long id;
    private String conversationId;
    private String messageType;
    private String content;
    private LocalDateTime createdAt;

    public Long getId() {
        return id;
    }

    public void setId(Long id) {
        this.id = id;
    }

    public String getConversationId() {
        return conversationId;
    }

    public void setConversationId(String conversationId) {
        this.conversationId = conversationId;
    }

    public String getMessageType() {
        return messageType;
    }

    public void setMessageType(String messageType) {
        this.messageType = messageType;
    }

    public String getContent() {
        return content;
    }

    public void setContent(String content) {
        this.content = content;
    }

    public LocalDateTime getCreatedAt() {
        return createdAt;
    }

    public void setCreatedAt(LocalDateTime createdAt) {
        this.createdAt = createdAt;
    }
}
复制代码
PgSQLChatMemory
java 复制代码
package org.spring.springaiprojet.config.chat;

import org.spring.springaiprojet.dao.ChatDao;
import org.spring.springaiprojet.entity.ChatMessageEntity;
import org.spring.springaiprojet.entity.MessageEnum;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.time.LocalDateTime;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

@Component
public class PgSQLChatMemory implements ChatMemory {

    @Autowired
    private ChatDao chatDao;
    @Override
    public void add(String conversationId, List<Message> messages) {
        if (messages == null || messages.isEmpty()) {
            return;
        }
        List<ChatMessageEntity> entities = messages.stream()
                .map(msg -> {
                    ChatMessageEntity entity = new ChatMessageEntity();
                    entity.setConversationId(conversationId);
                    entity.setContent(msg.getText());

                    if (msg instanceof UserMessage) {
                        entity.setMessageType(MessageEnum.USER.getValue());
                    } else if (msg instanceof AssistantMessage) {
                        entity.setMessageType(MessageEnum.ASSISTANT.getValue());
                    }
                    entity.setCreatedAt(LocalDateTime.now());
                    return entity;
                })
                .collect(Collectors.toList());
       chatDao.insertMessages(entities);
    }

    @Override
    public List<Message> get(String conversationId, int lastN) {
        List<ChatMessageEntity> entities = chatDao.findLastNMessages(conversationId, lastN);
        if (entities == null || entities.isEmpty()) {
            return Collections.emptyList();
        }
        // 倒序
        Collections.reverse(entities);
        return entities.stream()
                .map(entity -> {
                    switch (entity.getMessageType()) {
                        case "user":
                            return new UserMessage(entity.getContent());
                        case "assistant":
                            return new AssistantMessage(entity.getContent());
                        default:
                            throw new IllegalArgumentException("未知的消息类型!");
                    }
                })
                .collect(Collectors.toList());
    }

    @Override
    public void clear(String conversationId) {
        if (conversationId == null){
            return;
        }
        chatDao.deleteByConversationId(conversationId);
    }

}
  • AiConfig实现
java 复制代码
    /**
     * 基于PgSQL实现会话记忆
     */
    @Autowired
    private PgSQLChatMemory pgSQLChatMemory;

    @Bean
    public ChatClient chatClient(ChatClient.Builder builder) {
        // defaultSystem,默认系统角色,带有对话身份
        return builder
//                .defaultSystem("请以通俗开发者角度介绍")
                // 增加会话记忆,基于
                .defaultAdvisors(new PromptChatMemoryAdvisor(pgSQLChatMemory)).build();
//                .defaultAdvisors(new SimpleLoggerAdvisor()).build();
    }
  • 控制器实现
java 复制代码
    /**
     * 普通对话模式
     * @param question
     * @return
     */
    @RequestMapping("/qwen/chat/api")
    public String chat(String question) {
        return chatClient.prompt().user(question).call().content();
    }

三、测试调用

  • 第一次调用如下接口
XML 复制代码
GET http://localhost:8088/boot/ai/qwen/chat/api?question=SpringBoot框架各个模块作用

已经写入数据表中了

  • 第二次基于上文会话内容,写入表中
XML 复制代码
GET http://localhost:8088/boot/ai/qwen/chat/api?question=基于上述回答内容,SpringBoot哪个模块学习难度较大,评估一下

再次查看数据库表,发现已经基于会话保存回答问题了

  • 第三次提问,会话记忆
java 复制代码
GET http://localhost:8088/boot/ai/qwen/chat/api?question=基于上述回答内容,SpringBoot哪个模块学习难度相对较小,比较容易上手,评估一下

再次查看数据库,以及基本保存进来,基于数据库内容实现上下文会话回答了。

相关推荐
m0_561359671 天前
使用Python处理计算机图形学(PIL/Pillow)
jvm·数据库·python
山岚的运维笔记1 天前
SQL Server笔记 -- 第14章:CASE语句
数据库·笔记·sql·microsoft·sqlserver
Data_Journal1 天前
如何使用 Python 解析 JSON 数据
大数据·开发语言·前端·数据库·人工智能·php
ASS-ASH1 天前
AI时代之向量数据库概览
数据库·人工智能·python·llm·embedding·向量数据库·vlm
xixixi777771 天前
互联网和数据分析中的核心指标 DAU (日活跃用户数)
大数据·网络·数据库·数据·dau·mau·留存率
范纹杉想快点毕业1 天前
状态机设计与嵌入式系统开发完整指南从面向过程到面向对象,从理论到实践的全面解析
linux·服务器·数据库·c++·算法·mongodb·mfc
这周也會开心1 天前
Redis与MySQL回写中的数据类型存储设计
数据库·redis·mysql
Aaron_Wjf1 天前
PG Vector测试
数据库·postgresql
Aaron_Wjf1 天前
PG逻辑复制槽应用
数据库·postgresql
一碗面4211 天前
SQL性能优化:让数据库飞起来
数据库·sql·性能优化