【人工智能:Agent】--9.2.Langchain自定义中间件

通过实现在智能体执行流程中特定节点运行的钩子来构建自定义中间件。

在智能体开发中,"钩子"就像是执行流程中的检查站插座

  • 想象一下 :智能体原本的执行流程是一条高速公路(A → B → C → D)。**"钩子"**就是在这些路段之间设立的关卡。
  • 作用:当你在某个关卡(钩子)注册了一个函数,程序跑到这里时就会停下来,先执行你的函数,然后再继续上路。

2. 关键机制:什么是"特定节点" (Specific Nodes)?

LangChain 1.0 定义了智能体执行流程中的几个标准节点(时刻),你只能在这些预设的时刻进行干预。主要分为两类:

  • 节点式钩子 (Node-Style Hooks) :在特定步骤前后 运行。
    • before_model:在调用大模型之前。适合修改提示词、检查输入。
    • after_model:在大模型返回结果之后。适合检查输出、过滤敏感词。
    • before_agent / after_agent:在智能体启动前/结束后。适合初始化或清理工作。
  • 缠绕式钩子 (Wrap-Style Hooks)包裹 整个步骤,拥有最高控制权。
    • wrap_model_call:你可以完全控制模型调用的过程,比如实现重试机制、缓存逻辑,甚至直接拦截请求不让模型运行。

3. 最终目的:构建"自定义中间件"

中间件本质上就是一个或多个"钩子函数"的集合。通过在这些"特定节点"插入逻辑,你可以实现各种定制功能:

  • 监控 :在 before_model 打印日志,记录调用次数。
  • 修改 :在 after_model 修改模型的返回结果,比如把"你好"改成"Hello"。
  • 控制 :在 wrap_model_call 中实现逻辑,如果模型报错就自动重试 3 次。

🧩 举个形象的例子

假设 智能体 是一个 "厨师",他的任务是做一道菜(生成答案)。

  • 特定节点 就是厨房里的几个关键动作:"看食谱" (输入提示词)、"烹饪" (调用模型)、"装盘"(输出结果)。
  • 钩子 就是你在厨房里安装的 "监控器""控制器"
  • 这句话的意思 就是:你在"烹饪"这个动作之前安装了一个控制器(钩子),当厨师准备烹饪时,控制器会先检查食材是否过期,如果过期了就拦截掉,不让他做这道菜。

目录

1.节点钩子

2.缠绕钩子

3.创建中间件

3.1.装饰器中间件

3.2.类中间件

4.自定义状态模式

5.执行顺序

6.智能体跳转

7.实践过程

7.1.动态模型选择

7.2.工具调用监控

7.3.修改系统消息


1.节点钩子

在特定执行点顺序运行。用于日志记录、验证和状态更新。可用钩子:

  • before_agent - 智能体开始前(每次召唤一次)
  • before_model - 每次模型调用前
  • after_model - 每次模型响应后
  • after_agent - 智能体完成后(每次调用一次)
python 复制代码
from langchain.agents.middleware import before_model, after_model, AgentState
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import Any


# 定义一个在大模型调用之前的钩子
# 配置:如果触发限制,可以跳转到结束节点
@before_model(can_jump_to=["end"])
def check_message_limit(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """
    中间件逻辑:检查消息数量限制
    如果对话轮数超过50条,直接拦截流程,不再调用模型
    """
    # 检查当前对话历史记录的长度
    if len(state["messages"]) >= 50:
        # 如果超过限制,返回一个包含提示信息的字典
        # 这个字典会作为新的状态更新,并触发跳转到结束节点
        return {
            # 设置返回给用户的消息(中文提示)
            "messages": [AIMessage("对话长度已达上限,本次会话结束。")],
            # 指令:跳过后续步骤,直接跳转到 'end' 节点
            "jump_to": "end"
        }
    # 如果没有超过限制,返回 None,流程将继续正常执行
    return None


# 定义一个在大模型调用之后执行的钩子
@after_model
def log_response(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """
    中间件逻辑:记录模型响应
    在模型返回结果后,打印最后一条消息的内容到控制台
    """
    # 打印模型返回的最新消息内容
    print(f"模型返回内容: {state['messages'][-1].content}")
    # 仅用于打印日志,不需要修改状态,返回 None
    return None

2.缠绕钩子

当调用处理程序时,拦截执行和控制。用于重试、缓存和转换。你决定调用处理器是零次(短路)、一次(正常流)还是多次(重试逻辑)。可用钩子:

  • wrap_model_call - 围绕每个模型调用
  • wrap_tool_call - 每次工具调用周边
python 复制代码
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from typing import Callable


# 使用 @wrap_model_call 装饰器定义一个中间件
# 该装饰器用于完全控制模型调用的过程(包裹式钩子)
@wrap_model_call
def retry_model(
    # request: 包含当前模型调用所需的所有信息(如提示词、模型参数等)
    request: ModelRequest,
    # handler: 这是一个函数,调用 handler(request) 将真正执行模型推理
    # 我们通过控制何时调用 handler 来实现重试逻辑
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse: # 返回值必须是 ModelResponse 类型
    
    # 最多重试 3 次(包括首次尝试)
    for attempt in range(3):
        try:
            # 尝试执行模型调用
            # 如果成功,直接返回结果,函数结束
            return handler(request)
            
        except Exception as e:
            # 如果发生异常(如网络超时、API 错误等)
            
            # 如果已经是第 3 次尝试(attempt 从 0 开始计数)
            if attempt == 2: 
                # 抛出异常,不再重试
                # 这将导致 Agent 运行失败或进入错误处理流程
                raise 
            else:
                # 打印重试日志
                print(f"发生错误: {e},正在重试 ({attempt + 1}/3)...")

场景一:没有这个中间件(裸奔模式)

  1. 第 1 步:助手开始调用大模型分析数据。
  2. 第 2 步 :突然网络抖动,或者模型服务器太忙了,返回了一个 504 Gateway Timeout 错误。
  3. 第 3 步:程序崩溃,用户界面上显示一大串红色的报错代码(Traceback)。
  4. 结果:体验极差。用户需要重新上传文件再试一次,浪费时间。

场景二:有这个 retry_model 中间件(防抖模式)

  1. 第 1 步:助手开始调用大模型。
  2. 第 2 步 :网络抖动,第一次调用失败了。
    • 中间件反应 :捕获到错误,打印日志 Retry 1/3 after error: ...,等待片刻。
  3. 第 3 步:自动进行第二次尝试。如果这次成功了,直接返回结果给用户。
  4. 第 4 步:如果第二次还失败,继续重试第三次。
  5. 结果静默恢复。用户可能只感觉稍微卡顿了一下,但最终拿到了结果,甚至不知道后台刚刚发生过一次故障。

3.创建中间件

你可以用两种方式创建中间件:

3.1.装饰器中间件

单钩中间件快速简单。使用装饰师来包裹各个功能。**Available decorators:**节点风格:

缠绕式:

便利性:

python 复制代码
from langchain.agents.middleware import (
    before_model,
    wrap_model_call,
    AgentState,
    ModelRequest,
    ModelResponse,
)


from langchain_core.tools import tool
from langchain_core.prompts import ChatPromptTemplate
from langgraph.prebuilt import create_react_agent
import os

# ==========================================
# 1. 定义中间件 (Middleware)
# ==========================================

@before_model
def log_before_model(state: AgentState, runtime) -> None:
    print(f"[日志] 即将调用模型,当前上下文长度: {len(state['messages'])} 条消息")

@wrap_model_call
def retry_model(
    request: ModelRequest,
    handler,
) -> ModelResponse:
    for attempt in range(3):
        try:
            # 模拟网络不稳定:注释掉下面这行,取消注释后面的raise来测试重试机制
            return handler(request)
        except Exception as e:
            if attempt == 2:
                print(f"[错误] 所有重试均已耗尽,任务失败: {e}")
                raise
            print(f"[重试] 第 {attempt + 1}/3 次尝试失败: {e},正在重试...")

# ==========================================
# 2. 模拟工具 (Tools)
# ==========================================

# 假设这是你的地质数据库查询工具
@tool
def query_geology_db(formula: str) -> str:
    """模拟查询地质数据库"""
    # 这里模拟一个随机失败的情况,用来测试重试中间件
    import random
    if random.random() < 0.3: # 30% 的概率模拟失败
        raise ConnectionError("数据库连接超时!")
    return f"查询到 {formula} 的数据为: 储量 1.2亿吨,埋深 3500米"

# ==========================================
# 3. 创建智能体 (Agent)
# ==========================================


# 创建智能体,并注入中间件
# 注意:这里使用了 middleware 参数
agent = create_agent(
    llm, 
    tools=[query_geology_db], 
    # 👇 在这里注入中间件列表
    middleware=[log_before_model, retry_model] 
)

# ==========================================
# 4. 执行任务 (Invocation)
# ==========================================

if __name__ == "__main__":
    # 模拟用户输入
    input_message = "请帮我查一下 塔里木盆地A区 的石油储量和埋深。"
    
    print(f"用户输入: {input_message}")
    print("-" * 50)
    
    # 调用智能体,传入用户消息
    result = agent.invoke({
        "messages": [HumanMessage(input_message)]
        # 消息内容:查找包含 'async def' 的所有 Python 文件
    })
    for msg in result["messages"]:
            msg.pretty_print()
    
    # 调用智能体

何时使用装饰器:

  • 需要单钩
  • 无复杂配置
  • 快速原型制作

3.2.类中间件

对于拥有多个钩子或配置的复杂中间件来说,功能更强大。当你需要为同一钩子定义同步和非同步实现,或者想在单一中间件中合并多个钩子时,可以使用类。

python 复制代码
from langchain.agents.middleware import (
    AgentMiddleware,
    AgentState,
    ModelRequest,
    ModelResponse,
)
from langgraph.runtime import Runtime
from typing import Any, Callable

class LoggingMiddleware(AgentMiddleware):
    def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        # 在模型调用前执行:打印当前上下文的消息数量
        print(f"即将调用模型,当前包含 {len(state['messages'])} 条消息")
        return None

    def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        # 在模型调用后执行:打印模型返回的最新内容
        print(f"模型返回内容: {state['messages'][-1].content}")
        return None

agent = create_agent(
    model=llm,
    middleware=[LoggingMiddleware()],
    tools=[...],
)

何时使用类:

  • 为同一钩子定义同步和异步实现
  • 单个中间件需要多个钩子
  • 需要复杂的配置(例如,可配置阈值、自定义模型)
  • 在项目间重复使用初始时间配置

4.自定义状态模式

中间件可以通过自定义属性扩展智能体状态。这使得中间件能够:

  • 跨执行跟踪状态:维护计数器、标志或其他在整个执行生命周期中持续存在的值
  • 钩子之间共享数据 :将信息从 before_model 传递到 after_model 或不同中间件实例之间传递
  • 实现跨领域关注点:添加速率限制、使用跟踪、用户上下文或审计日志等功能,而无需修改核心代理逻辑
  • 做出有条件决策:利用累积状态决定是否继续执行、跳转到不同节点或动态修改行为

这段代码利用中间件机制创建了一个具备"防刷限流"功能的 AI 智能体,通过自定义状态记录模型调用次数,在每次运行前进行检查,一旦超过预设的 10 次上限便自动终止流程,从而有效防止了模型的无限循环或过度调用。

python 复制代码
from langchain.agents import create_agent
from langchain.messages import HumanMessage
from langchain.agents.middleware import AgentState, before_model, after_model
from typing_extensions import NotRequired
from typing import Any
from langgraph.runtime import Runtime


# 定义包含自定义字段的状态类
class CustomState(AgentState):
    # 可选字段:模型调用次数
    model_call_count: NotRequired[int]
    # 可选字段:用户ID
    user_id: NotRequired[str]


# 在模型调用前执行的钩子函数,用于检查调用次数限制
@before_model(state_schema=CustomState, can_jump_to=["end"])
def check_call_limit(state: CustomState, runtime: Runtime) -> dict[str, Any] | None:
    # 获取当前的调用次数,默认为0
    count = state.get("model_call_count", 0)
    # 如果次数超过10次,跳转到结束节点
    if count > 10:
        return {"jump_to": "end"}
    return None


# 在模型调用后执行的钩子函数,用于递增计数器
@after_model(state_schema=CustomState)
def increment_counter(state: CustomState, runtime: Runtime) -> dict[str, Any] | None:
    # 将状态中的调用次数加1
    return {"model_call_count": state.get("model_call_count", 0) + 1}


# 创建代理
agent = create_agent(
    model="gpt-4o",
    # 注册中间件钩子
    middleware=[check_call_limit, increment_counter],
    tools=[],
)

# 调用代理并传入自定义状态
result = agent.invoke({
    "messages": [HumanMessage("你好")],  # 消息内容改为中文
    "model_call_count": 0,  # 初始化调用次数
    "user_id": "用户-123",    # 用户ID改为中文标识
})

5.执行顺序

使用多个中间件时,要了解它们的执行方式:

python 复制代码
agent = create_agent(
    model="gpt-4o",
    middleware=[middleware1, middleware2, middleware3],
    tools=[...],
)
  • before_* 钩子:从头到尾
  • after_* 钩子:倒向倒转
  • wrap_* 钩子:嵌套(第一个中间件包裹所有其他中间件)

6.智能体跳转

要提前退出中间件,请返回包含 jump_to 的词典:

可用的跳跃目标:

  • 'end': 跳转到智能体执行的结尾(或第一个 after_agent 钩子)
  • "tools":跳转到工具节点
  • "model":跳转到模型节点(或第一个 before_model 钩子)

这段代码实现了一个简单的内容安全过滤器 。它在模型返回结果后立即检查内容,一旦发现包含"BLOCKED"字样,就立即拦截并替换为标准的拒绝回复("我无法响应此请求。"),同时强制结束对话流程,从而防止不当内容的输出。

@hook_config(can_jump_to=["end"])

是 LangChain Agent Middleware 中的一个装饰器配置 ,它的主要作用是为钩子函数开启"流程跳转"的权限

简单来说,它告诉系统:"这个函数被允许把执行流程直接跳转到指定的节点"。

python 复制代码
from langchain.agents.middleware import after_model, hook_config, AgentState
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import Any


@after_model
@hook_config(can_jump_to=["end"])
def check_for_blocked(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    # 获取最新的消息
    last_message = state["messages"][-1]
    
    # 检查消息内容中是否包含"被屏蔽"关键词
    if "BLOCKED" in last_message.content:
        return {
            # 如果包含,则替换为中文拒绝消息
            "messages": [AIMessage("我无法响应此请求。")],
            # 并跳转到结束节点
            "jump_to": "end"
        }
    return None

7.实践过程

  1. 让中间件聚焦------每个都应该做得很好
  2. 优雅地处理错误------不要让中间件错误导致智能体崩溃
  3. 使用合适的钩子类型
    • 节点式的顺序逻辑(日志记录、验证)
    • 缠绕流的 wrap 风格(重试、回退、缓存)
  4. 明确记录任何自定义状态属性
  5. 集成前独立测试单元中间件
  6. 考虑执行顺序------将关键中间件放在列表中第一
  7. 尽量使用内置中间件

7.1.动态模型选择

实现了一个动态模型路由机制。它根据当前对话的历史消息长度来决定使用哪个模型进行处理:当对话轮次较少(小于等于10条)时,使用轻量级的 gpt-4o-mini 模型以节省成本;当对话变得复杂(超过10条消息)时,则自动切换到功能更强的 gpt-4o 模型,以保证回复质量。

python 复制代码
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from langchain.chat_models import init_chat_model
from typing import Callable


# 初始化一个功能强大但成本较高的模型
complex_model = init_chat_model("gpt-4o")
# 初始化一个轻量级且成本较低的模型
simple_model = init_chat_model("gpt-4o-mini")

@wrap_model_call
def dynamic_model(
    request: ModelRequest, # 包含当前对话消息的请求对象
    handler: Callable[[ModelRequest], ModelResponse], # 用于调用模型的处理函数
) -> ModelResponse: # 返回模型响应
    # 根据对话长度动态选择模型
    # 如果对话消息数量超过10条,使用复杂模型处理
    if len(request.messages) > 10:
        model = complex_model
    else:
        # 否则使用简单模型处理
        model = simple_model
    # 调用处理函数,但将请求中的模型替换为选定的模型
    return handler(request.override(model=model))

7.2.工具调用监控

实现了一个工具调用的监控中间件。它在工具执行前后打印日志信息,包括工具名称、输入参数以及执行结果(成功或失败及错误详情),主要用于调试和监控智能体(Agent)在运行过程中调用了哪些工具以及工具的运行状态。

python 复制代码
from langchain.agents.middleware import wrap_tool_call
from langchain.tools.tool_node import ToolCallRequest
from langchain.messages import ToolMessage
from langgraph.types import Command
from typing import Callable


@wrap_tool_call
def monitor_tool(
    request: ToolCallRequest, # 包含工具调用信息的请求对象
    handler: Callable[[ToolCallRequest], ToolMessage | Command], # 用于执行工具的处理函数
) -> ToolMessage | Command: # 返回工具执行结果或指令
    # 打印正在执行的工具名称
    print(f"执行工具: {request.tool_call['name']}")
    # 打印传递给工具的参数
    print(f"参数: {request.tool_call['args']}")
    try:
        # 调用处理函数执行工具
        result = handler(request)
        # 打印工具执行成功的提示
        print(f"工具执行成功")
        return result
    except Exception as e:
        # 如果执行失败,打印错误信息并重新抛出异常
        print(f"工具执行失败: {e}")
        raise

7.3.修改系统消息

ModelRequest 上使用 system_message 字段修改中间件中的系统消息。system_message 字段包含一个 SystemMessage 对象(即使代理是用字符串 system_prompt 创建的)。

python 复制代码
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from langchain.messages import SystemMessage
from typing import Callable


@wrap_model_call
def add_context(
    request: ModelRequest, # 包含当前模型请求信息的对象
    handler: Callable[[ModelRequest], ModelResponse], # 用于调用模型的处理函数
) -> ModelResponse: # 返回修改后的模型响应
    # 构建新的内容块列表:保留原有的系统消息内容,并添加新的上下文
    new_content = list(request.system_message.content_blocks) + [
        {"type": "text", "text": "Additional context."}
    ]
    # 创建一个新的系统消息对象,包含更新后的内容
    new_system_message = SystemMessage(content=new_content)
    # 调用处理函数,但将请求中的系统消息替换为添加了新上下文的消息
    return handler(request.override(system_message=new_system_message))

实现了一个模型请求中间件,用于动态向系统消息中注入额外的上下文信息。它在不修改原始请求逻辑的前提下,将一条包含"Additional context."(附加上下文)的新文本块追加到现有的系统消息内容之后,从而确保模型在生成回复时能够基于更新后的、包含更多背景信息的提示词进行推理。

相关推荐
蛇皮划水怪6 小时前
深入浅出LangChain4J
java·langchain·llm
、BeYourself8 小时前
LangChain4j 流式响应
langchain
、BeYourself8 小时前
LangChain4j之Chat and Language
langchain
qfljg10 小时前
langchain usage
langchain
kjkdd14 小时前
6.1 核心组件(Agent)
python·ai·语言模型·langchain·ai编程
渣渣苏18 小时前
Langchain实战快速入门
人工智能·python·langchain
小天呐19 小时前
01—langchain 架构
langchain
香芋Yu21 小时前
【LangChain1.0】第九篇 Agent 架构设计
langchain·agent·架构设计
kjkdd1 天前
5. LangChain设计理念和发展历程
python·语言模型·langchain·ai编程
ASKED_20191 天前
Langchain学习笔记一 -基础模块以及架构概览
笔记·学习·langchain