多轮对话的记忆心脏:ChatMemory 滑动窗口原理

Spring AI 源码解读 - 第 4 篇:ChatMemory 记忆管理

多轮对话的上下文维护机制

📖 开篇引言

多轮对话的关键是 AI 能够记住之前的对话内容。但如何高效地存储和检索这些消息?如何避免消息堆积导致的 Token 浪费?

本篇将深入 ChatMemory 的设计与实现,理解记忆管理的核心机制。


一、ChatMemory 接口设计

1.1 ChatMemory 接口

java 复制代码
// org.springframework.ai.chat.memory.ChatMemory
public interface ChatMemory {
    
    // 添加消息到指定会话
    void add(ConversationId conversationId, Message message);
    
    // 获取指定会话的所有消息
    List<Message> get(ConversationId conversationId);
    
    // 获取指定会话的消息数量
    int getMessageCount(ConversationId conversationId);
    
    // 清空指定会话的消息
    void clear(ConversationId conversationId);
}

// 会话 ID 包装类
public record ConversationId(String id) {}

1.2 ChatMemory 的职责

复制代码
ChatMemory
├── 存储:将消息持久化
├── 检索:根据会话 ID 获取消息
├── 管理:清空、统计消息
└── 隔离:不同会话的消息相互独立

二、MessageWindowChatMemory 实现

2.1 滑动窗口的概念

css 复制代码
消息窗口大小 = 3

第 1 轮
┌─────────────────────┐
│ [msg1, msg2, msg3]  │ ← 窗口满
└─────────────────────┘

第 2 轮(添加 msg4)
┌─────────────────────┐
│ [msg2, msg3, msg4]  │ ← msg1 被移出
└─────────────────────┘

第 3 轮(添加 msg5)
┌─────────────────────┐
│ [msg3, msg4, msg5]  │ ← msg2 被移出
└─────────────────────┘

2.2 MessageWindowChatMemory 源码

java 复制代码
// org.springframework.ai.chat.memory.MessageWindowChatMemory
public class MessageWindowChatMemory implements ChatMemory {
    
    private final int maxMessages;  // 最大消息数
    private final Map<ConversationId, List<Message>> conversationHistory;
    
    // 构造方法
    public MessageWindowChatMemory(int maxMessages) {
        this.maxMessages = maxMessages;
        this.conversationHistory = new ConcurrentHashMap<>();
    }
    
    // 工厂方法:创建默认实例(最多 100 条消息)
    public static MessageWindowChatMemory create() {
        return new MessageWindowChatMemory(100);
    }
    
    @Override
    public void add(ConversationId conversationId, Message message) {
        
        // 1. 获取或创建该会话的消息列表
        List<Message> messages = conversationHistory
            .computeIfAbsent(conversationId, k -> new ArrayList<>());
        
        // 2. 添加新消息
        messages.add(message);
        
        // 3. 如果超过窗口大小,移除最旧的消息
        if (messages.size() > this.maxMessages) {
            messages.remove(0);  // 移除第一条(最旧)
        }
    }
    
    @Override
    public List<Message> get(ConversationId conversationId) {
        // 返回该会话的所有消息(已在窗口内)
        return conversationHistory.getOrDefault(conversationId, List.of());
    }
    
    @Override
    public int getMessageCount(ConversationId conversationId) {
        return conversationHistory
            .getOrDefault(conversationId, List.of())
            .size();
    }
    
    @Override
    public void clear(ConversationId conversationId) {
        conversationHistory.remove(conversationId);
    }
}

2.3 滑动窗口的优缺点

优点 缺点
实现简单 可能丢失重要的早期消息
内存占用固定 无法区分消息重要性
性能高 对长对话支持不足

三、Token 计数与消息截断

3.1 为什么需要 Token 计数?

复制代码
模型的上下文窗口大小是有限的
例如:Ollama qwen2.5:14b 的上下文窗口 = 32K tokens

如果消息总 Token 数超过上下文窗口,模型会报错

3.2 TokenTextSplitter 的 Token 计数

java 复制代码
// org.springframework.ai.document.TokenTextSplitter
public class TokenTextSplitter implements TextSplitter {
    
    private final int chunkSize;      // 每个块的 Token 数
    private final int chunkOverlap;   // 块之间的重叠 Token 数
    private final Tokenizer tokenizer; // Token 计数器
    
    // 计算文本的 Token 数
    public int countTokens(String text) {
        return this.tokenizer.countTokens(text);
    }
    
    // 分割文本
    public List<String> split(String text) {
        List<String> chunks = new ArrayList<>();
        List<Integer> tokenCounts = new ArrayList<>();
        
        // 1. 计算每个块的 Token 数
        for (String chunk : text.split("\n")) {
            int tokens = countTokens(chunk);
            if (tokens > this.chunkSize) {
                // 如果单个块超过大小,继续分割
                chunks.addAll(splitLargeChunk(chunk));
            } else {
                chunks.add(chunk);
            }
        }
        
        return chunks;
    }
}

3.3 消息截断策略

java 复制代码
// 在 ChatMemory 中实现 Token 限制
public class TokenLimitedChatMemory implements ChatMemory {
    
    private final int maxTokens;  // 最大 Token 数
    private final Tokenizer tokenizer;
    private final Map<ConversationId, List<Message>> conversationHistory;
    
    @Override
    public void add(ConversationId conversationId, Message message) {
        
        List<Message> messages = conversationHistory
            .computeIfAbsent(conversationId, k -> new ArrayList<>());
        
        messages.add(message);
        
        // 检查总 Token 数
        while (getTotalTokens(messages) > this.maxTokens) {
            // 移除最旧的消息
            messages.remove(0);
        }
    }
    
    // 计算消息列表的总 Token 数
    private int getTotalTokens(List<Message> messages) {
        return messages.stream()
            .mapToInt(msg -> tokenizer.countTokens(msg.getContent()))
            .sum();
    }
}

四、ConversationId 与会话隔离

4.1 ConversationId 的作用

java 复制代码
// 不同用户的会话需要隔离
ChatMemory memory = new MessageWindowChatMemory(100);

// 用户 A 的会话
ConversationId sessionA = new ConversationId("user-a-session-1");
memory.add(sessionA, new UserMessage("我想学 Java"));
memory.add(sessionA, new AssistantMessage("Java 是..."));

// 用户 B 的会话
ConversationId sessionB = new ConversationId("user-b-session-1");
memory.add(sessionB, new UserMessage("我想学 Python"));
memory.add(sessionB, new AssistantMessage("Python 是..."));

// 获取消息时完全隔离
List<Message> messagesA = memory.get(sessionA);  // 只包含 A 的消息
List<Message> messagesB = memory.get(sessionB);  // 只包含 B 的消息

4.2 会话 ID 的生成策略

java 复制代码
// 策略 1:基于线程 ID(单线程场景)
String sessionId = String.valueOf(Thread.currentThread().getId());

// 策略 2:基于用户 ID
String sessionId = "user-" + userId;

// 策略 3:基于 HTTP Session ID
String sessionId = httpSession.getId();

// 策略 4:基于 UUID(每次对话新建)
String sessionId = UUID.randomUUID().toString();

五、ChatMemory 在 Advisor 中的使用

5.1 MessageChatMemoryAdvisor 的完整流程

java 复制代码
public class MessageChatMemoryAdvisor 
    implements CallAroundAdvisor, StreamAroundAdvisor {
    
    private final ChatMemory chatMemory;
    
    // 前置处理:注入历史消息
    @Override
    public AdvisedRequest before(AdvisedRequest advisedRequest) {
        
        // 1. 获取会话 ID
        String sessionId = getSessionId(advisedRequest);
        ConversationId conversationId = new ConversationId(sessionId);
        
        // 2. 从 ChatMemory 获取历史消息
        List<Message> historyMessages = this.chatMemory.get(conversationId);
        
        // 3. 将历史消息注入到请求中
        // 注入顺序:SystemMessage → HistoryMessages → CurrentUserMessage
        List<Message> allMessages = new ArrayList<>();
        
        // 添加系统消息(如果有)
        if (advisedRequest.getSystemText() != null) {
            allMessages.add(new SystemMessage(advisedRequest.getSystemText()));
        }
        
        // 添加历史消息
        allMessages.addAll(historyMessages);
        
        // 添加当前用户消息
        allMessages.addAll(advisedRequest.getUserMessage());
        
        // 4. 更新请求
        return AdvisedRequest.from(advisedRequest)
            .userMessage(allMessages)
            .build();
    }
    
    // 后置处理:保存消息到 ChatMemory
    @Override
    public ChatResponse after(AdvisedRequest advisedRequest, 
                               AdvisedResponse<ChatResponse> advisedResponse) {
        
        String sessionId = getSessionId(advisedRequest);
        ConversationId conversationId = new ConversationId(sessionId);
        
        // 1. 保存用户消息
        for (Message msg : advisedRequest.getUserMessage()) {
            if (msg instanceof UserMessage) {
                this.chatMemory.add(conversationId, msg);
            }
        }
        
        // 2. 保存 AI 回复
        ChatResponse response = advisedResponse.getChatResponse();
        AssistantMessage assistantMessage = new AssistantMessage(
            response.getResult().getOutput().getContent()
        );
        this.chatMemory.add(conversationId, assistantMessage);
        
        return response;
    }
}

5.2 记忆注入的完整示例

scss 复制代码
第 1 轮调用
┌─────────────────────────────────────────┐
│ before()                                 │
│ ChatMemory.get(sessionId) → []           │
│ 注入消息:[SystemMessage, UserMessage]   │
└─────────────────────────────────────────┘
         ↓
┌─────────────────────────────────────────┐
│ doChat()                                 │
│ chatModel.call(prompt)                   │
│ 返回 AssistantMessage                    │
└─────────────────────────────────────────┘
         ↓
┌─────────────────────────────────────────┐
│ after()                                  │
│ ChatMemory.add(sessionId, UserMessage)   │
│ ChatMemory.add(sessionId, AssistantMsg)  │
└─────────────────────────────────────────┘

第 2 轮调用
┌─────────────────────────────────────────┐
│ before()                                 │
│ ChatMemory.get(sessionId) →              │
│   [UserMessage, AssistantMessage]        │
│ 注入消息:[SystemMessage,                ││   UserMessage(历史),                     ││   AssistantMessage(历史),                ││   UserMessage(当前)]                     │
└─────────────────────────────────────────┘

六、分布式记忆存储

6.1 为什么需要分布式记忆?

diff 复制代码
单机 ChatMemory 的问题:
- 应用重启后消息丢失
- 多实例部署时消息不共享
- 无法跨应用访问

解决方案:使用 Redis 等分布式存储

6.2 Redis 实现的 ChatMemory

java 复制代码
// 基于 Redis 的 ChatMemory 实现
public class RedisChatMemory implements ChatMemory {
    
    private final RedisTemplate<String, Message> redisTemplate;
    private final String keyPrefix = "chat:memory:";
    
    @Override
    public void add(ConversationId conversationId, Message message) {
        
        // 1. 构建 Redis key
        String key = keyPrefix + conversationId.id();
        
        // 2. 将消息序列化后存储
        redisTemplate.opsForList().rightPush(key, message);
        
        // 3. 设置过期时间(24 小时)
        redisTemplate.expire(key, Duration.ofHours(24));
    }
    
    @Override
    public List<Message> get(ConversationId conversationId) {
        
        String key = keyPrefix + conversationId.id();
        
        // 从 Redis 获取所有消息
        return redisTemplate.opsForList()
            .range(key, 0, -1);
    }
    
    @Override
    public void clear(ConversationId conversationId) {
        
        String key = keyPrefix + conversationId.id();
        redisTemplate.delete(key);
    }
}

6.3 Redis 中的数据结构

css 复制代码
Redis 中的存储结构:

chat:memory:user-a-session-1
├── [0] UserMessage("我想学 Java")
├── [1] AssistantMessage("Java 是...")
├── [2] UserMessage("它有哪些特点?")
└── [3] AssistantMessage("Java 的特点是...")

chat:memory:user-b-session-1
├── [0] UserMessage("我想学 Python")
└── [1] AssistantMessage("Python 是...")

七、ChatMemory 的生命周期

7.1 创建

java 复制代码
// 方式 1:默认实现
ChatMemory memory = MessageWindowChatMemory.create();

// 方式 2:自定义大小
ChatMemory memory = new MessageWindowChatMemory(50);

// 方式 3:Redis 实现
ChatMemory memory = new RedisChatMemory(redisTemplate);

7.2 使用

java 复制代码
// 在 ChatClient 中使用
ChatClient chatClient = ChatClient.builder(chatModel)
    .defaultAdvisors(new MessageChatMemoryAdvisor(memory))
    .build();

// 自动注入记忆
chatClient.prompt()
    .user("你好")
    .call();

7.3 清理

java 复制代码
// 清空特定会话的记忆
memory.clear(new ConversationId("user-a-session-1"));

// 或者依赖过期时间自动清理(Redis)

八、小结

8.1 本篇要点

主题 核心要点
ChatMemory 接口 add / get / clear 三个核心方法
MessageWindowChatMemory 滑动窗口实现,固定消息数量
Token 计数 防止消息堆积导致的 Token 溢出
ConversationId 会话隔离,不同用户消息独立
Advisor 集成 before() 注入记忆,after() 保存消息
分布式存储 Redis 实现跨应用、跨实例的记忆共享

8.2 关键类清单

类 / 接口 职责
ChatMemory 记忆接口
MessageWindowChatMemory 滑动窗口实现
RedisChatMemory Redis 分布式实现
ConversationId 会话 ID 包装类
MessageChatMemoryAdvisor 记忆注入拦截器
TokenTextSplitter Token 计数与分割

系列目录

  • 第 1 篇:整体架构与核心抽象
  • 第 2 篇:ChatClient 调用链路
  • 第 3 篇:Prompt 与 Message 体系
  • 第 4 篇:ChatMemory 记忆管理(本篇)

需要Spring AI系列学习代码的同学 欢迎关注公众号「AI日撰」,点击菜单「获取源码」获取完整代码(Gitee 仓库)。


相关推荐
薛定猫AI2 小时前
【深度解析】Hermes Agent:用“提示反向传播”打造可自我进化的 AI 智能体
人工智能
AAAAA92402 小时前
物联网BOM成本管理:精准化、智能化与可持续化
java·物联网·struts
AI成长日志2 小时前
【GitHub开源项目专栏】AI推理优化框架深度解析(下):TGI与TensorRT-LLM对比实战
人工智能·开源·github
96772 小时前
springMVC请求处理全过程
java
特别关注外国供应商2 小时前
SSH 的 PrivX OT 工业安全远程访问 (ISRA) 被 分析机构 Industrial Cyber 认可
人工智能·网络安全·ssh·特权访问管理·工业安全远程访问·privx·ot 访问安全
gelald2 小时前
Spring - 事务管理
java·后端·spring
橘子编程2 小时前
编译原理:从理论到实战全解析
java·linux·python·ubuntu
xuhaoyu_cpp_java2 小时前
Maven学习(一)
java·经验分享·笔记·学习·maven
sibylyue2 小时前
Nginx\Tomcat\Jetty\Netty
java·nginx·http