Langgraph: Human-in-the-Loop 实现机制

本文档展示如何在实际应用中监听 interrupt 异常,自动获取用户输入,然后恢复图执行。

核心概念

interrupt() 函数

interrupt() 是 LangGraph 中用于实现 human-in-the-loop 的核心函数,用于在节点内部暂停执行并等待用户输入。

工作原理

  1. 首次调用 :抛出 GraphInterrupt 异常,停止图执行

    • 异常中包含传入的 value(通常是问题或提示信息)
    • 图状态被保存到 checkpointer
  2. 恢复后调用 :返回用户通过 Command(resume=...) 提供的值

    • 节点会重新执行(从节点开始处)
    • interrupt() 检测到 resume 值,返回该值而不是抛出异常

重要特性

  • 必须启用 checkpointerinterrupt() 依赖 checkpointer 保存状态
  • 节点会重新执行:恢复时从节点开始处重新运行,而不是从中断点继续
  • 支持多个中断 :一个节点内可以有多个 interrupt() 调用,按顺序匹配 resume 值

Command 类

Command 是用于控制图执行和恢复的指令对象,主要用于恢复被 interrupt() 中断的执行。

主要参数

  • resume:恢复中断执行的值

    • 单个值:Command(resume="answer") - 恢复下一个中断
    • 字典映射:Command(resume={interrupt_id: "answer"}) - 恢复指定ID的中断
  • update:更新图状态

    python 复制代码
    Command(update={"key": "value"})
  • goto:跳转到指定节点

    python 复制代码
    Command(goto="node_name")

Checkpointer(检查点)

Checkpointer 是 LangGraph 的持久化层,负责保存和恢复图的状态。它是实现 interrupt() 功能的基础。

核心作用

  1. 保存图状态:在每个执行步骤(superstep)保存图的状态快照
  2. 支持恢复执行:可以从任意检查点恢复图执行
  3. 管理多个会话 :通过 thread_id 区分不同的执行会话

基本用法

python 复制代码
from langgraph.checkpoint.memory import InMemorySaver

# 创建 checkpointer
checkpointer = InMemorySaver()

# 编译图时启用 checkpointer
graph = builder.compile(checkpointer=checkpointer)

# 执行时需要提供 thread_id
config = {"configurable": {"thread_id": "thread_1"}}
graph.stream(input, config)

为什么 interrupt() 需要 checkpointer

  • interrupt() 暂停执行时,需要保存当前状态以便后续恢复
  • 恢复执行时,需要从保存的检查点读取状态
  • 没有 checkpointer,无法实现状态持久化和恢复

常用实现

  • InMemorySaver:内存存储,适用于开发和测试
  • PostgresSaver:PostgreSQL 存储,适用于生产环境
  • SqliteSaver:SQLite 存储,适用于轻量级应用

Thread 和 Checkpoint ID

  • thread_id:会话标识符,用于区分不同的执行会话(必需)
  • checkpoint_id:检查点标识符,用于从特定检查点恢复(可选)
python 复制代码
# 基本配置
config = {"configurable": {"thread_id": "user_123"}}

# 从特定检查点恢复
config = {
    "configurable": {
        "thread_id": "user_123",
        "checkpoint_id": "checkpoint_abc"
    }
}

完整示例

以下示例展示了一个完整的流程:Agent 在执行过程中调用 interrupt() 询问用户,监听中断事件,自动获取用户输入,然后恢复执行。

python 复制代码
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.types import Command, interrupt
from langgraph.checkpoint.memory import InMemorySaver
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from langchain_anthropic import ChatAnthropic
from pydantic import BaseModel
import uuid

# 定义工具和模型
@tool
def search(query: str):
    return f"搜索结果: {query}"

class AskHuman(BaseModel):
    question: str

model = ChatAnthropic(model="claude-3-5-sonnet-latest")
model = model.bind_tools([search, AskHuman])

# 定义节点
def call_model(state):
    messages = state["messages"]
    response = model.invoke(messages)
    return {"messages": [response]}

def ask_human(state):
    tool_call = state["messages"][-1].tool_calls[0]
    ask = AskHuman.model_validate(tool_call["args"])
    answer = interrupt(ask.question)  # 中断执行,等待用户输入
    return {
        "messages": [{
            "tool_call_id": tool_call["id"],
            "type": "tool",
            "content": answer
        }]
    }

# 构建图
workflow = StateGraph(MessagesState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", ToolNode([search]))
workflow.add_node("ask_human", ask_human)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges(
    "agent",
    lambda state: (
        "ask_human" if state["messages"][-1].tool_calls 
        and state["messages"][-1].tool_calls[0]["name"] == "AskHuman"
        else "tools" if state["messages"][-1].tool_calls
        else END
    )
)
workflow.add_edge("tools", "agent")
workflow.add_edge("ask_human", "agent")
app = workflow.compile(checkpointer=InMemorySaver())

# 获取用户输入的方法(可从数据库、API、消息队列等获取)
def get_user_input(question: str, interrupt_id: str) -> str:
    user_inputs = {
        "Where are you located?": "San Francisco",
        "What is your name?": "Alice",
        "What is your age?": "25"
    }
    return user_inputs.get(question, "Unknown")

# 监听中断并自动恢复执行
def run_with_auto_resume(app, initial_input, config):
    for event in app.stream(initial_input, config, stream_mode="updates"):
        if "__interrupt__" in event:
            interrupt_info = event["__interrupt__"][0]
            user_answer = get_user_input(interrupt_info.value, interrupt_info.id)
            return run_with_auto_resume(
                app, Command(resume=user_answer), config
            )
    return []

# 使用
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
initial_input = {
    "messages": [("user", "Ask the user where they are, then look up the weather there")]
}
run_with_auto_resume(app, initial_input, config)

工作原理

执行流程

  1. 初始执行 :图开始执行,Agent 调用工具,路由到 ask_human 节点,interrupt() 被调用并抛出异常

  2. 检测中断 :监听流事件,检测到 __interrupt__ 键,提取中断信息(问题和中断ID)

  3. 获取用户输入 :调用 get_user_input() 方法,从外部系统获取用户输入

  4. 恢复执行 :使用 Command(resume=user_answer) 恢复执行,递归调用继续处理

  5. 完成执行:图执行完成,返回所有事件

关键机制

中断检测

python 复制代码
if "__interrupt__" in event:
    interrupt_info = event["__interrupt__"][0]
    # interrupt_info.value: 问题或提示信息
    # interrupt_info.id: 中断的唯一标识符

恢复执行

python 复制代码
Command(resume=user_answer)  # 单个值
Command(resume={interrupt_id: user_answer})  # 字典映射(多个中断)

递归处理

python 复制代码
def run_with_auto_resume(app, initial_input, config):
    for event in app.stream(initial_input, config):
        if "__interrupt__" in event:
            # 获取用户输入并递归恢复
            return run_with_auto_resume(
                app, Command(resume=user_answer), config
            )
    return []

实际应用

从数据库获取用户输入

python 复制代码
def get_user_input(question: str, interrupt_id: str) -> str:
    import sqlite3
    
    conn = sqlite3.connect('user_inputs.db')
    cursor = conn.cursor()
    cursor.execute(
        "SELECT answer FROM user_inputs WHERE interrupt_id = ?",
        (interrupt_id,)
    )
    result = cursor.fetchone()
    conn.close()
    
    if result:
        return result[0]
    else:
        raise ValueError(f"未找到中断ID {interrupt_id} 对应的用户输入")

从 API 获取用户输入

python 复制代码
def get_user_input(question: str, interrupt_id: str) -> str:
    import requests
    
    response = requests.post(
        "https://api.example.com/get-user-input",
        json={"interrupt_id": interrupt_id, "question": question}
    )
    
    if response.status_code == 200:
        return response.json()["answer"]
    else:
        raise ValueError(f"API 调用失败: {response.status_code}")

从消息队列获取用户输入

python 复制代码
def get_user_input(question: str, interrupt_id: str) -> str:
    import redis
    import time
    
    r = redis.Redis(host='localhost', port=6379, db=0)
    
    # 轮询等待用户输入
    while True:
        answer = r.get(f"user_input:{interrupt_id}")
        if answer:
            return answer.decode('utf-8')
        time.sleep(0.1)  # 等待100ms后重试

从 WebSocket 获取用户输入

python 复制代码
def get_user_input(question: str, interrupt_id: str) -> str:
    import asyncio
    import websockets
    import json
    
    async def wait_for_input():
        async with websockets.connect("ws://localhost:8765") as websocket:
            await websocket.send(json.dumps({
                "interrupt_id": interrupt_id,
                "question": question
            }))
            response = await websocket.recv()
            return json.loads(response)["answer"]
    
    return asyncio.run(wait_for_input())

最佳实践

1. 监听中断事件

通过检查流事件中的 __interrupt__ 键来检测中断:

python 复制代码
for event in app.stream(input, config, stream_mode="updates"):
    if "__interrupt__" in event:
        interrupt_info = event["__interrupt__"][0]
        # 处理中断

2. 处理多个中断

如果图执行过程中可能有多个中断,使用循环而不是递归,避免无限循环:

python 复制代码
def run_with_auto_resume(app, initial_input, config, max_iterations=10):
    iteration = 0
    
    while iteration < max_iterations:
        iteration += 1
        interrupt_detected = False
        
        input_data = initial_input if iteration == 1 else Command(resume=user_answer)
        
        for event in app.stream(input_data, config, stream_mode="updates"):
            if "__interrupt__" in event:
                interrupt_info = event["__interrupt__"][0]
                user_answer = get_user_input(interrupt_info.value, interrupt_info.id)
                interrupt_detected = True
                break
        
        if not interrupt_detected:
            return events
    
    raise RuntimeError(f"达到最大迭代次数 {max_iterations}")

3. 错误处理

在实际应用中,应该添加适当的错误处理和超时机制:

python 复制代码
def get_user_input(question: str, interrupt_id: str, timeout: float = 30.0) -> str:
    import time
    
    start_time = time.time()
    while time.time() - start_time < timeout:
        try:
            # 尝试获取用户输入
            answer = fetch_from_external_system(interrupt_id)
            if answer:
                return answer
        except Exception as e:
            logger.error(f"获取用户输入失败: {e}")
        
        time.sleep(0.1)
    
    raise TimeoutError(f"获取用户输入超时: {timeout}秒")

4. 状态管理

使用 thread_id 管理不同的执行会话:

python 复制代码
# 为每个用户创建独立的 thread_id
config = {"configurable": {"thread_id": f"user_{user_id}"}}

# 可以从特定检查点恢复
config = {
    "configurable": {
        "thread_id": f"user_{user_id}",
        "checkpoint_id": checkpoint_id
    }
}
相关推荐
墨风如雪9 分钟前
OpenAI亮剑医疗:ChatGPT Health正式发布,你的私人健康参谋上线
aigc
勇哥java实战分享28 分钟前
短信平台 Pro 版本 ,比开源版本更强大
后端
学历真的很重要32 分钟前
LangChain V1.0 Context Engineering(上下文工程)详细指南
人工智能·后端·学习·语言模型·面试·职场和发展·langchain
计算机毕设VX:Fegn089535 分钟前
计算机毕业设计|基于springboot + vue二手家电管理系统(源码+数据库+文档)
vue.js·spring boot·后端·课程设计
上进小菜猪1 小时前
基于 YOLOv8 的智能杂草检测识别实战 [目标检测完整源码]
后端
工藤学编程1 小时前
零基础学AI大模型之LangChain智能体执行引擎AgentExecutor
人工智能·langchain
Miku162 小时前
使用 Claude Code 的 pptx-skills 技能生成精美 EVA 主题 PPT 完整指南
aigc·agent·claude
韩师傅2 小时前
前端开发消亡史:AI也无法掩盖没有设计创造力的真相
前端·人工智能·后端
栈与堆2 小时前
LeetCode-1-两数之和
java·数据结构·后端·python·算法·leetcode·rust
superman超哥2 小时前
双端迭代器(DoubleEndedIterator):Rust双向遍历的优雅实现
开发语言·后端·rust·双端迭代器·rust双向遍历