一、为什么需要记忆?
想象一下,你和一个朋友聊天:
你: 我叫小明
朋友: 你好小明!
你: 我叫什么名字?
朋友: 我不知道... 😅
这就尴尬了!如果没有记忆,每次对话都是"失忆"状态,AI 根本不记得刚才说过什么。
LangGraph 的记忆功能就是解决这个问题的------让图能够"记住"之前的状态,实现多轮对话、累计计算等场景。
二、两种模式对比
2.1 无记忆模式(失忆症患者)
第一次调用: count = 0 → 1
第二次调用: count = 0 → 1 (又从 0 开始了!)
第三次调用: count = 0 → 1 (还是从 0 开始...)
就像每次见面都重新认识一样,累!
2.2 有记忆模式(正常人)
第一次调用: count = 0 → 1
第二次调用: count = 1 → 2 (记住了上次的结果!)
第三次调用: count = 2 → 3 (继续累计!)
这才是我们想要的效果!
三、无记忆示例
先来看看"没有记忆"是什么情况。
3.1 代码实现
from typing import TypedDict
from langgraph.graph import StateGraph, START, END
# 定义状态
class State(TypedDict):
count: int # 调用次数
total: int # 累计值
history: list # 历史记录
# 定义节点
def counter_node(state: State) -> dict:
new_count = state['count'] + 1
new_total = state['total'] + 10
new_history = state['history'] + [f"第{new_count}次调用"]
return {
"count": new_count,
"total": new_total,
"history": new_history
}
# 构建图
graph_builder = StateGraph(State)
graph_builder.add_node("counter_node", counter_node)
graph_builder.add_edge(START, "counter_node")
graph_builder.add_edge("counter_node", END)
# 编译图 - 注意:没有 checkpointer 参数!
graph = graph_builder.compile()
3.2 执行结果
# 初始状态
initial_state = {"count": 0, "total": 0, "history": []}
# 第一次调用
result1 = graph.invoke(initial_state)
# 结果: count=1, total=10
# 第二次调用 - 传入相同的初始状态
result2 = graph.invoke(initial_state)
# 结果: count=1, total=10 ← 还是 1,没有累计!
# 第三次调用 - 手动传入上一次的结果
result3 = graph.invoke(result2)
# 结果: count=2, total=20 ← 必须手动传才能累计
3.3 问题分析
┌─────────────────────────────────────────────────────────────┐
│ 无记忆模式的问题 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 调用1: initial_state ──→ 节点 ──→ result1 (count=1) │
│ ↓ │
│ 丢失了! │
│ │
│ 调用2: initial_state ──→ 节点 ──→ result2 (count=1) │
│ ↓ │
│ 又丢失了! │
│ │
│ 调用3: result2 ────────→ 节点 ──→ result3 (count=2) │
│ ↑ │
│ 手动传入才能继续 │
│ │
└─────────────────────────────────────────────────────────────┘
简单说:每次调用都是"失忆"的,你得手动把上次的结果喂给它,它才能继续。
四、有记忆示例
现在来看看"有记忆"是怎么实现的。
4.1 核心代码
from langgraph.checkpoint.memory import MemorySaver
# 创建记忆组件
memory = MemorySaver()
# 编译图 - 添加 checkpointer 参数
graph = graph_builder.compile(checkpointer=memory)
就这么简单!加上 checkpointer=memory,图就有了记忆能力。
4.2 完整实现
from typing import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
# 定义状态
class State(TypedDict):
count: int
total: int
history: list
# 定义节点
def counter_node(state: State) -> dict:
new_count = state['count'] + 1
new_total = state['total'] + 10
new_history = state['history'] + [f"第{new_count}次调用"]
return {
"count": new_count,
"total": new_total,
"history": new_history
}
# 构建图
graph_builder = StateGraph(State)
graph_builder.add_node("counter_node", counter_node)
graph_builder.add_edge(START, "counter_node")
graph_builder.add_edge("counter_node", END)
# ⭐ 关键:创建记忆组件
memory = MemorySaver()
# ⭐ 关键:编译时添加 checkpointer
graph = graph_builder.compile(checkpointer=memory)
4.3 执行方式
# 配置:使用 thread_id 区分会话
config = {"configurable": {"thread_id": "session_001"}}
# 初始状态
initial_state = {"count": 0, "total": 0, "history": []}
# 第一次调用
result1 = graph.invoke(initial_state, config)
# 结果: count=1, total=10
# 第二次调用 - 传入空字典就行!
result2 = graph.invoke({}, config)
# 结果: count=2, total=20 ← 自动从上次继续!
# 第三次调用
result3 = graph.invoke({}, config)
# 结果: count=3, total=30 ← 继续累计!
4.4 执行流程图
┌─────────────────────────────────────────────────────────────┐
│ 有记忆模式的流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 调用1: initial_state ──→ 节点 ──→ result1 (count=1) │
│ │ │
│ ▼ │
│ ┌───────────┐ │
│ │ Memory │ │
│ │ Saver │ 保存状态 │
│ └─────┬─────┘ │
│ │ │
│ 调用2: {} ──────────────────────┘ │
│ ↑ │
│ 传入空字典,系统自动从 Memory 恢复 │
│ │ │
│ └───────→ 节点 ──→ result2 (count=2) │
│ │ │
│ ▼ │
│ ┌───────────┐ │
│ │ Memory │ 更新状态 │
│ │ Saver │ │
│ └─────┬─────┘ │
│ │ │
│ 调用3: {} ──────────────────────┘ │
│ │ │
│ └───────→ 节点 ──→ result3 (count=3) │
│ │
└─────────────────────────────────────────────────────────────┘
简单说:MemorySaver 就像一个"存档点",每次执行完自动存档,下次调用自动读档继续。
五、thread_id:区分不同会话
5.1 什么是 thread_id?
想象你在玩一个游戏,可以创建多个存档:
- 存档1:小明玩到第3关
- 存档2:小红玩到第5关
- 存档3:小刚刚开始玩
thread_id 就是这个"存档槽"的概念,不同的 thread_id 对应不同的会话状态。
5.2 代码示例
# 会话1:小明的会话
config_ming = {"configurable": {"thread_id": "xiaoming"}}
result = graph.invoke({"count": 0, ...}, config_ming)
# 小明: count=1
result = graph.invoke({}, config_ming)
# 小明: count=2
# 会话2:小红的会话
config_hong = {"configurable": {"thread_id": "xiaohong"}}
result = graph.invoke({"count": 0, ...}, config_hong)
# 小红: count=1(新会话,从头开始)
# 切回小明的会话
result = graph.invoke({}, config_ming)
# 小明: count=3(继续之前的进度)
5.3 会话隔离示意图
┌─────────────────────────────────────────────────────────────┐
│ MemorySaver 内部结构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ thread_id: "xiaoming" │
│ ├── count: 3 │
│ ├── total: 30 │
│ └── history: [第1次, 第2次, 第3次] │
│ │
│ thread_id: "xiaohong" │
│ ├── count: 1 │
│ ├── total: 10 │
│ └── history: [第1次] │
│ │
│ thread_id: "xiaogang" │
│ ├── count: 0 │
│ └── ... │
│ │
└─────────────────────────────────────────────────────────────┘
六、get_state():查看当前状态
有时候我们需要看看当前保存的状态是什么:
# 查看当前状态
snapshot = graph.get_state(config)
print(snapshot.values)
# 输出: {'count': 3, 'total': 30, 'history': [...]}
print(snapshot.next)
# 输出: () 空元组表示执行完毕
6.1 实际应用场景
# 场景:检查用户当前进度
state = graph.get_state(config)
if state.values['count'] > 10:
print("您已经是老用户了!")
else:
print("欢迎新用户!")
七、对比总结
7.1 代码对比
| 特性 | 无记忆 | 有记忆 |
|---|---|---|
| 导入 | 无需额外导入 | from langgraph.checkpoint.memory import MemorySaver |
| 创建 | graph = builder.compile() |
memory = MemorySaver() graph = builder.compile(checkpointer=memory) |
| 调用 | graph.invoke(state) |
graph.invoke(state, config) |
| 状态保持 | ❌ 不保持 | ✅ 自动保持 |
| 会话隔离 | ❌ 无 | ✅ 通过 thread_id |
7.2 使用场景对比
| 场景 | 无记忆 | 有记忆 |
|---|---|---|
| 单次任务 | ✅ 适合 | 可用 |
| 多轮对话 | ❌ 不适合 | ✅ 适合 |
| 累计计算 | ❌ 需手动传 | ✅ 自动累计 |
| 聊天机器人 | ❌ 不适合 | ✅ 适合 |
| 审批流程 | ❌ 不适合 | ✅ 必须 |
7.3 一句话总结
- 无记忆:每次都是"新的一天",适合单次任务
- 有记忆:像正常聊天,记得住之前说过什么
八、常见问题
Q1: MemorySaver 会持久化到磁盘吗?
不会 。MemorySaver 是内存存储,程序重启后数据就丢失了。
如果需要持久化,可以使用:
-
SqliteSaver:SQLite 数据库 -
PostgresSaver:PostgreSQL 数据库from langgraph.checkpoint.sqlite import SqliteSaver
持久化到 SQLite 数据库
memory = SqliteSaver.from_conn_string("checkpoints.db")
graph = builder.compile(checkpointer=memory)
Q2: thread_id 可以随便设置吗?
是的 。thread_id 只是一个标识符,可以是任意字符串:
- 用户ID:
"user_12345" - 会话ID:
"session_abc" - 任意字符串:
"my_chat_session"
Q3: 传入空字典 {} 是什么意思?
graph.invoke({}, config)
传入空字典表示"不提供新的输入,从保存的状态继续"。系统会:
- 从 MemorySaver 中读取
thread_id对应的状态 - 用这个状态作为输入执行图
九、完整代码
9.1 无记忆示例
python
"""
LangGraph 无记忆示例 - 演示没有 MemorySaver 时无法记住状态
本示例演示:
- 没有记忆组件时,每次调用 graph.invoke() 都是独立的
- 上一轮的状态不会被保留
- 每次都从初始状态开始
图结构: START -> counter_node -> END
"""
from typing import TypedDict
from langgraph.graph import StateGraph, START, END
# ==================== 状态定义 ====================
class State(TypedDict):
"""图的状态结构"""
count: int # 调用次数
total: int # 累计值
history: list # 历史记录
# ==================== 节点函数 ====================
def counter_node(state: State) -> dict:
"""计数器节点:递增计数并记录历史"""
new_count = state['count'] + 1
new_total = state['total'] + 10
new_history = state['history'] + [f"第{new_count}次调用,累计值={new_total}"]
print(f"\n 📊 节点执行中...")
print(f" 接收到的 count: {state['count']}")
print(f" 修改后的 count: {new_count}")
print(f" 累计值 total: {new_total}")
return {
"count": new_count,
"total": new_total,
"history": new_history
}
# ==================== 构建图(无记忆) ====================
graph_builder = StateGraph(State)
graph_builder.add_node("counter_node", counter_node)
graph_builder.add_edge(START, "counter_node")
graph_builder.add_edge("counter_node", END)
# 编译图 - 注意:没有 checkpointer 参数
graph = graph_builder.compile()
# ==================== 主程序 ====================
if __name__ == "__main__":
print("=" * 70)
print("📌 示例:无记忆组件 (No MemorySaver)")
print("=" * 70)
print("\n图结构: START -> counter_node -> END")
print("\n说明: 每次调用都是独立的,不会记住之前的状态")
print("=" * 70)
# 初始状态
initial_state = {
"count": 0,
"total": 0,
"history": []
}
# 第一次调用
print("\n" + "-" * 70)
print("🔄 第一次调用 graph.invoke()")
print("-" * 70)
result1 = graph.invoke(initial_state)
print(f"\n✅ 第一次调用结果:")
print(f" count: {result1['count']}")
print(f" total: {result1['total']}")
print(f" history: {result1['history']}")
# 第二次调用 - 使用相同的初始状态
print("\n" + "-" * 70)
print("🔄 第二次调用 graph.invoke()")
print("-" * 70)
result2 = graph.invoke(initial_state) # 注意:传入的是相同的初始状态
print(f"\n✅ 第二次调用结果:")
print(f" count: {result2['count']}")
print(f" total: {result2['total']}")
print(f" history: {result2['history']}")
# 第三次调用 - 尝试传入上一次的结果
print("\n" + "-" * 70)
print("🔄 第三次调用 graph.invoke()")
print(" (尝试传入上一次的结果)")
print("-" * 70)
result3 = graph.invoke(result2) # 手动传入上一次的结果
print(f"\n✅ 第三次调用结果:")
print(f" count: {result3['count']}")
print(f" total: {result3['total']}")
print(f" history: {result3['history']}")
# 总结
print("\n" + "=" * 70)
print("📊 总结")
print("=" * 70)
print("""
❌ 没有记忆组件的问题:
- 每次调用 graph.invoke() 都是独立的
- 必须手动传入上一次的状态才能"记住"
- 无法自动保持对话历史或累计数据
💡 解决方案:
- 使用 MemorySaver 作为 checkpointer
- 使用 thread_id 来区分不同的会话
""")
9.2 有记忆示例
python
"""
LangGraph 有记忆示例 - 演示使用 MemorySaver 记住状态
本示例演示:
- 使用 MemorySaver 作为 checkpointer
- 通过 thread_id 区分不同会话
- 自动记住上一轮的状态
- 使用 get_state() 查看当前状态
- 使用 update_state() 更新状态
图结构: START -> counter_node -> END
"""
from typing import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
# ==================== 状态定义 ====================
class State(TypedDict):
"""图的状态结构"""
count: int # 调用次数
total: int # 累计值
history: list # 历史记录
# ==================== 节点函数 ====================
def counter_node(state: State) -> dict:
"""计数器节点:递增计数并记录历史"""
new_count = state['count'] + 1
new_total = state['total'] + 10
new_history = state['history'] + [f"第{new_count}次调用,累计值={new_total}"]
print(f"\n 📊 节点执行中...")
print(f" 接收到的 count: {state['count']}")
print(f" 修改后的 count: {new_count}")
print(f" 累计值 total: {new_total}")
return {
"count": new_count,
"total": new_total,
"history": new_history
}
# ==================== 构建图(有记忆) ====================
graph_builder = StateGraph(State)
graph_builder.add_node("counter_node", counter_node)
graph_builder.add_edge(START, "counter_node")
graph_builder.add_edge("counter_node", END)
# 创建记忆组件
memory = MemorySaver()
# 编译图 - 添加 checkpointer 参数
graph = graph_builder.compile(checkpointer=memory)
# ==================== 主程序 ====================
if __name__ == "__main__":
print("=" * 70)
print("📌 示例:有记忆组件 (With MemorySaver)")
print("=" * 70)
print("\n图结构: START -> counter_node -> END")
print("\n说明: 使用 MemorySaver 自动记住状态")
print("=" * 70)
# 配置:使用 thread_id 区分会话
config = {"configurable": {"thread_id": "session_001"}}
# 初始状态
initial_state = {
"count": 0,
"total": 0,
"history": []
}
# ==================== 第一次调用 ====================
print("\n" + "-" * 70)
print("🔄 第一次调用 graph.invoke()")
print(f" thread_id: {config['configurable']['thread_id']}")
print("-" * 70)
result1 = graph.invoke(initial_state, config)
print(f"\n✅ 第一次调用结果:")
print(f" count: {result1['count']}")
print(f" total: {result1['total']}")
print(f" history: {result1['history']}")
# 查看当前保存的状态
saved_state = graph.get_state(config)
print(f"\n💾 当前保存的状态:")
print(f" {saved_state.values}")
# ==================== 第二次调用 ====================
print("\n" + "-" * 70)
print("🔄 第二次调用 graph.invoke()")
print(" 传入空字典 {},系统会自动从保存的状态恢复")
print(f" thread_id: {config['configurable']['thread_id']}")
print("-" * 70)
# 传入空字典,系统会从保存的状态恢复并继续执行
result2 = graph.invoke({}, config)
print(f"\n✅ 第二次调用结果:")
print(f" count: {result2['count']} (从 1 增加到 2)")
print(f" total: {result2['total']} (从 10 增加到 20)")
print(f" history: {result2['history']}")
# 查看当前保存的状态
saved_state = graph.get_state(config)
print(f"\n💾 当前保存的状态:")
print(f" {saved_state.values}")
# ==================== 第三次调用 ====================
print("\n" + "-" * 70)
print("🔄 第三次调用 graph.invoke()")
print(f" thread_id: {config['configurable']['thread_id']}")
print("-" * 70)
result3 = graph.invoke({}, config)
print(f"\n✅ 第三次调用结果:")
print(f" count: {result3['count']} (从 2 增加到 3)")
print(f" total: {result3['total']} (从 20 增加到 30)")
print(f" history: {result3['history']}")
# ==================== 演示不同 thread_id ====================
print("\n" + "=" * 70)
print("🔄 演示不同的 thread_id(新会话)")
print("=" * 70)
config_new = {"configurable": {"thread_id": "session_002"}}
print(f"\n使用新的 thread_id: {config_new['configurable']['thread_id']}")
print("这是一个全新的会话,状态从头开始")
result_new = graph.invoke(initial_state, config_new)
print(f"\n✅ 新会话结果:")
print(f" count: {result_new['count']} (从头开始,不是 3)")
print(f" total: {result_new['total']}")
print(f" history: {result_new['history']}")
# ==================== 演示 get_state() ====================
print("\n" + "=" * 70)
print("🔍 演示 get_state() - 查看两个会话的状态")
print("=" * 70)
state_001 = graph.get_state(config)
state_002 = graph.get_state(config_new)
print(f"\n会话 session_001 的状态:")
print(f" count: {state_001.values['count']}")
print(f" total: {state_001.values['total']}")
print(f"\n会话 session_002 的状态:")
print(f" count: {state_002.values['count']}")
print(f" total: {state_002.values['total']}")
# ==================== 总结 ====================
print("\n" + "=" * 70)
print("📊 总结")
print("=" * 70)
print("""
✅ 有记忆组件的优势:
- 自动记住上一轮的状态
- 每次调用都会从保存的状态继续
- 通过 thread_id 区分不同会话
- 可以使用 get_state() 查看当前状态
📝 使用方法:
1. 创建 MemorySaver: memory = MemorySaver()
2. 编译时添加: graph.compile(checkpointer=memory)
3. 调用时指定 thread_id: config = {"configurable": {"thread_id": "xxx"}}
4. 后续调用: graph.invoke({}, config) # 自动恢复并继续
5. 查看状态: graph.get_state(config)
""")
十、延伸阅读
附录:核心 API 速查
# 1. 创建记忆组件
from langgraph.checkpoint.memory import MemorySaver
memory = MemorySaver()
# 2. 编译时添加
graph = builder.compile(checkpointer=memory)
# 3. 调用时指定 thread_id
config = {"configurable": {"thread_id": "xxx"}}
result = graph.invoke(state, config)
# 4. 后续调用(自动恢复)
result = graph.invoke({}, config)
# 5. 查看当前状态
snapshot = graph.get_state(config)
print(snapshot.values)