LangChain 1.0 内置的Agent中间件详解

内置中间件清单

LangChain 为常见场景提供了预构建的中间件。每个中间件都可以在生产环境使用,可根据具体需求配置。如下列表

Middleware 中间件 Description 描述
Summarization 摘要 上下文压缩
Human-in-the-loop 人机参与 暂停执行以供人工批准工具调用。
Model call limit 模型呼叫限制 限制模型调用次数,以防止过高成本。
Tool call limit 工具调用限制 通过限制呼叫次数来控制工具执行。
Model fallback 模型的备选 当主模式失败时,会自动回退到其他模式。
PII detection PII 敏感信息检测 检测并处理敏感信息(PII)。
To-do list 待办事项列表 为客服人员配备任务规划和跟踪能力。
LLM tool selector LLM 工具选择器 在调用主模型之前,先用 LLM 选择相关工具。
Tool retry 工具重试 用指数回撤自动重试失败的工具调用。
Model retry 模型重试 自动用指数退回方式重试失败的模型调用。
LLM tool emulator LLM 工具仿真器 用 LLM 模拟工具执行以进行测试。
Context editing 上下文编辑 通过修剪或清理工具使用来管理对话上下文。
Shell tool 壳体工具 向代理开放一个持久的壳会话以执行命令。
File search 文件搜索 在文件系统文件上提供 Glob 和 Grep 搜索工具。

网址:https://docs.langchain.com/oss/python/langchain/middleware/built-in

需要安装的依赖包

python 复制代码
pip list | grep langchain
复制代码
langchain                                1.0.8

langchain-chroma                         1.0.0

langchain-classic                        1.0.0

langchain-community                      0.4.1

langchain-core                           1.0.7

langchain-deepseek                       1.0.0

langchain-experimental                   0.4.0

langchain-google-genai                   3.0.3

langchain-mcp-adapters                   0.1.13

langchain-ollama                         1.0.0

langchain-openai                         1.0.2

langchain-tavily                         0.2.13

langchain-text-splitters                 1.0.0

一、 before_model 模型调用前

1.SummarizationMiddleware 上下文压缩

中间件类型

before_model - 模型调用前中间件

概述

使用 LangChain 1.0 的 SummarizationMiddleware 来自动压缩历史会话,减少 token 使用,提高响应速度。

核心特性

  1. 官方中间件集成 :使用 from langchain.agents.middleware import SummarizationMiddleware
  2. 自动压缩 :在 create_agent 中通过 middleware 参数集成
  3. 智能保留:自动压缩历史消息,保留最近的对话
  4. 无需手动管理:中间件自动处理压缩逻辑

工作原理

当历史消息的 token 数量超过阈值(500)且消息数量超过保留数量(5条)时,中间件会自动:

  1. 将旧消息发送给摘要模型进行压缩
  2. 保留最近的 N 条消息
  3. 将摘要结果作为上下文传递给 Agent

预期结果

  • 压缩前:20 条消息,约 1000+ tokens
  • 压缩后:5-6 条消息(保留最近5条 + 摘要),约 300-500 tokens
python 复制代码
# ==================== SummarizationMiddleware 完整实现 ====================

from langchain.agents import create_agent,AgentState
from langchain.agents.middleware import SummarizationMiddleware
from langchain_deepseek import ChatDeepSeek
from langchain.tools import tool, ToolRuntime
from langchain_core.messages import HumanMessage
from langchain_core.runnables import ensure_config
from pydantic import BaseModel, Field
from typing import Optional
from dotenv import load_dotenv
import logging
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.runtime import Runtime
from dataclasses import dataclass
# ==================== 1. 配置日志 ====================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
load_dotenv(override=True)

# ==================== 2. 定义上下文 ====================
@dataclass
class UserContext :
    user_id: str
    department: str
    max_history_tokens: Optional[int]=0

# ==================== 3. 定义工具 ====================
@tool
def search_patent(query: str, runtime: ToolRuntime[UserContext]) -> str:
    """搜索专利数据库"""
    print(f"专利搜索结果 runtime.context.user_id =: {runtime.context.user_id},department={runtime.context.department}")
    return f"专利搜索结果: 找到与 '{query}' 相关的 3 项专利..."

@tool
def analyze_technology(tech_desc: str) -> str:
    """分析技术可行性"""
    return f"技术分析: '{tech_desc}' 的实现可行性评估完成..."

tools = [search_patent, analyze_technology]



# ==================== 4. 配置中间件 ====================
summarization_middleware = SummarizationMiddleware(
    model=ChatDeepSeek(model="deepseek-chat", temperature=0.1),
    # max_tokens_before_summary=200,          # 历史消息 token 数量超过 200 时触发压缩
    trigger= ("messages",5),    # 或消息数超过 5 条时触发
    keep=("messages",4),  #  保持条件4条
    # summary_prompt="请将以下对话历史进行摘要,保留关键决策点和技术细节:\n\n{messages}\n\n摘要:"   # 摘要提示词
)

memory=InMemorySaver()

# ==================== 5. 创建 Agent ====================
agent = create_agent(
    model=ChatDeepSeek(model="deepseek-chat", temperature=0.2),
    tools=tools,
    middleware=[summarization_middleware],
    context_schema=UserContext,
    debug=True,
    checkpointer=memory
)

i:int =0
# ==================== 6. 执行测试 ====================
def run_summarization_test(question:str):
    logger.info("开始 SummarizationMiddleware 测试")
    # 创建对话
    msg = [HumanMessage(content=f"问题 {i+1}: {question}")]

    # 执行
    result = agent.invoke(
        {"messages": msg},
        context=UserContext(user_id="engineer_001", department="研发部"),
        config=ensure_config({"configurable": {"thread_id": "session_001"}})
    )

    result_messages = result.get("messages", [])
    logger.info(f"执行后消息数: {len(result_messages)},若该值没有变化或变小了,则summarization_middleware起作用了")

    return result

# ==================== 7. main ====================
if __name__ == '__main__':
     while True:
        q=input("请输入问题:")
        result = run_summarization_test(q)
        print("-"*10)
        i=i+1

2.PIIMiddleware PII信息脱敏

中间件类型

before_model - 模型调用前中间件

本示例展示如何使用 PIIMiddleware 来自动检测和脱敏个人身份信息(PII),保护用户隐私和数据安全。

核心特性

  1. 自动PII检测 :使用 from langchain.agents.middleware import PIIMiddleware
  2. 智能脱敏:自动识别并处理敏感信息
  3. 多种策略:支持 block、redact、mask、hash 四种处理策略
  4. 无缝集成:在模型调用前自动处理,对业务逻辑透明

工作原理

在模型调用前,中间件会自动:

  1. 扫描消息内容,识别指定类型的PII信息
  2. 根据策略处理敏感信息(阻止/脱敏/遮蔽/哈希)
  3. 将处理后的消息传递给模型

支持的PII类型

  • email:电子邮件地址
  • credit_card:信用卡号
  • ip:IP地址
  • mac_address:MAC地址
  • url:URL地址

处理策略

  • block:阻止包含PII的消息
  • redact:完全移除PII信息
  • mask:部分遮蔽PII信息
  • hash:将PII转换为哈希值

预期结果

  • 脱敏前:"我的银行卡号是4532-1234-5678-9010"
  • 脱敏后 :"银行卡号是****-****-****-9010"
python 复制代码
# ========================================
# LangChain 1.0 信用卡PII掩码中间件实战
# ========================================
import os
from typing import Annotated
from langchain.agents import create_agent
from langchain.agents.middleware import PIIMiddleware
from langchain_deepseek import ChatDeepSeek
from langchain_core.tools import tool, BaseTool
from langchain_core.messages import HumanMessage, AIMessage
from pydantic import BaseModel, Field
import re
from dotenv import load_dotenv

# ==================== 1. 加载环境 ====================
load_dotenv(override=True)

# ==================== 2. 定义模拟工具 ====================
@tool
def verify_credit_card(card_number: Annotated[str, "信用卡号"]) -> dict:
    """
    验证信用卡号有效性(模拟工具)
    注意:实际生产环境中不应接收真实卡号
    """
    # 工具接收到的参数已经是掩码后的
    print(f"工具接收到的卡号: {card_number}")

    # 模拟验证逻辑
    if len(card_number) >= 16:  # 掩码后的长度也足够判断
        return {
            "is_valid": True,
            "card_type": "Visa",
            "masked_card": card_number
        }
    return {"is_valid": False}

@tool
def process_payment(card_number: str, amount: float) -> str:
    """
    处理信用卡支付(模拟工具)
    """
    print(f"支付工具接收到的卡号: {card_number}")
    return f"支付成功!金额: ${amount}, 卡号: {card_number}"

@tool
def search_user_history(user_id: str) -> str:
    """查询用户历史记录"""
    return f"用户 {user_id} 的历史订单:订单123, 订单456"

# 工具列表
tools: list[BaseTool] = [verify_credit_card, process_payment, search_user_history]

# ==================== 3. 定义用户上下文 ====================
class UserContext(BaseModel):
    """用户上下文 Schema"""
    user_id: str = Field(..., description="用户唯一标识")
    department: str = Field(..., description="所属部门")
    security_level: str = Field(default="normal", description="安全级别")

# ==================== 4. 配置 PIIMiddleware ====================
# 核心配置:信用卡掩码中间件
piim_credit_card = PIIMiddleware(
    "credit_card",
    detector=r"\b(?:\d{4}[-\s]?){3}\d{4}\b",  # 匹配格式: 1234-5678-9012-3456
    strategy="mask",       # 掩码策略
    apply_to_input=True,   # 对输入消息进行掩码
    apply_to_output=False,  # 不对工具输出进行掩码(工具返回的是业务结果)
)

# ==================== 5. 创建智能体 ====================
agent = create_agent(
    # 主模型:用于决策和对话
    model=ChatDeepSeek(model="deepseek-chat", temperature=0.2),

    # 工具列表
    tools=tools,

    # 中间件:只启用PII掩码(生产环境可添加日志等)
    middleware=[
        piim_credit_card,  # 信用卡掩码中间件
    ],

    # 启用上下文
    context_schema=UserContext,

    # 调试模式
    debug=True,
)

# ==================== 6. 测试用例与执行 ====================
def test_credit_card_masking():
    """测试信用卡掩码全流程"""

    print("=" * 60)
    print("测试场景:用户尝试使用信用卡支付")
    print("=" * 60)

    # 测试输入:包含多种信用卡格式
    test_query = """
    请帮我验证以下信用卡是否有效:
    我的卡号是 4532-1234-5678-9010,另外备用卡是 4532123456781234。
    请检查这两张卡,然后处理一笔 99.99 美元的支付。
    """

    print(f"\n【原始用户输入】\n{test_query}\n")

    # 执行 Agent
    result = agent.invoke(
        # 消息列表
        {"messages": [HumanMessage(content=test_query)]},

        # 上下文(必须)
        context=UserContext(
            user_id="user_789",
            department="财务部",
            security_level="high"
        ),

        # 配置(可选)
        config={"configurable": {"thread_id": "session_cc_001"}}
    )

    print("\n【Agent 最终返回的消息】")
    final_message = result["messages"][-1]
    if isinstance(final_message, AIMessage):
        print(f"角色: {final_message.type}")
        print(f"内容: {final_message.content}")

        # 检查工具调用
        if hasattr(final_message, 'tool_calls') and final_message.tool_calls:
            print("\n【工具调用记录】")
            for tc in final_message.tool_calls:
                print(f"- 工具: {tc['name']}")
                print(f"  参数: {tc['args']}")

    return result

test_credit_card_masking()

3.ModelCallLimitMiddleware 模型调用限制

中间件类型

before_model - 模型调用前中间件

本示例展示如何使用 ModelCallLimitMiddleware 来限制 Agent 的模型调用次数,防止死循环或意外的高消耗。

核心特性

  1. 安全防护:防止 Agent 陷入无限循环
  2. 简单配置 :通过 max_calls 参数设置最大调用次数
  3. 自动熔断:达到限制后自动停止并返回错误或特定消息

工作原理

中间件会跟踪当前会话中的模型调用次数。当调用次数达到设定的阈值时,中间件会阻止后续的模型调用,并引发异常或返回预设的响应。

预期结果

  • 正常情况:调用次数未超限,正常执行
  • 超限情况 :抛出 ModelCallLimitExceeded 异常或停止执行
python 复制代码
# ==================== ModelCallLimitMiddleware 完整实现 ====================

from langchain.agents import create_agent
from langchain.agents.middleware import ModelCallLimitMiddleware
from langchain_deepseek import ChatDeepSeek
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
# 
from langchain_core.runnables import ensure_config
from pydantic import BaseModel, Field
from dotenv import load_dotenv

load_dotenv(override=True)

# ==================== 1. 定义工具 ====================
@tool
def complex_calculation(x: int) -> int:
    """执行复杂计算"""
    return x * 2

@tool
def get_weather(city: str) -> str:
    """获取天气信息"""
    return f"{city}的天气:晴天,温度25°C"

tools = [complex_calculation, get_weather]

# ==================== 2. 定义上下文 ====================
class UserContext(BaseModel):
    user_id: str = Field(..., description="用户唯一标识")

# ==================== 3. 配置中间件 ====================
limit_middleware = ModelCallLimitMiddleware(
    run_limit=3,  # 每次运行最多调用模型3次
    exit_behavior='error'  # 超限时抛出异常
)

# ==================== 4. 创建 Agent ====================
agent = create_agent(
    model=ChatDeepSeek(model="deepseek-chat", temperature=0.1),
    tools=tools,
    middleware=[limit_middleware],
    context_schema=UserContext,
    debug=False,  # 关闭调试模式以减少输出
)

# ==================== 5. 执行测试 ====================
def run_limit_test():
    """测试 ModelCallLimitMiddleware 触发逻辑"""

    # 设计一个需要多次模型调用的任务
    query = """
    请按照以下步骤执行:
    1. 计算 5 的两倍
    2. 用第一步的结果再计算两倍
    3. 用第二步的结果再计算两倍
    4. 用第三步的结果再计算两倍
    5. 最后告诉我北京的天气

    请一步一步执行,每次只做一个计算。
    """

    print("=" * 60)
    print("ModelCallLimitMiddleware 测试")
    print("=" * 60)
    print(f"\n【输入】\n{query.strip()}\n")

    model_call_count = 0
    limit_triggered = False
    final_output = None

    try:
        for chunk in agent.stream(
            {"messages": [HumanMessage(content=query)]},
            context=UserContext(user_id="user_limit_test"),
            config=ensure_config({"configurable": {"thread_id": "thread_limit_001"}}),
            stream_mode="updates"
        ):
            if isinstance(chunk, dict):
                for key, value in chunk.items():
                    # 统计模型调用,这里进行了修改,in模式会把中间件节点算进去,==模式才是只算模型调用
                    if "model" == str(key).lower():
                        model_call_count += 1

                    # 检测中间件触发
                    if "ModelCallLimitMiddleware" in str(key):
                        limit_triggered = True

                    # 获取最终输出
                    if isinstance(value, dict) and "messages" in value:
                        messages = value["messages"]
                        if messages and hasattr(messages[-1], 'content'):
                            final_output = messages[-1].content

        print(f"【输出】\n{final_output}\n")

    except Exception as e:
        print(f"【输出】\n执行被中断: {str(e)}\n")

        if "limit" in str(e).lower() or "exceeded" in str(e).lower():
            limit_triggered = True

    # 输出触发结果
    print("=" * 60)
    print(f"模型调用次数: {model_call_count}")
    print(f"中间件触发: {'✅ 是 (达到 run_limit=3 限制)' if limit_triggered else '❌ 否'}")
    print("=" * 60)

# ==================== 6. 运行测试 ====================
run_limit_test()

二、 wrap_model_call (包裹模型调用)

1.ContextEditingMiddleware 管理上下文大小

中间件类型

wrap_model_call - 模型调用包装中间件

概述

本示例展示如何使用 ContextEditingMiddleware 来自动管理上下文大小,通过清理旧的工具调用结果来防止超出 token 限制。

核心特性

  1. 自动上下文管理:当 token 数量超过阈值时自动清理旧的工具结果
  2. 灵活配置:支持自定义触发阈值、保留数量、排除工具等
  3. 智能清理:保留最近的 N 个工具结果,清理较旧的内容
  4. 无缝集成:在模型调用前自动处理,对业务逻辑透明

工作原理

当消息历史的 token 数量超过配置的阈值时,中间件会自动:

  1. 统计当前消息的 token 数量
  2. 如果超过阈值,清理旧的工具调用结果
  3. 保留最近的 N 个工具结果
  4. 将清理后的消息传递给模型

ClearToolUsesEdit 配置参数

  • trigger: 触发清理的 token 阈值(默认 100,000)
  • keep: 保留最近的 N 个工具结果(默认 3)
  • clear_at_least: 最少清理的 token 数量(默认 0)
  • clear_tool_inputs: 是否清理工具调用的输入参数(默认 False)
  • exclude_tools: 排除不清理的工具列表(默认空)
  • placeholder: 清理后的占位符文本(默认 "[cleared]")

预期结果

  • 未超限:保留所有工具调用结果
  • 超限后:自动清理旧的工具结果,只保留最近的 N 个
python 复制代码
# ==================== ContextEditingMiddleware 完整实现 ====================

from langchain.agents import create_agent
from langchain.agents.middleware import ContextEditingMiddleware
from langchain.agents.middleware.context_editing import ClearToolUsesEdit
from langchain_deepseek import ChatDeepSeek
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
# ensure_config 作用:确保在 Runnable 中获取到正确的配置
from langchain_core.runnables import ensure_config
from langgraph.checkpoint.memory import MemorySaver
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import logging

# ==================== 1. 配置日志 ====================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
load_dotenv(override=True)

# ==================== 2. 定义工具 ====================
@tool
def search_database(query: str) -> str:
    """搜索数据库并返回大量结果"""
    # 每次返回约 1000 个字符(约 250 tokens)
    result = f"搜索 '{query}' 的结果:\n"
    result += "\n".join([f"记录 {i}: 这是关于 {query} 的详细信息,包含大量文本内容..." * 5 for i in range(10)])
    logger.info(f"search_database 被调用,查询: {query},返回约 {len(result)} 字符")
    return result

@tool
def analyze_data(data_id: str) -> str:
    """分析数据并返回详细报告"""
    # 每次返回约 1000 个字符(约 250 tokens)
    result = f"数据 {data_id} 的分析报告:\n"
    result += "详细分析内容包括统计数据、趋势分析、异常检测等..." * 20
    logger.info(f"analyze_data 被调用,数据ID: {data_id},返回约 {len(result)} 字符")
    return result

@tool
def generate_report(topic: str) -> str:
    """生成报告"""
    result = f"关于 '{topic}' 的报告:\n"
    result += "报告内容包括背景介绍、现状分析、未来展望等..." * 15
    logger.info(f"generate_report 被调用,主题: {topic},返回约 {len(result)} 字符")
    return result

tools = [search_database, analyze_data, generate_report]

# ==================== 3. 定义上下文 ====================
class UserContext(BaseModel):
    user_id: str = Field(..., description="用户唯一标识")

# ==================== 4. 配置中间件 ====================
# 关键:设置较低的触发阈值,确保能够触发清理
custom_context_middleware = ContextEditingMiddleware(
    edits=[
        ClearToolUsesEdit(
            trigger=800,  # 当 token 数超过 800 时触发清理(约 3-4 次工具调用后)
            keep=1,  # 只保留最近的 1 个工具结果
            clear_at_least=0,  # 清理所有超出keep数量的内容
            clear_tool_inputs=False,  # 不清理工具输入参数
            exclude_tools=["generate_report"],  # 不清理 generate_report 的结果
            placeholder="[已清理以节省空间]",  # 自定义占位符
        )
    ],
    token_count_method="approximate"  # 使用近似计数(更快)
)

# ==================== 5. 创建 Agent(使用 checkpointer 来累积消息)====================
agent = create_agent(
    model=ChatDeepSeek(model="deepseek-chat", temperature=0.1),
    tools=tools,
    middleware=[
        custom_context_middleware,  # 使用自定义配置
    ],
    context_schema=UserContext,  # 定义上下文参数,这里是 UserContext
    checkpointer=MemorySaver(),  # 关键:使用 checkpointer 来保存消息历史
    debug=True,  # 开启调试模式以观察中间件行为
)

# ==================== 6. 执行测试 ====================
def run_context_editing_test():
    """
    测试 ContextEditingMiddleware 的上下文清理功能

    场景:在同一个线程中执行多次查询,累积消息历史,触发上下文清理
    """
    logger.info("开始 ContextEditingMiddleware 测试")
    logger.info("配置: trigger=800 tokens, keep=1, exclude_tools=['generate_report']")
    logger.info("策略: 在同一线程中执行多次查询,累积消息历史")

    # 使用同一个 thread_id 来累积消息,ensure_config 确保在 Runnable 中获取到正确的配置
    config = ensure_config({"configurable": {"thread_id": "session_context_accumulate"}})

    # 初始化上下文,这里是 UserContext
    context = UserContext(user_id="user_context_test")

    # 定义测试查询
    queries = [
        "请搜索数据库中关于 'AI技术' 的信息",
        "请分析数据 'dataset_001'",
        "请搜索数据库中关于 '机器学习' 的信息",
        "请分析数据 'dataset_002'",
        "请生成关于 '人工智能发展趋势' 的报告",
    ]

    # 记录中间件是否触发
    middleware_triggered = False

    for i, query in enumerate(queries, 1):
        logger.info(f"\n{'='*60}")
        logger.info(f"第 {i} 次查询: {query}")
        logger.info(f"{'='*60}")

        try:
            # 执行查询
            result = agent.invoke(
                {"messages": [HumanMessage(content=query)]},
                context=context,
                config=config
            )

            # 检查消息历史
            messages = result.get("messages", [])
            logger.info(f"当前消息数量: {len(messages)}")

            # 检查是否有被清理的消息
            cleared_count = sum(
                1 for msg in messages
                if hasattr(msg, 'response_metadata')
                and msg.response_metadata.get("context_editing", {}).get("cleared")
            )

            # 检查是否触发了中间件,如果触发了,就设置 middleware_triggered 为 True
            if cleared_count > 0:
                middleware_triggered = True
                logger.info(f"✅ 检测到 {cleared_count} 个工具结果已被清理!")

        except Exception as e:
            logger.error(f"查询 {i} 出错: {e}")
            import traceback
            traceback.print_exc()

    # 输出最终结果
    logger.info("\n" + "=" * 60)
    logger.info("测试完成")
    logger.info(f"中间件触发: {'✅ 是 - 旧工具结果已被清理' if middleware_triggered else '❌ 否 - 未达到触发阈值'}")
    logger.info("=" * 60)

    # 说明
    print("\n" + "=" * 60)
    print("ContextEditingMiddleware 工作原理说明")
    print("=" * 60)
    print("1. 使用 checkpointer 在同一线程中累积消息历史")
    print("2. 当消息历史超过 800 tokens 时触发清理")
    print("3. 只保留最近的 1 个工具调用结果")
    print("4. 'generate_report' 工具的结果不会被清理(exclude_tools)")
    print("5. 被清理的内容会被替换为 '[已清理以节省空间]'")
    print("6. 每个工具返回约 250 tokens,3-4 次调用后应触发清理")
    print("=" * 60 + "\n")

# ==================== 7. 运行测试 ====================
run_context_editing_test()

2.ModelFallbackMiddleware 模型故障自动切换

中间件类型

wrap_model_call - 模型调用包装中间件

概述

本示例展示如何使用 ModelFallbackMiddleware 来实现模型故障自动切换,当主模型调用失败时自动尝试备用模型。

核心特性

  1. 自动故障转移:主模型失败时自动切换到备用模型
  2. 多级备份:支持配置多个备用模型,按顺序尝试
  3. 无缝切换:对业务逻辑透明,自动处理重试逻辑
  4. 提高可用性:显著提升系统的稳定性和可靠性

工作原理

当模型调用失败时,中间件会自动:

  1. 捕获主模型的异常
  2. 按顺序尝试备用模型
  3. 返回第一个成功的模型响应
  4. 如果所有模型都失败,抛出最后一个异常

配置参数

  • first_model: 第一个备用模型(字符串名称或模型实例)
  • additional_models: 额外的备用模型列表

预期结果

  • 主模型成功:直接返回主模型结果
  • 主模型失败:自动切换到备用模型,返回备用模型结果
python 复制代码
# ==================== ModelFallbackMiddleware 完整实现 ====================

from langchain.agents import create_agent
from langchain.agents.middleware import ModelFallbackMiddleware
from langchain_deepseek import ChatDeepSeek
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import logging

# ==================== 1. 配置日志 ====================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
load_dotenv(override=True)

# ==================== 2. 定义工具 ====================
@tool
def calculate_sum(a: int, b: int) -> int:
    """计算两个数的和"""
    logger.info(f"calculate_sum 被调用: {a} + {b}")
    return a + b

@tool
def get_system_info() -> str:
    """获取系统信息"""
    logger.info("get_system_info 被调用")
    return "系统运行正常,CPU使用率: 45%, 内存使用率: 60%"

tools = [calculate_sum, get_system_info]

# ==================== 3. 定义上下文 ====================
class UserContext(BaseModel):
    user_id: str = Field(..., description="用户唯一标识")

# ==================== 4. 配置中间件 ====================
# 配置模型故障转移:主模型 -> 备用模型1 -> 备用模型2
# 注意:这里使用相同的模型作为演示,实际应用中应使用不同的模型
fallback_middleware = ModelFallbackMiddleware(
    ChatDeepSeek(model="deepseek-chat", temperature=0.3),  # 第一个备用模型
    ChatDeepSeek(model="deepseek-reasoner", temperature=0.5),  # 第二个备用模型
)

# ==================== 5. 创建 Agent ====================
agent = create_agent(
    model=ChatDeepSeek(model="deepseek-chat", temperature=0.1),  # 主模型
    tools=tools,
    middleware=[
        fallback_middleware,  # 添加故障转移中间件
    ],
    context_schema=UserContext,  # 定义上下文参数,这里是 UserContext
    debug=False,    # 关闭调试模式,避免在测试中输出详细信息
)

# ==================== 6. 执行测试 ====================
def run_fallback_test():
    """
    测试 ModelFallbackMiddleware 的故障转移功能

    场景:正常情况下使用主模型,模拟故障时自动切换到备用模型
    """
    logger.info("开始 ModelFallbackMiddleware 测试")
    logger.info("配置: 主模型(deepseek-chat) + 2个备用模型")

    # 测试场景1: 正常调用(主模型成功)
    logger.info("\n" + "="*60)
    logger.info("场景1: 正常调用 - 主模型应该成功处理")
    logger.info("="*60)

    query1 = "请计算 15 + 27 的结果"
    logger.info(f"查询: {query1}")

    try:
        result1 = agent.invoke(
            {"messages": [HumanMessage(content=query1)]},
            context=UserContext(user_id="user_fallback_test"),   # 加入上下文参数
            config=ensure_config({"configurable": {"thread_id": "session_fallback_001"}})  # 加入线程 ID
        )

        final_message = result1["messages"][-1]
        logger.info(f"✅ 场景1成功: {final_message.content[:100]}...")

    except Exception as e:
        logger.error(f"❌ 场景1失败: {e}")

    # 测试场景2: 复杂查询
    logger.info("\n" + "="*60)
    logger.info("场景2: 复杂查询 - 测试模型处理能力")
    logger.info("="*60)

    query2 = "请先获取系统信息,然后计算 100 + 200 的结果,最后总结一下"
    logger.info(f"查询: {query2}")

    try:
        result2 = agent.invoke(
            {"messages": [HumanMessage(content=query2)]},
            context=UserContext(user_id="user_fallback_test"),
            config=ensure_config({"configurable": {"thread_id": "session_fallback_002"}})
        )

        final_message = result2["messages"][-1]
        logger.info(f"✅ 场景2成功: {final_message.content[:100]}...")

    except Exception as e:
        logger.error(f"❌ 场景2失败: {e}")

    # 输出说明
    logger.info("\n" + "="*60)
    logger.info("测试完成")
    logger.info("="*60)

    print("\n" + "="*60)
    print("ModelFallbackMiddleware 工作原理说明")
    print("="*60)
    print("1. 主模型: deepseek-chat (temperature=0.1)")
    print("2. 备用模型1: deepseek-reasoner (temperature=0.3)")
    print("3. 备用模型2: deepseek-chat (temperature=0.5)")
    print("4. 当主模型调用失败时,自动尝试备用模型1")
    print("5. 如果备用模型1也失败,继续尝试备用模型2")
    print("6. 返回第一个成功的模型响应")
    print("7. 实际应用中应配置不同的模型提供商(如 OpenAI, Anthropic 等)")
    print("="*60 + "\n")

    print("\n💡 提示:")
    print("在生产环境中,建议配置不同提供商的模型,例如:")
    print("  主模型: openai:gpt-4o")
    print("  备用1: anthropic:claude-sonnet-4-5-20250929")
    print("  备用2: deepseek:deepseek-chat")
    print("这样可以在某个提供商服务中断时,自动切换到其他提供商。\n")

# ==================== 7. 运行测试 ====================
run_fallback_test()

3.LLMToolSelectorMiddleware 智能工具选择

中间件类型

wrap_model_call - 模型调用包装中间件

概述

本示例展示如何使用 LLMToolSelectorMiddleware 来智能选择最相关的工具,当 Agent 拥有大量工具时,自动筛选出最相关的工具子集。

核心特性

  1. 智能工具筛选:使用 LLM 分析查询并选择最相关的工具
  2. 减少 Token 消耗:只将相关工具传递给主模型,降低成本
  3. 提高准确性:帮助主模型聚焦于正确的工具,提升响应质量
  4. 灵活配置:支持限制工具数量、指定必选工具、自定义选择模型

工作原理

在主模型调用前,中间件会自动:

  1. 使用选择模型分析用户查询
  2. 从所有可用工具中选择最相关的 N 个工具
  3. 将筛选后的工具列表传递给主模型
  4. 主模型只能看到和使用被选中的工具

配置参数

  • model: 用于工具选择的模型(默认使用主模型)
  • system_prompt: 工具选择的系统提示词
  • max_tools: 最多选择的工具数量(默认无限制)
  • always_include: 始终包含的工具名称列表(不计入 max_tools 限制)

预期结果

  • 工具数量减少:从 10+ 个工具筛选到 2-3 个最相关的工具
  • Token 使用降低:减少传递给主模型的工具描述,节省成本
  • 响应质量提升:主模型更容易选择正确的工具
python 复制代码
# ==================== LLMToolSelectorMiddleware 完整实现 ====================

from langchain.agents import create_agent
from langchain.agents.middleware import LLMToolSelectorMiddleware
from langchain_deepseek import ChatDeepSeek
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from langchain_core.runnables import ensure_config
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import logging

# ==================== 1. 配置日志 ====================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
load_dotenv(override=True)

# ==================== 2. 定义多个工具(模拟大量工具场景)====================
@tool
def search_weather(city: str) -> str:
    """查询指定城市的天气信息"""
    logger.info(f"search_weather 被调用: {city}")
    return f"{city}的天气:晴天,温度25°C,湿度60%"

@tool
def search_news(topic: str) -> str:
    """搜索指定主题的最新新闻"""
    logger.info(f"search_news 被调用: {topic}")
    return f"关于'{topic}'的最新新闻:今日头条新闻内容..."

@tool
def calculate_math(expression: str) -> str:
    """计算数学表达式的结果"""
    logger.info(f"calculate_math 被调用: {expression}")
    try:
        result = eval(expression)
        return f"计算结果: {expression} = {result}"
    except:
        return "计算错误"

@tool
def translate_text(text: str, target_lang: str) -> str:
    """将文本翻译成目标语言"""
    logger.info(f"translate_text 被调用: {text} -> {target_lang}")
    return f"翻译结果: [模拟翻译到{target_lang}]"

@tool
def search_database(query: str) -> str:
    """在数据库中搜索信息"""
    logger.info(f"search_database 被调用: {query}")
    return f"数据库搜索结果: 找到3条关于'{query}'的记录"

@tool
def send_email(recipient: str, subject: str) -> str:
    """发送电子邮件"""
    logger.info(f"send_email 被调用: {recipient}, {subject}")
    return f"邮件已发送给 {recipient}"

@tool
def get_stock_price(symbol: str) -> str:
    """获取股票价格"""
    logger.info(f"get_stock_price 被调用: {symbol}")
    return f"股票 {symbol} 当前价格: $150.25"

@tool
def book_meeting(date: str, time: str) -> str:
    """预订会议室"""
    logger.info(f"book_meeting 被调用: {date} {time}")
    return f"会议室已预订: {date} {time}"

# 所有工具列表(模拟拥有大量工具的场景)
all_tools = [
    search_weather,
    search_news,
    calculate_math,
    translate_text,
    search_database,
    send_email,
    get_stock_price,
    book_meeting,
]

# ==================== 3. 定义上下文 ====================
class UserContext(BaseModel):
    user_id: str = Field(..., description="用户唯一标识")

# ==================== 4. 配置中间件 ====================
# 配置工具选择中间件:使用 LLM 智能选择最相关的工具
tool_selector_middleware = LLMToolSelectorMiddleware(
    model=ChatDeepSeek(model="deepseek-chat", temperature=0.1),  # 使用较小的模型进行工具选择
    max_tools=3,  # 最多选择3个工具
    always_include=["calculate_math"],  # 始终包含数学计算工具
    system_prompt="分析用户查询,选择最相关的工具。优先选择直接相关的工具。"
)

# ==================== 5. 创建 Agent ====================
agent = create_agent(
    model=ChatDeepSeek(model="deepseek-chat", temperature=0.2),  # 主模型
    tools=all_tools,  # 提供所有8个工具
    middleware=[
        tool_selector_middleware,  # 添加工具选择中间件
    ],
    context_schema=UserContext,
    debug=True,  # 开启调试模式以观察工具选择过程
)

# ==================== 6. 执行测试 ====================
def run_tool_selector_test():
    """
    测试 LLMToolSelectorMiddleware 的智能工具选择功能

    场景:从8个工具中智能选择最相关的3个工具
    """
    logger.info("开始 LLMToolSelectorMiddleware 测试")
    logger.info(f"配置: 总共 {len(all_tools)} 个工具,最多选择 3 个,始终包含 calculate_math")

    test_queries = [
        "北京今天的天气怎么样?",
        "帮我计算 123 + 456 的结果",
        "查询苹果公司的股票价格",
        "搜索关于人工智能的最新新闻",
    ]

    for i, query in enumerate(test_queries, 1):
        logger.info("\n" + "="*60)
        logger.info(f"测试场景 {i}: {query}")
        logger.info("="*60)

        try:
            result = agent.invoke(
                {"messages": [HumanMessage(content=query)]},
                context=UserContext(user_id="user_selector_test"),
                config=ensure_config({"configurable": {"thread_id": f"session_selector_{i:03d}"}})
            )

            final_message = result["messages"][-1]
            logger.info(f"✅ 场景 {i} 完成")
            logger.info(f"响应摘要: {final_message.content[:80]}...")

        except Exception as e:
            logger.error(f"❌ 场景 {i} 失败: {e}")
            import traceback
            traceback.print_exc()

    # 输出说明
    logger.info("\n" + "="*60)
    logger.info("测试完成")
    logger.info("="*60)

    print("\n" + "="*60)
    print("LLMToolSelectorMiddleware 工作原理说明")
    print("="*60)
    print("1. Agent 配置了 8 个不同功能的工具")
    print("2. 中间件使用 LLM 分析用户查询")
    print("3. 从 8 个工具中智能选择最相关的 3 个")
    print("4. calculate_math 工具始终被包含(always_include)")
    print("5. 主模型只能看到被选中的工具")
    print("6. 这样可以减少 token 消耗,提高响应质量")
    print("="*60 + "\n")

    print("\n💡 优势:")
    print("- Token 节省:只传递相关工具描述,减少约 60-70% 的工具相关 token")
    print("- 准确性提升:主模型更容易选择正确的工具")
    print("- 成本降低:减少 API 调用成本")
    print("- 可扩展性:支持数十甚至上百个工具的场景\n")

# ==================== 7. 运行测试 ====================
run_tool_selector_test()

三、 wrap_tool_call (包裹工具调用)

1.ToolRetryMiddleware 自动重试工具调用

中间件类型

wrap_tool_call - 工具调用包装中间件

概述

本示例展示如何使用 ToolRetryMiddleware 来自动重试失败的工具调用,提高系统的稳定性和可靠性。

核心特性

  1. 自动重试:工具调用失败时自动重试,无需手动处理
  2. 指数退避:支持指数退避策略,避免过度请求
  3. 灵活配置:可配置重试次数、退避因子、延迟时间等
  4. 异常过滤:支持只重试特定类型的异常
  5. 工具级控制:可以针对特定工具配置重试策略

工作原理

当工具调用失败时,中间件会自动:

  1. 捕获工具调用异常
  2. 检查是否应该重试(基于异常类型和重试次数)
  3. 等待一段时间(使用指数退避策略)
  4. 重新执行工具调用
  5. 返回成功结果或最终失败消息

配置参数

  • max_retries: 最大重试次数(默认 2)
  • tools: 应用重试的工具列表(默认所有工具)
  • retry_on: 应该重试的异常类型或判断函数
  • on_failure: 失败时的处理方式('raise' 或 'return_message' 或自定义函数)
  • backoff_factor: 退避因子(默认 2.0)
  • initial_delay: 初始延迟时间(默认 1.0 秒)
  • max_delay: 最大延迟时间(默认 60.0 秒)
  • jitter: 是否添加随机抖动(默认 True)

预期结果

  • 临时故障:自动重试后成功执行
  • 持续故障:达到最大重试次数后返回友好的错误消息
python 复制代码
# ==================== ToolRetryMiddleware 完整实现 ====================

from langchain.agents import create_agent
from langchain.agents.middleware import ToolRetryMiddleware
from langchain_deepseek import ChatDeepSeek
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from langchain_core.runnables import ensure_config
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import logging
import random

# ==================== 1. 配置日志 ====================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
load_dotenv(override=True)

# ==================== 2. 定义工具(模拟可能失败的工具)====================
# 全局计数器,用于模拟间歇性故障
call_counts = {}

@tool
def unreliable_api_call(query: str) -> str:
    """
    模拟不稳定的 API 调用
    前2次调用会失败,第3次成功
    """
    if 'unreliable_api_call' not in call_counts:
        call_counts['unreliable_api_call'] = 0

    call_counts['unreliable_api_call'] += 1
    attempt = call_counts['unreliable_api_call']

    logger.info(f"unreliable_api_call 第 {attempt} 次调用: {query}")

    # 前2次调用失败
    if attempt <= 2:
        logger.warning(f"模拟 API 调用失败(第 {attempt} 次尝试)")
        raise ConnectionError(f"API 连接失败(尝试 {attempt}/3)")

    # 第3次成功
    logger.info(f"✅ API 调用成功(第 {attempt} 次尝试)")
    return f"API 查询成功: '{query}' 的结果数据"

@tool
def stable_tool(data: str) -> str:
    """稳定的工具,总是成功"""
    logger.info(f"stable_tool 被调用: {data}")
    return f"处理完成: {data}"

@tool
def random_failure_tool(input_text: str) -> str:
    """
    随机失败的工具
    50% 概率失败
    """
    logger.info(f"random_failure_tool 被调用: {input_text}")

    if random.random() < 0.5:
        logger.warning("模拟随机失败")
        raise RuntimeError("随机错误:服务暂时不可用")

    logger.info("✅ 随机工具调用成功")
    return f"随机工具处理结果: {input_text}"

tools = [unreliable_api_call, stable_tool, random_failure_tool]

# ==================== 3. 定义上下文 ====================
class UserContext(BaseModel):
    user_id: str = Field(..., description="用户唯一标识")

# ==================== 4. 配置中间件 ====================
# 配置工具重试中间件:自动重试失败的工具调用
retry_middleware = ToolRetryMiddleware(
    max_retries=3,  # 最多重试3次
    tools=["unreliable_api_call", "random_failure_tool"],  # 只对这两个工具启用重试
    retry_on=(ConnectionError, RuntimeError),  # 只重试这些异常
    on_failure="return_message",  # 失败时返回友好消息而不是抛出异常
    backoff_factor=1.5,  # 退避因子,每次重试延迟增加1.5倍
    initial_delay=0.5,  # 初始延迟0.5秒
    max_delay=5.0,  # 最大延迟5秒
    jitter=True,  # 添加随机抖动,避免同时重试
)

# ==================== 5. 创建 Agent ====================
agent = create_agent(
    model=ChatDeepSeek(model="deepseek-chat", temperature=0.2),
    tools=tools,
    middleware=[
        retry_middleware,  # 添加重试中间件
    ],
    context_schema=UserContext,
    debug=False,
)

# ==================== 6. 执行测试 ====================
def run_retry_test():
    """
    测试 ToolRetryMiddleware 的自动重试功能

    场景:测试不稳定工具的自动重试机制
    """
    logger.info("开始 ToolRetryMiddleware 测试")
    logger.info("配置: max_retries=3, 对 unreliable_api_call 和 random_failure_tool 启用重试")

    # 重置计数器
    call_counts.clear()

    # 测试场景1: 不稳定的 API 调用(前2次失败,第3次成功)
    logger.info("\n" + "="*60)
    logger.info("场景1: 测试不稳定的 API 调用(应该在重试后成功)")
    logger.info("="*60)

    query1 = "请调用 unreliable_api_call 查询用户数据"
    logger.info(f"查询: {query1}")

    try:
        result1 = agent.invoke(
            {"messages": [HumanMessage(content=query1)]},
            context=UserContext(user_id="user_retry_test"),
            config=ensure_config({"configurable": {"thread_id": "session_retry_001"}})
        )

        final_message = result1["messages"][-1]
        logger.info(f"✅ 场景1完成")
        logger.info(f"响应: {final_message.content[:100]}...")

    except Exception as e:
        logger.error(f"❌ 场景1失败: {e}")

    # 测试场景2: 稳定工具(不需要重试)
    logger.info("\n" + "="*60)
    logger.info("场景2: 测试稳定工具(不需要重试)")
    logger.info("="*60)

    query2 = "请使用 stable_tool 处理数据"
    logger.info(f"查询: {query2}")

    try:
        result2 = agent.invoke(
            {"messages": [HumanMessage(content=query2)]},
            context=UserContext(user_id="user_retry_test"),
            config=ensure_config({"configurable": {"thread_id": "session_retry_002"}})
        )

        final_message = result2["messages"][-1]
        logger.info(f"✅ 场景2完成")
        logger.info(f"响应: {final_message.content[:100]}...")

    except Exception as e:
        logger.error(f"❌ 场景2失败: {e}")

    # 输出说明
    logger.info("\n" + "="*60)
    logger.info("测试完成")
    logger.info("="*60)

    print("\n" + "="*60)
    print("ToolRetryMiddleware 工作原理说明")
    print("="*60)
    print("1. unreliable_api_call 工具前2次调用失败")
    print("2. 中间件自动捕获 ConnectionError 异常")
    print("3. 使用指数退避策略等待后重试")
    print("4. 第3次调用成功,返回结果")
    print("5. stable_tool 工具始终成功,不需要重试")
    print("6. 重试机制对业务逻辑完全透明")
    print("="*60 + "\n")

    print("\n💡 重试策略:")
    print("- 第1次重试延迟: 0.5秒 × 1.5^0 = 0.5秒")
    print("- 第2次重试延迟: 0.5秒 × 1.5^1 = 0.75秒")
    print("- 第3次重试延迟: 0.5秒 × 1.5^2 = 1.125秒")
    print("- 添加随机抖动避免雷鸣群效应")
    print("\n🎯 适用场景:")
    print("- 网络请求不稳定")
    print("- 外部 API 限流")
    print("- 数据库连接超时")
    print("- 临时性服务故障\n")

# ==================== 7. 运行测试 ====================
run_retry_test()

2.LLMToolEmulator 模拟工具执行

中间件类型

wrap_tool_call - 工具调用包装中间件

概述

本示例展示如何使用 LLMToolEmulator 来使用 LLM 模拟工具执行,而不是真正调用工具。这对于测试、演示和开发非常有用。

核心特性

  1. LLM 模拟执行:使用 LLM 生成模拟的工具执行结果
  2. 选择性模拟:可以选择模拟特定工具或所有工具
  3. 安全测试:在不执行真实操作的情况下测试 Agent 逻辑
  4. 快速原型:无需实现真实工具即可测试 Agent 流程

工作原理

当工具被调用时,中间件会自动:

  1. 拦截工具调用请求
  2. 检查该工具是否在模拟列表中
  3. 使用 LLM 根据工具描述和参数生成模拟结果
  4. 返回模拟结果而不是执行真实工具

配置参数

  • tools : 要模拟的工具列表(工具名称或 BaseTool 实例)
    • None: 模拟所有工具(默认)
    • []: 不模拟任何工具
    • ["tool1", "tool2"]: 只模拟指定的工具
  • model: 用于模拟的 LLM 模型(默认 anthropic:claude-sonnet-4-5-20250929)

预期结果

  • 模拟工具:返回 LLM 生成的模拟结果,不执行真实操作
  • 非模拟工具:正常执行真实工具代码
python 复制代码
# ==================== LLMToolEmulator 完整实现 ====================

from langchain.agents import create_agent
from langchain.agents.middleware import LLMToolEmulator
from langchain_deepseek import ChatDeepSeek
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from langchain_core.runnables import ensure_config
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import logging

# ==================== 1. 配置日志 ====================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
load_dotenv(override=True)

# ==================== 2. 定义工具 ====================
@tool
def send_real_email(recipient: str, subject: str, body: str) -> str:
    """
    发送真实邮件(在测试中会被模拟)
    实际生产环境中这会真正发送邮件
    """
    logger.info(f"⚠️ send_real_email 被真实调用: {recipient}")
    # 这里应该是真实的邮件发送逻辑
    return f"真实邮件已发送给 {recipient},主题: {subject}"

@tool
def charge_credit_card(card_number: str, amount: float) -> str:
    """
    真实扣款(在测试中会被模拟)
    实际生产环境中这会真正扣款
    """
    logger.info(f"⚠️ charge_credit_card 被真实调用: ${amount}")
    # 这里应该是真实的支付逻辑
    return f"已从卡号 {card_number} 扣款 ${amount}"

@tool
def delete_database_record(record_id: str) -> str:
    """
    删除数据库记录(在测试中会被模拟)
    实际生产环境中这会真正删除数据
    """
    logger.info(f"⚠️ delete_database_record 被真实调用: {record_id}")
    # 这里应该是真实的数据库删除逻辑
    return f"记录 {record_id} 已从数据库中删除"

@tool
def safe_query_tool(query: str) -> str:
    """
    安全的查询工具(不会被模拟,真实执行)
    """
    logger.info(f"✅ safe_query_tool 被真实调用: {query}")
    return f"查询结果: 找到关于 '{query}' 的 5 条记录"

tools = [send_real_email, charge_credit_card, delete_database_record, safe_query_tool]

# ==================== 3. 定义上下文 ====================
class UserContext(BaseModel):
    user_id: str = Field(..., description="用户唯一标识")

# ==================== 4. 配置中间件 ====================
# 配置工具模拟中间件:使用 LLM 模拟危险操作,避免真实执行
emulator_middleware = LLMToolEmulator(
    tools=["send_real_email", "charge_credit_card", "delete_database_record"],  # 只模拟这些危险工具
    model=ChatDeepSeek(model="deepseek-chat", temperature=0.7),  # 使用 DeepSeek 进行模拟
)

# ==================== 5. 创建 Agent ====================
agent = create_agent(
    model=ChatDeepSeek(model="deepseek-chat", temperature=0.2),
    tools=tools,
    middleware=[
        emulator_middleware,  # 添加工具模拟中间件
    ],
    context_schema=UserContext,
    debug=True,  # 开启调试模式以观察模拟过程
)

# ==================== 6. 执行测试 ====================
def run_emulator_test():
    """
    测试 LLMToolEmulator 的工具模拟功能

    场景:测试危险操作的模拟执行,确保不会真实执行
    """
    logger.info("开始 LLMToolEmulator 测试")
    logger.info("配置: 模拟 send_real_email, charge_credit_card, delete_database_record")
    logger.info("safe_query_tool 不被模拟,会真实执行")

    test_scenarios = [
        ("场景1: 发送邮件(应该被模拟)", "请发送邮件给 test@example.com,主题是测试邮件"),
        ("场景2: 信用卡扣款(应该被模拟)", "请从卡号 1234-5678-9012-3456 扣款 99.99 美元"),
        ("场景3: 删除数据(应该被模拟)", "请删除数据库中 ID 为 record_123 的记录"),
        ("场景4: 安全查询(应该真实执行)", "请查询用户信息"),
    ]

    for i, (scenario_name, query) in enumerate(test_scenarios, 1):
        logger.info("\n" + "="*60)
        logger.info(scenario_name)
        logger.info("="*60)
        logger.info(f"查询: {query}")

        try:
            result = agent.invoke(
                {"messages": [HumanMessage(content=query)]},
                context=UserContext(user_id="user_emulator_test"),
                config=ensure_config({"configurable": {"thread_id": f"session_emulator_{i:03d}"}})
            )

            final_message = result["messages"][-1]
            logger.info(f"✅ {scenario_name} 完成")
            logger.info(f"响应摘要: {final_message.content[:80]}...")

        except Exception as e:
            logger.error(f"❌ {scenario_name} 失败: {e}")
            import traceback
            traceback.print_exc()

    # 输出说明
    logger.info("\n" + "="*60)
    logger.info("测试完成")
    logger.info("="*60)

    print("\n" + "="*60)
    print("LLMToolEmulator 工作原理说明")
    print("="*60)
    print("1. send_real_email, charge_credit_card, delete_database_record 被 LLM 模拟")
    print("2. 这些工具的代码不会被真实执行")
    print("3. LLM 根据工具描述和参数生成合理的模拟结果")
    print("4. safe_query_tool 不在模拟列表中,会真实执行")
    print("5. 日志中可以看到哪些工具被真实调用(⚠️)或模拟(无标记)")
    print("="*60 + "\n")

    print("\n🎯 使用场景:")
    print("- 测试环境:避免执行危险操作(删除、扣款、发送邮件等)")
    print("- 快速原型:无需实现真实工具即可测试 Agent 流程")
    print("- 演示系统:展示功能而不触发真实操作")
    print("- 开发调试:在开发阶段模拟外部 API 调用")
    print("\n💡 最佳实践:")
    print("- 在测试环境中模拟所有危险操作")
    print("- 在生产环境中移除模拟中间件")
    print("- 使用环境变量控制是否启用模拟")
    print("- 模拟结果应该尽可能接近真实结果\n")

# ==================== 7. 运行测试 ====================
run_emulator_test()
复制代码
🎯 使用场景:
- 测试环境:避免执行危险操作(删除、扣款、发送邮件等)
- 快速原型:无需实现真实工具即可测试 Agent 流程
- 演示系统:展示功能而不触发真实操作
- 开发调试:在开发阶段模拟外部 API 调用

💡 最佳实践:
- 在测试环境中模拟所有危险操作
- 在生产环境中移除模拟中间件
- 使用环境变量控制是否启用模拟
- 模拟结果应该尽可能接近真实结果

四、 after_model 模型调用后

1.HumanInTheLoopMiddleware 人工干预中间件

中间件类型

after_model - 模型调用后中间件

概述

本示例展示如何使用 HumanInTheLoopMiddleware 来实现人工审批流程,确保关键操作在执行前得到人工确认。

核心特性

  1. 官方中间件集成 :使用 from langchain.agents.middleware import HumanInTheLoopMiddleware
  2. 工具调用拦截 :在 create_agent 中通过 middleware 参数集成
  3. 灵活审批策略:支持 approve(批准)、edit(编辑)、reject(拒绝)三种决策
  4. 无缝集成:中间件自动处理中断和恢复逻辑

工作原理

当 AI 决定调用需要审批的工具时,中间件会自动:

  1. 拦截工具调用请求
  2. 触发中断(interrupt),等待人工决策
  3. 根据人工决策执行相应操作(批准/编辑/拒绝)
  4. 继续或终止执行流程

审批决策类型

  • approve:批准执行,使用原始参数
  • edit:修改参数后执行
  • reject:拒绝执行,返回错误消息

预期结果

  • 无中间件:AI 直接调用工具发送邮件
  • 有中间件:执行暂停,等待人工批准后才发送邮件
python 复制代码
# ========================================
# LangChain 1.0 人工审批中间件实战
# ========================================
import os
from dotenv import load_dotenv
from langchain_deepseek import ChatDeepSeek
from langchain.agents import create_agent
from langchain.agents.middleware import HumanInTheLoopMiddleware
from langchain.agents.middleware.human_in_the_loop import (
    HITLResponse,
    ApproveDecision,
    EditDecision,
    RejectDecision
)
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from langgraph.types import Command

# 1. 加载环境变量
load_dotenv(override=True)

# ---------------------------------------------------------------------------
# 2. 定义工具 (Tools)
# ---------------------------------------------------------------------------
class SendEmailSchema(BaseModel):
    recipient: str = Field(description="邮件接收者的邮箱地址")
    subject: str = Field(description="邮件主题")
    body: str = Field(description="邮件正文内容")

@tool(args_schema=SendEmailSchema)
def send_email(recipient: str, subject: str, body: str):
    """模拟发送邮件的工具"""
    print(f"\n======== [SYSTEM ACTION: 正在执行发送邮件] ========")
    print(f"收件人: {recipient}")
    print(f"主题  : {subject}")
    print(f"内容  : {body}")
    print(f"================================================\n")
    return f"邮件已成功发送给 {recipient}"

tools = [send_email]

# ---------------------------------------------------------------------------
# 3. 创建模型
# ---------------------------------------------------------------------------
model = ChatDeepSeek(model="deepseek-chat")

# ---------------------------------------------------------------------------
# 4. 创建带 HumanInTheLoopMiddleware 的图
# ---------------------------------------------------------------------------
system_prompt = """
你是一个专业的行政助手。
当用户请求发送邮件时,你必须直接调用 `send_email` 工具。
不要问任何后续问题,不要要求确认,直接生成工具调用。
"""

# 定义中间件:指定 'send_email' 工具需要中断审批
# interrupt_on 字典中的 True 表示允许批准、编辑和拒绝
hitl_middleware = HumanInTheLoopMiddleware(
    interrupt_on={"send_email": True},
    description_prefix="需要人工批准才能发送邮件"
)

# 使用 create_agent 创建图,并注入中间件
# LangGraph Studio 会自动处理持久化,不需要传入 checkpointer
graph = create_agent(
    model=model,
    tools=tools,
    system_prompt=system_prompt,
    middleware=[hitl_middleware]
)

# ---------------------------------------------------------------------------
# 5. 定义观察和执行函数
# ---------------------------------------------------------------------------

def run_interactive_session():
    """本地运行的交互式会话(需要 checkpointer)"""
    from langgraph.checkpoint.memory import MemorySaver

    # 本地运行时需要创建带 checkpointer 的 graph
    local_graph = create_agent(
        model=model,
        tools=tools,
        system_prompt=system_prompt,
        middleware=[hitl_middleware],
        checkpointer=MemorySaver()  # 本地运行需要
    )

    # 配置线程 ID,用于区分不同的对话会话
    thread_id = "demo_thread_middleware_1"
    config = {"configurable": {"thread_id": thread_id}}

    user_input = "帮我给 hr@example.com 发一封邮件,主题是'休假申请',内容是我下周一想请假一天。"
    print(f"\n[用户]: {user_input}")

    # === 第一步:初始执行 ===
    print("\n[系统]: 开始处理请求...")
    # input 传入用户消息
    # stream_mode="values" 可以让我们看到消息流
    for event in local_graph.stream(
        {"messages": [{"role": "user", "content": user_input}]},
        config=config,
        stream_mode="values"
    ):
        # 简单打印最后一条消息的内容
        if "messages" in event:
            last_msg = event["messages"][-1]
            if last_msg.type == "ai" and last_msg.tool_calls:
                print(f"[AI 思考]: 决定调用工具 -> {last_msg.tool_calls[0]['name']}")

    # === 第二步:观察 (Observation) ===
    # 中间件应该触发了中断
    snapshot = local_graph.get_state(config)

    print(f"\n--- 🛑 执行已暂停 (HITL Middleware) ---")
    print(f"下一步骤 (Next): {snapshot.next}")
    print(f"任务数量: {len(snapshot.tasks) if snapshot.tasks else 0}")

    # 检查是否有任务(这表示中断发生)
    if snapshot.tasks:
        # 获取最后一条消息
        last_message = snapshot.values["messages"][-1]

        if hasattr(last_message, "tool_calls") and last_message.tool_calls:
            tool_call = last_message.tool_calls[0]
            print(f"\n[待审批操作]:")
            print(f"  - 工具: {tool_call['name']}")
            print(f"  - 参数: {tool_call['args']}")

            # === 第三步:人工介入 (Human Input) ===
            approval = input("\n[管理员]: 是否批准执行此操作? (y/n/e[编辑]): ")

            if approval.lower() == 'y':
                # === 第四步:恢复执行 (Resume) - 批准 ===
                print("\n[系统]: 操作已批准,继续执行...")

                # 创建 HITLResponse 对象,包含 ApproveDecision
                hitl_response = HITLResponse(
                    decisions=[ApproveDecision(type="approve")]
                )

                # 使用 Command(resume=hitl_response) 来批准并继续执行
                for event in local_graph.stream(
                    Command(resume=hitl_response),  # 传入 HITLResponse 对象
                    config=config,
                    stream_mode="values"
                ):
                    if "messages" in event:
                        last_msg = event["messages"][-1]
                        if last_msg.type == "tool":
                            print(f"[工具输出]: {last_msg.content}")
                        elif last_msg.type == "ai" and last_msg.content:
                            print(f"[AI 回复]: {last_msg.content}")

            elif approval.lower() == 'e':
                # === 编辑工具调用参数 ===
                print("\n[系统]: 编辑模式...")
                print(f"当前参数: {tool_call['args']}")

                # 让用户编辑参数
                new_recipient = input(f"新收件人 (当前: {tool_call['args'].get('recipient', '')},留空保持不变): ").strip()
                new_subject = input(f"新主题 (当前: {tool_call['args'].get('subject', '')},留空保持不变): ").strip()
                new_body = input(f"新内容 (当前: {tool_call['args'].get('body', '')},留空保持不变): ").strip()

                # 构建新的参数
                updated_args = tool_call['args'].copy()
                if new_recipient:
                    updated_args['recipient'] = new_recipient
                if new_subject:
                    updated_args['subject'] = new_subject
                if new_body:
                    updated_args['body'] = new_body

                print(f"\n[系统]: 使用更新后的参数继续执行...")
                print(f"更新后的参数: {updated_args}")

                # 创建 HITLResponse 对象,包含 EditDecision
                # EditDecision 需要 edited_action,包含 name 和 args
                hitl_response = HITLResponse(
                    decisions=[EditDecision(
                        type="edit",
                        edited_action={
                            "name": tool_call['name'],
                            "args": updated_args
                        }
                    )]
                )

                # 使用 Command(resume=hitl_response) 来批准并使用新参数
                for event in local_graph.stream(
                    Command(resume=hitl_response),  # 传入包含编辑决策的 HITLResponse
                    config=config,
                    stream_mode="values"
                ):
                    if "messages" in event:
                        last_msg = event["messages"][-1]
                        if last_msg.type == "tool":
                            print(f"[工具输出]: {last_msg.content}")
                        elif last_msg.type == "ai" and last_msg.content:
                            print(f"[AI 回复]: {last_msg.content}")

            else:
                # === 拒绝操作 ===
                print("\n[系统]: 操作被拒绝。")

                # 创建 HITLResponse 对象,包含 RejectDecision
                rejection_reason = input("拒绝原因 (可选): ").strip() or "操作被管理员拒绝"

                hitl_response = HITLResponse(
                    decisions=[RejectDecision(
                        type="reject",
                        message=rejection_reason
                    )]
                )

                # 使用 Command(resume=hitl_response) 来拒绝
                for event in local_graph.stream(
                    Command(resume=hitl_response),  # 传入包含拒绝决策的 HITLResponse
                    config=config,
                    stream_mode="values"
                ):
                    if "messages" in event:
                        last_msg = event["messages"][-1]
                        if last_msg.type == "ai" and last_msg.content:
                            print(f"[AI 回复]: {last_msg.content}")
                        elif last_msg.type == "tool":
                            print(f"[工具消息]: {last_msg.content}")

                print("[系统]: 流程已终止。")
        else:
            print("没有检测到待处理的工具调用。")
    else:
        print("流程已完成,没有触发中断。")
        # 打印最终结果
        if snapshot.values.get("messages"):
            last_msg = snapshot.values["messages"][-1]
            if last_msg.type == "ai" and last_msg.content:
                print(f"\n[最终回复]: {last_msg.content}")


run_interactive_session()
复制代码
[用户]: 帮我给 hr@example.com 发一封邮件,主题是'休假申请',内容是我下周一想请假一天。

[系统]: 开始处理请求...
[AI 思考]: 决定调用工具 -> send_email
[AI 思考]: 决定调用工具 -> send_email

--- 🛑 执行已暂停 (HITL Middleware) ---
下一步骤 (Next): ('HumanInTheLoopMiddleware.after_model',)
任务数量: 1

[待审批操作]:
  - 工具: send_email
  - 参数: {'recipient': 'hr@example.com', 'subject': '休假申请', 'body': '您好,\n\n我申请下周一请假一天。\n\n谢谢!\n\n此致\n敬礼'}

[系统]: 编辑模式...
当前参数: {'recipient': 'hr@example.com', 'subject': '休假申请', 'body': '您好,\n\n我申请下周一请假一天。\n\n谢谢!\n\n此致\n敬礼'}

[系统]: 使用更新后的参数继续执行...
更新后的参数: {'recipient': 'hr@gamil.com', 'subject': '病假申请', 'body': '您好,\n\n我申请下周一请假一天。\n\n谢谢!\n\n此致\n敬礼'}
[AI 回复]: 我将直接为您发送这封休假申请邮件。
[AI 回复]: 我将直接为您发送这封休假申请邮件。

======== [SYSTEM ACTION: 正在执行发送邮件] ========
收件人: hr@gamil.com
主题  : 病假申请
内容  : 您好,

我申请下周一请假一天。

谢谢!

此致
敬礼
================================================

[工具输出]: 邮件已成功发送给 hr@gamil.com
[AI 回复]: 邮件已成功发送给 hr@gamil.com,主题为"病假申请",内容包含您下周一的请假申请。

2.ToolCallLimitMiddleware 工具调用限制

中间件类型

wrap_tool_call - 工具调用包装中间件

概述

本示例展示如何使用 ToolCallLimitMiddleware 来限制 Agent 的工具调用频率或总量,防止工具被滥用或过度消耗资源。

核心特性

  1. 资源保护:防止特定工具被频繁调用
  2. 灵活配置:支持全局限制或针对特定工具的限制
  3. 自动熔断:达到限制后阻止工具执行并返回错误

工作原理

中间件会跟踪当前会话中的工具调用次数。当特定工具或总工具调用次数达到设定的阈值时,中间件会阻止后续的工具调用,并引发异常或返回预设的响应。

预期结果

  • 正常情况:工具调用次数未超限,正常执行
  • 超限情况 :抛出 ToolCallLimitExceeded 异常或返回错误信息
python 复制代码
# ==================== ToolCallLimitMiddleware 完整实现 ====================

from langchain.agents import create_agent
from langchain.agents.middleware import ToolCallLimitMiddleware
from langchain_deepseek import ChatDeepSeek
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from langchain_core.runnables import ensure_config
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import logging

# ==================== 1. 配置日志 ====================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
load_dotenv(override=True)

# ==================== 2. 定义工具 ====================
@tool
def check_server_status(server_id: str) -> str:
    """检查服务器状态"""
    logger.info(f"正在检查服务器 {server_id} 的状态...")
    return f"服务器 {server_id} 运行正常,负载 45%"

@tool
def restart_server(server_id: str) -> str:
    """重启服务器"""
    logger.info(f"正在重启服务器 {server_id}...")
    return f"服务器 {server_id} 已重启"

tools = [check_server_status, restart_server]

# ==================== 3. 定义上下文 ====================
class UserContext(BaseModel):
    user_id: str = Field(..., description="用户唯一标识")

# ==================== 4. 配置中间件 ====================
# 方式1: 限制所有工具的调用次数(全局限制)
global_tool_limiter = ToolCallLimitMiddleware(
    tool_name=None,  # None = 限制所有工具
    run_limit=3,     # 每次运行最多调用 3 次工具
    exit_behavior="continue"  # 超限后阻止工具调用,但继续执行
)

# 方式2: 限制特定工具的调用次数
specific_tool_limiter = ToolCallLimitMiddleware(
    tool_name="check_server_status",  # 只限制 check_server_status 工具
    thread_limit=5,   # 整个线程最多调用 5 次
    run_limit=2,      # 每次运行最多调用 2 次
    exit_behavior="error"  # 超限后返回错误消息
)

# ==================== 5. 创建 Agent ====================
agent = create_agent(
    model=ChatDeepSeek(model="deepseek-chat", temperature=0.1),
    tools=tools,
    middleware=[
        specific_tool_limiter,  # 使用特定工具限制器
    ],
    context_schema=UserContext,
    debug=False,
)

# ==================== 6. 执行测试 ====================
def run_tool_limit_test():
    logger.info("开始 ToolCallLimitMiddleware 测试")
    logger.info("配置: check_server_status 工具限制为 run_limit=2")

    # 设计一个会触发多次工具调用的场景
    query = """
    请帮我检查以下服务器的状态:
    1. Server-A
    2. Server-B
    3. Server-C
    4. Server-D

    请逐个检查每台服务器。
    """

    logger.info(f"用户查询: {query.strip()}")

    tool_call_count = 0
    limit_triggered = False

    try:
        for chunk in agent.stream(
            {"messages": [HumanMessage(content=query)]},
            context=UserContext(user_id="user_tool_limit"),
            config=ensure_config({"configurable": {"thread_id": "session_tool_limit_001"}}),
            stream_mode="updates"
        ):
            if isinstance(chunk, dict):
                for key, value in chunk.items():
                    # 统计工具调用
                    if "tools" in str(key).lower():
                        tool_call_count += 1

                    # 检测中间件触发
                    if "ToolCallLimitMiddleware" in str(key):
                        limit_triggered = True
                        logger.warning("检测到 ToolCallLimitMiddleware 触发!")

        logger.info("任务完成")

    except Exception as e:
        logger.error(f"捕获到异常: {e}")
        if "limit" in str(e).lower() or "exceeded" in str(e).lower():
            limit_triggered = True
        return str(e)

    # 输出结果
    logger.info("=" * 60)
    logger.info(f"工具调用次数: {tool_call_count}")
    logger.info(f"中间件触发: {'✅ 是' if limit_triggered else '❌ 否'}")
    logger.info("=" * 60)

# ==================== 7. 运行测试 ====================
run_tool_limit_test()
复制代码
2025-12-03 01:40:51,894 - INFO - 开始 ToolCallLimitMiddleware 测试
2025-12-03 01:40:51,894 - INFO - 配置: check_server_status 工具限制为 run_limit=2
2025-12-03 01:40:51,894 - INFO - 用户查询: 请帮我检查以下服务器的状态:
    1. Server-A
    2. Server-B
    3. Server-C
    4. Server-D

    请逐个检查每台服务器。
2025-12-03 01:40:52,186 - INFO - HTTP Request: POST https://api.deepseek.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-12-03 01:40:54,645 - WARNING - 检测到 ToolCallLimitMiddleware 触发!
2025-12-03 01:40:54,648 - INFO - 正在检查服务器 Server-A 的状态...
2025-12-03 01:40:54,702 - INFO - HTTP Request: POST https://api.deepseek.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-12-03 01:40:56,763 - WARNING - 检测到 ToolCallLimitMiddleware 触发!
2025-12-03 01:40:56,766 - INFO - 正在检查服务器 Server-B 的状态...
2025-12-03 01:40:56,820 - INFO - HTTP Request: POST https://api.deepseek.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-12-03 01:40:59,361 - WARNING - 检测到 ToolCallLimitMiddleware 触发!
2025-12-03 01:40:59,416 - INFO - HTTP Request: POST https://api.deepseek.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-12-03 01:41:01,690 - WARNING - 检测到 ToolCallLimitMiddleware 触发!
2025-12-03 01:41:01,738 - INFO - HTTP Request: POST https://api.deepseek.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-12-03 01:41:05,886 - WARNING - 检测到 ToolCallLimitMiddleware 触发!
2025-12-03 01:41:05,888 - INFO - 任务完成
2025-12-03 01:41:05,888 - INFO - ============================================================
2025-12-03 01:41:05,889 - INFO - 工具调用次数: 2
2025-12-03 01:41:05,889 - INFO - 中间件触发: ✅ 是
2025-12-03 01:41:05,890 - INFO - ============================================================
相关推荐
木子啊2 小时前
PHP中间件:ThinkCMF 6.x核心利器解析
开发语言·中间件·php
Bruk.Liu3 小时前
(LangChain实战5):LangChain消息模版ChatPromptTemplate
人工智能·python·langchain·agent
爱敲代码的TOM3 小时前
大模型应用开发-LangChain框架基础
python·langchain·大模型应用
Bruk.Liu4 小时前
(LangChain实战3):LangChain阻塞式invoke与流式stream的调用
人工智能·python·langchain
Bruk.Liu4 小时前
(LangChain实战4):LangChain消息模版PromptTemplate
人工智能·python·langchain
共享家95274 小时前
LangChain初识
人工智能·langchain
Wang201220135 小时前
langchai自带的搜索功能国内tool有哪些(langchain+deepseek+百度AI搜索 打造带搜索功能的agent)
langchain
玄同7651 天前
Llama.cpp 全实战指南:跨平台部署本地大模型的零门槛方案
人工智能·语言模型·自然语言处理·langchain·交互·llama·ollama