本文档展示如何在实际应用中监听 interrupt 异常,自动获取用户输入,然后恢复图执行。
核心概念
interrupt() 函数
interrupt() 是 LangGraph 中用于实现 human-in-the-loop 的核心函数,用于在节点内部暂停执行并等待用户输入。
工作原理:
-
首次调用 :抛出
GraphInterrupt异常,停止图执行- 异常中包含传入的
value(通常是问题或提示信息) - 图状态被保存到 checkpointer
- 异常中包含传入的
-
恢复后调用 :返回用户通过
Command(resume=...)提供的值- 节点会重新执行(从节点开始处)
interrupt()检测到 resume 值,返回该值而不是抛出异常
重要特性:
- 必须启用 checkpointer :
interrupt()依赖 checkpointer 保存状态 - 节点会重新执行:恢复时从节点开始处重新运行,而不是从中断点继续
- 支持多个中断 :一个节点内可以有多个
interrupt()调用,按顺序匹配 resume 值
Command 类
Command 是用于控制图执行和恢复的指令对象,主要用于恢复被 interrupt() 中断的执行。
主要参数:
-
resume:恢复中断执行的值- 单个值:
Command(resume="answer")- 恢复下一个中断 - 字典映射:
Command(resume={interrupt_id: "answer"})- 恢复指定ID的中断
- 单个值:
-
update:更新图状态pythonCommand(update={"key": "value"}) -
goto:跳转到指定节点pythonCommand(goto="node_name")
Checkpointer(检查点)
Checkpointer 是 LangGraph 的持久化层,负责保存和恢复图的状态。它是实现 interrupt() 功能的基础。
核心作用:
- 保存图状态:在每个执行步骤(superstep)保存图的状态快照
- 支持恢复执行:可以从任意检查点恢复图执行
- 管理多个会话 :通过
thread_id区分不同的执行会话
基本用法:
python
from langgraph.checkpoint.memory import InMemorySaver
# 创建 checkpointer
checkpointer = InMemorySaver()
# 编译图时启用 checkpointer
graph = builder.compile(checkpointer=checkpointer)
# 执行时需要提供 thread_id
config = {"configurable": {"thread_id": "thread_1"}}
graph.stream(input, config)
为什么 interrupt() 需要 checkpointer:
interrupt()暂停执行时,需要保存当前状态以便后续恢复- 恢复执行时,需要从保存的检查点读取状态
- 没有 checkpointer,无法实现状态持久化和恢复
常用实现:
InMemorySaver:内存存储,适用于开发和测试PostgresSaver:PostgreSQL 存储,适用于生产环境SqliteSaver:SQLite 存储,适用于轻量级应用
Thread 和 Checkpoint ID:
thread_id:会话标识符,用于区分不同的执行会话(必需)checkpoint_id:检查点标识符,用于从特定检查点恢复(可选)
python
# 基本配置
config = {"configurable": {"thread_id": "user_123"}}
# 从特定检查点恢复
config = {
"configurable": {
"thread_id": "user_123",
"checkpoint_id": "checkpoint_abc"
}
}
完整示例
以下示例展示了一个完整的流程:Agent 在执行过程中调用 interrupt() 询问用户,监听中断事件,自动获取用户输入,然后恢复执行。
python
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.types import Command, interrupt
from langgraph.checkpoint.memory import InMemorySaver
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from langchain_anthropic import ChatAnthropic
from pydantic import BaseModel
import uuid
# 定义工具和模型
@tool
def search(query: str):
return f"搜索结果: {query}"
class AskHuman(BaseModel):
question: str
model = ChatAnthropic(model="claude-3-5-sonnet-latest")
model = model.bind_tools([search, AskHuman])
# 定义节点
def call_model(state):
messages = state["messages"]
response = model.invoke(messages)
return {"messages": [response]}
def ask_human(state):
tool_call = state["messages"][-1].tool_calls[0]
ask = AskHuman.model_validate(tool_call["args"])
answer = interrupt(ask.question) # 中断执行,等待用户输入
return {
"messages": [{
"tool_call_id": tool_call["id"],
"type": "tool",
"content": answer
}]
}
# 构建图
workflow = StateGraph(MessagesState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", ToolNode([search]))
workflow.add_node("ask_human", ask_human)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges(
"agent",
lambda state: (
"ask_human" if state["messages"][-1].tool_calls
and state["messages"][-1].tool_calls[0]["name"] == "AskHuman"
else "tools" if state["messages"][-1].tool_calls
else END
)
)
workflow.add_edge("tools", "agent")
workflow.add_edge("ask_human", "agent")
app = workflow.compile(checkpointer=InMemorySaver())
# 获取用户输入的方法(可从数据库、API、消息队列等获取)
def get_user_input(question: str, interrupt_id: str) -> str:
user_inputs = {
"Where are you located?": "San Francisco",
"What is your name?": "Alice",
"What is your age?": "25"
}
return user_inputs.get(question, "Unknown")
# 监听中断并自动恢复执行
def run_with_auto_resume(app, initial_input, config):
for event in app.stream(initial_input, config, stream_mode="updates"):
if "__interrupt__" in event:
interrupt_info = event["__interrupt__"][0]
user_answer = get_user_input(interrupt_info.value, interrupt_info.id)
return run_with_auto_resume(
app, Command(resume=user_answer), config
)
return []
# 使用
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
initial_input = {
"messages": [("user", "Ask the user where they are, then look up the weather there")]
}
run_with_auto_resume(app, initial_input, config)
工作原理
执行流程
-
初始执行 :图开始执行,Agent 调用工具,路由到
ask_human节点,interrupt()被调用并抛出异常 -
检测中断 :监听流事件,检测到
__interrupt__键,提取中断信息(问题和中断ID) -
获取用户输入 :调用
get_user_input()方法,从外部系统获取用户输入 -
恢复执行 :使用
Command(resume=user_answer)恢复执行,递归调用继续处理 -
完成执行:图执行完成,返回所有事件
关键机制
中断检测:
python
if "__interrupt__" in event:
interrupt_info = event["__interrupt__"][0]
# interrupt_info.value: 问题或提示信息
# interrupt_info.id: 中断的唯一标识符
恢复执行:
python
Command(resume=user_answer) # 单个值
Command(resume={interrupt_id: user_answer}) # 字典映射(多个中断)
递归处理:
python
def run_with_auto_resume(app, initial_input, config):
for event in app.stream(initial_input, config):
if "__interrupt__" in event:
# 获取用户输入并递归恢复
return run_with_auto_resume(
app, Command(resume=user_answer), config
)
return []
实际应用
从数据库获取用户输入
python
def get_user_input(question: str, interrupt_id: str) -> str:
import sqlite3
conn = sqlite3.connect('user_inputs.db')
cursor = conn.cursor()
cursor.execute(
"SELECT answer FROM user_inputs WHERE interrupt_id = ?",
(interrupt_id,)
)
result = cursor.fetchone()
conn.close()
if result:
return result[0]
else:
raise ValueError(f"未找到中断ID {interrupt_id} 对应的用户输入")
从 API 获取用户输入
python
def get_user_input(question: str, interrupt_id: str) -> str:
import requests
response = requests.post(
"https://api.example.com/get-user-input",
json={"interrupt_id": interrupt_id, "question": question}
)
if response.status_code == 200:
return response.json()["answer"]
else:
raise ValueError(f"API 调用失败: {response.status_code}")
从消息队列获取用户输入
python
def get_user_input(question: str, interrupt_id: str) -> str:
import redis
import time
r = redis.Redis(host='localhost', port=6379, db=0)
# 轮询等待用户输入
while True:
answer = r.get(f"user_input:{interrupt_id}")
if answer:
return answer.decode('utf-8')
time.sleep(0.1) # 等待100ms后重试
从 WebSocket 获取用户输入
python
def get_user_input(question: str, interrupt_id: str) -> str:
import asyncio
import websockets
import json
async def wait_for_input():
async with websockets.connect("ws://localhost:8765") as websocket:
await websocket.send(json.dumps({
"interrupt_id": interrupt_id,
"question": question
}))
response = await websocket.recv()
return json.loads(response)["answer"]
return asyncio.run(wait_for_input())
最佳实践
1. 监听中断事件
通过检查流事件中的 __interrupt__ 键来检测中断:
python
for event in app.stream(input, config, stream_mode="updates"):
if "__interrupt__" in event:
interrupt_info = event["__interrupt__"][0]
# 处理中断
2. 处理多个中断
如果图执行过程中可能有多个中断,使用循环而不是递归,避免无限循环:
python
def run_with_auto_resume(app, initial_input, config, max_iterations=10):
iteration = 0
while iteration < max_iterations:
iteration += 1
interrupt_detected = False
input_data = initial_input if iteration == 1 else Command(resume=user_answer)
for event in app.stream(input_data, config, stream_mode="updates"):
if "__interrupt__" in event:
interrupt_info = event["__interrupt__"][0]
user_answer = get_user_input(interrupt_info.value, interrupt_info.id)
interrupt_detected = True
break
if not interrupt_detected:
return events
raise RuntimeError(f"达到最大迭代次数 {max_iterations}")
3. 错误处理
在实际应用中,应该添加适当的错误处理和超时机制:
python
def get_user_input(question: str, interrupt_id: str, timeout: float = 30.0) -> str:
import time
start_time = time.time()
while time.time() - start_time < timeout:
try:
# 尝试获取用户输入
answer = fetch_from_external_system(interrupt_id)
if answer:
return answer
except Exception as e:
logger.error(f"获取用户输入失败: {e}")
time.sleep(0.1)
raise TimeoutError(f"获取用户输入超时: {timeout}秒")
4. 状态管理
使用 thread_id 管理不同的执行会话:
python
# 为每个用户创建独立的 thread_id
config = {"configurable": {"thread_id": f"user_{user_id}"}}
# 可以从特定检查点恢复
config = {
"configurable": {
"thread_id": f"user_{user_id}",
"checkpoint_id": checkpoint_id
}
}