注:本文使用 LangChain v1.0+
Custom 中间件是 LangChain Agent 最强大的扩展机制,让开发者能够在 Agent 执行的任何关键点插入自定义逻辑。
Custom 中间件的两种风格
Custom 中间件提供两种方式来拦截和修改 Agent 行为:
1️⃣ Node-Style Hooks(节点风格)
顺序执行在特定执行点,用于日志记录、验证、状态更新。
python
import os
from langchain_openai import ChatOpenAI
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import before_model, after_model
from langchain.tools import tool
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import TypedDict
model = ChatOpenAI(
model=os.getenv("MODEL_NAME", "Qwen/Qwen2-7B-Instruct"),
temperature=0.7,
base_url=os.getenv("SILICONFLOW_BASE_URL"),
api_key=os.getenv("SILICONFLOW_API_KEY")
)
class Context(TypedDict):
user_id: str
request_id: str
# ===== 钩子 1: 模型调用前 =====
@before_model
def log_request_info(state: AgentState, runtime: Runtime[Context]) -> dict | None:
"""记录请求信息(节点风格)"""
user_id = runtime.context.user_id
msg_count = len(state.get("messages", []))
print(f"📊 [用户 {user_id}] 即将调用模型")
print(f" 消息数量: {msg_count}")
# 不修改状态,返回 None
return None
# ===== 钩子 2: 模型调用后 =====
@after_model
def log_response_info(state: AgentState, runtime: Runtime[Context]) -> dict | None:
"""记录响应信息(节点风格)"""
last_msg = state["messages"][-1]
if isinstance(last_msg, AIMessage):
print(f"✅ 模型已响应,长度: {len(last_msg.content)}")
return None
# ===== 创建 Agent =====
agent = create_agent(
model=model,
tools=[],
middleware=[log_request_info, log_response_info],
context_schema=Context
)
result = agent.invoke(
{"messages": [{"role": "user", "content": "你好"}]},
context={"user_id": "user123", "request_id": "req_001"}
)
2️⃣ Wrap-Style Hooks(环绕风格)
环绕执行,你控制 handler 何时被调用。用于重试、缓存、动态修改。
python
from langchain.agents.middleware import before_model
from langchain.agents import create_agent
# ===== Wrap 风格中间件:重试逻辑 =====
def create_retry_middleware(max_retries: int = 3):
"""创建重试中间件"""
# 这里用 before_model 演示,但真实场景中应该用 wrap 风格
@before_model
def retry_on_error(state: AgentState, runtime: Runtime[Context]) -> dict | None:
# 在这个例子中,我们可以使用状态追踪重试次数
# 实际的 wrap 风格需要使用特殊的装饰器
return None
return retry_on_error
agent = create_agent(
model=model,
tools=[],
middleware=[create_retry_middleware(max_retries=3)],
context_schema=Context
)
更完整的 Wrap 风格示例:
python
from langchain.agents.middleware import before_model
def create_model_retry_middleware(max_retries: int = 3):
"""
使用状态跟踪来实现重试逻辑
"""
@before_model
def handle_retry(state: AgentState, runtime: Runtime[Context]) -> dict | None:
# 检查重试计数
retry_count = state.get("retry_count", 0)
if retry_count >= max_retries:
return {
"messages": [AIMessage(content="已达到最大重试次数")],
"jumpTo": "end" # 跳到结束
}
return None
return handle_retry
修改状态的中间件
中间件可以返回一个字典来修改状态:
python
from typing import TypedDict
from langchain.agents import AgentState, create_agent
from langchain.agents.middleware import before_model, after_model
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
class Context(TypedDict):
user_id: str
# ===== 例子 1: 在模型调用前修改消息 =====
@before_model
def trim_long_messages(state: AgentState, runtime: Runtime[Context]) -> dict | None:
"""如果消息过多,删除旧消息"""
messages = state.get("messages", [])
# 超过 100 条消息则只保留最近 50 条
if len(messages) > 100:
print(f"⚠️ 消息过多 ({len(messages)}),正在修剪...")
trimmed_messages = messages[-50:] # 只保留最后 50 条
return {"messages": trimmed_messages}
return None
# ===== 例子 2: 在模型响应后替换内容 =====
@after_model
def filter_sensitive_content(state: AgentState, runtime: Runtime[Context]) -> dict | None:
"""过滤敏感内容"""
last_msg = state["messages"][-1]
if isinstance(last_msg, AIMessage):
content = last_msg.content.lower()
# 检测敏感词
if "password" in content or "api_key" in content:
print("🚫 检测到敏感内容,正在替换...")
return {
"messages": [AIMessage(content="无法显示该内容。")]
}
return None
agent = create_agent(
model=model,
tools=[],
middleware=[trim_long_messages, filter_sensitive_content],
context_schema=Context
)
使用 jumpTo 控制流程
中间件可以用 jumpTo 提前结束或跳过执行:
python
from langchain.agents.middleware import before_model
@before_model
def rate_limit_check(state: AgentState, runtime: Runtime[Context]) -> dict | None:
"""检查速率限制"""
msg_count = len(state.get("messages", []))
# 如果消息太多,直接返回错误并结束
if msg_count > 1000:
return {
"messages": [AIMessage(content="已达到请求限制")],
"jumpTo": "end" # 直接跳到 agent 结束
}
return None
agent = create_agent(
model=model,
tools=[],
middleware=[rate_limit_check],
context_schema=Context
)
扩展状态 Schema
中间件可以为 Agent 状态添加自定义字段:
python
from langchain.agents import AgentState, create_agent
from typing import TypedDict
from langchain.agents.middleware import before_model, after_model
# 创建扩展的状态类
class ExtendedAgentState(AgentState):
"""扩展状态,添加自定义字段"""
call_count: int # 模型调用次数
total_tokens: int # 总 token 数
user_metadata: dict # 用户元数据
# ===== 中间件追踪调用次数 =====
@before_model
def track_calls(state: ExtendedAgentState, runtime: Runtime[Context]) -> dict | None:
"""记录模型调用次数"""
current_count = state.get("call_count", 0)
print(f"📞 这是第 {current_count + 1} 次模型调用")
return {"call_count": current_count + 1}
@after_model
def track_tokens(state: ExtendedAgentState, runtime: Runtime[Context]) -> dict | None:
"""记录 token 使用量"""
last_msg = state["messages"][-1]
# 估算 token 数(简单方法)
token_count = len(str(last_msg.content).split())
current_total = state.get("total_tokens", 0)
return {"total_tokens": current_total + token_count}
agent = create_agent(
model=model,
tools=[],
middleware=[track_calls, track_tokens],
context_schema=Context,
state_schema=ExtendedAgentState # 使用扩展的状态
)
result = agent.invoke(
{
"messages": [{"role": "user", "content": "你好"}],
"call_count": 0,
"total_tokens": 0,
"user_metadata": {"plan": "premium"}
},
context={"user_id": "user123"}
)
print(f"最终调用次数: {result['call_count']}")
print(f"总 token 数: {result['total_tokens']}")
实战:完整的中间件系统
python
import os
from langchain_openai import ChatOpenAI
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import before_model, after_model
from langchain.tools import tool
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import TypedDict
from datetime import datetime
import time
model = ChatOpenAI(
model=os.getenv("MODEL_NAME", "Qwen/Qwen2-7B-Instruct"),
temperature=0.7,
base_url=os.getenv("SILICONFLOW_BASE_URL"),
api_key=os.getenv("SILICONFLOW_API_KEY")
)
# ===== 上下文和状态 =====
class Context(TypedDict):
user_id: str
user_role: str # "admin" 或 "user"
class ExtendedState(AgentState):
call_count: int
start_time: float
execution_logs: list
# ===== 工具 =====
@tool
def search_tool(query: str) -> str:
"""搜索工具"""
return f"找到关于 '{query}' 的 3 个结果"
# ===== 中间件 1: 请求验证 =====
@before_model
def validate_request(state: ExtendedState, runtime: Runtime[Context]) -> dict | None:
"""验证请求权限"""
user_role = runtime.context.get("user_role", "user")
# 仅管理员可以连续发送超过 10 条消息
msg_count = len(state.get("messages", []))
if msg_count > 10 and user_role != "admin":
return {
"messages": [AIMessage(content="您已达到消息限制。")],
"jumpTo": "end"
}
return None
# ===== 中间件 2: 请求日志 =====
@before_model
def log_request(state: ExtendedState, runtime: Runtime[Context]) -> dict | None:
"""记录请求详情"""
user_id = runtime.context.user_id
call_count = state.get("call_count", 0)
log_msg = f"[{datetime.now().strftime('%H:%M:%S')}] 用户 {user_id} - 调用 #{call_count + 1}"
logs = state.get("execution_logs", [])
logs.append(log_msg)
return {
"call_count": call_count + 1,
"execution_logs": logs
}
# ===== 中间件 3: 响应验证 =====
@after_model
def validate_response(state: ExtendedState, runtime: Runtime[Context]) -> dict | None:
"""检查响应质量"""
last_msg = state["messages"][-1]
if isinstance(last_msg, AIMessage):
content = last_msg.content
# 内容过短可能是错误
if len(content) < 10:
return {
"messages": [AIMessage(content="模型响应异常,请重试。")]
}
return None
# ===== 中间件 4: 性能监控 =====
@after_model
def monitor_performance(state: ExtendedState, runtime: Runtime[Context]) -> dict | None:
"""监控执行时间"""
start_time = state.get("start_time", time.time())
elapsed = time.time() - start_time
logs = state.get("execution_logs", [])
logs.append(f"⏱️ 耗时: {elapsed:.2f}秒")
return {"execution_logs": logs}
# ===== 创建 Agent =====
agent = create_agent(
model=model,
tools=[search_tool],
middleware=[
validate_request, # ① 验证权限
log_request, # ② 记录请求
validate_response, # ③ 验证响应
monitor_performance # ④ 监控性能
],
context_schema=Context,
state_schema=ExtendedState
)
# ===== 测试 =====
if __name__ == "__main__":
result = agent.invoke(
{
"messages": [{"role": "user", "content": "搜索 Python"}],
"call_count": 0,
"start_time": time.time(),
"execution_logs": []
},
context={"user_id": "user_001", "user_role": "admin"}
)
print("执行日志:")
for log in result.get("execution_logs", []):
print(f" {log}")
关键概念总结
| 特性 | Node-Style | Wrap-Style |
|---|---|---|
| 执行时机 | 顺序执行在特定点 | 环绕执行 |
| 用途 | 日志、验证、更新 | 重试、缓存、动态选择 |
| 控制权 | 框架控制 | 你控制 handler |
| 钩子 | before_*, after_* |
wrap_* |
| 返回值 | dict 或 None |
直接返回结果 |
本文使用的模型服务来自硅基流动平台。新用户通过邀请链接注册可领取 2000万免费token,支持GLM-4.6、Kimi-K2-Thinking、MiniMaxAI/MiniMax-M2、DeepSeek-R2等主流大模型调用,API稳定性与响应速度俱佳。
专属注册链接:cloud.siliconflow.cn/i/AvDmOKTO
(打直球:这是我的推荐码 感谢大佬的支持)
作者:世界那么哒哒
来源:稀土掘金
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
相关文档:
- Custom Middleware 详解
- Runtime 对象
-
Agents 中间件