你有没有遇到过这种情况:Agent 跑得好好的,你想加个日志看看它到底在干嘛,或者想加个安全检查防止它搞出危险操作,结果发现不知道往哪儿插?
Middleware 就是来解决这个问题的。说白了,它就是一个"拦截器",让你在模型调用前后插入你自己的逻辑。听起来是不是挺简单的?别急,咱们直接上手写代码,你一看就明白了。
动画视频在《27. Agent 需要拦截模型调用?用 Middleware 给它加个"拦截器"!》。
第一个 Middleware:日志记录
咱们先从一个最简单的需求开始------记录日志。我想知道每次调用模型的时候,当前有多少条消息,模型又回了什么。
怎么做呢?很简单,写一个类,继承 AgentMiddleware,然后实现两个方法就行。
from langchain.agents.middleware import AgentMiddleware
from langchain.agents import AgentState
from langgraph.runtime import Runtime
class LoggingMiddleware(AgentMiddleware):
def before_model(self, state: AgentState, runtime: Runtime) -> None:
print(f"[日志] 即将调用模型,当前消息数: {len(state['messages'])}")
def after_model(self, state: AgentState, runtime: Runtime) -> None:
last_msg = state['messages'][-1]
print(f"[日志] 模型已响应: {last_msg.content[:50]}...")
你注意看啊,这里有两个方法。before_model 就是在模型调用之前执行的,我打印一下当前消息的数量。after_model 是模型响应之后执行的,我取最后一条消息,截取前 50 个字符看看模型回了啥。
关键点来了------你注意这两个方法的返回类型,都是 None。为啥?因为这个 Middleware 只是记录日志,它不需要干预流程,所以直接返回 None,就是在告诉框架:"我完事了,你继续往下走就行。"
好,日志中间件搞定了。是不是特别简单?接下来咱们上点强度------写一个安全检查中间件。
第二个 Middleware:安全检查
这个中间件的作用是啥呢?拦截危险操作。比如用户说"删除所有文件",你肯定不想让 Agent 傻乎乎地去执行吧?
同样继承 AgentMiddleware,重点在 before_model 里做文章。
from langchain_core.messages import AIMessage
class SafetyMiddleware(AgentMiddleware):
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
last_msg = state['messages'][-1].content
if "删除" in last_msg or "危险" in last_msg:
return {
"jump_to": "end",
"messages": [AIMessage(content="检测到危险操作,已终止")]
}
return None
这里有个关键区别,你一定要注意到------返回类型变了,变成了 dict | None,而不是单纯的 None。
为什么?因为这个 Middleware 有可能需要干预流程。
逻辑是这样的:拿到最后一条消息的内容,检查一下里面有没有"删除"或者"危险"这样的关键词。如果命中了,就返回一个字典,jump_to 设为 "end",意思是跳过模型调用直接结束。同时塞一条 AIMessage 告诉用户操作被终止了。这个返回值就是在告诉框架:"别调用模型了,直接结束,并且把这条消息加进去。"
如果没有问题呢?就返回 None,意思是"我没意见,流程正常继续"。
你看,这就是 Middleware 的精髓------你可以选择不管,也可以选择直接接管整个流程。
完整实战代码
好,两个 Middleware 都写好了,接下来咱们把它们组装起来,看看实际效果。
先把需要的依赖都导入进来,然后配置模型、定义工具。
import os
import sqlite3
from dotenv import load_dotenv
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.chat_models import init_chat_model
from langchain_classic.agents import Agent
from langchain_community.tools import WriteFileTool, ReadFileTool, ListDirectoryTool
from langchain_core.messages import AIMessage
from langchain_core.tools import tool, BaseTool
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.runtime import Runtime
from langgraph.store.memory import InMemoryStore
from langgraph.store.sqlite import SqliteStore
load_dotenv()
prefix = "QWEN"
model = init_chat_model(
model_provider="openai",
configurable_fields=["model", "api_key", "base_url"],
config_prefix=prefix).with_config({
"configurable": {
f"{prefix}_model": os.getenv(f"{prefix}_MODEL"),
f"{prefix}_api_key": os.getenv(f"{prefix}_API_KEY"),
f"{prefix}_base_url": os.getenv(f"{prefix}_BASE_URL")
}})
class CalculateTool(BaseTool):
name: str = "calculate"
description: str = "计算数学表达式的值"
def _run(self, expression: str) -> str:
try:
return f"计算结果:{eval(expression)}"
except Exception as e:
return f"计算错误:{str(e)}"
async def _arun(self, expression: str) -> str:
return self._run(expression)
# Middleware 1:记录模型调用日志
class LoggingMiddleware(AgentMiddleware):
def before_model(self, state: AgentState, runtime: Runtime) -> None:
print(f"[日志] 即将调用模型,当前消息数: {len(state['messages'])}")
def after_model(self, state: AgentState, runtime: Runtime) -> None:
last_msg = state['messages'][-1]
print(f"[日志] 模型已响应: {last_msg.content[:50]}...")
# Middleware 2:安全检查,拦截危险操作
class SafetyMiddleware(AgentMiddleware):
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
last_msg = state['messages'][-1].content
if "删除" in last_msg or "危险" in last_msg:
return {
"jump_to": "end",
"messages": [AIMessage(content="检测到危险操作,已终止")]
}
return None
calculate = CalculateTool()
write_file = WriteFileTool()
read_file = ReadFileTool()
list_dir = ListDirectoryTool()
checkpoint_conn = sqlite3.connect("agent.db", check_same_thread=False, isolation_level=None)
checkpointer = SqliteSaver(checkpoint_conn)
store_conn = sqlite3.connect("agent.db", check_same_thread=False, isolation_level=None)
store = SqliteStore(store_conn)
agent = create_agent(
model=model,
tools=[calculate, write_file, read_file, list_dir],
system_prompt="你是一个助手,会用工具计算、读写文件、列出目录。",
debug=True,
checkpointer=checkpointer,
store=store,
middleware=[LoggingMiddleware(), SafetyMiddleware()]
)
config = {"configurable": {"thread_id": "session-1"}}
store.put(("user", "user-1"), "profile",
{"name": "张三", "role": "developer", "skills": ["python", "typescript", "java"]})
profile = store.get(("user", "user-1"), "profile")
print(f"用户资料:{profile.value}")
queries = ["计算 2024*12+500,然后把结果保存到 result.txt",
"读取 result.txt的内容",
"列出当前目录文件",
"刚才计算的结果是多少?",
"删除所有文件"
]
for q in queries:
print(f"\n问:{q}")
response = agent.invoke({"messages": [{"role": "user", "content": q}]}, config=config)
print(response)
print(f"\n答:{response['messages'][-1].content}")
checkpoint_conn.close()
store_conn.close()
代码比较长,但核心逻辑其实就三块:模型配置、Middleware 定义、Agent 创建。
你重点看 create_agent 这一行------注意 middleware 这个参数,咱们把 LoggingMiddleware 和 SafetyMiddleware 两个实例一起传进去了。就这么一行,Agent 就同时具备了日志记录和安全拦截的能力。
然后下面准备了一些测试问题,挨个发给 Agent 执行。你注意最后一个问题------"删除所有文件",这就是用来测试安全检查中间件的。
运行效果
跑起来之后你会看到什么呢?
首先,每次模型调用前后,控制台都会打印日志信息,告诉你当前消息数和模型响应内容,这就是 LoggingMiddleware 在干活。
然后,当前面几个正常问题执行的时候,一切顺利。但是到了最后一个问题"删除所有文件"的时候,Agent 直接返回"检测到危险操作,已终止"------根本不会去调用模型。这就是 SafetyMiddleware 在拦截。
总结
好了,今天的内容就到这里,咱们快速回顾一下。
Middleware 说白了就是 Agent 的"拦截器",它给你提供了两个钩子:before_model 和 after_model。你可以在模型调用前注入上下文、检查权限,也可以在模型响应后记录日志、审计操作。
最关键的是,你可以通过返回值来控制流程------返回 None 就是"我没意见,继续走",返回一个 dict 就可以直接改变流程走向,比如跳过模型调用。
这个机制非常灵活,你可以用它做很多事情:限流、鉴权、日志、审计、上下文注入......基本上你能想到的横切关注点,都可以用 Middleware 来优雅地实现。
如果觉得有用的话,记得点赞关注,咱们下期再见!