企业级Agent开发教程(四) 基于 MongoDB 实现持久化记忆缓存

一、前言

在前面几篇文章中,我们构建了 KaFlow-Py 这个配置驱动的 AI Agent 框架。但有个问题一直困扰着我:对话记忆存在内存里,服务重启就全没了

这在开发阶段还能忍,但到了生产环境就不行了。用户跟 Agent 聊了半天,结果服务器重启一次,所有历史都丢了,用户体验会非常糟糕。

所以这次我们要解决的核心问题是:如何实现生产级的持久化记忆缓存

实现后的效果:

20251016100406

我们会先从 LangGraph 的 Checkpointer 原理入手,详细讲解如何实现一个生产级的 MongoDB Checkpointer

二、LangGraph Checkpointer 原理

在实现自己的 Checkpointer 之前,我们需要先深入理解 LangGraph 的设计思路。让我们从顶层抽象开始,逐层剖析。

2.1 BaseCheckpointSaver - 顶层抽象

BaseCheckpointSaver 是 LangGraph 为所有 Checkpointer 定义的抽象基类,它规定了持久化层必须实现的完整接口规范。

完整接口定义

python 复制代码
from abc import ABC, abstractmethod
from typing import Optional, Iterator, AsyncIterator, Any
from langchain_core.runnables import RunnableConfig

class BaseCheckpointSaver(ABC):
    """Checkpoint Saver 抽象基类"""
    
    # ==================== 配置属性 ====================
    
    @property
    def config_specs(self) -> list[ConfigurableFieldSpec]:
        """配置规范(有默认实现,可重写)"""
        return [CheckpointThreadId, CheckpointNS, CheckpointId]
    
    # ==================== 同步接口(必须实现) ====================
    
    @abstractmethod
    def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        """获取指定的 checkpoint tuple"""
        pass
    
    @abstractmethod
    def list(
        self, 
        config: Optional[RunnableConfig], 
        *, 
        filter: Optional[dict[str, Any]] = None,
        before: Optional[RunnableConfig] = None,
        limit: Optional[int] = None
    ) -> Iterator[CheckpointTuple]:
        """列出 checkpoints(支持分页和过滤)"""
        pass
    
    @abstractmethod
    def put(
        self, 
        config: RunnableConfig, 
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
        new_versions: ChannelVersions
    ) -> RunnableConfig:
        """保存 checkpoint"""
        pass
    
    @abstractmethod
    def put_writes(
        self,
        config: RunnableConfig,
        writes: Sequence[tuple[str, Any]],
        task_id: str,
        task_path: str = ""
    ) -> None:
        """保存中间写入操作(用于 subgraphs)"""
        pass
    
    @abstractmethod
    def delete_thread(self, thread_id: str) -> None:
        """删除指定 thread 的所有 checkpoints"""
        pass
    
    # ==================== 异步接口(必须实现) ====================
    
    @abstractmethod
    async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        """异步获取 checkpoint tuple"""
        pass
    
    @abstractmethod
    async def alist(
        self, 
        config: Optional[RunnableConfig], 
        *, 
        filter: Optional[dict[str, Any]] = None,
        before: Optional[RunnableConfig] = None,
        limit: Optional[int] = None
    ) -> AsyncIterator[CheckpointTuple]:
        """异步列出 checkpoints"""
        pass
    
    @abstractmethod
    async def aput(
        self, 
        config: RunnableConfig, 
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
        new_versions: ChannelVersions
    ) -> RunnableConfig:
        """异步保存 checkpoint"""
        pass
    
    @abstractmethod
    async def aput_writes(
        self,
        config: RunnableConfig,
        writes: Sequence[tuple[str, Any]],
        task_id: str,
        task_path: str = ""
    ) -> None:
        """异步保存中间写入"""
        pass
    
    @abstractmethod
    async def adelete_thread(self, thread_id: str) -> None:
        """异步删除指定 thread 的所有 checkpoints"""
        pass
    
    # ==================== 可选方法(有默认实现) ====================
    
    def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V:
        """生成下一个版本号(默认递增整数)"""
        if isinstance(current, str):
            raise NotImplementedError
        elif current is None:
            return 1
        else:
            return current + 1
    
    # ==================== 辅助方法(基于抽象方法实现) ====================
    
    def get(self, config: RunnableConfig) -> Optional[Checkpoint]:
        """获取 checkpoint(内部调用 get_tuple)"""
        if value := self.get_tuple(config):
            return value.checkpoint
    
    async def aget(self, config: RunnableConfig) -> Optional[Checkpoint]:
        """异步获取 checkpoint(内部调用 aget_tuple)"""
        if value := await self.aget_tuple(config):
            return value.checkpoint

接口实现要求

LangGraph 定义了 10 个抽象方法,必须全部实现:

方法 类型 是否必须 说明
get_tuple() 同步 ✅ 必须 获取单个 checkpoint,LangGraph 加载对话历史时调用
list() 同步 ✅ 必须 列出 checkpoints,支持分页和过滤
put() 同步 ✅ 必须 保存 checkpoint,每轮对话后调用
put_writes() 同步 ✅ 必须 保存中间写入(subgraphs),可以空实现
delete_thread() 同步 ✅ 必须 删除 thread,删除对话历史时调用
aget_tuple() 异步 ✅ 必须 异步版本的 get_tuple
alist() 异步 ✅ 必须 异步版本的 list
aput() 异步 ✅ 必须 异步版本的 put
aput_writes() 异步 ✅ 必须 异步版本的 put_writes
adelete_thread() 异步 ✅ 必须 异步版本的 delete_thread
config_specs 属性 ⚪ 可选 配置规范,有默认实现
get_next_version() 方法 ⚪ 可选 生成版本号,有默认实现

核心方法详解

1. get_tuple() - 读取 Checkpoint

python 复制代码
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
    """
    获取指定的 checkpoint
    
    Args:
        config: 包含 thread_id 和可选的 checkpoint_id
        
    Returns:
        CheckpointTuple: 包含 checkpoint、metadata、parent_config
        
    调用时机:
        - LangGraph 加载对话历史时
        - 恢复状态时
    """

关键点

  • 如果 checkpoint_id 未提供,返回最新的 checkpoint
  • 必须返回 CheckpointTuple,包含完整的上下文信息

2. put() - 保存 Checkpoint

python 复制代码
def put(
    self, 
    config: RunnableConfig, 
    checkpoint: Checkpoint,
    metadata: CheckpointMetadata,
    new_versions: ChannelVersions
) -> RunnableConfig:
    """
    保存 checkpoint
    
    Args:
        config: 包含 thread_id
        checkpoint: 完整的 checkpoint 对象(包含 channel_values)
        metadata: 元数据(source、step、writes 等)
        new_versions: 通道版本信息
        
    Returns:
        RunnableConfig: 更新后的 config,包含新的 checkpoint_id
        
    调用时机:
        - 每轮对话结束后
        - 状态更新后
    """

关键点

  • 使用 upsert 模式,避免重复插入
  • 必须返回包含 checkpoint_id 的 config

3. list() - 列出 Checkpoints

python 复制代码
def list(
    self, 
    config: Optional[RunnableConfig], 
    *, 
    filter: Optional[dict[str, Any]] = None,
    before: Optional[RunnableConfig] = None,
    limit: Optional[int] = None
) -> Iterator[CheckpointTuple]:
    """
    列出 checkpoints
    
    Args:
        config: 包含 thread_id(可选,不传则列出所有)
        filter: 元数据过滤条件
        before: 列出此 checkpoint 之前的记录(用于分页)
        limit: 返回的最大数量
        
    Yields:
        CheckpointTuple: checkpoint 迭代器
        
    调用时机:
        - 查询历史对话时
        - 时间旅行(回溯状态)时
    """

关键点

  • 返回 Iterator,支持惰性加载
  • before 参数用于实现分页

4. put_writes() - 保存中间写入

python 复制代码
def put_writes(
    self,
    config: RunnableConfig,
    writes: Sequence[tuple[str, Any]],
    task_id: str,
    task_path: str = ""
) -> None:
    """
    保存中间写入操作(用于 subgraphs)
    
    Args:
        config: 包含 thread_id 和 checkpoint_id
        writes: 写入操作列表 [(channel_name, value), ...]
        task_id: 任务 ID
        task_path: 任务路径(嵌套 subgraph 时使用)
        
    调用时机:
        - Subgraph 执行时
        - 需要记录中间状态时
    """

关键点

  • 可以空实现(如果不需要支持 subgraph 状态恢复)
  • 如果需要完整支持,应存储到单独的集合

5. delete_thread() - 删除 Thread

python 复制代码
def delete_thread(self, thread_id: str) -> None:
    """
    删除指定 thread 的所有 checkpoints
    
    Args:
        thread_id: 线程 ID
        
    调用时机:
        - 用户删除对话历史时
        - 清理过期数据时
    """

关键点

  • 可以实现硬删除或软删除
  • 需要考虑级联删除(writes、blobs 等)

同步 vs 异步接口

LangGraph 要求同时实现同步和异步两套接口。在实际开发中,我们有两种选择:

选择 1:真正的异步驱动

python 复制代码
# 使用 motor(MongoDB 异步驱动)
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
    doc = await self._collection.find_one({"thread_id": thread_id})
    return self._deserialize(doc)

优点 :真正的非阻塞 I/O,高并发性能好
缺点:代码复杂度高,motor 生态不如 pymongo 成熟

选择 2:同步驱动 + 异步包装

python 复制代码
# 使用 pymongo(同步驱动)
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
    return self.get_tuple(config)  # 直接调用同步方法

优点 :实现简单,pymongo 生态成熟
缺点:在异步环境中会被放到线程池执行

我的选择 :对于 checkpoint 这种低频操作(每轮对话才保存一次),选择 pymongo + 异步包装 已经足够,而且大大降低了实现复杂度。

2.2 InMemorySaver - 官方参考实现

理解了 BaseCheckpointSaver 的接口规范后,让我们看看 LangGraph 官方提供的 InMemorySaver 是如何实现这些抽象方法的。

InMemorySaver 的核心设计

python 复制代码
class InMemorySaver(BaseCheckpointSaver[str]):
    """内存版 Checkpointer(官方参考实现)"""
    
    def __init__(self, *, serde: Optional[SerializerProtocol] = None):
        super().__init__(serde=serde)
        # 核心数据结构
        self.storage = defaultdict(lambda: defaultdict(dict))
        self.writes = defaultdict(dict)
        self.blobs = {}

核心数据结构解析

python 复制代码
# 1. storage: 存储 checkpoint 核心数据
storage: defaultdict[
    str,  # thread_id
    dict[
        str,  # checkpoint_ns (命名空间)
        dict[
            str,  # checkpoint_id
            tuple[
                tuple[str, bytes],  # checkpoint 序列化数据
                tuple[str, bytes],  # metadata 序列化数据
                Optional[str]       # parent_checkpoint_id
            ]
        ]
    ]
]

# 2. writes: 存储中间写入操作
writes: defaultdict[
    tuple[str, str, str],  # (thread_id, checkpoint_ns, checkpoint_id)
    dict[
        tuple[str, int],   # (task_id, idx)
        tuple[str, str, tuple[str, bytes], str]  # 写入数据
    ]
]

# 3. blobs: 存储 channel values
blobs: dict[
    tuple[str, str, str, Union[str, int, float]],  # (thread_id, ns, channel, version)
    tuple[str, bytes]  # 序列化的 channel value
]

为什么这样设计?

  1. 三层嵌套隔离 - thread_id -> checkpoint_ns -> checkpoint_id,支持多线程、多命名空间
  2. Blobs 分离存储 - channel values 单独存储,避免重复序列化大对象
  3. Writes 增量存储 - 中间写入单独管理,支持 subgraph 状态回溯

InMemorySaver 实现的抽象方法

让我们看看 InMemorySaver 如何实现 BaseCheckpointSaver 的核心方法:

1. get_tuple() 实现

ini 复制代码
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
    """获取 checkpoint(内存版)"""
    thread_id = config["configurable"]["thread_id"]
    checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
    checkpoint_id = get_checkpoint_id(config)
    
    if checkpoint_id:
        # 获取指定的 checkpoint
        if saved := self.storage[thread_id][checkpoint_ns].get(checkpoint_id):
            checkpoint, metadata, parent_checkpoint_id = saved
            # 从 blobs 加载 channel values
            checkpoint_ = self.serde.loads_typed(checkpoint)
            checkpoint_["channel_values"] = self._load_blobs(
                thread_id, checkpoint_ns, checkpoint_["channel_versions"]
            )
            return CheckpointTuple(
                config=config,
                checkpoint=checkpoint_,
                metadata=self.serde.loads_typed(metadata),
                parent_config=...
            )
    else:
        # 获取最新的 checkpoint
        if checkpoints := self.storage[thread_id][checkpoint_ns]:
            checkpoint_id = max(checkpoints.keys())  # 🎯 取最大 ID(最新)
            # ... 类似上面的逻辑

关键点

  • 使用 三层字典嵌套 快速定位 checkpoint
  • 使用 max(checkpoints.keys()) 获取最新 checkpoint
  • 通过 _load_blobs() 从 blobs 字典加载 channel values

2. put() 实现

python 复制代码
def put(
    self,
    config: RunnableConfig,
    checkpoint: Checkpoint,
    metadata: CheckpointMetadata,
    new_versions: ChannelVersions,
) -> RunnableConfig:
    """保存 checkpoint(内存版)"""
    thread_id = config["configurable"]["thread_id"]
    checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
    
    # 1. 从 checkpoint 中分离 channel_values
    c = checkpoint.copy()
    c.pop("pending_sends")
    values = c.pop("channel_values")
    
    # 2. 将 channel_values 存储到 blobs
    for k, v in new_versions.items():
        self.blobs[(thread_id, checkpoint_ns, k, v)] = (
            self.serde.dumps_typed(values[k]) if k in values else ("empty", b"")
        )
    
    # 3. 存储 checkpoint 核心数据
    self.storage[thread_id][checkpoint_ns][checkpoint["id"]] = (
        self.serde.dumps_typed(c),
        self.serde.dumps_typed(metadata),
        config["configurable"].get("checkpoint_id"),  # parent
    )
    
    return {
        "configurable": {
            "thread_id": thread_id,
            "checkpoint_ns": checkpoint_ns,
            "checkpoint_id": checkpoint["id"],
        }
    }

关键点

  • 分离存储 - checkpoint 核心数据和 channel_values 分开存储
  • 版本管理 - 通过 new_versions 管理 channel 的版本变化
  • 父子关系 - 记录 parent_checkpoint_id,支持状态回溯

3. list() 实现

python 复制代码
def list(
    self,
    config: Optional[RunnableConfig],
    *,
    filter: Optional[Dict[str, Any]] = None,
    before: Optional[RunnableConfig] = None,
    limit: Optional[int] = None,
) -> Iterator[CheckpointTuple]:
    """列出 checkpoints(内存版)"""
    thread_ids = (config["configurable"]["thread_id"],) if config else self.storage
    
    for thread_id in thread_ids:
        for checkpoint_ns in self.storage[thread_id].keys():
            for checkpoint_id, (checkpoint, metadata_b, parent_id) in sorted(
                self.storage[thread_id][checkpoint_ns].items(),
                key=lambda x: x[0],  # 按 checkpoint_id 排序
                reverse=True,        # 降序(最新的在前)
            ):
                # 应用过滤条件
                metadata = self.serde.loads_typed(metadata_b)
                if filter and not all(
                    query_value == metadata.get(query_key)
                    for query_key, query_value in filter.items()
                ):
                    continue
                
                # 应用 limit
                if limit is not None and limit <= 0:
                    break
                elif limit is not None:
                    limit -= 1
                
                # 加载并返回 checkpoint
                checkpoint_ = self.serde.loads_typed(checkpoint)
                checkpoint_["channel_values"] = self._load_blobs(...)
                yield CheckpointTuple(...)

关键点

  • 迭代器模式 - 返回 Iterator,支持惰性加载
  • 排序逻辑 - 按 checkpoint_id 降序,最新的在前
  • 过滤和限制 - 支持 metadata 过滤和数量限制

4. delete_thread() 实现

python 复制代码
def delete_thread(self, thread_id: str) -> None:
    """删除 thread(内存版)"""
    # 1. 删除 storage 中的数据
    if thread_id in self.storage:
        del self.storage[thread_id]
    
    # 2. 删除 writes 中的数据
    for k in list(self.writes.keys()):
        if k[0] == thread_id:
            del self.writes[k]
    
    # 3. 删除 blobs 中的数据
    for k in list(self.blobs.keys()):
        if k[0] == thread_id:
            del self.blobs[k]

关键点

  • 级联删除 - 需要删除 storage、writes、blobs 三处的数据
  • 遍历删除 - writes 和 blobs 使用元组 key,需要遍历查找

InMemorySaver 的优势与局限

维度 优势 局限
性能 O(1) 随机访问 无序列化成本 内存级速度 重启后数据丢失 内存有限
设计 三层嵌套隔离清晰 Blobs 分离避免重复 版本管理精细 不支持持久化 无法分布式部署
适用场景 开发测试 临时对话 性能测试 ❌ 生产环境 ❌ 长期存储 ❌ 多实例部署

从 InMemorySaver 到 MongoDB:设计转换

理解了 InMemorySaver 的设计后,我们在实现 MongoDB 版本时需要做出调整:

设计元素 InMemorySaver MongoDBCheckpointer
数据结构 三层嵌套字典 扁平化文档
索引策略 无需索引(内存) 复合索引优化查询
Blobs 存储 分离存储(版本管理) 合并存储(简化设计)
序列化 使用 serde 协议 使用 pickle
查询最新 max(keys()) sort([("created_at", -1)])
级联删除 三处手动删除 MongoDB 单次删除

核心思想

  • InMemorySaver 追求 性能和精细控制
  • MongoDBCheckpointer 追求 持久化和简单可靠

三、MongoDB 记忆缓存实现

3.1 框架设计

KaFlow-Py 的记忆系统采用工厂模式 + 策略模式的设计,支持多种存储后端的无缝切换。

整体架构

bash 复制代码
src/memory/
├── __init__.py              # 模块导出
├── base.py                  # BaseCheckpointer 抽象基类
├── factory.py               # CheckpointerFactory 工厂类
├── memory_checkpointer.py   # MemoryCheckpointer (内存实现)
└── mongodb_checkpointer.py  # MongoDBCheckpointer (MongoDB 实现)

设计模式

1. 抽象基类(BaseCheckpointer)

python 复制代码
# src/memory/base.py
from abc import ABC, abstractmethod
from langgraph.checkpoint.base import BaseCheckpointSaver

class BaseCheckpointer(BaseCheckpointSaver, ABC):
    """
    Checkpointer 抽象基类
    
    继承自 LangGraph 的 BaseCheckpointSaver,添加了 KaFlow-Py 特有的功能
    """
    
    def __init__(self, config: Optional[Dict[str, Any]] = None):
        super().__init__()
        self.config = config or {}
        self._is_connected = False
    
    @abstractmethod
    async def setup(self) -> None:
        """初始化连接(子类实现)"""
        pass
    
    @abstractmethod
    async def teardown(self) -> None:
        """清理资源(子类实现)"""
        pass
    
    @property
    def is_connected(self) -> bool:
        """连接状态"""
        return self._is_connected
    
    def validate_config(self) -> bool:
        """验证配置(子类可重写)"""
        return True

2. 工厂类(CheckpointerFactory)

python 复制代码
# src/memory/factory.py
class CheckpointerType(str, Enum):
    """Checkpointer 类型枚举"""
    MEMORY = "memory"
    MONGODB = "mongodb"
    REDIS = "redis"          # 预留
    POSTGRESQL = "postgresql"  # 预留

class CheckpointerFactory:
    """
    Checkpointer 工厂类
    
    负责根据配置创建相应的 checkpointer 实例
    使用单例模式,确保每种类型只创建一次
    """
    
    # 实例缓存
    _instances: Dict[str, BaseCheckpointer] = {}
    
    @classmethod
    def create(
        cls, 
        provider: str, 
        config: Optional[Dict[str, Any]] = None
    ) -> Optional[BaseCheckpointer]:
        """
        创建 Checkpointer 实例
        
        Args:
            provider: 提供商类型 (memory, mongodb, redis, ...)
            config: 配置字典
            
        Returns:
            Checkpointer 实例或 None
        """
        config = config or {}
        
        # 生成缓存 key
        cache_key = f"{provider}_{hash(str(sorted(config.items())))}"
        
        # 检查缓存
        if cache_key in cls._instances:
            logger.debug(f"♻️ 复用已存在的 {provider} checkpointer")
            return cls._instances[cache_key]
        
        # 根据 provider 创建
        if provider == CheckpointerType.MEMORY:
            checkpointer = MemoryCheckpointer(config)
        elif provider == CheckpointerType.MONGODB:
            checkpointer = MongoDBCheckpointer(config)
        else:
            logger.error(f"❌ 不支持的 checkpointer 类型: {provider}")
            return None
        
        # 缓存实例
        cls._instances[cache_key] = checkpointer
        return checkpointer

YAML 配置驱动

KaFlow-Py 通过 YAML 配置文件控制记忆系统的行为:

bash 复制代码
# ops_agent.yaml
global_config:
  memory:
    enabled: true                    # 是否启用记忆
    provider: "mongodb"              # 🎯 存储提供商(切换点)
    connection:                      # 连接配置
      host: "127.0.0.1"
      port: 27017
      database: "kaflow"
      collection: "checkpoints"
      username: "test"
      password: "${MEMORY_PASSWORD}"  # 环境变量
      auth_source: "admin"

切换存储后端只需修改 provider

bash 复制代码
# 切换到内存模式
memory:
  enabled: true
  provider: "memory"  # ✅ 改这里即可

工作流程

markdown 复制代码
1. 加载 YAML 配置
   ↓
2. 解析 memory.provider (如 "mongodb")
   ↓
3. CheckpointerFactory.create("mongodb", config)
   ↓
4. 创建 MongoDBCheckpointer 实例
   ↓
5. 调用 setup() 初始化连接
   ↓
6. 传给 LangGraph: graph.compile(checkpointer=checkpointer)
   ↓
7. LangGraph 调用 get_tuple/put 等方法

3.2 MongoDB Checkpointer 实现

数据结构设计

perl 复制代码
# MongoDB 文档结构
{
    "thread_id": "user@example.com_uuid_config_id",
    "checkpoint_id": "1704636800.123456",
    "parent_checkpoint_id": "1704636700.123456",
    "checkpoint_data": Binary(pickle),  # 完整的 Checkpoint 对象
    "metadata": {
        "source": "update",
        "step": 1,
        "writes": null
    },
    "username": "user@example.com",  # 从 thread_id 解析
    "created_at": ISODate("2025-01-07T10:00:00.000+08:00"),
    "updated_at": ISODate("2025-01-07T10:00:00.000+08:00")
}

设计要点

  1. 扁平化结构 - 不像 InMemorySaver 嵌套三层,而是扁平存储,便于索引
  2. pickle 序列化 - 直接序列化整个 Checkpoint,避免分离 blobs 的复杂性
  3. 提取 username - 从 thread_id 解析用户名,支持按用户查询
  4. 时区感知 - 使用东八区时间,避免时区转换问题

索引策略

php 复制代码
# 创建索引
self._collection.create_index([
    ("thread_id", ASCENDING), 
    ("checkpoint_id", DESCENDING)
])  # 复合索引:快速查询某个 thread 的最新 checkpoint

self._collection.create_index([
    ("created_at", DESCENDING)
])  # 时间索引:按时间排序

self._collection.create_index([
    ("username", ASCENDING), 
    ("created_at", DESCENDING)
])  # 用户索引:按用户查询历史

核心实现

1. 初始化和连接管理

python 复制代码
class MongoDBCheckpointer(BaseCheckpointer):
    # 🎯 类级别的单例客户端(避免线程泄漏)
    _shared_client = None
    _shared_client_uri = None
    _client_lock = None
    
    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        
        # 提取配置
        self.host = config.get("host", "localhost")
        self.port = int(config.get("port", 27017))
        self.database_name = config.get("database", "kaflow")
        self.collection_name = config.get("collection", "checkpoints")
        self.username = config.get("username", "")
        self.password = config.get("password", "")
        self.auth_source = config.get("auth_source", "admin")
        
        # 🎯 从环境变量提取密码
        if self.username and self.password:
            password = os.environ.get(
                self.password.replace("${", "").replace("}", ""),
                self.password
            )
            self._uri = f"mongodb://{self.username}:{password}@{self.host}:{self.port}/{self.database_name}?authSource={self.auth_source}"
        else:
            self._uri = f"mongodb://{self.host}:{self.port}/{self.database_name}"
        
        # 初始化锁
        if MongoDBCheckpointer._client_lock is None:
            import threading
            MongoDBCheckpointer._client_lock = threading.Lock()
    
    def _get_or_create_shared_client(self):
        """
        获取或创建共享的 MongoDB 客户端(单例模式)
        
        关键:避免每个实例都创建 MongoClient,导致线程泄漏
        """
        with MongoDBCheckpointer._client_lock:
            # 如果已有客户端且 URI 匹配,直接复用
            if (MongoDBCheckpointer._shared_client is not None and 
                MongoDBCheckpointer._shared_client_uri == self._uri):
                return MongoDBCheckpointer._shared_client
            
            # 创建新的客户端
            from pymongo import MongoClient
            cn_tz = timezone(timedelta(hours=8))
            
            MongoDBCheckpointer._shared_client = MongoClient(
                self._uri,
                serverSelectionTimeoutMS=5000,
                connectTimeoutMS=5000,
                maxPoolSize=5,  # Docker 环境减少连接数
                minPoolSize=1,
                connect=False,  # 延迟连接
                tz_aware=True,  # 时区感知
                tzinfo=cn_tz    # 东八区
            )
            MongoDBCheckpointer._shared_client_uri = self._uri
            
            # 测试连接
            MongoDBCheckpointer._shared_client.admin.command('ping')
            logger.info("✅ MongoDB 共享客户端创建成功")
            
            return MongoDBCheckpointer._shared_client

2. get_tuple() - 读取 Checkpoint

python 复制代码
def get_tuple(self, config: Dict[str, Any]) -> Optional[CheckpointTuple]:
    """获取指定的 checkpoint"""
    if not self._ensure_connected():
        return None
    
    thread_id = config.get("configurable", {}).get("thread_id")
    checkpoint_id = config.get("configurable", {}).get("checkpoint_id")
    
    # 构建查询
    query = {"thread_id": thread_id}
    if checkpoint_id:
        query["checkpoint_id"] = checkpoint_id
    
    # 查询最新的 checkpoint(按创建时间降序)
    doc = self._collection.find_one(query, sort=[("created_at", -1)])
    
    if not doc:
        return None
    
    # 反序列化 checkpoint
    checkpoint = pickle.loads(doc["checkpoint_data"])
    metadata = doc.get("metadata", {})
    
    # 构建 CheckpointTuple
    return CheckpointTuple(
        config=config,
        checkpoint=checkpoint,
        metadata=metadata,
        parent_config={
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_id": doc.get("parent_checkpoint_id"),
            }
        } if doc.get("parent_checkpoint_id") else None,
    )

3. put() - 保存 Checkpoint

python 复制代码
def put(
    self, 
    config: Dict[str, Any], 
    checkpoint: Checkpoint, 
    metadata: CheckpointMetadata, 
    new_versions: ChannelVersions
) -> Dict[str, Any]:
    """保存 checkpoint"""
    if not self._ensure_connected():
        return config
    
    thread_id = config.get("configurable", {}).get("thread_id")
    checkpoint_id = checkpoint.get("id", str(self._get_cn_time().timestamp()))
    parent_checkpoint_id = config.get("configurable", {}).get("checkpoint_id")
    
    # 序列化 checkpoint
    checkpoint_data = pickle.dumps(checkpoint)
    
    # 从 thread_id 提取 username
    username = self._extract_username_from_thread_id(thread_id)
    
    # 获取东八区时间
    cn_time = self._get_cn_time()
    
    # 构建文档
    doc = {
        "thread_id": thread_id,
        "checkpoint_id": checkpoint_id,
        "parent_checkpoint_id": parent_checkpoint_id,
        "checkpoint_data": checkpoint_data,
        "metadata": dict(metadata) if metadata else {},
        "username": username,
        "created_at": cn_time,
        "updated_at": cn_time,
    }
    
    # upsert(存在则更新,不存在则插入)
    self._collection.update_one(
        {"thread_id": thread_id, "checkpoint_id": checkpoint_id},
        {"$set": doc},
        upsert=True
    )
    
    logger.debug(f"💾 checkpoint 已保存: {checkpoint_id}")
    
    return {
        "configurable": {
            "thread_id": thread_id,
            "checkpoint_id": checkpoint_id,
        }
    }

4. list() - 列出 Checkpoints

python 复制代码
def list(
    self, 
    config: Dict[str, Any], 
    *, 
    limit: Optional[int] = None, 
    before: Optional[Dict[str, Any]] = None
) -> Iterator[CheckpointTuple]:
    """列出指定 thread 的所有 checkpoints"""
    if not self._ensure_connected():
        return
    
    thread_id = config.get("configurable", {}).get("thread_id")
    query = {"thread_id": thread_id}
    
    # before 条件(用于分页)
    if before:
        before_checkpoint_id = before.get("configurable", {}).get("checkpoint_id")
        if before_checkpoint_id:
            before_doc = self._collection.find_one({"checkpoint_id": before_checkpoint_id})
            if before_doc:
                query["created_at"] = {"$lt": before_doc["created_at"]}
    
    # 查询
    cursor = self._collection.find(query).sort("created_at", -1)
    if limit:
        cursor = cursor.limit(limit)
    
    # 迭代返回
    for doc in cursor:
        checkpoint = pickle.loads(doc["checkpoint_data"])
        yield CheckpointTuple(
            config=config,
            checkpoint=checkpoint,
            metadata=doc.get("metadata", {}),
            parent_config={
                "configurable": {
                    "thread_id": thread_id,
                    "checkpoint_id": doc.get("parent_checkpoint_id"),
                }
            } if doc.get("parent_checkpoint_id") else None,
        )

5. 异步接口实现

python 复制代码
# 直接调用同步方法(pymongo 是同步驱动)
async def aget_tuple(self, config: Dict[str, Any]) -> Optional[CheckpointTuple]:
    return self.get_tuple(config)

async def aput(self, config, checkpoint, metadata, new_versions):
    return self.put(config, checkpoint, metadata, new_versions)

async def alist(self, config, *, filter=None, before=None, limit=None):
    for item in self.list(config, before=before, limit=limit):
        yield item

辅助功能

1. 从 thread_id 提取 username

python 复制代码
@staticmethod
def _extract_username_from_thread_id(thread_id: str) -> Optional[str]:
    """
    从 thread_id 中提取 username
    
    thread_id 格式: username_uuid_config_id
    例如: user@example.com_uuid_1
    """
    if not thread_id:
        return None
    
    parts = thread_id.split('_')
    if len(parts) < 3:
        return None
    
    return parts[0]  # 第一个部分是 username

2. 东八区时间

python 复制代码
@staticmethod
def _get_cn_time() -> datetime:
    """获取东八区(中国)时间"""
    cn_tz = timezone(timedelta(hours=8))
    return datetime.now(cn_tz)

3. 消息去重

python 复制代码
def _format_messages(self, messages: list) -> list:
    """
    格式化消息列表(自动去重)
    
    原因:LangGraph 的 checkpoint 包含完整消息列表,
         每个 checkpoint 都会重复之前的消息
    """
    formatted = []
    seen_human_contents = set()
    
    for msg in messages:
        msg_type = type(msg).__name__
        content = msg.content if hasattr(msg, "content") else ""
        
        # 去重逻辑:跳过重复的 HumanMessage
        if "Human" in msg_type:
            # 子串匹配(因为有些消息可能被截断)
            is_duplicate = any(
                seen_content in content 
                for seen_content in seen_human_contents
            )
            if is_duplicate:
                continue
            seen_human_contents.add(content)
        
        formatted.append({
            "type": msg_type,
            "content": content,
            "role": self._infer_role(msg_type),
        })
    
    return formatted

3.3 历史对话查询 API

除了 LangGraph 要求的接口,MongoDBCheckpointer 还提供了额外的查询功能。

1. 获取会话列表

php 复制代码
def get_thread_list(
    self,
    username: Optional[str] = None,
    page: int = 1,
    page_size: int = 20,
    order: str = "desc"
) -> Dict[str, Any]:
    """
    获取会话列表(支持按用户筛选)
    
    Args:
        username: 用户名(可选)
        page: 页码
        page_size: 每页大小
        order: 排序方式(desc/asc)
        
    Returns:
        {
            "total": 总数,
            "threads": [
                {
                    "thread_id": "...",
                    "username": "...",
                    "first_message": "...",
                    "last_updated": "...",
                    "message_count": 10
                }
            ]
        }
    """
    # 使用 MongoDB 聚合管道
    pipeline = []
    
    # 1. 筛选用户
    if username:
        pipeline.append({"$match": {"username": username}})
    
    # 2. 按 thread_id 分组
    pipeline.append({"$sort": {"thread_id": 1, "created_at": 1}})
    pipeline.append({
        "$group": {
            "_id": "$thread_id",
            "username": {"$first": "$username"},
            "last_updated": {"$max": "$updated_at"},
            "message_count": {"$sum": 1},
            "latest_checkpoint": {"$last": "$checkpoint_data"},
        }
    })
    
    # 3. 排序
    pipeline.append({
        "$sort": {"last_updated": -1 if order == "desc" else 1}
    })
    
    # 执行查询
    all_threads = list(self._collection.aggregate(pipeline))
    
    # 分页
    start_idx = (page - 1) * page_size
    end_idx = start_idx + page_size
    paginated_threads = all_threads[start_idx:end_idx]
    
    # 格式化结果
    threads = []
    for thread_data in paginated_threads:
        # 从 checkpoint 中提取第一条消息
        first_message = self._extract_first_message(
            thread_data.get("latest_checkpoint")
        )
        
        threads.append({
            "thread_id": thread_data["_id"],
            "username": thread_data.get("username"),
            "first_message": first_message,
            "last_updated": thread_data["last_updated"].isoformat(),
            "message_count": thread_data.get("message_count", 0),
        })
    
    return {
        "total": len(all_threads),
        "threads": threads
    }

2. 获取对话消息

ini 复制代码
def get_flat_messages(
    self,
    thread_id: str,
    page: int = 1,
    page_size: int = 20,
    order: str = "desc"
) -> Dict[str, Any]:
    """
    获取指定 thread 的消息列表(按单条消息分页)
    
    Args:
        thread_id: 线程 ID
        page: 页码
        page_size: 每页大小
        order: 排序方式
        
    Returns:
        {
            "total": 消息总数,
            "messages": [
                {
                    "type": "HumanMessage",
                    "content": "...",
                    "role": "human",
                    "timestamp": "..."
                }
            ]
        }
    """
    # 1. 获取最新的 checkpoint(包含完整对话历史)
    latest_checkpoint = self._collection.find_one(
        {"thread_id": thread_id},
        sort=[("created_at", DESCENDING)]
    )
    
    if not latest_checkpoint:
        return {"total": 0, "messages": []}
    
    # 2. 反序列化并提取消息
    checkpoint = pickle.loads(latest_checkpoint["checkpoint_data"])
    checkpoint_messages = checkpoint.get("channel_values", {}).get("messages", [])
    
    # 3. 格式化所有消息
    all_messages = self._format_messages(checkpoint_messages)
    
    # 4. 根据排序方式调整顺序
    if order == "desc":
        all_messages = list(reversed(all_messages))
    
    # 5. 分页
    total = len(all_messages)
    start_idx = (page - 1) * page_size
    end_idx = start_idx + page_size
    paginated_messages = all_messages[start_idx:end_idx]
    
    return {
        "total": total,
        "messages": paginated_messages
    }

3. 获取 Checkpoint 历史

python 复制代码
def get_history_messages(
    self, 
    thread_id: str, 
    page: int = 1, 
    page_size: int = 20,
    order: str = "desc"
) -> Dict[str, Any]:
    """
    获取指定 thread 的历史 checkpoints(按 checkpoint 分页)
    
    用途:查看对话的演变过程,每个 checkpoint 是一个快照
    """
    total = self._collection.count_documents({"thread_id": thread_id})
    skip = (page - 1) * page_size
    
    cursor = self._collection.find(
        {"thread_id": thread_id}
    ).sort("created_at", DESCENDING if order == "desc" else ASCENDING).skip(skip).limit(page_size)
    
    messages = []
    for doc in cursor:
        checkpoint = pickle.loads(doc["checkpoint_data"])
        checkpoint_messages = checkpoint.get("channel_values", {}).get("messages", [])
        
        messages.append({
            "checkpoint_id": doc["checkpoint_id"],
            "messages": self._format_messages(checkpoint_messages),
            "created_at": doc["created_at"].isoformat(),
        })
    
    return {
        "total": total,
        "messages": messages
    }

四、如何使用

4.1 YAML 配置

在场景配置文件中启用 MongoDB:

bash 复制代码
# src/core/config/ops_agent.yaml
global_config:
  memory:
    enabled: true                          # 是否启用记忆
    provider: "mongodb"                    # 存储提供商: memory|mongodb
    connection:
      host: "127.0.0.1"                    # MongoDB 主机
      port: 27017                          # MongoDB 端口
      database: "kaflow"                   # 数据库名称
      collection: "checkpoints"            # 集合名称
      username: "test"                     # 用户名
      password: "${MEMORY_PASSWORD}"       # 密码(环境变量)
      auth_source: "admin"                 # 认证数据库

切换到内存模式

yaml 复制代码
global_config:
  memory:
    enabled: true
    provider: "memory"  # ✅ 只需改这里

4.2 环境变量配置

在项目根目录创建 .env 文件:

ini 复制代码
# .env

# MongoDB 密码(必须)
MEMORY_PASSWORD=your_mongodb_password_here

# LLM API Key
DEEPSEEK_API_KEY=sk-xxx

安全提示

  • ⚠️ 永远不要将 .env 文件提交到 Git 仓库
  • .gitignore 中添加 .env
  • 生产环境使用 Docker Secrets 或云服务的密钥管理

4.3 环境变量提取机制

MongoDBCheckpointer 会自动从环境变量中提取密码:

python 复制代码
# mongodb_checkpointer.py
if self.username and self.password:
    # 如果 password 是 "${MEMORY_PASSWORD}",则从环境变量中获取
    password = os.environ.get(
        self.password.replace("${", "").replace("}", ""),
        self.password  # 如果环境变量不存在,使用原值
    )

工作原理

  1. YAML 中配置 password: "${MEMORY_PASSWORD}"
  2. 提取环境变量名:MEMORY_PASSWORD
  3. os.environ 中获取实际密码值
  4. 构建 MongoDB 连接 URI

五、总结

核心收获

经过这次实现,让我对 LangGraph 的 Checkpoint 机制有了更深入的理解:

  1. 理解接口契约 - LangGraph 定义了 10 个抽象方法,必须全部实现才能正常工作
  2. 参考官方实现 - InMemorySaver 的代码质量很高,设计思路可以直接借鉴
  3. 单例模式很重要 - 在涉及数据库连接时,必须考虑单例,避免资源泄漏
  4. 工厂模式的优雅 - 通过工厂模式实现存储后端的无缝切换,配置驱动非常灵活

设计亮点

1. 工厂模式 + 策略模式

通过 CheckpointerFactory 实现存储后端的无缝切换,用户只需修改 YAML 配置中的 provider 即可。

2. 类级别单例客户端

所有 MongoDBCheckpointer 实例共享同一个 MongoClient,避免线程泄漏和资源浪费。

3. 环境变量注入

敏感信息(如密码)通过环境变量传入,支持 ${ENV_VAR} 语法,安全又灵活。

4. 扁平化存储 + 索引优化

不像 InMemorySaver 嵌套三层,而是扁平存储,配合索引实现高效查询。

5. 扩展查询 API

除了 LangGraph 要求的接口,还提供了 get_thread_listget_flat_messages 等实用功能。


项目地址

如果这篇文章对你有帮助,欢迎点赞、转发、Star 支持!🙏

相关推荐
hnode20 小时前
🚀 前端开发者的 AI 入门指南:5 分钟搭建你的第一个 RAG 智能问答系统
langchain
阿里云云原生20 小时前
一起聊聊大规模 AI Agent 部署与运维实战
agent
大模型真好玩1 天前
LangChain1.0实战之多模态RAG系统(二)——多模态RAG系统图片分析与语音转写功能实现
人工智能·langchain·mcp
大模型教程1 天前
谷歌AI Agent技术指南深度解读,从概念到生产
langchain·llm·agent
大模型教程1 天前
一张图拆解 AI Agent 的“五脏六腑”,从感知到进化的完整逻辑!
程序员·llm·agent
爱装代码的小瓶子1 天前
【初识AI】大模型和LangChain?
人工智能·langchain
AI大模型1 天前
一文了解LLM应用架构:从Prompt到Multi-Agent
程序员·llm·agent
AI大模型1 天前
LangChain、LangGraph、LangSmith这些AI开发框架有什么区别?一篇文章解释清楚
langchain·llm·agent
爬点儿啥1 天前
[Ai Agent] 09 LangGraph 进阶:构建可控、可协作的多智能体系统
人工智能·ai·langchain·大模型·agent·langgraph
聚梦小课堂1 天前
在n8n中,让AI Agent能正常使用tools的方法
人工智能·agent·n8n