LangChain v1.0+ 如何构建自定义中间件来拦截和控制 Agent 执行过程

注:本文使用 LangChain v1.0+

Custom 中间件是 LangChain Agent 最强大的扩展机制,让开发者能够在 Agent 执行的任何关键点插入自定义逻辑。

Custom 中间件的两种风格

Custom 中间件提供两种方式来拦截和修改 Agent 行为:

1️⃣ Node-Style Hooks(节点风格)

顺序执行在特定执行点,用于日志记录、验证、状态更新。

python 复制代码
import os
from langchain_openai import ChatOpenAI
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import before_model, after_model
from langchain.tools import tool
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import TypedDict

model = ChatOpenAI(
    model=os.getenv("MODEL_NAME", "Qwen/Qwen2-7B-Instruct"),
    temperature=0.7,
    base_url=os.getenv("SILICONFLOW_BASE_URL"),
    api_key=os.getenv("SILICONFLOW_API_KEY")
)

class Context(TypedDict):
    user_id: str
    request_id: str

# ===== 钩子 1: 模型调用前 =====
@before_model
def log_request_info(state: AgentState, runtime: Runtime[Context]) -> dict | None:
    """记录请求信息(节点风格)"""
    user_id = runtime.context.user_id
    msg_count = len(state.get("messages", []))
    
    print(f"📊 [用户 {user_id}] 即将调用模型")
    print(f"   消息数量: {msg_count}")
    
    # 不修改状态,返回 None
    return None

# ===== 钩子 2: 模型调用后 =====
@after_model
def log_response_info(state: AgentState, runtime: Runtime[Context]) -> dict | None:
    """记录响应信息(节点风格)"""
    last_msg = state["messages"][-1]
    
    if isinstance(last_msg, AIMessage):
        print(f"✅ 模型已响应,长度: {len(last_msg.content)}")
    
    return None

# ===== 创建 Agent =====
agent = create_agent(
    model=model,
    tools=[],
    middleware=[log_request_info, log_response_info],
    context_schema=Context
)

result = agent.invoke(
    {"messages": [{"role": "user", "content": "你好"}]},
    context={"user_id": "user123", "request_id": "req_001"}
)

2️⃣ Wrap-Style Hooks(环绕风格)

环绕执行,你控制 handler 何时被调用。用于重试、缓存、动态修改。

python 复制代码
from langchain.agents.middleware import before_model
from langchain.agents import create_agent

# ===== Wrap 风格中间件:重试逻辑 =====
def create_retry_middleware(max_retries: int = 3):
    """创建重试中间件"""
    
    # 这里用 before_model 演示,但真实场景中应该用 wrap 风格
    @before_model
    def retry_on_error(state: AgentState, runtime: Runtime[Context]) -> dict | None:
        # 在这个例子中,我们可以使用状态追踪重试次数
        # 实际的 wrap 风格需要使用特殊的装饰器
        return None
    
    return retry_on_error

agent = create_agent(
    model=model,
    tools=[],
    middleware=[create_retry_middleware(max_retries=3)],
    context_schema=Context
)

更完整的 Wrap 风格示例:

python 复制代码
from langchain.agents.middleware import before_model

def create_model_retry_middleware(max_retries: int = 3):
    """
    使用状态跟踪来实现重试逻辑
    """
    @before_model
    def handle_retry(state: AgentState, runtime: Runtime[Context]) -> dict | None:
        # 检查重试计数
        retry_count = state.get("retry_count", 0)
        
        if retry_count >= max_retries:
            return {
                "messages": [AIMessage(content="已达到最大重试次数")],
                "jumpTo": "end"  # 跳到结束
            }
        
        return None
    
    return handle_retry

修改状态的中间件

中间件可以返回一个字典来修改状态:

python 复制代码
from typing import TypedDict
from langchain.agents import AgentState, create_agent
from langchain.agents.middleware import before_model, after_model
from langchain.messages import AIMessage
from langgraph.runtime import Runtime

class Context(TypedDict):
    user_id: str

# ===== 例子 1: 在模型调用前修改消息 =====
@before_model
def trim_long_messages(state: AgentState, runtime: Runtime[Context]) -> dict | None:
    """如果消息过多,删除旧消息"""
    messages = state.get("messages", [])
    
    # 超过 100 条消息则只保留最近 50 条
    if len(messages) > 100:
        print(f"⚠️  消息过多 ({len(messages)}),正在修剪...")
        trimmed_messages = messages[-50:]  # 只保留最后 50 条
        
        return {"messages": trimmed_messages}
    
    return None

# ===== 例子 2: 在模型响应后替换内容 =====
@after_model
def filter_sensitive_content(state: AgentState, runtime: Runtime[Context]) -> dict | None:
    """过滤敏感内容"""
    last_msg = state["messages"][-1]
    
    if isinstance(last_msg, AIMessage):
        content = last_msg.content.lower()
        
        # 检测敏感词
        if "password" in content or "api_key" in content:
            print("🚫 检测到敏感内容,正在替换...")
            return {
                "messages": [AIMessage(content="无法显示该内容。")]
            }
    
    return None

agent = create_agent(
    model=model,
    tools=[],
    middleware=[trim_long_messages, filter_sensitive_content],
    context_schema=Context
)

使用 jumpTo 控制流程

中间件可以用 jumpTo 提前结束或跳过执行:

python 复制代码
from langchain.agents.middleware import before_model

@before_model
def rate_limit_check(state: AgentState, runtime: Runtime[Context]) -> dict | None:
    """检查速率限制"""
    msg_count = len(state.get("messages", []))
    
    # 如果消息太多,直接返回错误并结束
    if msg_count > 1000:
        return {
            "messages": [AIMessage(content="已达到请求限制")],
            "jumpTo": "end"  # 直接跳到 agent 结束
        }
    
    return None

agent = create_agent(
    model=model,
    tools=[],
    middleware=[rate_limit_check],
    context_schema=Context
)

扩展状态 Schema

中间件可以为 Agent 状态添加自定义字段:

python 复制代码
from langchain.agents import AgentState, create_agent
from typing import TypedDict
from langchain.agents.middleware import before_model, after_model

# 创建扩展的状态类
class ExtendedAgentState(AgentState):
    """扩展状态,添加自定义字段"""
    call_count: int  # 模型调用次数
    total_tokens: int  # 总 token 数
    user_metadata: dict  # 用户元数据

# ===== 中间件追踪调用次数 =====
@before_model
def track_calls(state: ExtendedAgentState, runtime: Runtime[Context]) -> dict | None:
    """记录模型调用次数"""
    current_count = state.get("call_count", 0)
    print(f"📞 这是第 {current_count + 1} 次模型调用")
    
    return {"call_count": current_count + 1}

@after_model
def track_tokens(state: ExtendedAgentState, runtime: Runtime[Context]) -> dict | None:
    """记录 token 使用量"""
    last_msg = state["messages"][-1]
    
    # 估算 token 数(简单方法)
    token_count = len(str(last_msg.content).split())
    current_total = state.get("total_tokens", 0)
    
    return {"total_tokens": current_total + token_count}

agent = create_agent(
    model=model,
    tools=[],
    middleware=[track_calls, track_tokens],
    context_schema=Context,
    state_schema=ExtendedAgentState  # 使用扩展的状态
)

result = agent.invoke(
    {
        "messages": [{"role": "user", "content": "你好"}],
        "call_count": 0,
        "total_tokens": 0,
        "user_metadata": {"plan": "premium"}
    },
    context={"user_id": "user123"}
)

print(f"最终调用次数: {result['call_count']}")
print(f"总 token 数: {result['total_tokens']}")

实战:完整的中间件系统

python 复制代码
import os
from langchain_openai import ChatOpenAI
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import before_model, after_model
from langchain.tools import tool
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import TypedDict
from datetime import datetime
import time

model = ChatOpenAI(
    model=os.getenv("MODEL_NAME", "Qwen/Qwen2-7B-Instruct"),
    temperature=0.7,
    base_url=os.getenv("SILICONFLOW_BASE_URL"),
    api_key=os.getenv("SILICONFLOW_API_KEY")
)

# ===== 上下文和状态 =====
class Context(TypedDict):
    user_id: str
    user_role: str  # "admin" 或 "user"

class ExtendedState(AgentState):
    call_count: int
    start_time: float
    execution_logs: list

# ===== 工具 =====
@tool
def search_tool(query: str) -> str:
    """搜索工具"""
    return f"找到关于 '{query}' 的 3 个结果"

# ===== 中间件 1: 请求验证 =====
@before_model
def validate_request(state: ExtendedState, runtime: Runtime[Context]) -> dict | None:
    """验证请求权限"""
    user_role = runtime.context.get("user_role", "user")
    
    # 仅管理员可以连续发送超过 10 条消息
    msg_count = len(state.get("messages", []))
    if msg_count > 10 and user_role != "admin":
        return {
            "messages": [AIMessage(content="您已达到消息限制。")],
            "jumpTo": "end"
        }
    
    return None

# ===== 中间件 2: 请求日志 =====
@before_model
def log_request(state: ExtendedState, runtime: Runtime[Context]) -> dict | None:
    """记录请求详情"""
    user_id = runtime.context.user_id
    call_count = state.get("call_count", 0)
    
    log_msg = f"[{datetime.now().strftime('%H:%M:%S')}] 用户 {user_id} - 调用 #{call_count + 1}"
    
    logs = state.get("execution_logs", [])
    logs.append(log_msg)
    
    return {
        "call_count": call_count + 1,
        "execution_logs": logs
    }

# ===== 中间件 3: 响应验证 =====
@after_model
def validate_response(state: ExtendedState, runtime: Runtime[Context]) -> dict | None:
    """检查响应质量"""
    last_msg = state["messages"][-1]
    
    if isinstance(last_msg, AIMessage):
        content = last_msg.content
        
        # 内容过短可能是错误
        if len(content) < 10:
            return {
                "messages": [AIMessage(content="模型响应异常,请重试。")]
            }
    
    return None

# ===== 中间件 4: 性能监控 =====
@after_model
def monitor_performance(state: ExtendedState, runtime: Runtime[Context]) -> dict | None:
    """监控执行时间"""
    start_time = state.get("start_time", time.time())
    elapsed = time.time() - start_time
    
    logs = state.get("execution_logs", [])
    logs.append(f"⏱️  耗时: {elapsed:.2f}秒")
    
    return {"execution_logs": logs}

# ===== 创建 Agent =====
agent = create_agent(
    model=model,
    tools=[search_tool],
    middleware=[
        validate_request,    # ① 验证权限
        log_request,        # ② 记录请求
        validate_response,  # ③ 验证响应
        monitor_performance # ④ 监控性能
    ],
    context_schema=Context,
    state_schema=ExtendedState
)

# ===== 测试 =====
if __name__ == "__main__":
    result = agent.invoke(
        {
            "messages": [{"role": "user", "content": "搜索 Python"}],
            "call_count": 0,
            "start_time": time.time(),
            "execution_logs": []
        },
        context={"user_id": "user_001", "user_role": "admin"}
    )
    
    print("执行日志:")
    for log in result.get("execution_logs", []):
        print(f"  {log}")

关键概念总结

特性 Node-Style Wrap-Style
执行时机 顺序执行在特定点 环绕执行
用途 日志、验证、更新 重试、缓存、动态选择
控制权 框架控制 你控制 handler
钩子 before_*, after_* wrap_*
返回值 dictNone 直接返回结果

本文使用的模型服务来自硅基流动平台。新用户通过邀请链接注册可领取 2000万免费token,支持GLM-4.6、Kimi-K2-Thinking、MiniMaxAI/MiniMax-M2、DeepSeek-R2等主流大模型调用,API稳定性与响应速度俱佳。

专属注册链接:cloud.siliconflow.cn/i/AvDmOKTO

(打直球:这是我的推荐码 感谢大佬的支持)

作者:世界那么哒哒

链接:juejin.cn/post/757240...

来源:稀土掘金

著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

相关文档:

相关推荐
常先森7 小时前
【解密源码】 RAGFlow 切分最佳实践- paper 篇
架构·llm·agent
温柔哥`8 小时前
PANDA:通过代理型 AI 工程师迈向通用视频异常检测
大模型·agent·rag·vad·视频异常检测·工具调用·mllms
serve the people17 小时前
Prompts for Chat Models in LangChain
java·linux·langchain
serve the people21 小时前
Comma-Separated List Output Parser in LangChain
windows·langchain·list
后端小肥肠1 天前
别再找提示词了!n8n+Coze+Sora2:扔个链接,AI自动反推,爆款视频直存本地!
aigc·agent·coze
Joker-Tong1 天前
大模型数据洞察能力方法调研
人工智能·python·agent
AI大模型1 天前
12 节课解锁 AI Agents,让AI替你打工(一): 简介
程序员·llm·agent
AI大模型1 天前
12 节课解锁 AI Agents,让AI替你打工(二):从零开始构建一个Agent
程序员·llm·agent
用户48466566957491 天前
最小可运行 Agent 架构图(专业版)
agent