Langgraph: Human-in-the-Loop 实现机制

本文档展示如何在实际应用中监听 interrupt 异常,自动获取用户输入,然后恢复图执行。

核心概念

interrupt() 函数

interrupt() 是 LangGraph 中用于实现 human-in-the-loop 的核心函数,用于在节点内部暂停执行并等待用户输入。

工作原理

  1. 首次调用 :抛出 GraphInterrupt 异常,停止图执行

    • 异常中包含传入的 value(通常是问题或提示信息)
    • 图状态被保存到 checkpointer
  2. 恢复后调用 :返回用户通过 Command(resume=...) 提供的值

    • 节点会重新执行(从节点开始处)
    • interrupt() 检测到 resume 值,返回该值而不是抛出异常

重要特性

  • 必须启用 checkpointerinterrupt() 依赖 checkpointer 保存状态
  • 节点会重新执行:恢复时从节点开始处重新运行,而不是从中断点继续
  • 支持多个中断 :一个节点内可以有多个 interrupt() 调用,按顺序匹配 resume 值

Command 类

Command 是用于控制图执行和恢复的指令对象,主要用于恢复被 interrupt() 中断的执行。

主要参数

  • resume:恢复中断执行的值

    • 单个值:Command(resume="answer") - 恢复下一个中断
    • 字典映射:Command(resume={interrupt_id: "answer"}) - 恢复指定ID的中断
  • update:更新图状态

    python 复制代码
    Command(update={"key": "value"})
  • goto:跳转到指定节点

    python 复制代码
    Command(goto="node_name")

Checkpointer(检查点)

Checkpointer 是 LangGraph 的持久化层,负责保存和恢复图的状态。它是实现 interrupt() 功能的基础。

核心作用

  1. 保存图状态:在每个执行步骤(superstep)保存图的状态快照
  2. 支持恢复执行:可以从任意检查点恢复图执行
  3. 管理多个会话 :通过 thread_id 区分不同的执行会话

基本用法

python 复制代码
from langgraph.checkpoint.memory import InMemorySaver

# 创建 checkpointer
checkpointer = InMemorySaver()

# 编译图时启用 checkpointer
graph = builder.compile(checkpointer=checkpointer)

# 执行时需要提供 thread_id
config = {"configurable": {"thread_id": "thread_1"}}
graph.stream(input, config)

为什么 interrupt() 需要 checkpointer

  • interrupt() 暂停执行时,需要保存当前状态以便后续恢复
  • 恢复执行时,需要从保存的检查点读取状态
  • 没有 checkpointer,无法实现状态持久化和恢复

常用实现

  • InMemorySaver:内存存储,适用于开发和测试
  • PostgresSaver:PostgreSQL 存储,适用于生产环境
  • SqliteSaver:SQLite 存储,适用于轻量级应用

Thread 和 Checkpoint ID

  • thread_id:会话标识符,用于区分不同的执行会话(必需)
  • checkpoint_id:检查点标识符,用于从特定检查点恢复(可选)
python 复制代码
# 基本配置
config = {"configurable": {"thread_id": "user_123"}}

# 从特定检查点恢复
config = {
    "configurable": {
        "thread_id": "user_123",
        "checkpoint_id": "checkpoint_abc"
    }
}

完整示例

以下示例展示了一个完整的流程:Agent 在执行过程中调用 interrupt() 询问用户,监听中断事件,自动获取用户输入,然后恢复执行。

python 复制代码
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.types import Command, interrupt
from langgraph.checkpoint.memory import InMemorySaver
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from langchain_anthropic import ChatAnthropic
from pydantic import BaseModel
import uuid

# 定义工具和模型
@tool
def search(query: str):
    return f"搜索结果: {query}"

class AskHuman(BaseModel):
    question: str

model = ChatAnthropic(model="claude-3-5-sonnet-latest")
model = model.bind_tools([search, AskHuman])

# 定义节点
def call_model(state):
    messages = state["messages"]
    response = model.invoke(messages)
    return {"messages": [response]}

def ask_human(state):
    tool_call = state["messages"][-1].tool_calls[0]
    ask = AskHuman.model_validate(tool_call["args"])
    answer = interrupt(ask.question)  # 中断执行,等待用户输入
    return {
        "messages": [{
            "tool_call_id": tool_call["id"],
            "type": "tool",
            "content": answer
        }]
    }

# 构建图
workflow = StateGraph(MessagesState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", ToolNode([search]))
workflow.add_node("ask_human", ask_human)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges(
    "agent",
    lambda state: (
        "ask_human" if state["messages"][-1].tool_calls 
        and state["messages"][-1].tool_calls[0]["name"] == "AskHuman"
        else "tools" if state["messages"][-1].tool_calls
        else END
    )
)
workflow.add_edge("tools", "agent")
workflow.add_edge("ask_human", "agent")
app = workflow.compile(checkpointer=InMemorySaver())

# 获取用户输入的方法(可从数据库、API、消息队列等获取)
def get_user_input(question: str, interrupt_id: str) -> str:
    user_inputs = {
        "Where are you located?": "San Francisco",
        "What is your name?": "Alice",
        "What is your age?": "25"
    }
    return user_inputs.get(question, "Unknown")

# 监听中断并自动恢复执行
def run_with_auto_resume(app, initial_input, config):
    for event in app.stream(initial_input, config, stream_mode="updates"):
        if "__interrupt__" in event:
            interrupt_info = event["__interrupt__"][0]
            user_answer = get_user_input(interrupt_info.value, interrupt_info.id)
            return run_with_auto_resume(
                app, Command(resume=user_answer), config
            )
    return []

# 使用
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
initial_input = {
    "messages": [("user", "Ask the user where they are, then look up the weather there")]
}
run_with_auto_resume(app, initial_input, config)

工作原理

执行流程

  1. 初始执行 :图开始执行,Agent 调用工具,路由到 ask_human 节点,interrupt() 被调用并抛出异常

  2. 检测中断 :监听流事件,检测到 __interrupt__ 键,提取中断信息(问题和中断ID)

  3. 获取用户输入 :调用 get_user_input() 方法,从外部系统获取用户输入

  4. 恢复执行 :使用 Command(resume=user_answer) 恢复执行,递归调用继续处理

  5. 完成执行:图执行完成,返回所有事件

关键机制

中断检测

python 复制代码
if "__interrupt__" in event:
    interrupt_info = event["__interrupt__"][0]
    # interrupt_info.value: 问题或提示信息
    # interrupt_info.id: 中断的唯一标识符

恢复执行

python 复制代码
Command(resume=user_answer)  # 单个值
Command(resume={interrupt_id: user_answer})  # 字典映射(多个中断)

递归处理

python 复制代码
def run_with_auto_resume(app, initial_input, config):
    for event in app.stream(initial_input, config):
        if "__interrupt__" in event:
            # 获取用户输入并递归恢复
            return run_with_auto_resume(
                app, Command(resume=user_answer), config
            )
    return []

实际应用

从数据库获取用户输入

python 复制代码
def get_user_input(question: str, interrupt_id: str) -> str:
    import sqlite3
    
    conn = sqlite3.connect('user_inputs.db')
    cursor = conn.cursor()
    cursor.execute(
        "SELECT answer FROM user_inputs WHERE interrupt_id = ?",
        (interrupt_id,)
    )
    result = cursor.fetchone()
    conn.close()
    
    if result:
        return result[0]
    else:
        raise ValueError(f"未找到中断ID {interrupt_id} 对应的用户输入")

从 API 获取用户输入

python 复制代码
def get_user_input(question: str, interrupt_id: str) -> str:
    import requests
    
    response = requests.post(
        "https://api.example.com/get-user-input",
        json={"interrupt_id": interrupt_id, "question": question}
    )
    
    if response.status_code == 200:
        return response.json()["answer"]
    else:
        raise ValueError(f"API 调用失败: {response.status_code}")

从消息队列获取用户输入

python 复制代码
def get_user_input(question: str, interrupt_id: str) -> str:
    import redis
    import time
    
    r = redis.Redis(host='localhost', port=6379, db=0)
    
    # 轮询等待用户输入
    while True:
        answer = r.get(f"user_input:{interrupt_id}")
        if answer:
            return answer.decode('utf-8')
        time.sleep(0.1)  # 等待100ms后重试

从 WebSocket 获取用户输入

python 复制代码
def get_user_input(question: str, interrupt_id: str) -> str:
    import asyncio
    import websockets
    import json
    
    async def wait_for_input():
        async with websockets.connect("ws://localhost:8765") as websocket:
            await websocket.send(json.dumps({
                "interrupt_id": interrupt_id,
                "question": question
            }))
            response = await websocket.recv()
            return json.loads(response)["answer"]
    
    return asyncio.run(wait_for_input())

最佳实践

1. 监听中断事件

通过检查流事件中的 __interrupt__ 键来检测中断:

python 复制代码
for event in app.stream(input, config, stream_mode="updates"):
    if "__interrupt__" in event:
        interrupt_info = event["__interrupt__"][0]
        # 处理中断

2. 处理多个中断

如果图执行过程中可能有多个中断,使用循环而不是递归,避免无限循环:

python 复制代码
def run_with_auto_resume(app, initial_input, config, max_iterations=10):
    iteration = 0
    
    while iteration < max_iterations:
        iteration += 1
        interrupt_detected = False
        
        input_data = initial_input if iteration == 1 else Command(resume=user_answer)
        
        for event in app.stream(input_data, config, stream_mode="updates"):
            if "__interrupt__" in event:
                interrupt_info = event["__interrupt__"][0]
                user_answer = get_user_input(interrupt_info.value, interrupt_info.id)
                interrupt_detected = True
                break
        
        if not interrupt_detected:
            return events
    
    raise RuntimeError(f"达到最大迭代次数 {max_iterations}")

3. 错误处理

在实际应用中,应该添加适当的错误处理和超时机制:

python 复制代码
def get_user_input(question: str, interrupt_id: str, timeout: float = 30.0) -> str:
    import time
    
    start_time = time.time()
    while time.time() - start_time < timeout:
        try:
            # 尝试获取用户输入
            answer = fetch_from_external_system(interrupt_id)
            if answer:
                return answer
        except Exception as e:
            logger.error(f"获取用户输入失败: {e}")
        
        time.sleep(0.1)
    
    raise TimeoutError(f"获取用户输入超时: {timeout}秒")

4. 状态管理

使用 thread_id 管理不同的执行会话:

python 复制代码
# 为每个用户创建独立的 thread_id
config = {"configurable": {"thread_id": f"user_{user_id}"}}

# 可以从特定检查点恢复
config = {
    "configurable": {
        "thread_id": f"user_{user_id}",
        "checkpoint_id": checkpoint_id
    }
}
相关推荐
猪猪拆迁队5 小时前
虚拟工厂仿真引擎的架构设计:让一条产线可编程、可观测、可干预
后端·ai编程
字节跳动数据库5 小时前
文章分享——相似函数处理方法
人工智能·后端·程序员
云技纵横5 小时前
@Transactional 失效的 7 种场景:第 5 种最难排查
后端
用户6757049885025 小时前
你知道 Go 结构体和结构体指针调用的区别吗?一文带你彻底搞懂!
后端·go
程序员cxuan6 小时前
读懂 Claude Code 架构分析系列,第一篇,开始!
人工智能·后端·架构
用户6757049885026 小时前
面试官问“装饰器模式”,这样回答薪资多要 3000!
后端
tntxia6 小时前
Geo Scene域名修改引起的一些问题
后端
用户298698530146 小时前
Java 实现 Word 文档加密与权限解除
java·后端
vanuan6 小时前
给你的A2A-Agent加把锁-认证鉴权实战指南
后端
Yeats_Liao6 小时前
14:Servlet中的页面跳转-Java Web
java·后端·架构