动手实现一个微型的状态图执行器,深入理解有状态图、检查点、中断恢复这些抽象是如何协同工作的。
阶段 预估耗时 建议 搭建主循环 ( invoke)45--60 min 直接实现"动态超步循环",无需静态编译执行计划 > 实现条件边与循环 15--20 min 条件边仅需在下一步收集时调用 condition并查表 >并行执行 15--20 min 使用 ThreadPoolExecutor.map,注意给每个节点传副本流式 ( stream)10--15 min 在节点完成后 yield state检查点与线程隔离 15--20 min 在开始时加载,结束时保存 中断恢复(最难点) 30--45 min 考虑在检查点中额外保存"剩余待执行节点列表" > 异步 ainvoke15--20 min 判断节点是否为 async def,用asyncio.gather并行调试与通过测试 20--30 min 建议先跑测试 1--3,再补充分支和并行 > 总计 3--4 小时
一. 实验目标
-
将复杂多步 AI 工作流建模为有向状态图(节点 = 纯函数,边 = 控制流),理解图结构如何自然表达分支、循环和并行。
-
实现一个基于**超步(superstep)**的动态图执行器,掌握节点调度、状态合并和条件路由的内部原理。
-
亲手构建内存检查点机制,并支持节点抛出中断、外部恢复执行------这正是 Human‑in‑the‑loop 工作流的根基。
-
实现
stream方法,在每个节点完成后立即产出状态快照,而不仅是最终结果。 -
统一同步与异步执行接口,体会图结构与执行策略的解耦。
二. 前置知识
-
Python 基础:类、生成器(
yield)、异常处理、asyncio初步。 -
基本图论概念(有向边、环路)。
-
了解
ThreadPoolExecutor的基本用法。 -
浏览 LangGraph 概念文档 中的 "StateGraph" 一节。
三. 项目脚手架
langgraph_core_lab/
├── src/
│ ├── graph.py # StateGraph 类与 CompiledGraph 类
│ ├── checkpoint.py # MemorySaver
│ └── exceptions.py # InterruptError
├── tests/
│ └── test_graph.py # 全部 8 个测试(已为你准备好)
└── requirements.txt # pytest, pytest-asyncio
你必须实现的接口(严禁修改签名):
src/graph.py
python
class StateGraph:
def __init__(self): ...
def add_node(self, name: str, func: Callable[[dict], dict]) -> None: ...
def add_edge(self, start: str, end: str) -> None: ...
def add_conditional_edges(self, start: str,
condition: Callable[[dict], str],
mapping: dict[str, str]) -> None: ...
def set_entry_point(self, name: str) -> None: ...
def compile(self, checkpointer=None) -> 'CompiledGraph': ...
class CompiledGraph:
def invoke(self, state: dict, config: dict = None) -> dict: ...
def stream(self, state: dict, config: dict = None) -> Iterator[dict]: ...
async def ainvoke(self, state: dict, config: dict = None) -> dict: ...
src/exceptions.py
python
class InterruptError(Exception):
"""节点抛出此异常以暂停图执行,等待外部干预。"""
pass
src/checkpoint.py
python
class MemorySaver:
def __init__(self):
# 内部存储:thread_id → 状态字典(需拷贝)
...
def put(self, config: dict, state: dict) -> None:
"""保存 thread_id 对应的最新状态。"""
...
def get(self, config: dict) -> Optional[dict]:
"""获取 thread_id 对应的已保存状态,若不存在则返回 None。"""
...
!请遵循以下约束:
只使用 Python 标准库与
concurrent.futures,不依赖任何其他框架。禁止使用
langgraph或langchain库。节点函数签名为
def node(state: dict) -> dict,返回部分状态更新 ;引擎负责合并(使用{**state, **update})。条件函数签名为
def condition(state: dict) -> str,返回下一个节点名(或"__end__"表示终止)。
config为可选字典,结构为{"configurable": {"thread_id": "..."}},用于检查点。
四. 核心设计指引
1. 图编译与动态超步执行
LangGraph 的执行模型基于 超步(superstep):每个超步包含一组可并行执行的节点,超步间严格串行。编译时你无需生成静态"执行计划",而应在运行时直接使用图结构动态调度。
你的主循环可以这样设计(伪代码):
current_nodes ← [入口节点]
state ← 初始状态(如果检查点中有则加载)
while current_nodes 非空:
并行执行 current_nodes 中的所有节点(每个节点接收 state 副本)
收集所有节点返回的更新,合并到 state
为每个节点确定下一个要去的节点:
· 普通边 → 固定的后继节点
· 条件边 → 调用 condition(state) 并查 mapping 得到后继
去重,得到下一批 current_nodes
如果某个节点的后继被标记为 "__end__",则不加入
返回 state
思考:为什么给每个节点传入 state 的副本很重要(尤其在并行时)?
2. 并行执行
当同一超步内有多个节点时,它们之间没有读写依赖,可使用concurrent.futures.ThreadPoolExecutor 并发执行。注意每个节点都应获得当前状态的一个独立副本,避免线程间互相干扰。在所有节点完成后,合并它们返回的更新字典。
3. 流式输出 (stream)
stream 方法是一个生成器。最简单且符合测试预期的策略是:每完成一个节点(或每完成一个超步)就 yield 一份累积状态的快照。如果你选择按超步 yield,注意测试要求产生 3 个中间状态(线性图共 3 个节点,按节点 yield 即可通过)。
4. 检查点与线程隔离
MemorySaver 按 thread_id 存储状态快照。在 invoke 开始时,尝试从检查点加载状态作为初始状态(若不存在则用传入的 state)。执行结束后,务必 将最终状态保存到检查点。不同 thread_id 的状态应完全隔离------这正是测试 6 要验证的。
5. 中断与恢复(最挑战部分)
当节点抛出 InterruptError 时,图必须立刻暂停,并保留足够的信息以便后续恢复。你需要决定在检查点中保存什么。提示:
-
仅保存状态足够吗?恢复时应该从哪个节点接着执行?(思考:可能需要保存尚未运行的节点列表)
-
恢复时,调用者会再次调用
invoke并传入额外的状态(如{"approved": True})。你的引擎应能合并这些新输入,并从断点继续,而不是从头开始。
实现此机制后,测试 7 的行为应为:
第一次 invoke({}, config) → 在 step1 中 raise InterruptError
第二次 invoke({"approved": True}, config) → 合并状态,继续执行 step1、step2,得到完整结果
6. 异步支持
ainvoke 需要处理节点可能是 async def 的情况。你可以用 asyncio.iscoroutinefunction 判断节点函数类型,并在超步内使用 asyncio.gather 并发执行异步节点。为简化,也可以复用同步循环,但需注意线程池与异步的兼容性。
7. 常见阻碍
如果卡住超过 15 分钟,请思考以下关键问题:
-
状态合并是否使用了
{**state, **node_result}? -
条件边映射中是否包含了
"__end__"特殊键? -
并行节点是否收到了同一个
dict引用?(必须传入副本) -
中断恢复时,引擎是重新从入口点开始,还是从保存的剩余节点列表继续?
五. 测试用例
请将以下代码放入 tests/test_graph.py。每个测试均附有目标 和通过标准,帮助你明确要验证的内核特性。
测试 1:线性图执行与状态传播
目标 :验证 invoke 能按图结构顺序执行节点,且最终状态是各节点输出的累积。
通过标准 :app.invoke({}) 返回 {"a": 1, "b": 2, "c": 3}。
python
def test_linear_graph():
from src.graph import StateGraph
def a(s): return {"a": 1}
def b(s): return {"b": 2}
def c(s): return {"c": 3}
g = StateGraph()
g.add_node("a", a); g.add_node("b", b); g.add_node("c", c)
g.set_entry_point("a")
g.add_edge("a", "b"); g.add_edge("b", "c")
app = g.compile()
assert app.invoke({}) == {"a": 1, "b": 2, "c": 3}
测试 2:条件分支
目标:验证图能根据运行时状态动态选择不同执行路径。
通过标准:
-
invoke({"value": 2})进入double节点,结果为{"value": 4, "path": "double"} -
invoke({"value": -1})进入inc节点,结果为{"value": 0, "path": "inc"}
python
def test_conditional_branching():
from src.graph import StateGraph
def start(s): return s
def double(s): return {"value": s["value"]*2, "path": "double"}
def inc(s): return {"value": s["value"]+1, "path": "inc"}
def route(s):
return "double" if s["value"] > 0 else "inc"
g = StateGraph()
g.add_node("start", start)
g.add_node("double", double)
g.add_node("inc", inc)
g.set_entry_point("start")
g.add_conditional_edges("start", route, {"double": "double", "inc": "inc"})
app = g.compile()
assert app.invoke({"value": 2}) == {"value": 4, "path": "double"}
assert app.invoke({"value": -1}) == {"value": 0, "path": "inc"}
测试 3:循环
目标:验证条件边指向前序节点可形成循环,直到满足终止条件才离开。
通过标准:
-
最终
total == 7(初始 1,每次 +3,共执行 2 次,1+3+3=7) -
节点
accumulate被调用 2 次
python
def test_loop():
from src.graph import StateGraph
call_counter = {"cnt": 0}
def acc(state):
call_counter["cnt"] += 1
return {"total": state["total"] + state.get("inc", 2)}
def should_continue(state):
return "accumulate" if state["total"] < 5 else "__end__"
g = StateGraph()
g.add_node("accumulate", acc)
g.set_entry_point("accumulate")
g.add_conditional_edges("accumulate", should_continue,
{"accumulate": "accumulate", "__end__": None})
app = g.compile()
result = app.invoke({"total": 1, "inc": 3})
assert result["total"] == 7
assert call_counter["cnt"] == 2
测试 4:并行执行
目标:验证从一个节点分叉到多个无依赖节点时,这些节点可以并发执行以缩短总耗时。
通过标准:
-
result包含{"x": 10, "a": 11, "b": 12} -
执行耗时 < 0.15 秒(两个节点分别 sleep 0.1s,若并行则总耗时约 0.1s,串行则 0.2s 以上)
python
import time
def test_parallel_execution():
from src.graph import StateGraph
def split(s): return s
def ta(s):
time.sleep(0.1)
return {"a": s.get("x",0)+1}
def tb(s):
time.sleep(0.1)
return {"b": s.get("x",0)+2}
def merge(s): return s
g = StateGraph()
g.add_node("split", split)
g.add_node("ta", ta)
g.add_node("tb", tb)
g.add_node("merge", merge)
g.set_entry_point("split")
g.add_edge("split", "ta")
g.add_edge("split", "tb")
g.add_edge("ta", "merge")
g.add_edge("tb", "merge")
app = g.compile()
start = time.perf_counter()
res = app.invoke({"x": 10})
elapsed = time.perf_counter() - start
assert res == {"x": 10, "a": 11, "b": 12}
assert elapsed < 0.15, f"Not parallel: {elapsed:.2f}s"
测试 5:流式输出
目标 :验证 stream 方法能在每个节点完成后立即产出当前状态快照。
通过标准:
-
收集到的中间状态数量为 3
-
最后一个状态为
{"step": 3}
python
def test_stream():
from src.graph import StateGraph
def s1(s): return {"step": 1}
def s2(s): return {"step": 2}
def s3(s): return {"step": 3}
g = StateGraph()
g.add_node("s1", s1)
g.add_node("s2", s2)
g.add_node("s3", s3)
g.set_entry_point("s1")
g.add_edge("s1", "s2")
g.add_edge("s2", "s3")
app = g.compile()
states = list(app.stream({}))
assert len(states) == 3
assert states[-1] == {"step": 3}
测试 6:线程隔离与检查点恢复
目标 :验证使用检查点后,不同 thread_id 的状态相互隔离;同一 thread_id 再次调用时会从上次保存的状态继续。
通过标准:
-
线程 1 首次调用:
counter从 0 变为 1 -
线程 2 调用:
counter从 10 变为 11 -
线程 1 第二次调用(不传初始值):基于上次保存的 1 再增加 1,得到 2
python
def test_thread_isolation():
from src.graph import StateGraph
from src.checkpoint import MemorySaver
def counter(state):
c = state.get("counter", 0) + 1
return {"counter": c}
g = StateGraph()
g.add_node("count", counter)
g.set_entry_point("count")
g.add_edge("count", "__end__")
app = g.compile(checkpointer=MemorySaver())
c1 = {"configurable": {"thread_id": "t1"}}
c2 = {"configurable": {"thread_id": "t2"}}
assert app.invoke({"counter": 0}, c1)["counter"] == 1
assert app.invoke({"counter": 10}, c2)["counter"] == 11
assert app.invoke({}, c1)["counter"] == 2
测试 7:中断与恢复
目标 :验证节点可抛出 InterruptError 暂停执行,外部可在提供额外状态后从断点继续,最终获得完整结果。
通过标准:
-
首次调用抛出
InterruptError -
第二次调用(传入
{"approved": True})正常完成,返回{"approved": True, "step1_done": True, "step2_done": True}
python
import pytest
from src.graph import StateGraph
from src.checkpoint import MemorySaver
from src.exceptions import InterruptError
def test_interrupt_and_resume():
def step1(state):
if "approved" not in state:
raise InterruptError("Need approval")
return {"step1_done": True}
def step2(state):
return {"step2_done": True}
g = StateGraph()
g.add_node("step1", step1)
g.add_node("step2", step2)
g.set_entry_point("step1")
g.add_edge("step1", "step2")
app = g.compile(checkpointer=MemorySaver())
config = {"configurable": {"thread_id": "irq1"}}
with pytest.raises(InterruptError):
app.invoke({}, config)
result = app.invoke({"approved": True}, config)
assert result == {"approved": True, "step1_done": True, "step2_done": True}
测试 8:异步接口等价性
目标 :验证 ainvoke 与同步 invoke 结果完全一致,且可并发执行异步节点。
通过标准 :同步调用与异步调用都返回 {"a": "async", "b": "works"}。
python
import asyncio
import pytest
from src.graph import StateGraph
@pytest.mark.asyncio
async def test_async_equivalence():
async def a(s):
await asyncio.sleep(0.01)
return {"a": "async"}
def b(s):
return {"b": "works"}
g = StateGraph()
g.add_node("a", a)
g.add_node("b", b)
g.set_entry_point("a")
g.add_edge("a", "b")
app = g.compile()
sync_res = app.invoke({})
async_res = await app.ainvoke({})
assert sync_res == async_res == {"a": "async", "b": "works"}
六. 思考检验
1. 图结构与执行模型
请用伪代码描述你的 invoke 主循环的关键步骤(尤其如何收集下一批节点、如何处理条件边)。相比于简单的链式调用 (a | b | c),状态图在支持分支、循环和并行时带来了哪些结构上的优势?
2. 检查点与中断恢复
在测试 7 中,第一次调用抛出 InterruptError 后,检查点中保存了哪些信息?第二次调用时,你的引擎是如何利用这些信息从断点继续的?如果 step1 内部有一个循环,恢复后循环计数器会丢失吗?为什么?
3. 流式与状态传播
如果你的 stream 方法在每个节点完成后就产出状态,这对于展示工作进度很有效。但若某个节点本身需要逐步产出大量数据(如 LLM 的 token 流)而不只是最终状态更新,当前的图引擎能否支持?你认为需要如何扩展?LangGraph 实际为此提供了什么机制?
资源附录
- 实现源码:agent-grok-labs
- 实验文档:密码:84k2