基于LangGraph的react_agent的源码解析

python 复制代码
def create_react_agent(
    model: Union[str, LanguageModelLike],
    tools: Union[Sequence[Union[BaseTool, Callable, dict[str, Any]]], ToolNode],
    *,
    prompt: Optional[Prompt] = None,
    response_format: Optional[
        Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]]
    ] = None,
    pre_model_hook: Optional[RunnableLike] = None,      # 调用LLM前执行的处理模块
    post_model_hook: Optional[RunnableLike] = None,     # 调用LLM后执行的处理模块
    state_schema: Optional[StateSchemaType] = None,
    config_schema: Optional[Type[Any]] = None,
    checkpointer: Optional[Checkpointer] = None,        # 状态保存模块(如记忆、会话持久化)
    store: Optional[BaseStore] = None,
    interrupt_before: Optional[list[str]] = None,       # 在某个节点之前中断
    interrupt_after: Optional[list[str]] = None,        # 在某个节点之后中断
    debug: bool = False,
    version: Literal["v1", "v2"] = "v2",                # v1是单节点工具并行调用。v2 Send分布式
    name: Optional[str] = None,
) -> CompiledGraph:
    """Creates an agent graph that calls tools in a loop until a stopping condition is met.

    For more details on using `create_react_agent`, visit [Agents](https://langchain-ai.github.io/langgraph/agents/overview/) documentation.

    Args:
        model: The `LangChain` chat model that supports tool calling.
        tools: A list of tools or a ToolNode instance.
            If an empty list is provided, the agent will consist of a single LLM node without tool calling.
        prompt: An optional prompt for the LLM. Can take a few different forms:

            - str: This is converted to a SystemMessage and added to the beginning of the list of messages in state["messages"].
            - SystemMessage: this is added to the beginning of the list of messages in state["messages"].
            - Callable: This function should take in full graph state and the output is then passed to the language model.
            - Runnable: This runnable should take in full graph state and the output is then passed to the language model.

        response_format: An optional schema for the final agent output.

            If provided, output will be formatted to match the given schema and returned in the 'structured_response' state key.
            If not provided, `structured_response` will not be present in the output state.
            Can be passed in as:

                - an OpenAI function/tool schema,
                - a JSON Schema,
                - a TypedDict class,
                - or a Pydantic class.
                - a tuple (prompt, schema), where schema is one of the above.
                    The prompt will be used together with the model that is being used to generate the structured response.

            !!! Important
                `response_format` requires the model to support `.with_structured_output`

            !!! Note
                The graph will make a separate call to the LLM to generate the structured response after the agent loop is finished.
                This is not the only strategy to get structured responses, see more options in [this guide](https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/).

        pre_model_hook: An optional node to add before the `agent` node (i.e., the node that calls the LLM).
            Useful for managing long message histories (e.g., message trimming, summarization, etc.).
            Pre-model hook must be a callable or a runnable that takes in current graph state and returns a state update in the form of
                ```python
                # At least one of `messages` or `llm_input_messages` MUST be provided
                {
                    # If provided, will UPDATE the `messages` in the state
                    "messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), ...],
                    # If provided, will be used as the input to the LLM,
                    # and will NOT UPDATE `messages` in the state
                    "llm_input_messages": [...],
                    # Any other state keys that need to be propagated
                    ...
                }
                ```

            !!! Important
                At least one of `messages` or `llm_input_messages` MUST be provided and will be used as an input to the `agent` node.
                The rest of the keys will be added to the graph state.

            !!! Warning
                If you are returning `messages` in the pre-model hook, you should OVERWRITE the `messages` key by doing the following:

                ```python
                {
                    "messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), *new_messages]
                    ...
                }
                ```
        post_model_hook: An optional node to add after the `agent` node (i.e., the node that calls the LLM).
            Useful for implementing human-in-the-loop, guardrails, validation, or other post-processing.
            Post-model hook must be a callable or a runnable that takes in current graph state and returns a state update.

            !!! Note
                Only available with `version="v2"`.
        state_schema: An optional state schema that defines graph state.
            Must have `messages` and `remaining_steps` keys.
            Defaults to `AgentState` that defines those two keys.
        config_schema: An optional schema for configuration.
            Use this to expose configurable parameters via agent.config_specs.
        checkpointer: An optional checkpoint saver object. This is used for persisting
            the state of the graph (e.g., as chat memory) for a single thread (e.g., a single conversation).
        store: An optional store object. This is used for persisting data
            across multiple threads (e.g., multiple conversations / users).
        interrupt_before: An optional list of node names to interrupt before.
            Should be one of the following: "agent", "tools".
            This is useful if you want to add a user confirmation or other interrupt before taking an action.
        interrupt_after: An optional list of node names to interrupt after.
            Should be one of the following: "agent", "tools".
            This is useful if you want to return directly or run additional processing on an output.
        debug: A flag indicating whether to enable debug mode.
        version: Determines the version of the graph to create.
            Can be one of:

            - `"v1"`: The tool node processes a single message. All tool
                calls in the message are executed in parallel within the tool node.
            - `"v2"`: The tool node processes a tool call.
                Tool calls are distributed across multiple instances of the tool
                node using the [Send](https://langchain-ai.github.io/langgraph/concepts/low_level/#send)
                API.
        name: An optional name for the CompiledStateGraph.
            This name will be automatically used when adding ReAct agent graph to another graph as a subgraph node -
            particularly useful for building multi-agent systems.

    Returns:
        A compiled LangChain runnable that can be used for chat interactions.

    The "agent" node calls the language model with the messages list (after applying the prompt).
    If the resulting AIMessage contains `tool_calls`, the graph will then call the ["tools"][langgraph.prebuilt.tool_node.ToolNode].
    The "tools" node executes the tools (1 tool per `tool_call`) and adds the responses to the messages list
    as `ToolMessage` objects. The agent node then calls the language model again.
    The process repeats until no more `tool_calls` are present in the response.
    The agent then returns the full list of messages as a dictionary containing the key "messages".

    ```mermaid
        sequenceDiagram
            participant U as User
            participant A as LLM
            participant T as Tools
            U->>A: Initial input
            Note over A: Prompt + LLM
            loop while tool_calls present
                A->>T: Execute tools
                T-->>A: ToolMessage for each tool_calls
            end
            A->>U: Return final state
    ```

    Example:
        ```python
        from langgraph.prebuilt import create_react_agent

        def check_weather(location: str) -> str:
            '''Return the weather forecast for the specified location.'''
            return f"It's always sunny in {location}"

        graph = create_react_agent(
            "anthropic:claude-3-7-sonnet-latest",
            tools=[check_weather],
            prompt="You are a helpful assistant",
        )
        inputs = {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
        for chunk in graph.stream(inputs, stream_mode="updates"):
            print(chunk)
        ```
    """
    # 版本合规性判断:v1是工具并行执行,v2是每个tool_call一个Send事件(分布式并发)
    if version not in ("v1", "v2"):
        raise ValueError(
            f"Invalid version {version}. Supported versions are 'v1' and 'v2'."
        )

    # 用户自定义了 state_schema。
    if state_schema is not None:
        # messages:Agent的对话历史。remaining_steps:允许的最多思考步数
        required_keys = {"messages", "remaining_steps"}
        # 如果需要结构化输出,状态里添加对应字段
        if response_format is not None:
            required_keys.add("structured_response")

        # 使用 get_type_hints() 获取 schema 的所有字段名
        schema_keys = set(get_type_hints(state_schema))
        # 如果缺少上面定义的 required_keys 中任何字段,就抛出异常
        if missing_keys := required_keys - set(schema_keys):
            raise ValueError(f"Missing required key(s) {missing_keys} in state_schema")

    # 自动设定默认 State Schema (如果用户未提供)
    if state_schema is None:
        # 如果用户没有自定义 state_schema,就使用默认的。如果用户启用了 response_format,说明最终需要结构化输出 -> 用 AgentStateWithStructuredResponse。否则就用基础的 AgentState (只包含 messages+remaining_steps )
        state_schema = (
            AgentStateWithStructuredResponse
            if response_format is not None
            else AgentState
        )

    # 工具准备阶段。将 tools 转换为 ToolNode
    llm_builtin_tools: list[dict] = []
    if isinstance(tools, ToolNode):
        # 如果传进来的是已经构建好的 ToolNode 实例(LangGraph 的工具节点),直接取出其中注册的工具类
        tool_classes = list(tools.tools_by_name.values())
        tool_node = tools
    # 否则,用户传入的是普通工具列表(函数、BaseTool 等)
    else:
        # 把其中是 dict 类型(即 OpenAI 风格工具 schema)的归入 llm_builtin_tools,后续要绑定到 LLM 上
        llm_builtin_tools = [t for t in tools if isinstance(t, dict)]
        # 其余转为 ToolNode 以供 Agent 控制调用
        tool_node = ToolNode([t for t in tools if not isinstance(t, dict)])
        # 再从 ToolNode 中提取真正的工具类用于绑定
        tool_classes = list(tool_node.tools_by_name.values())

    # 模型初始化。从字符串加载实际模型
    if isinstance(model, str):
        # 如果用户传的是字符串(如 "openai:gpt-4"),尝试通过 init_chat_model 初始化一个 LangChain 支持的模型
        try:
            from langchain.chat_models import (  # type: ignore[import-not-found]
                init_chat_model,
            )
        # 若用户未安装 langchain,则抛出明确提示
        except ImportError:
            raise ImportError(
                "Please install langchain (`pip install langchain`) to use '<provider>:<model>' string syntax for `model` parameter."
            )

        model = cast(BaseChatModel, init_chat_model(model))
    # 判断是否启用工具调用 + 工具绑定到模型
    tool_calling_enabled = len(tool_classes) > 0

    # 如果模型本身支持绑定工具(例如 OpenAI 或 Claude 系模型)
    if (
        _should_bind_tools(model, tool_classes)
        and len(tool_classes) > 0
        or (len(llm_builtin_tools) > 0)
    ):
        # 就把 tool_classes + llm_builtin_tools 一起绑定到模型中(以支持 ReAct 格式 tool_calls)
        model = cast(BaseChatModel, model).bind_tools(tool_classes + llm_builtin_tools)  # type: ignore[operator]

    # 构建 model_runnable(带 prompt)
    model_runnable = _get_prompt_runnable(prompt) | model

    # 判断是否存在 return_direct 工具
    # 有些工具被配置了 return_direct=True,一旦这些工具被调用,agent 就会直接停止执行,返回结果
    # If any of the tools are configured to return_directly after running,
    # our graph needs to check if these were called    
    should_return_direct = {t.name for t in tool_classes if t.return_direct}

    # 判断当前是否还需要继续 agent loop(即是否需要再走一轮工具调用和 LLM)。
    ## 根据当前 agent 状态 state 和最新一次模型输出 response,判断是否要继续下一轮步骤(例如再次调用工具)
    def _are_more_steps_needed(state: StateSchema, response: BaseMessage) -> bool:
        # 先判断当前 LLM 输出 response 是否是 AIMessage 类型(只有 AIMessage 才会包含 tool_calls 字段)
        has_tool_calls = isinstance(response, AIMessage) and response.tool_calls
        # 判断 response 中所有 tool_call 的工具名字是否都在 should_return_direct 集合中(之前收集过)
        all_tools_return_direct = (
            all(call["name"] in should_return_direct for call in response.tool_calls)
            if isinstance(response, AIMessage)      # 如果是,说明这些工具被调用后应该立刻结束 Agent 执行。
            else False                              # 如果不是 AIMessage 类型,直接返回 False
        )
        # 获取剩余可用的推理步数
        remaining_steps = _get_state_value(state, "remaining_steps", None)  # 从 state 中提取 remaining_steps(可用步数),用于控制 agent 不陷入死循环
        # 是否是最后一步(is_last_step)
        is_last_step = _get_state_value(state, "is_last_step", False)       # 某些定制流程中,agent 会明确标记当前是否是最后一步(如来自外部控制),用于配合 remaining_steps=None 的情况
        
        # 逻辑判断是否需要更多步骤
        return (
            # 情况 1:没有 remaining_steps 设定,但 is_last_step + 有工具调用
            (remaining_steps is None and is_last_step and has_tool_calls)
            # 情况 2:已经没有剩余步骤,并且所有工具都 return_direct
            or (
                remaining_steps is not None     
                and remaining_steps < 1
                and all_tools_return_direct
            )
            # 只剩一步,且还有工具要调用
            or (remaining_steps is not None and remaining_steps < 2 and has_tool_calls)
        )

    # 这个函数用于从当前 Agent 的状态中提取 LLM 输入用的 messages,并确保其格式合法。
    def _get_model_input_state(state: StateSchema) -> StateSchema: 
        # 如果用户配置了 pre_model_hook(提示:用户可能想做对话压缩、裁剪等操作),就要优先使用其输出字段。 
        if pre_model_hook is not None:
            # 优先使用 llm_input_messages,这是 pre_model_hook 专门生成给 LLM 的输入。如果没有,则 fallback 到原始 message            
            messages = (
                _get_state_value(state, "llm_input_messages")
            ) or _get_state_value(state, "messages")
            # 错误提示语,便于调试
            error_msg = f"Expected input to call_model to have 'llm_input_messages' or 'messages' key, but got {state}"
        # 如果没设置 pre_model_hook,则只从 messages 字段获取输入
        else:
            messages = _get_state_value(state, "messages")
            error_msg = (
                f"Expected input to call_model to have 'messages' key, but got {state}"
            )

        # 如果连 messages 都取不到,则说明状态结构有误,抛出异常
        if messages is None:
            raise ValueError(error_msg)

        # 调用工具函数,验证 chat history 是否符合格式(角色、结构、长度等)
        _validate_chat_history(messages)
        # 如果 state 是 Pydantic 模型(BaseModel),直接以属性方式赋值
        # we're passing messages under `messages` key, as this is expected by the prompt
        if isinstance(state_schema, type) and issubclass(state_schema, BaseModel):
            state.messages = messages  # type: ignore
        # 如果是 TypedDict 或普通字典,则用字典方式赋值
        else:
            state["messages"] = messages  # type: ignore

        # 返回更新过 messages 字段的状态
        return state

    # 定义模型节点 call_model 和异步版本 acall_model
    def call_model(state: StateSchema, config: RunnableConfig) -> StateSchema:
        # 先处理输入状态(确保 messages 字段正确)
        state = _get_model_input_state(state)
        # 调用模型,传入完整 state 和运行配置 config(比如 stop、timeout 等), 返回一个 AIMessage 类型响应(必须确保类型匹配)
        response = cast(AIMessage, model_runnable.invoke(state, config))
        # add agent name to the AIMessage
        # 为返回的 AIMessage 添加 agent 的 name(便于多 Agent 系统区分消息来源)
        response.name = name

        # 如果根据返回内容判断还需要进一步步骤(比如工具调用),则插入一条"中间消息",表示 agent 还在处理中
        if _are_more_steps_needed(state, response):
            return {
                "messages": [
                    AIMessage(
                        id=response.id,     # id=response.id 保持一致,方便追踪
                        content="Sorry, need more steps to process this request.",
                    )
                ]
            }
        # 否则直接返回模型响应(以列表形式包装)
        # We return a list, because this will get added to the existing list
        return {"messages": [response]}

    async def acall_model(state: StateSchema, config: RunnableConfig) -> StateSchema:
        state = _get_model_input_state(state)
        response = cast(AIMessage, await model_runnable.ainvoke(state, config))
        # add agent name to the AIMessage
        response.name = name
        if _are_more_steps_needed(state, response):
            return {
                "messages": [
                    AIMessage(
                        id=response.id,
                        content="Sorry, need more steps to process this request.",
                    )
                ]
            }
        # We return a list, because this will get added to the existing list
        return {"messages": [response]}

    # 动态构造 input_schema(输入状态类型)
    input_schema: StateSchemaType
    # 如果设置了 pre_model_hook,就要支持 llm_input_messages 字段作为模型输入
    if pre_model_hook is not None:
        # 对于 Pydantic 类型 schema,检查是否是 Pydantic 模型
        # Dynamically create a schema that inherits from state_schema and adds 'llm_input_messages'
        if isinstance(state_schema, type) and issubclass(state_schema, BaseModel):
            # For Pydantic schemas
            # 引入 create_model 工具函数用于动态扩展字段
            from pydantic import create_model

            input_schema = create_model(        # 创建一个新的 Pydantic 模型,继承用户传入的 state schema
                "CallModelInputSchema",
                llm_input_messages=(list[AnyMessage], ...),     # 添加一个字段 llm_input_messages,类型为 list[AnyMessage]
                __base__=state_schema,          # 用于支持 pre_model_hook 输出的中间结果(不影响原始 messages)
            )
        else:
            # 对于 TypedDict 类型 schema,如果不是 Pydantic,而是 TypedDict,则动态继承并添加新字段
            # For TypedDict schemas
            class CallModelInputSchema(state_schema):  # type: ignore
                llm_input_messages: list[AnyMessage]        # Python 动态类型黑魔法:直接在函数中定义类,并给它加一个额外字段

            input_schema = CallModelInputSchema
    # 如果没有 pre_model_hook,直接用原始 state schema 作为输入 schema
    else:
        input_schema = state_schema

    def generate_structured_response(
        state: StateSchema, config: RunnableConfig
    ) -> StateSchema:
        messages = _get_state_value(state, "messages")
        structured_response_schema = response_format
        if isinstance(response_format, tuple):
            system_prompt, structured_response_schema = response_format
            messages = [SystemMessage(content=system_prompt)] + list(messages)    
        
        model_with_structured_output = _get_model(model).with_structured_output(
            cast(StructuredResponseSchema, structured_response_schema)
        )
        
        print("messages:", messages)
        response = model_with_structured_output.invoke(messages, config)
        return {"structured_response": response}

    async def agenerate_structured_response(
        state: StateSchema, config: RunnableConfig
    ) -> StateSchema:
        messages = _get_state_value(state, "messages")
        structured_response_schema = response_format
        if isinstance(response_format, tuple):
            system_prompt, structured_response_schema = response_format
            messages = [SystemMessage(content=system_prompt)] + list(messages)
        
        model_with_structured_output = _get_model(model).with_structured_output(
            cast(StructuredResponseSchema, structured_response_schema)
        )
        print("messages:", messages)
        response = await model_with_structured_output.ainvoke(messages, config)
        return {"structured_response": response}

    if not tool_calling_enabled:
        # Define a new graph
        workflow = StateGraph(state_schema, config_schema=config_schema)
        workflow.add_node(
            "agent",
            RunnableCallable(call_model, acall_model),
            input=input_schema,
        )
        if pre_model_hook is not None:
            workflow.add_node("pre_model_hook", pre_model_hook)
            workflow.add_edge("pre_model_hook", "agent")
            entrypoint = "pre_model_hook"
        else:
            entrypoint = "agent"

        workflow.set_entry_point(entrypoint)

        if post_model_hook is not None:
            workflow.add_node("post_model_hook", post_model_hook)
            workflow.add_edge("agent", "post_model_hook")

        if response_format is not None:
            workflow.add_node(
                "generate_structured_response",
                RunnableCallable(
                    generate_structured_response, agenerate_structured_response
                ),
            )
            if post_model_hook is not None:
                workflow.add_edge("post_model_hook", "generate_structured_response")
            else:
                workflow.add_edge("agent", "generate_structured_response")

        return workflow.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            debug=debug,
            name=name,
        )

    # Define the function that determines whether to continue or not
    def should_continue(state: StateSchema) -> Union[str, list[Send]]:
        messages = _get_state_value(state, "messages")
        last_message = messages[-1]
        # If there is no function call, then we finish
        if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
            if post_model_hook is not None:
                return "post_model_hook"
            elif response_format is not None:
                return "generate_structured_response"
            else:
                return END
        # Otherwise if there is, we continue
        else:
            if version == "v1":
                return "tools"
            elif version == "v2":
                if post_model_hook is not None:
                    return "post_model_hook"
                tool_calls = [
                    tool_node.inject_tool_args(call, state, store)  # type: ignore[arg-type]
                    for call in last_message.tool_calls
                ]
                return [Send("tools", [tool_call]) for tool_call in tool_calls]

    # Define a new graph
    workflow = StateGraph(state_schema or AgentState, config_schema=config_schema)

    # Define the two nodes we will cycle between
    workflow.add_node(
        "agent", RunnableCallable(call_model, acall_model), input=input_schema
    )
    workflow.add_node("tools", tool_node)

    # Optionally add a pre-model hook node that will be called
    # every time before the "agent" (LLM-calling node)
    if pre_model_hook is not None:
        workflow.add_node("pre_model_hook", pre_model_hook)
        workflow.add_edge("pre_model_hook", "agent")
        entrypoint = "pre_model_hook"
    else:
        entrypoint = "agent"

    # Set the entrypoint as `agent`
    # This means that this node is the first one called
    workflow.set_entry_point(entrypoint)

    agent_paths = []
    post_model_hook_paths = [entrypoint, "tools"]

    # Add a post model hook node if post_model_hook is provided
    if post_model_hook is not None:
        workflow.add_node("post_model_hook", post_model_hook)
        agent_paths.append("post_model_hook")
        workflow.add_edge("agent", "post_model_hook")
    else:
        agent_paths.append("tools")

    # Add a structured output node if response_format is provided
    if response_format is not None:
        workflow.add_node(
            "generate_structured_response",
            RunnableCallable(
                generate_structured_response, agenerate_structured_response
            ),
        )
        if post_model_hook is not None:
            post_model_hook_paths.append("generate_structured_response")
        else:
            agent_paths.append("generate_structured_response")
    else:
        if post_model_hook is not None:
            post_model_hook_paths.append(END)
        else:
            agent_paths.append(END)

    if post_model_hook is not None:

        def post_model_hook_router(state: StateSchema) -> Union[str, list[Send]]:
            """Route to the next node after post_model_hook.

            Routes to one of:
            * "tools": if there are pending tool calls without a corresponding message.
            * "generate_structured_response": if no pending tool calls exist and response_format is specified.
            * END: if no pending tool calls exist and no response_format is specified.
            """

            messages = _get_state_value(state, "messages")
            tool_messages = [
                m.tool_call_id for m in messages if isinstance(m, ToolMessage)
            ]
            last_ai_message = next(
                m for m in reversed(messages) if isinstance(m, AIMessage)
            )
            pending_tool_calls = [
                c for c in last_ai_message.tool_calls if c["id"] not in tool_messages
            ]

            if pending_tool_calls:
                pending_tool_calls = [
                    tool_node.inject_tool_args(call, state, store)  # type: ignore[arg-type]
                    for call in pending_tool_calls
                ]
                return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
            elif isinstance(messages[-1], ToolMessage):
                return entrypoint
            elif response_format is not None:
                return "generate_structured_response"
            else:
                return END

        workflow.add_conditional_edges(
            "post_model_hook",
            post_model_hook_router,  # type: ignore[arg-type]
            path_map=post_model_hook_paths,
        )

    workflow.add_conditional_edges(
        "agent",
        should_continue,  # type: ignore[arg-type]
        path_map=agent_paths,
    )

    def route_tool_responses(state: StateSchema) -> str:
        for m in reversed(_get_state_value(state, "messages")):
            if not isinstance(m, ToolMessage):
                break
            if m.name in should_return_direct:
                return END

        # handle a case of parallel tool calls where
        # the tool w/ `return_direct` was executed in a different `Send`
        if isinstance(m, AIMessage) and m.tool_calls:
            if any(call["name"] in should_return_direct for call in m.tool_calls):
                return END

        return entrypoint

    if should_return_direct:
        workflow.add_conditional_edges(
            "tools", route_tool_responses, path_map=[entrypoint, END]
        )
    else:
        workflow.add_edge("tools", entrypoint)

    # Finally, we compile it!
    # This compiles it into a LangChain Runnable,
    # meaning you can use it as you would any other runnable
    return workflow.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        debug=debug,
        name=name,
    )


# Keep for backwards compatibility
create_tool_calling_executor = create_react_agent

__all__ = [
    "create_react_agent",
    "create_tool_calling_executor",
    "AgentState",
    "AgentStatePydantic",
    "AgentStateWithStructuredResponse",
    "AgentStateWithStructuredResponsePydantic",
]