SpringAI 1.0.0 正式版——利用Redis存储会话(ChatMemory)

官方文档:Chat Memory :: Spring AI Reference

1. 引言

SpringAI 1.0.0 改动了很多地方,本文根据官方的InMemoryChatMemoryRepository实现了自定义的RedisChatMemoryRepository,并使用MessageWindowChatMemory创建ChatMemory

2. 实现

2.1. 添加依赖

复制代码
<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-starter-model-openai</artifactId>
    <version>1.0.0</version>
</dependency>

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>

注意:SpringAI 1.0.0的maven依赖有所改变,artifactId变化了

2.2. 配置文件

java 复制代码
server:
  port: 8080
spring:
  ai:
    openai:
      api-key: xxx     # 填自己的api-key
      base-url: https://api.deepseek.com
      chat:
        options:
          model: deepseek-chat
          temperature: 0.7
  data:
    redis:
      host: localhost
      port: 6379
      password: 123456

正确配置redis连接即可

api-key可以填deepseek的(需要购买,1块钱能用挺久)

2.3. RedisChatMemoryRepository

RedisChatMemoryRepository用于存储会话数据

这里参考InMemoryChatMemoryRepository与【SpringAI 1.0.0】 ChatMemory 转换为 Redis 存储_springai如何将数据保存到redis-CSDN博客

java 复制代码
package com.njust.repository;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.content.Media;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.util.MimeType;
import java.io.IOException;
import java.net.URL;
import java.util.*;
import java.util.stream.Collectors;

public class RedisChatMemoryRepository implements ChatMemoryRepository {
    private final StringRedisTemplate stringRedisTemplate;  // 用于操作 Redis
    private final ObjectMapper objectMapper;    // 用于序列化和反序列化
    private final String PREFIX ;       // 存储对话的 Redis Key 前缀
    private final String CONVERSATION_IDS_SET;  // 存储对话ID的 Redis Key

    public RedisChatMemoryRepository(StringRedisTemplate stringRedisTemplate, ObjectMapper objectMapper) {
        this(stringRedisTemplate, objectMapper, "chat:conversation:", "chat:all_conversation_ids");
    }

    public RedisChatMemoryRepository(StringRedisTemplate stringRedisTemplate, ObjectMapper objectMapper, String PREFIX) {
        this(stringRedisTemplate, objectMapper, PREFIX, "chat:all_conversation_ids");
    }

    public RedisChatMemoryRepository(StringRedisTemplate stringRedisTemplate, ObjectMapper objectMapper, String PREFIX, String CONVERSATION_IDS_SET) {
        this.stringRedisTemplate = stringRedisTemplate;
        this.objectMapper = objectMapper;
        this.PREFIX = PREFIX;
        this.CONVERSATION_IDS_SET = CONVERSATION_IDS_SET;
    }

    // 获取所有 conversationId(KEYS 命令匹配 chat:*)
    @Override
    public List<String> findConversationIds() {
        // 使用ZSet存储对话ID(更高效)
        // 获取对话ID集合(按时间倒序排序,越晚创建的对话ID排在前面)
        Set<String> conversationIds = stringRedisTemplate.opsForZSet().reverseRange(CONVERSATION_IDS_SET, 0, -1);

        if (conversationIds == null || conversationIds.isEmpty()) {
            return List.of();
        }

        return new ArrayList<>(conversationIds);
    }

    // 根据 conversationId 获取 Message 列表
    @Override
    public List<Message> findByConversationId(String conversationId) {
        // 参数验证
        if (conversationId == null || conversationId.isEmpty()) {
            throw new IllegalArgumentException("conversationId cannot be null or empty");
        }

        List<String> list = stringRedisTemplate.opsForList().range(PREFIX + conversationId, 0, -1);
        if (list == null || list.isEmpty()) {
            return List.of();
        }
        return list.stream()
                .map(json -> {
                    try {
                        // return objectMapper.convertValue(json, Message.class);   // 直接反序列化Message会报错
                        return deserializeMessage(json);    // 手动反序列化
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                })
                .collect(Collectors.toList());
    }

    // 保存整个 Message 列表到指定 conversationId
    @Override
    public void saveAll(String conversationId, List<Message> messages) {
        // 参数验证
        if (conversationId == null || conversationId.isEmpty()) {
            throw new IllegalArgumentException("conversationId cannot be null or empty");
        }
        // 先清除原有的 conversation 数据
        stringRedisTemplate.delete(PREFIX + conversationId);

        if (messages == null || messages.isEmpty()) {
            return;
        }
        List<String> list = messages.stream()
                .map(message -> {
                    try {
                        return objectMapper.writeValueAsString(message);
                    } catch (JsonProcessingException e) {
                        throw new RuntimeException("Failed to serialize Message", e);
                    }
                })
                .collect(Collectors.toList());

        stringRedisTemplate.opsForList().rightPushAll(PREFIX + conversationId, list);
        // 更新对话ID集合
        stringRedisTemplate.opsForZSet().add(CONVERSATION_IDS_SET, conversationId, System.currentTimeMillis());
    }

    // 删除指定 conversationId 的数据
    @Override
    public void deleteByConversationId(String conversationId) {
        if (conversationId == null || conversationId.isEmpty()) {
            throw new IllegalArgumentException("conversationId cannot be null or empty");
        }

        stringRedisTemplate.delete(PREFIX + conversationId);
        stringRedisTemplate.opsForZSet().remove(CONVERSATION_IDS_SET, conversationId);
    }

    // 手动反序列化 Message
    public Message deserializeMessage(String json) throws IOException {
        // 解析 JSON 字符串为 JsonNode
        JsonNode jsonNode = objectMapper.readTree(json);
        // 获取 messageType 字段值
        if (!jsonNode.has("messageType")) {
            throw new IllegalArgumentException("Missing or invalid messageType field");
        }
        String messageType = jsonNode.get("messageType").asText();

        // 获取 text 字段值
        String text = jsonNode.has("text") ? jsonNode.get("text").asText() : "";

        // 获取 metadata 字段值
        Map<String, Object> metadata = getMetadata(jsonNode);

        // 获取 media 字段值
        List<Media> mediaList = getMediaList(jsonNode);

        return switch (MessageType.valueOf(messageType)) {
            case SYSTEM -> new SystemMessage(text);
            case USER -> UserMessage.builder()
                            .text(text)
                            .media(mediaList)
                            .metadata(metadata)
                            .build();
            case ASSISTANT -> {
                List<AssistantMessage.ToolCall> toolCalls = getToolCalls(jsonNode);
                yield new AssistantMessage(text, metadata, toolCalls, mediaList);
            }
            default -> throw new IllegalArgumentException("Unknown message type: " + messageType);
        };
    }

    private Media deserializeMedia(ObjectMapper mapper, JsonNode mediaNode) throws IOException {
        Media.Builder builder = Media.builder();

        // Handle MIME type
        if (mediaNode.has("mimeType")) {
            JsonNode mimeNode = mediaNode.get("mimeType");
            String type = mimeNode.get("type").asText();
            String subtype = mimeNode.get("subtype").asText();
            builder.mimeType(new MimeType(type, subtype));
        }

        // Handle data - could be either URL string or byte array
        if (mediaNode.has("data")) {
            String data = mediaNode.get("data").asText();
            if (data.startsWith("http://") || data.startsWith("https://")) {
                builder.data(new URL(data));
            } else {
                // Assume it's base64 encoded binary data
                byte[] bytes = Base64.getDecoder().decode(data);
                builder.data(bytes);
            }
        }

        // Handle dataAsByteArray if present (overrides data if both exist)
        if (mediaNode.has("dataAsByteArray")) {
            byte[] bytes = Base64.getDecoder().decode(mediaNode.get("dataAsByteArray").asText());
            builder.data(bytes);
        }

        // Handle optional fields
        if (mediaNode.has("id")) {
            builder.id(mediaNode.get("id").asText());
        }

        if (mediaNode.has("name")) {
            builder.name(mediaNode.get("name").asText());
        }

        return builder.build();
    }

    private Map<String, Object> getMetadata(JsonNode jsonNode) {
        if (jsonNode.has("metadata")) {
            return objectMapper.convertValue(jsonNode.get("metadata"), new TypeReference<>() {});
        }
        return new HashMap<>();
    }

    private List<Media> getMediaList(JsonNode jsonNode) throws IOException {
        List<Media> mediaList = new ArrayList<>();
        if (jsonNode.has("media")) {
            for (JsonNode mediaNode : jsonNode.get("media")) {
                mediaList.add(deserializeMedia(objectMapper, mediaNode));
            }
        }
        return mediaList;
    }

    private List<AssistantMessage.ToolCall> getToolCalls(JsonNode jsonNode) {
        if (jsonNode.has("toolCalls")) {
            return objectMapper.convertValue(jsonNode.get("toolCalls"), new TypeReference<>() {});
        }
        return Collections.emptyList();
    }
}

主要的部分都写上注释了,应该比较好理解

需要注意的是反序列化Message需要手动进行

2.4. 注册Bean

java 复制代码
package com.njust.config;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.njust.repository.ChatHistoryRepository;
import com.njust.repository.RedisChatHistoryRepository;
import com.njust.repository.RedisChatMemoryRepository;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.StringRedisTemplate;

@Configuration
public class CommonConfiguration {

    @Bean
    public ChatMemoryRepository chatMemoryRepository(StringRedisTemplate stringRedisTemplate) {
        // 默认情况下,如果尚未配置其他存储库,则 Spring AI 会自动配置ChatMemoryRepository类型的 beanInMemoryChatMemoryRepository可以直接在应用程序中使用。
        // 这里手动创建内存聊天记忆存储库
        return new RedisChatMemoryRepository(stringRedisTemplate, new ObjectMapper());
    }

    @Bean
    public ChatMemory chatMemory(ChatMemoryRepository chatMemoryRepository) {
        // 注册聊天上下文记忆机制
        return MessageWindowChatMemory
                .builder()
                .chatMemoryRepository(chatMemoryRepository)
                .maxMessages(20)   // 聊天记忆条数
                .build();
    }

    @Bean
    // 通过OpenAI平台注入deepseek模型
    public ChatClient deepseekChatClient(OpenAiChatModel openAiChatModel, ChatMemory chatMemory) {
        return ChatClient
                .builder(openAiChatModel)
                .defaultSystem("你是南京理工大学计算机科学与工程学院的一名研究生,你的名字叫小兰")
                .defaultAdvisors(
                        new SimpleLoggerAdvisor(),  // 配置日志Advisor
                        MessageChatMemoryAdvisor.builder(chatMemory).build()    // 绑定上下文记忆
                )
                .build();
    }
}

这里用MessageWindowChatMemory创建ChatMemory,用于限制上下文记忆条数

2.5. Controller

java 复制代码
package com.njust.controller;

import com.njust.repository.ChatHistoryRepository;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.content.Media;
import org.springframework.util.MimeType;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import reactor.core.publisher.Flux;

import java.util.List;
import java.util.Objects;

import static org.springframework.ai.chat.memory.ChatMemory.CONVERSATION_ID;

// @RequiredArgsConstructor 的作用是:为所有 final 字段或带有 @NotNull 注解的字段自动生成构造函数,实现简洁、安全的依赖注入
@RequiredArgsConstructor
@RestController
@RequestMapping("/ai")
public class ChatController {

    private final ChatClient deepseekChatClient;

    @RequestMapping(value = "/chat", produces = "text/html;charset=utf-8")
    public Flux<String> chat(
            @RequestParam("prompt") String prompt,
            @RequestParam("chatId") String chatId) {
        
        return deepseekChatClient.prompt()
                .user(prompt)
                .advisors(a -> a.param(CONVERSATION_ID, chatId))
                .stream()
                .content();
    }
}

3. 效果

redis中存储的内容为:

如果想存储某个用户的会话id,也可以存储在redis中,自己写个方法,用zset存,在controller中的chat方法内,将chatId作为member存入zset,score为时间戳,这样就能根据创建时间排序了。每次会话都更新下score,这样最新发送消息订单会话就能排在最上面,和deepseek一样。

相关推荐
程序员JerrySUN13 分钟前
[特殊字符] 深入理解 Linux 内核进程管理:架构、核心函数与调度机制
java·linux·架构
2302_8097983216 分钟前
【JavaWeb】Docker项目部署
java·运维·后端·青少年编程·docker·容器
孔令飞33 分钟前
Kubernetes 节点自动伸缩(Cluster Autoscaler)原理与实践
ai·云原生·容器·golang·kubernetes
网安INF40 分钟前
CVE-2020-17519源码分析与漏洞复现(Flink 任意文件读取)
java·web安全·网络安全·flink·漏洞
一叶知秋哈40 分钟前
Java应用Flink CDC监听MySQL数据变动内容输出到控制台
java·mysql·flink
jackson凌1 小时前
【Java学习笔记】SringBuffer类(重点)
java·笔记·学习
sclibingqing1 小时前
SpringBoot项目接口集中测试方法及实现
java·spring boot·后端
程序员JerrySUN1 小时前
全面理解 Linux 内核性能问题:分类、实战与调优策略
java·linux·运维·服务器·单片机
糯米导航1 小时前
Java毕业设计:办公自动化系统的设计与实现
java·开发语言·课程设计
糯米导航1 小时前
Java毕业设计:WML信息查询与后端信息发布系统开发
java·开发语言·课程设计