AI开发之LangGraph教程4~记忆 (Memory)

一、为什么需要记忆?

想象一下,你和一个朋友聊天:

复制代码
你: 我叫小明
朋友: 你好小明!
你: 我叫什么名字?
朋友: 我不知道...  😅

这就尴尬了!如果没有记忆,每次对话都是"失忆"状态,AI 根本不记得刚才说过什么。

LangGraph 的记忆功能就是解决这个问题的------让图能够"记住"之前的状态,实现多轮对话、累计计算等场景。


二、两种模式对比

2.1 无记忆模式(失忆症患者)

复制代码
第一次调用: count = 0 → 1
第二次调用: count = 0 → 1  (又从 0 开始了!)
第三次调用: count = 0 → 1  (还是从 0 开始...)

就像每次见面都重新认识一样,累!

2.2 有记忆模式(正常人)

复制代码
第一次调用: count = 0 → 1
第二次调用: count = 1 → 2  (记住了上次的结果!)
第三次调用: count = 2 → 3  (继续累计!)

这才是我们想要的效果!


三、无记忆示例

先来看看"没有记忆"是什么情况。

3.1 代码实现

复制代码
from typing import TypedDict
from langgraph.graph import StateGraph, START, END

# 定义状态
class State(TypedDict):
    count: int      # 调用次数
    total: int      # 累计值
    history: list   # 历史记录

# 定义节点
def counter_node(state: State) -> dict:
    new_count = state['count'] + 1
    new_total = state['total'] + 10
    new_history = state['history'] + [f"第{new_count}次调用"]

    return {
        "count": new_count,
        "total": new_total,
        "history": new_history
    }

# 构建图
graph_builder = StateGraph(State)
graph_builder.add_node("counter_node", counter_node)
graph_builder.add_edge(START, "counter_node")
graph_builder.add_edge("counter_node", END)

# 编译图 - 注意:没有 checkpointer 参数!
graph = graph_builder.compile()

3.2 执行结果

复制代码
# 初始状态
initial_state = {"count": 0, "total": 0, "history": []}

# 第一次调用
result1 = graph.invoke(initial_state)
# 结果: count=1, total=10

# 第二次调用 - 传入相同的初始状态
result2 = graph.invoke(initial_state)
# 结果: count=1, total=10  ← 还是 1,没有累计!

# 第三次调用 - 手动传入上一次的结果
result3 = graph.invoke(result2)
# 结果: count=2, total=20  ← 必须手动传才能累计

3.3 问题分析

复制代码
┌─────────────────────────────────────────────────────────────┐
│                    无记忆模式的问题                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  调用1: initial_state ──→ 节点 ──→ result1 (count=1)       │
│                         ↓                                   │
│                      丢失了!                                │
│                                                             │
│  调用2: initial_state ──→ 节点 ──→ result2 (count=1)       │
│                         ↓                                   │
│                      又丢失了!                              │
│                                                             │
│  调用3: result2 ────────→ 节点 ──→ result3 (count=2)       │
│         ↑                                                   │
│      手动传入才能继续                                         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

简单说:每次调用都是"失忆"的,你得手动把上次的结果喂给它,它才能继续。


四、有记忆示例

现在来看看"有记忆"是怎么实现的。

4.1 核心代码

复制代码
from langgraph.checkpoint.memory import MemorySaver

# 创建记忆组件
memory = MemorySaver()

# 编译图 - 添加 checkpointer 参数
graph = graph_builder.compile(checkpointer=memory)

就这么简单!加上 checkpointer=memory,图就有了记忆能力。

4.2 完整实现

复制代码
from typing import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver

# 定义状态
class State(TypedDict):
    count: int
    total: int
    history: list

# 定义节点
def counter_node(state: State) -> dict:
    new_count = state['count'] + 1
    new_total = state['total'] + 10
    new_history = state['history'] + [f"第{new_count}次调用"]

    return {
        "count": new_count,
        "total": new_total,
        "history": new_history
    }

# 构建图
graph_builder = StateGraph(State)
graph_builder.add_node("counter_node", counter_node)
graph_builder.add_edge(START, "counter_node")
graph_builder.add_edge("counter_node", END)

# ⭐ 关键:创建记忆组件
memory = MemorySaver()

# ⭐ 关键:编译时添加 checkpointer
graph = graph_builder.compile(checkpointer=memory)

4.3 执行方式

复制代码
# 配置:使用 thread_id 区分会话
config = {"configurable": {"thread_id": "session_001"}}

# 初始状态
initial_state = {"count": 0, "total": 0, "history": []}

# 第一次调用
result1 = graph.invoke(initial_state, config)
# 结果: count=1, total=10

# 第二次调用 - 传入空字典就行!
result2 = graph.invoke({}, config)
# 结果: count=2, total=20  ← 自动从上次继续!

# 第三次调用
result3 = graph.invoke({}, config)
# 结果: count=3, total=30  ← 继续累计!

4.4 执行流程图

复制代码
┌─────────────────────────────────────────────────────────────┐
│                    有记忆模式的流程                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  调用1: initial_state ──→ 节点 ──→ result1 (count=1)       │
│                                   │                         │
│                                   ▼                         │
│                            ┌───────────┐                    │
│                            │  Memory   │                    │
│                            │  Saver    │ 保存状态           │
│                            └─────┬─────┘                    │
│                                  │                          │
│  调用2: {} ──────────────────────┘                          │
│         ↑                                                    │
│      传入空字典,系统自动从 Memory 恢复                       │
│         │                                                    │
│         └───────→ 节点 ──→ result2 (count=2)                │
│                                   │                         │
│                                   ▼                         │
│                            ┌───────────┐                    │
│                            │  Memory   │ 更新状态           │
│                            │  Saver    │                    │
│                            └─────┬─────┘                    │
│                                  │                          │
│  调用3: {} ──────────────────────┘                          │
│         │                                                    │
│         └───────→ 节点 ──→ result3 (count=3)                │
│                                                             │
└─────────────────────────────────────────────────────────────┘

简单说:MemorySaver 就像一个"存档点",每次执行完自动存档,下次调用自动读档继续。


五、thread_id:区分不同会话

5.1 什么是 thread_id?

想象你在玩一个游戏,可以创建多个存档:

  • 存档1:小明玩到第3关
  • 存档2:小红玩到第5关
  • 存档3:小刚刚开始玩

thread_id 就是这个"存档槽"的概念,不同的 thread_id 对应不同的会话状态。

5.2 代码示例

复制代码
# 会话1:小明的会话
config_ming = {"configurable": {"thread_id": "xiaoming"}}
result = graph.invoke({"count": 0, ...}, config_ming)
# 小明: count=1

result = graph.invoke({}, config_ming)
# 小明: count=2

# 会话2:小红的会话
config_hong = {"configurable": {"thread_id": "xiaohong"}}
result = graph.invoke({"count": 0, ...}, config_hong)
# 小红: count=1(新会话,从头开始)

# 切回小明的会话
result = graph.invoke({}, config_ming)
# 小明: count=3(继续之前的进度)

5.3 会话隔离示意图

复制代码
┌─────────────────────────────────────────────────────────────┐
│                    MemorySaver 内部结构                     │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  thread_id: "xiaoming"                                      │
│  ├── count: 3                                               │
│  ├── total: 30                                              │
│  └── history: [第1次, 第2次, 第3次]                          │
│                                                             │
│  thread_id: "xiaohong"                                      │
│  ├── count: 1                                               │
│  ├── total: 10                                              │
│  └── history: [第1次]                                        │
│                                                             │
│  thread_id: "xiaogang"                                      │
│  ├── count: 0                                               │
│  └── ...                                                    │
│                                                             │
└─────────────────────────────────────────────────────────────┘

六、get_state():查看当前状态

有时候我们需要看看当前保存的状态是什么:

复制代码
# 查看当前状态
snapshot = graph.get_state(config)

print(snapshot.values)
# 输出: {'count': 3, 'total': 30, 'history': [...]}

print(snapshot.next)
# 输出: ()  空元组表示执行完毕

6.1 实际应用场景

复制代码
# 场景:检查用户当前进度
state = graph.get_state(config)

if state.values['count'] > 10:
    print("您已经是老用户了!")
else:
    print("欢迎新用户!")

七、对比总结

7.1 代码对比

特性 无记忆 有记忆
导入 无需额外导入 from langgraph.checkpoint.memory import MemorySaver
创建 graph = builder.compile() memory = MemorySaver() graph = builder.compile(checkpointer=memory)
调用 graph.invoke(state) graph.invoke(state, config)
状态保持 ❌ 不保持 ✅ 自动保持
会话隔离 ❌ 无 ✅ 通过 thread_id

7.2 使用场景对比

场景 无记忆 有记忆
单次任务 ✅ 适合 可用
多轮对话 ❌ 不适合 ✅ 适合
累计计算 ❌ 需手动传 ✅ 自动累计
聊天机器人 ❌ 不适合 ✅ 适合
审批流程 ❌ 不适合 ✅ 必须

7.3 一句话总结

  • 无记忆:每次都是"新的一天",适合单次任务
  • 有记忆:像正常聊天,记得住之前说过什么

八、常见问题

Q1: MemorySaver 会持久化到磁盘吗?

不会MemorySaver 是内存存储,程序重启后数据就丢失了。

如果需要持久化,可以使用:

  • SqliteSaver:SQLite 数据库

  • PostgresSaver:PostgreSQL 数据库

    from langgraph.checkpoint.sqlite import SqliteSaver

    持久化到 SQLite 数据库

    memory = SqliteSaver.from_conn_string("checkpoints.db")
    graph = builder.compile(checkpointer=memory)

Q2: thread_id 可以随便设置吗?

是的thread_id 只是一个标识符,可以是任意字符串:

  • 用户ID:"user_12345"
  • 会话ID:"session_abc"
  • 任意字符串:"my_chat_session"

Q3: 传入空字典 {} 是什么意思?

复制代码
graph.invoke({}, config)

传入空字典表示"不提供新的输入,从保存的状态继续"。系统会:

  1. 从 MemorySaver 中读取 thread_id 对应的状态
  2. 用这个状态作为输入执行图

九、完整代码

9.1 无记忆示例

python 复制代码
"""
LangGraph 无记忆示例 - 演示没有 MemorySaver 时无法记住状态

本示例演示:
- 没有记忆组件时,每次调用 graph.invoke() 都是独立的
- 上一轮的状态不会被保留
- 每次都从初始状态开始

图结构: START -> counter_node -> END
"""

from typing import TypedDict
from langgraph.graph import StateGraph, START, END


# ==================== 状态定义 ====================
class State(TypedDict):
    """图的状态结构"""
    count: int          # 调用次数
    total: int          # 累计值
    history: list       # 历史记录


# ==================== 节点函数 ====================
def counter_node(state: State) -> dict:
    """计数器节点:递增计数并记录历史"""
    new_count = state['count'] + 1
    new_total = state['total'] + 10
    new_history = state['history'] + [f"第{new_count}次调用,累计值={new_total}"]
    
    print(f"\n  📊 节点执行中...")
    print(f"     接收到的 count: {state['count']}")
    print(f"     修改后的 count: {new_count}")
    print(f"     累计值 total: {new_total}")
    
    return {
        "count": new_count,
        "total": new_total,
        "history": new_history
    }


# ==================== 构建图(无记忆) ====================
graph_builder = StateGraph(State)
graph_builder.add_node("counter_node", counter_node)
graph_builder.add_edge(START, "counter_node")
graph_builder.add_edge("counter_node", END)

# 编译图 - 注意:没有 checkpointer 参数
graph = graph_builder.compile()


# ==================== 主程序 ====================
if __name__ == "__main__":
    print("=" * 70)
    print("📌 示例:无记忆组件 (No MemorySaver)")
    print("=" * 70)
    print("\n图结构: START -> counter_node -> END")
    print("\n说明: 每次调用都是独立的,不会记住之前的状态")
    print("=" * 70)
    
    # 初始状态
    initial_state = {
        "count": 0,
        "total": 0,
        "history": []
    }
    
    # 第一次调用
    print("\n" + "-" * 70)
    print("🔄 第一次调用 graph.invoke()")
    print("-" * 70)
    result1 = graph.invoke(initial_state)
    print(f"\n✅ 第一次调用结果:")
    print(f"   count: {result1['count']}")
    print(f"   total: {result1['total']}")
    print(f"   history: {result1['history']}")
    
    # 第二次调用 - 使用相同的初始状态
    print("\n" + "-" * 70)
    print("🔄 第二次调用 graph.invoke()")
    print("-" * 70)
    result2 = graph.invoke(initial_state)  # 注意:传入的是相同的初始状态
    print(f"\n✅ 第二次调用结果:")
    print(f"   count: {result2['count']}")
    print(f"   total: {result2['total']}")
    print(f"   history: {result2['history']}")
    
    # 第三次调用 - 尝试传入上一次的结果
    print("\n" + "-" * 70)
    print("🔄 第三次调用 graph.invoke()")
    print("   (尝试传入上一次的结果)")
    print("-" * 70)
    result3 = graph.invoke(result2)  # 手动传入上一次的结果
    print(f"\n✅ 第三次调用结果:")
    print(f"   count: {result3['count']}")
    print(f"   total: {result3['total']}")
    print(f"   history: {result3['history']}")
    
    # 总结
    print("\n" + "=" * 70)
    print("📊 总结")
    print("=" * 70)
    print("""
❌ 没有记忆组件的问题:
   - 每次调用 graph.invoke() 都是独立的
   - 必须手动传入上一次的状态才能"记住"
   - 无法自动保持对话历史或累计数据

💡 解决方案:
   - 使用 MemorySaver 作为 checkpointer
   - 使用 thread_id 来区分不同的会话
""")

9.2 有记忆示例

python 复制代码
"""
LangGraph 有记忆示例 - 演示使用 MemorySaver 记住状态

本示例演示:
- 使用 MemorySaver 作为 checkpointer
- 通过 thread_id 区分不同会话
- 自动记住上一轮的状态
- 使用 get_state() 查看当前状态
- 使用 update_state() 更新状态

图结构: START -> counter_node -> END
"""

from typing import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver


# ==================== 状态定义 ====================
class State(TypedDict):
    """图的状态结构"""
    count: int          # 调用次数
    total: int          # 累计值
    history: list       # 历史记录


# ==================== 节点函数 ====================
def counter_node(state: State) -> dict:
    """计数器节点:递增计数并记录历史"""
    new_count = state['count'] + 1
    new_total = state['total'] + 10
    new_history = state['history'] + [f"第{new_count}次调用,累计值={new_total}"]
    
    print(f"\n  📊 节点执行中...")
    print(f"     接收到的 count: {state['count']}")
    print(f"     修改后的 count: {new_count}")
    print(f"     累计值 total: {new_total}")
    
    return {
        "count": new_count,
        "total": new_total,
        "history": new_history
    }


# ==================== 构建图(有记忆) ====================
graph_builder = StateGraph(State)
graph_builder.add_node("counter_node", counter_node)
graph_builder.add_edge(START, "counter_node")
graph_builder.add_edge("counter_node", END)

# 创建记忆组件
memory = MemorySaver()

# 编译图 - 添加 checkpointer 参数
graph = graph_builder.compile(checkpointer=memory)


# ==================== 主程序 ====================
if __name__ == "__main__":
    print("=" * 70)
    print("📌 示例:有记忆组件 (With MemorySaver)")
    print("=" * 70)
    print("\n图结构: START -> counter_node -> END")
    print("\n说明: 使用 MemorySaver 自动记住状态")
    print("=" * 70)
    
    # 配置:使用 thread_id 区分会话
    config = {"configurable": {"thread_id": "session_001"}}
    
    # 初始状态
    initial_state = {
        "count": 0,
        "total": 0,
        "history": []
    }
    
    # ==================== 第一次调用 ====================
    print("\n" + "-" * 70)
    print("🔄 第一次调用 graph.invoke()")
    print(f"   thread_id: {config['configurable']['thread_id']}")
    print("-" * 70)
    result1 = graph.invoke(initial_state, config)
    print(f"\n✅ 第一次调用结果:")
    print(f"   count: {result1['count']}")
    print(f"   total: {result1['total']}")
    print(f"   history: {result1['history']}")
    
    # 查看当前保存的状态
    saved_state = graph.get_state(config)
    print(f"\n💾 当前保存的状态:")
    print(f"   {saved_state.values}")
    
    # ==================== 第二次调用 ====================
    print("\n" + "-" * 70)
    print("🔄 第二次调用 graph.invoke()")
    print("   传入空字典 {},系统会自动从保存的状态恢复")
    print(f"   thread_id: {config['configurable']['thread_id']}")
    print("-" * 70)
    
    # 传入空字典,系统会从保存的状态恢复并继续执行
    result2 = graph.invoke({}, config)
    print(f"\n✅ 第二次调用结果:")
    print(f"   count: {result2['count']} (从 1 增加到 2)")
    print(f"   total: {result2['total']} (从 10 增加到 20)")
    print(f"   history: {result2['history']}")
    
    # 查看当前保存的状态
    saved_state = graph.get_state(config)
    print(f"\n💾 当前保存的状态:")
    print(f"   {saved_state.values}")
    
    # ==================== 第三次调用 ====================
    print("\n" + "-" * 70)
    print("🔄 第三次调用 graph.invoke()")
    print(f"   thread_id: {config['configurable']['thread_id']}")
    print("-" * 70)
    result3 = graph.invoke({}, config)
    print(f"\n✅ 第三次调用结果:")
    print(f"   count: {result3['count']} (从 2 增加到 3)")
    print(f"   total: {result3['total']} (从 20 增加到 30)")
    print(f"   history: {result3['history']}")
    
    # ==================== 演示不同 thread_id ====================
    print("\n" + "=" * 70)
    print("🔄 演示不同的 thread_id(新会话)")
    print("=" * 70)
    
    config_new = {"configurable": {"thread_id": "session_002"}}
    print(f"\n使用新的 thread_id: {config_new['configurable']['thread_id']}")
    print("这是一个全新的会话,状态从头开始")
    
    result_new = graph.invoke(initial_state, config_new)
    print(f"\n✅ 新会话结果:")
    print(f"   count: {result_new['count']} (从头开始,不是 3)")
    print(f"   total: {result_new['total']}")
    print(f"   history: {result_new['history']}")
    
    # ==================== 演示 get_state() ====================
    print("\n" + "=" * 70)
    print("🔍 演示 get_state() - 查看两个会话的状态")
    print("=" * 70)
    
    state_001 = graph.get_state(config)
    state_002 = graph.get_state(config_new)
    
    print(f"\n会话 session_001 的状态:")
    print(f"   count: {state_001.values['count']}")
    print(f"   total: {state_001.values['total']}")
    
    print(f"\n会话 session_002 的状态:")
    print(f"   count: {state_002.values['count']}")
    print(f"   total: {state_002.values['total']}")
    
    # ==================== 总结 ====================
    print("\n" + "=" * 70)
    print("📊 总结")
    print("=" * 70)
    print("""
✅ 有记忆组件的优势:
   - 自动记住上一轮的状态
   - 每次调用都会从保存的状态继续
   - 通过 thread_id 区分不同会话
   - 可以使用 get_state() 查看当前状态

📝 使用方法:
   1. 创建 MemorySaver: memory = MemorySaver()
   2. 编译时添加: graph.compile(checkpointer=memory)
   3. 调用时指定 thread_id: config = {"configurable": {"thread_id": "xxx"}}
   4. 后续调用: graph.invoke({}, config)  # 自动恢复并继续
   5. 查看状态: graph.get_state(config)
""")

十、延伸阅读


附录:核心 API 速查

复制代码
# 1. 创建记忆组件
from langgraph.checkpoint.memory import MemorySaver
memory = MemorySaver()

# 2. 编译时添加
graph = builder.compile(checkpointer=memory)

# 3. 调用时指定 thread_id
config = {"configurable": {"thread_id": "xxx"}}
result = graph.invoke(state, config)

# 4. 后续调用(自动恢复)
result = graph.invoke({}, config)

# 5. 查看当前状态
snapshot = graph.get_state(config)
print(snapshot.values)
相关推荐
2301_812539671 小时前
Tailwind CSS如何设置不同断点的内边距_使用p-4 md-p-8类.txt
jvm·数据库·python
m0_596749091 小时前
CSS实现动态悬浮菜单位置_JS计算配合CSS绝对定位
jvm·数据库·python
2301_812539671 小时前
golang如何实现最小堆定时器_golang最小堆定时器实现总结
jvm·数据库·python
lyc87801 小时前
【Qwen3.5-2B-Base】本地模型部署和验证联动千帆api
大数据·python
m0_690825821 小时前
检测三位随机数中重复数字的Python实现方法
jvm·数据库·python
谙弆悕博士1 小时前
GPT-5.5 Instant 免费开放背后的技术跃迁与战略阳谋
人工智能·python·gpt·chatgpt·学习方法·业界资讯
WL_Aurora1 小时前
备战蓝桥杯国赛【Day 6】
python·算法·蓝桥杯
阿正呀1 小时前
Redis如何处理数据持久化与主从切换的冲突_确保选主期间的数据安全落盘.txt
jvm·数据库·python
AI精钢1 小时前
把 Markdown 笔记变成可问答的知识图谱:本地 Graph RAG 工具 Kwipu 实测
人工智能·笔记·python·aigc·知识图谱