LangGraph 记忆机制:基于 Checkpointer 的状态持久化

前言

在构建 AI 对话系统时,记忆能力是衡量用户体验的关键指标之一。一个没有记忆的 AI 助手,每次对话都像是第一次认识用户,这显然无法满足真实场景的需求。

LangGraph 提供了优雅的解决方案 ------ Checkpointer(检查点)机制。本文将深入解析这一机制,并通过完整的代码示例,带你实现一个具备多轮对话记忆能力的 AI 助手。


核心概念解析

1. Checkpointer:状态持久化的核心

Checkpointer 是 LangGraph 实现状态持久化的基础设施,它承担着三个核心职责:

功能 说明
状态保存 在节点执行后自动持久化当前状态
状态恢复 根据配置加载历史状态,实现上下文延续
会话隔离 通过唯一标识符区分不同对话会话

2. Thread ID:会话的唯一标识

thread_id 是区分不同对话会话的关键:

  • 同一会话 :相同的 thread_id 共享完整对话历史
  • 会话隔离 :不同的 thread_id 之间数据完全隔离
  • 多场景支持:可支持多用户、多会话的复杂场景

3. MemorySaver:内存检查点实现

MemorySaver 是 LangGraph 内置的检查点实现:

python 复制代码
from langgraph.checkpoint.memory import MemorySaver

# 内存存储,适合开发和测试
memory = MemorySaver()

# 生产环境建议使用持久化存储
# from langgraph.checkpoint.sqlite import SqliteSaver
# from langgraph.checkpoint.postgres import PostgresSaver

完整代码实现

python 复制代码
"""
LangGraph 教程 - 为聊天机器人添加记忆

本示例在基础聊天机器人的基础上,添加持久化检查点(checkpointer)实现多轮对话记忆。
通过 MemorySaver 和 thread_id,聊天机器人可以记住之前的交互上下文。

官方教程地址:https://langchain-ai.github.io/langgraph/tutorials/introduction/
"""

# 过滤警告信息
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from langchain_core.messages import ToolMessage, AIMessage
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver
from dotenv import load_dotenv
import json
import os

# 加载环境变量
load_dotenv()

# 检查 Tavily API Key 是否设置
if not os.getenv("TAVILY_API_KEY"):
    raise ValueError("TAVILY_API_KEY 未设置,请在 .env 文件中配置")


# ==================== 1. 定义状态 ====================
class State(TypedDict):
    """
    定义图的状态结构。
    
    messages: 消息列表,使用 add_messages reducer 函数
              确保新消息追加到列表,而不是覆盖
    """
    messages: Annotated[list, add_messages]


# ==================== 2. 创建图 ====================
def create_graph():
    """
    创建并编译 StateGraph,添加 checkpointer 实现记忆功能。
    
    Returns:
        编译后的图对象
    """
    # 创建图构建器
    graph_builder = StateGraph(State)
    
    # 初始化模型
    llm = ChatOpenAI(
        model="Qwen/Qwen3-Next-80B-A3B-Instruct",
        openai_api_key=os.getenv("SILICONFLOW_API_KEY"),
        openai_api_base="https://api.siliconflow.cn/v1",
        temperature=0.7
    )
    
    # 创建 Tavily 搜索工具
    tool = TavilySearchResults(max_results=2)
    tools = [tool]
    
    # 绑定工具到 LLM
    llm_with_tools = llm.bind_tools(tools)
    
    # 定义聊天机器人节点
    def chatbot(state: State):
        """
        聊天机器人节点。
        
        Args:
            state: 当前状态
            
        Returns:
            包含 LLM 响应的字典
        """
        return {"messages": [llm_with_tools.invoke(state["messages"])]}
    
    # 添加节点
    graph_builder.add_node("chatbot", chatbot)
    
    # 使用 LangGraph 预定义的 ToolNode(替代 BasicToolNode)
    tool_node = ToolNode(tools=tools)
    graph_builder.add_node("tools", tool_node)
    
    # 添加边
    graph_builder.add_edge(START, "chatbot")
    
    # 添加条件边:从 chatbot 到 tools 或 END
    # 使用 LangGraph 预定义的 tools_condition
    graph_builder.add_conditional_edges(
        "chatbot",
        tools_condition,
    )
    
    # 添加边:从 tools 回到 chatbot(形成循环)
    graph_builder.add_edge("tools", "chatbot")
    
    # ==================== 关键:添加 MemorySaver 检查点 ====================
    # 创建内存中的检查点(生产环境可使用 SqliteSaver 或 PostgresSaver)
    memory = MemorySaver()
    
    # 编译图时传入 checkpointer
    return graph_builder.compile(checkpointer=memory)


# ==================== 3. 运行聊天机器人 ====================
def stream_graph_updates(graph, user_input: str, config: dict):
    """
    流式处理图更新。
    
    Args:
        graph: 编译后的图对象
        user_input: 用户输入的消息
        config: 包含 thread_id 的配置字典
    """
    # 注意:config 是第二个位置参数!
    for event in graph.stream(
        {"messages": [{"role": "user", "content": user_input}]},
        config,
        stream_mode="values"
    ):
        if "messages" in event:
            last_message = event["messages"][-1]
            # 只打印 AI 消息(助手回复),不打印用户消息和工具消息
            if isinstance(last_message, AIMessage) and last_message.content:
                print("助手:", last_message.content)


def main():
    """主函数 - 运行交互式聊天机器人。"""
    print("🤖 LangGraph 带记忆功能的聊天机器人已启动!")
    print("=" * 50)
    print("提示:")
    print("  - 输入 'quit'、'exit' 或 'q' 退出对话")
    print("  - 输入 'new' 开始新会话(切换 thread_id)")
    print("  - 聊天机器人可以记住之前的对话内容\n")
    
    # 创建图
    graph = create_graph()
    
    # ==================== 关键:使用 thread_id 配置 ====================
    # 默认会话 ID
    current_thread_id = "1"
    config = {"configurable": {"thread_id": current_thread_id}}
    
    print(f"当前会话 ID: {current_thread_id}")
    print("-" * 50)
    
    while True:
        try:
            # 获取用户输入
            user_input = input("用户: ")
            
            # 检查退出命令
            if user_input.lower() in ["quit", "exit", "q"]:
                print("\n👋 再见!")
                break
            
            # 检查是否切换会话
            if user_input.lower() == "new":
                current_thread_id = str(int(current_thread_id) + 1)
                config = {"configurable": {"thread_id": current_thread_id}}
                print(f"\n📝 已切换到新会话,会话 ID: {current_thread_id}")
                print("-" * 50)
                continue
            
            # 检查当前状态(可选,用于调试)
            if user_input.lower() == "state":
                snapshot = graph.get_state(config)
                print(f"\n📊 当前会话状态 (thread_id={current_thread_id}):")
                print(f"  消息数量: {len(snapshot.values.get('messages', []))}")
                print(f"  下一步: {snapshot.next}")
                print("-" * 50)
                continue
            
            # 处理用户输入并获取响应
            # 传入 config 使图能够加载和保存状态
            stream_graph_updates(graph, user_input, config)
            print()  # 空行分隔对话
            
        except KeyboardInterrupt:
            print("\n\n👋 再见!")
            break
        except Exception as e:
            print(f"发生错误: {e}")
            break


# ==================== 4. 演示记忆功能 ====================
def demo_memory():
    """
    演示记忆功能的示例函数。
    展示同一个 thread_id 能记住上下文,不同 thread_id 无法共享记忆。
    """
    print("🧪 演示记忆功能")
    print("=" * 50)
    
    graph = create_graph()
    
    # 会话 1:建立记忆
    print("\n【会话 1 - thread_id='1'】")
    config_1 = {"configurable": {"thread_id": "1"}}
    
    user_input = "你好,我叫张三"
    print(f"用户: {user_input}")
    for event in graph.stream(
        {"messages": [{"role": "user", "content": user_input}]},
        config_1,
        stream_mode="values"
    ):
        if "messages" in event:
            last_message = event["messages"][-1]
            # 只打印 AI 消息(助手回复)
            if isinstance(last_message, AIMessage) and last_message.content:
                print(f"助手: {last_message.content}")

    # 会话 1:测试记忆
    user_input = "我叫什么名字?"
    print(f"\n用户: {user_input}")
    for event in graph.stream(
        {"messages": [{"role": "user", "content": user_input}]},
        config_1,
        stream_mode="values"
    ):
        if "messages" in event:
            last_message = event["messages"][-1]
            # 只打印 AI 消息(助手回复)
            if isinstance(last_message, AIMessage) and last_message.content:
                print(f"助手: {last_message.content}")

    # 会话 2:没有记忆
    print("\n【会话 2 - thread_id='2'】")
    config_2 = {"configurable": {"thread_id": "2"}}

    user_input = "我叫什么名字?"
    print(f"用户: {user_input}")
    for event in graph.stream(
        {"messages": [{"role": "user", "content": user_input}]},
        config_2,
        stream_mode="values"
    ):
        if "messages" in event:
            last_message = event["messages"][-1]
            # 只打印 AI 消息(助手回复)
            if isinstance(last_message, AIMessage) and last_message.content:
                print(f"助手: {last_message.content}")
    
    # 检查状态
    print("\n【检查状态】")
    snapshot = graph.get_state(config_1)
    print(f"会话 1 的消息数量: {len(snapshot.values.get('messages', []))}")
    
    snapshot = graph.get_state(config_2)
    print(f"会话 2 的消息数量: {len(snapshot.values.get('messages', []))}")


if __name__ == "__main__":
    # 运行交互式聊天机器人
    # main()
    
    # 如需运行演示,取消下面这行的注释:
    demo_memory()

关键技术点详解

1. 集成 MemorySaver

python 复制代码
from langgraph.checkpoint.memory import MemorySaver

# 创建检查点实例
memory = MemorySaver()

# 编译时注入 checkpointer
graph = graph_builder.compile(checkpointer=memory)

核心逻辑compile(checkpointer=memory) 将检查点机制绑定到图实例,使其具备状态持久化能力。

2. 配置会话标识

python 复制代码
# 构建配置字典
config = {"configurable": {"thread_id": "session_001"}}

# 流式执行时传入配置
for event in graph.stream(input_data, config, stream_mode="values"):
    # 处理事件...

重要提示configstream()第二个位置参数,这个细节很容易出错。

3. 状态快照查询

python 复制代码
# 获取指定会话的完整状态
snapshot = graph.get_state(config)

# 访问状态数据
messages = snapshot.values.get("messages", [])
next_step = snapshot.next  # 下一个待执行节点

运行效果展示

执行 demo_memory() 后的输出:

现象解读

  • 相同 thread_id 的会话能够记住上下文信息
  • 不同 thread_id 的会话数据完全隔离
  • 消息计数反映了对话轮次的累积

踩坑记录

stream_mode="values" 的消息类型问题

使用 stream_mode="values" 时,event["messages"] 包含三种消息类型:

  1. HumanMessage - 用户输入
  2. AIMessage - AI 回复
  3. ToolMessage - 工具执行结果

踩坑场景 :直接输出 last_message 可能打印工具调用结果,影响用户体验。

解决方案

python 复制代码
last_message = event["messages"][-1]

# 类型过滤 + 内容校验
if isinstance(last_message, AIMessage) and last_message.content:
    print(f"助手: {last_message.content}")

总结

本文深入讲解了 LangGraph 的记忆机制实现:

核心要点

组件 作用 使用方式
MemorySaver 状态持久化 compile(checkpointer=memory)
thread_id 会话标识 config["configurable"]["thread_id"]
get_state 状态查询 graph.get_state(config)

进阶方向

  1. 持久化存储:使用 SQLite/PostgreSQL 替代内存存储
  2. 历史管理:实现滑动窗口,限制上下文长度
  3. 多租户架构 :将 thread_id 与用户体系关联

相关资源


作者:AI探索者 如果本文对你有帮助,欢迎点赞、收藏、分享!

相关推荐
over6973 小时前
从 LLM 到全栈 Agent:MCP 协议 × RAG 技术如何重构 AI 的“做事能力”
面试·llm·mcp
UIUV4 小时前
RAG技术学习笔记(含实操解析)
javascript·langchain·llm
神秘的猪头10 小时前
🚀 拒绝“一本正经胡说八道”!手把手带你用 LangChain 实现 RAG,打造你的专属 AI 知识库
langchain·llm·openai
栀秋66610 小时前
重塑 AI 交互边界:基于 LangChain 与 MCP 协议的全栈实践
langchain·llm·mcp
EdisonZhou1 天前
MAF快速入门(18)Agent Skill 快速开始
llm·aigc·agent
会写代码的柯基犬1 天前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
神秘的猪头1 天前
🔌 给 AI 装上“三头六臂”!实战大模型接入第三方 MCP 全攻略
langchain·llm·mcp
神秘的猪头2 天前
🔌 把 MCP 装进大脑!手把手带你构建能“热插拔”工具的 AI Agent
langchain·llm·mcp
智泊AI2 天前
一文讲清:Agent、Workflow、MCP的区别是啥?
llm