MCP 协议深度解析(八):Prompts 提示模板与 Sampling 采样机制!

🎯 博主简介

CSDN 「新星创作者」 人工智能技术领域博主,码龄5年 ,累计发布 180+篇原创 文章,博客总访问量 24万+浏览!

🚀 持续更新AI前沿实战知识,专注于 AI 技术实战,核心方向包括:

  • 🔥 生产级 RAG 系统 --- 架构设计、多路召回、混合检索、生产优化
  • 🤖 多 Agent 协作系统 --- Swarm 架构、任务编排、状态管理
  • 🔌 MCP 协议 --- 协议规范与深度实现
  • OpenClaw --- AI 助手框架进阶应用

同时也涉猎计算机视觉、Java 后端与 Spring 生态、Transformer 等深度学习技术。坚持从架构到代码、从原理到部署的实战风格。每篇文章配套代码与扩展资料,欢迎交流探讨。

📱 公众号: Anyi研习社--- 每天一点AI:资料、笔记、工具、趋势,一起进步。

🤝 商务合作 :请搜索关注微信公众号 「Anyi研习社」

系列导读:本文是《MCP 协议深度解析》系列的第 8 篇,深入探讨 MCP 协议中 Prompts 提示模板系统与 Sampling 采样机制的设计原理、实现细节与最佳实践。通过双语言代码示例(TypeScript + Python),帮助开发者掌握 MCP 的高级交互能力。


📑 目录

  • [1. 引言:Prompt 提示模板系统设计理念](#1. 引言:Prompt 提示模板系统设计理念)
  • [2. 参数化提示与动态生成](#2. 参数化提示与动态生成)
  • [3. 提示模板注册与发现](#3. 提示模板注册与发现)
  • [4. 多轮对话上下文管理](#4. 多轮对话上下文管理)
  • [5. Sampling 采样机制:Server 请求 LLM](#5. Sampling 采样机制:Server 请求 LLM)
  • [6. 采样请求的生命周期管理](#6. 采样请求的生命周期管理)
  • [7. 嵌套调用与递归防护](#7. 嵌套调用与递归防护)
  • [8. 采样成本控制与优化](#8. 采样成本控制与优化)
  • [9. 提示模板与采样的安全考虑](#9. 提示模板与采样的安全考虑)
  • [10. 实际应用场景与案例](#10. 实际应用场景与案例)
  • [11. 常见问题 FAQ](#11. 常见问题 FAQ)
  • [12. 参考文献](#12. 参考文献)

1. 引言:Prompt 提示模板系统设计理念

1.1 什么是 MCP Prompts

Model Context Protocol(MCP)的 Prompts 特性允许服务器向客户端暴露结构化的提示模板。这些模板可以被用户直接调用,或者由 LLM 根据上下文自动选择使用。Prompts 是 MCP 协议中实现"可复用 AI 工作流"的核心机制。

复制代码
┌─────────────────────────────────────────────────────────────────┐
│                    MCP Prompts 架构概览                          │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│   ┌─────────────┐         ┌─────────────┐         ┌──────────┐ │
│   │   Client    │◄───────►│   Server    │◄───────►│   LLM    │ │
│   │  (Host)     │  MCP    │  (Prompts   │ Sampling│ Provider │ │
│   │             │ Protocol│  Provider)  │ Protocol│          │ │
│   └─────────────┘         └─────────────┘         └──────────┘ │
│          │                       │                             │
│          ▼                       ▼                             │
│   ┌─────────────────────────────────────┐                      │
│   │        Prompt Templates             │                      │
│   │  • 代码审查模板                      │                      │
│   │  • 数据分析模板                      │                      │
│   │  • 文档生成模板                      │                      │
│   │  • 自定义工作流                      │                      │
│   └─────────────────────────────────────┘                      │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

1.2 Prompts 与 Tools 的区别

特性 Prompts Tools
用途 提供结构化提示模板 执行具体功能操作
返回值 消息列表(给 LLM 的上下文) 工具执行结果
交互方式 用户选择或 LLM 建议 LLM 自动调用
参数类型 描述性参数(用于生成提示) 功能性参数(用于执行操作)
典型场景 代码审查、分析模板 文件操作、API 调用
状态影响 通常无副作用 可能改变系统状态

1.3 设计哲学

MCP Prompts 的设计遵循以下核心原则:

  1. 可发现性(Discoverability):客户端可以动态发现服务器提供的所有提示模板
  2. 可组合性(Composability):提示模板可以嵌套组合,构建复杂工作流
  3. 类型安全(Type Safety):参数通过 JSON Schema 定义,确保类型正确
  4. 上下文感知(Context Awareness):提示可以访问 MCP 会话的完整上下文

2. 参数化提示与动态生成

2.1 参数化提示的基本概念

参数化提示允许开发者定义带有变量的模板,在运行时根据用户输入动态生成完整的提示内容。这是实现可复用提示模板的基础。

2.1.1 参数类型系统

MCP 支持以下参数类型:

类型 说明 示例
string 文本字符串 "code": "function hello() {}"
number 数值 "temperature": 0.7
boolean 布尔值 "includeTests": true
array 数组 "files": ["src/index.ts", "src/utils.ts"]
object 对象 "config": {"style": "google"}
enum 枚举值 "language": ["zh", "en", "ja"]

2.2 TypeScript 实现:参数化提示定义

typescript 复制代码
// prompts/code-review.ts
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
import {
  ListPromptsRequestSchema,
  GetPromptRequestSchema,
} from "@modelcontextprotocol/sdk/types.js";

// 定义提示模板接口
interface CodeReviewArgs {
  code: string;
  language: string;
  focus?: "performance" | "security" | "readability" | "all";
  includeExamples?: boolean;
}

// 注册提示模板
server.setRequestHandler(ListPromptsRequestSchema, async () => {
  return {
    prompts: [
      {
        name: "code-review",
        description: "对代码进行专业审查,提供改进建议",
        arguments: [
          {
            name: "code",
            description: "需要审查的代码",
            required: true,
          },
          {
            name: "language",
            description: "编程语言",
            required: true,
          },
          {
            name: "focus",
            description: "审查重点",
            required: false,
          },
          {
            name: "includeExamples",
            description: "是否包含改进示例",
            required: false,
          },
        ],
      },
    ],
  };
});

// 处理提示请求
server.setRequestHandler(GetPromptRequestSchema, async (request) => {
  const { name, arguments: args } = request.params;
  
  if (name !== "code-review") {
    throw new Error(`Unknown prompt: ${name}`);
  }

  const { code, language, focus = "all", includeExamples = false } = args as CodeReviewArgs;

  // 动态生成系统提示
  const systemPrompt = generateReviewPrompt(focus, includeExamples);
  
  // 构建消息列表
  const messages = [
    {
      role: "system" as const,
      content: {
        type: "text" as const,
        text: systemPrompt,
      },
    },
    {
      role: "user" as const,
      content: {
        type: "text" as const,
        text: `请审查以下 ${language} 代码:\n\n\`\`\`${language}\n${code}\n\`\`\``,
      },
    },
  ];

  return {
    description: `代码审查 - ${language}`,
    messages,
  };
});

// 根据参数动态生成提示内容
function generateReviewPrompt(focus: string, includeExamples: boolean): string {
  const focusAreas: Record<string, string> = {
    performance: "性能优化、算法复杂度、资源使用",
    security: "安全漏洞、输入验证、权限控制",
    readability: "代码风格、命名规范、注释质量",
    all: "性能、安全、可读性、可维护性",
  };

  let prompt = `你是一位资深 ${focusAreas[focus] || focusAreas.all} 专家。`;
  prompt += `请对提供的代码进行专业审查,重点关注 ${focusAreas[focus] || focusAreas.all}。\n\n`;
  prompt += `审查要求:\n`;
  prompt += `1. 识别潜在问题和风险\n`;
  prompt += `2. 提供具体的改进建议\n`;
  prompt += `3. 说明问题的严重程度(高/中/低)\n`;
  
  if (includeExamples) {
    prompt += `4. 为每个问题提供改进后的代码示例\n`;
  }
  
  prompt += `\n请以结构化的方式输出审查结果。`;
  
  return prompt;
}

2.3 Python 实现:参数化提示定义

python 复制代码
# prompts/code_review.py
from mcp.server import Server
from mcp.types import (
    ListPromptsRequest,
    GetPromptRequest,
    Prompt,
    PromptArgument,
    TextContent,
)
from typing import Optional
import json

# 创建服务器实例
server = Server("code-review-server")

# 定义提示模板配置
PROMPT_TEMPLATES = {
    "code-review": {
        "description": "对代码进行专业审查,提供改进建议",
        "arguments": [
            PromptArgument(
                name="code",
                description="需要审查的代码",
                required=True,
            ),
            PromptArgument(
                name="language",
                description="编程语言",
                required=True,
            ),
            PromptArgument(
                name="focus",
                description="审查重点 (performance/security/readability/all)",
                required=False,
            ),
            PromptArgument(
                name="include_examples",
                description="是否包含改进示例",
                required=False,
            ),
        ],
    },
    "doc-generate": {
        "description": "为代码生成文档",
        "arguments": [
            PromptArgument(
                name="code",
                description="需要生成文档的代码",
                required=True,
            ),
            PromptArgument(
                name="style",
                description="文档风格 (google/numpy/jsdoc)",
                required=False,
            ),
        ],
    },
}

# 提示内容生成器
class PromptGenerator:
    """动态提示内容生成器"""
    
    FOCUS_AREAS = {
        "performance": "性能优化、算法复杂度、资源使用",
        "security": "安全漏洞、输入验证、权限控制",
        "readability": "代码风格、命名规范、注释质量",
        "all": "性能、安全、可读性、可维护性",
    }
    
    @classmethod
    def generate_review_prompt(cls, focus: str, include_examples: bool) -> str:
        focus_area = cls.FOCUS_AREAS.get(focus, cls.FOCUS_AREAS["all"])
        
        prompt = f"你是一位资深 {focus_area} 专家。\n"
        prompt += f"请对提供的代码进行专业审查,重点关注 {focus_area}。\n\n"
        prompt += "审查要求:\n"
        prompt += "1. 识别潜在问题和风险\n"
        prompt += "2. 提供具体的改进建议\n"
        prompt += "3. 说明问题的严重程度(高/中/低)\n"
        
        if include_examples:
            prompt += "4. 为每个问题提供改进后的代码示例\n"
        
        prompt += "\n请以结构化的方式输出审查结果。"
        
        return prompt

# 注册提示列表处理器
@server.list_prompts()
async def list_prompts() -> list[Prompt]:
    return [
        Prompt(
            name=name,
            description=config["description"],
            arguments=config["arguments"],
        )
        for name, config in PROMPT_TEMPLATES.items()
    ]

# 注册提示获取处理器
@server.get_prompt()
async def get_prompt(name: str, arguments: dict) -> dict:
    if name not in PROMPT_TEMPLATES:
        raise ValueError(f"Unknown prompt: {name}")
    
    if name == "code-review":
        code = arguments.get("code", "")
        language = arguments.get("language", "")
        focus = arguments.get("focus", "all")
        include_examples = arguments.get("include_examples", False)
        
        system_prompt = PromptGenerator.generate_review_prompt(focus, include_examples)
        
        return {
            "description": f"代码审查 - {language}",
            "messages": [
                {
                    "role": "system",
                    "content": TextContent(type="text", text=system_prompt),
                },
                {
                    "role": "user",
                    "content": TextContent(
                        type="text",
                        text=f"请审查以下 {language} 代码:\n\n\`\`\`{language}\n{code}\n\`\`\`"
                    ),
                },
            ],
        }
    
    # 其他提示模板处理...
    return {}

2.4 动态提示生成策略

在实际应用中,提示模板可以根据多种因素动态生成:

策略 说明 适用场景
条件分支 根据参数选择不同提示分支 多语言支持、不同审查级别
模板继承 基础模板 + 特定扩展 通用审查 + 专项审查
上下文注入 将运行时数据注入提示 当前项目信息、用户偏好
链式组合 多个提示模板组合 复杂多阶段分析

3. 提示模板注册与发现

3.1 注册机制

MCP 服务器通过实现 prompts/list 方法来暴露可用的提示模板。

3.1.1 注册时序图

Prompt Registry MCP Server MCP Client Prompt Registry MCP Server MCP Client prompts/list 请求 查询所有注册模板 返回模板列表 返回 prompts[] prompts/get (name, args) 查找指定模板 返回模板处理器 执行模板生成 返回 messages[]

3.2 TypeScript 实现:完整注册流程

typescript 复制代码
// server/prompt-registry.ts
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
import {
  ListPromptsRequestSchema,
  GetPromptRequestSchema,
  Prompt,
  PromptMessage,
} from "@modelcontextprotocol/sdk/types.js";

// 提示模板定义接口
interface PromptTemplate {
  name: string;
  description: string;
  arguments: PromptArgument[];
  handler: (args: Record<string, unknown>) => Promise<{
    description: string;
    messages: PromptMessage[];
  }>;
}

// 提示注册表
class PromptRegistry {
  private templates: Map<string, PromptTemplate> = new Map();

  register(template: PromptTemplate): void {
    this.templates.set(template.name, template);
    console.log(`[PromptRegistry] Registered: ${template.name}`);
  }

  unregister(name: string): void {
    this.templates.delete(name);
    console.log(`[PromptRegistry] Unregistered: ${name}`);
  }

  list(): Prompt[] {
    return Array.from(this.templates.values()).map((t) => ({
      name: t.name,
      description: t.description,
      arguments: t.arguments,
    }));
  }

  async execute(name: string, args: Record<string, unknown>): Promise<{
    description: string;
    messages: PromptMessage[];
  }> {
    const template = this.templates.get(name);
    if (!template) {
      throw new Error(`Prompt not found: ${name}`);
    }
    return await template.handler(args);
  }
}

// 创建注册表实例
const registry = new PromptRegistry();

// 注册代码审查模板
registry.register({
  name: "code-review",
  description: "代码审查提示模板",
  arguments: [
    { name: "code", description: "代码内容", required: true },
    { name: "language", description: "编程语言", required: true },
  ],
  handler: async (args) => {
    const { code, language } = args as { code: string; language: string };
    return {
      description: `代码审查 (${language})`,
      messages: [
        {
          role: "system",
          content: {
            type: "text",
            text: "你是一位资深代码审查专家...",
          },
        },
        {
          role: "user",
          content: {
            type: "text",
            text: `请审查以下代码:\n\n\`\`\`${language}\n${code}\n\`\`\``,
          },
        },
      ],
    };
  },
});

// 注册数据分析模板
registry.register({
  name: "data-analysis",
  description: "数据分析提示模板",
  arguments: [
    { name: "data", description: "数据内容", required: true },
    { name: "format", description: "数据格式", required: false },
  ],
  handler: async (args) => {
    // 实现...
    return { description: "数据分析", messages: [] };
  },
});

// 设置 MCP 请求处理器
export function setupPromptHandlers(server: Server): void {
  // 处理列表请求
  server.setRequestHandler(ListPromptsRequestSchema, async () => {
    return { prompts: registry.list() };
  });

  // 处理获取请求
  server.setRequestHandler(GetPromptRequestSchema, async (request) => {
    const { name, arguments: args } = request.params;
    return await registry.execute(name, args || {});
  });
}

3.3 Python 实现:完整注册流程

python 复制代码
# server/prompt_registry.py
from typing import Callable, Dict, List, Any, Optional
from dataclasses import dataclass
from mcp.types import Prompt, PromptArgument, PromptMessage, TextContent
import asyncio

@dataclass
class PromptTemplate:
    """提示模板定义"""
    name: str
    description: str
    arguments: List[PromptArgument]
    handler: Callable[[Dict[str, Any]], asyncio.Future]

class PromptRegistry:
    """提示模板注册表"""
    
    def __init__(self):
        self._templates: Dict[str, PromptTemplate] = {}
    
    def register(self, template: PromptTemplate) -> None:
        """注册提示模板"""
        self._templates[template.name] = template
        print(f"[PromptRegistry] Registered: {template.name}")
    
    def unregister(self, name: str) -> None:
        """注销提示模板"""
        if name in self._templates:
            del self._templates[name]
            print(f"[PromptRegistry] Unregistered: {name}")
    
    def list(self) -> List[Prompt]:
        """列出所有提示模板"""
        return [
            Prompt(
                name=t.name,
                description=t.description,
                arguments=t.arguments,
            )
            for t in self._templates.values()
        ]
    
    async def execute(self, name: str, args: Dict[str, Any]) -> Dict[str, Any]:
        """执行提示模板"""
        template = self._templates.get(name)
        if not template:
            raise ValueError(f"Prompt not found: {name}")
        return await template.handler(args)
    
    def get_template(self, name: str) -> Optional[PromptTemplate]:
        """获取指定模板"""
        return self._templates.get(name)

# 创建全局注册表
registry = PromptRegistry()

# 装饰器方式注册
def prompt(
    name: str,
    description: str,
    arguments: Optional[List[PromptArgument]] = None
):
    """提示模板装饰器"""
    def decorator(func: Callable):
        template = PromptTemplate(
            name=name,
            description=description,
            arguments=arguments or [],
            handler=func,
        )
        registry.register(template)
        return func
    return decorator

# 使用示例
@prompt(
    name="code-review",
    description="代码审查提示模板",
    arguments=[
        PromptArgument(name="code", description="代码内容", required=True),
        PromptArgument(name="language", description="编程语言", required=True),
    ]
)
async def code_review_handler(args: Dict[str, Any]) -> Dict[str, Any]:
    """代码审查处理器"""
    code = args.get("code", "")
    language = args.get("language", "")
    
    return {
        "description": f"代码审查 ({language})",
        "messages": [
            {
                "role": "system",
                "content": TextContent(
                    type="text",
                    text="你是一位资深代码审查专家..."
                ),
            },
            {
                "role": "user",
                "content": TextContent(
                    type="text",
                    text=f"请审查以下代码:\n\n\`\`\`{language}\n{code}\n\`\`\`"
                ),
            },
        ],
    }

# 动态注册示例
def register_dynamic_prompt(name: str, config: Dict[str, Any]):
    """动态注册提示模板"""
    async def handler(args: Dict[str, Any]) -> Dict[str, Any]:
        # 根据配置动态生成提示
        template_content = config.get("template", "")
        return {
            "description": config.get("description", name),
            "messages": [
                {
                    "role": "system",
                    "content": TextContent(type="text", text=template_content),
                },
            ],
        }
    
    registry.register(PromptTemplate(
        name=name,
        description=config.get("description", ""),
        arguments=[
            PromptArgument(**arg)
            for arg in config.get("arguments", [])
        ],
        handler=handler,
    ))

4. 多轮对话上下文管理

4.1 上下文管理的重要性

在多轮对话场景中,Prompts 需要维护对话历史,确保 LLM 能够理解上下文并做出连贯的响应。

4.2 上下文传递机制

复制代码
┌─────────────────────────────────────────────────────────────┐
│              多轮对话上下文传递流程                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  Round 1:                                                  │
│  ┌─────────┐    prompts/get    ┌─────────┐                │
│  │ Client  │ ─────────────────► │ Server  │                │
│  │         │ ◄───────────────── │         │                │
│  │         │    messages[role]  │         │                │
│  └────┬────┘                    └─────────┘                │
│       │                                                     │
│       ▼ 存储到会话上下文                                     │
│  ┌─────────────┐                                           │
│  │ SessionCtx  │ {role: "assistant", content: "..."}        │
│  └─────────────┘                                           │
│                                                             │
│  Round 2:                                                  │
│  ┌─────────┐    prompts/get    ┌─────────┐                │
│  │ Client  │ ─────────────────► │ Server  │                │
│  │ (含历史) │ ◄───────────────── │         │                │
│  └────┬────┘    messages[]      └─────────┘                │
│       │                                                     │
│       ▼ 追加到历史                                           │
│  ┌─────────────┐                                           │
│  │ SessionCtx  │ [历史消息1, 历史消息2, ...]                │
│  └─────────────┘                                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

4.3 TypeScript 实现:上下文管理

typescript 复制代码
// context/session-manager.ts
import { PromptMessage } from "@modelcontextprotocol/sdk/types.js";

interface SessionContext {
  sessionId: string;
  messages: PromptMessage[];
  metadata: Record<string, unknown>;
  createdAt: Date;
  lastAccessedAt: Date;
}

class SessionManager {
  private sessions: Map<string, SessionContext> = new Map();
  private readonly TTL_MS = 30 * 60 * 1000; // 30分钟过期

  createSession(): string {
    const sessionId = crypto.randomUUID();
    this.sessions.set(sessionId, {
      sessionId,
      messages: [],
      metadata: {},
      createdAt: new Date(),
      lastAccessedAt: new Date(),
    });
    return sessionId;
  }

  getSession(sessionId: string): SessionContext | undefined {
    const session = this.sessions.get(sessionId);
    if (session) {
      session.lastAccessedAt = new Date();
    }
    return session;
  }

  appendMessage(sessionId: string, message: PromptMessage): void {
    const session = this.sessions.get(sessionId);
    if (session) {
      session.messages.push(message);
      session.lastAccessedAt = new Date();
    }
  }

  getContextMessages(sessionId: string, maxHistory: number = 10): PromptMessage[] {
    const session = this.sessions.get(sessionId);
    if (!session) return [];
    
    // 返回最近的消息
    return session.messages.slice(-maxHistory);
  }

  // 清理过期会话
  cleanup(): void {
    const now = Date.now();
    for (const [id, session] of this.sessions.entries()) {
      if (now - session.lastAccessedAt.getTime() > this.TTL_MS) {
        this.sessions.delete(id);
        console.log(`[SessionManager] Cleaned up session: ${id}`);
      }
    }
  }
}

// 使用示例:构建带上下文的提示
async function buildContextualPrompt(
  sessionManager: SessionManager,
  sessionId: string,
  newPrompt: PromptMessage[]
): Promise<PromptMessage[]> {
  const history = sessionManager.getContextMessages(sessionId, 5);
  
  // 组合历史消息和新提示
  return [...history, ...newPrompt];
}

4.4 Python 实现:上下文管理

python 复制代码
# context/session_manager.py
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import asyncio
import uuid

@dataclass
class SessionContext:
    """会话上下文"""
    session_id: str
    messages: List[Dict[str, Any]] = field(default_factory=list)
    metadata: Dict[str, Any] = field(default_factory=dict)
    created_at: datetime = field(default_factory=datetime.now)
    last_accessed_at: datetime = field(default_factory=datetime.now)

class SessionManager:
    """会话管理器"""
    
    TTL_SECONDS = 30 * 60  # 30分钟过期
    
    def __init__(self):
        self._sessions: Dict[str, SessionContext] = {}
        self._cleanup_task: Optional[asyncio.Task] = None
    
    def create_session(self) -> str:
        """创建新会话"""
        session_id = str(uuid.uuid4())
        self._sessions[session_id] = SessionContext(session_id=session_id)
        print(f"[SessionManager] Created session: {session_id}")
        return session_id
    
    def get_session(self, session_id: str) -> Optional[SessionContext]:
        """获取会话"""
        session = self._sessions.get(session_id)
        if session:
            session.last_accessed_at = datetime.now()
        return session
    
    def append_message(self, session_id: str, message: Dict[str, Any]) -> None:
        """追加消息到会话"""
        session = self._sessions.get(session_id)
        if session:
            session.messages.append(message)
            session.last_accessed_at = datetime.now()
    
    def get_context_messages(
        self, 
        session_id: str, 
        max_history: int = 10
    ) -> List[Dict[str, Any]]:
        """获取上下文消息"""
        session = self._sessions.get(session_id)
        if not session:
            return []
        return session.messages[-max_history:]
    
    def cleanup(self) -> None:
        """清理过期会话"""
        now = datetime.now()
        expired = [
            sid for sid, session in self._sessions.items()
            if (now - session.last_accessed_at).total_seconds() > self.TTL_SECONDS
        ]
        for sid in expired:
            del self._sessions[sid]
            print(f"[SessionManager] Cleaned up session: {sid}")
    
    async def start_cleanup_scheduler(self):
        """启动定期清理任务"""
        while True:
            await asyncio.sleep(60)  # 每分钟检查一次
            self.cleanup()

# 使用示例
async def build_contextual_prompt(
    session_manager: SessionManager,
    session_id: str,
    new_messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
    """构建带上下文的提示"""
    history = session_manager.get_context_messages(session_id, max_history=5)
    return history + new_messages

5. Sampling 采样机制:Server 请求 LLM

5.1 Sampling 概述

Sampling 是 MCP 协议中允许服务器向客户端(Host)请求 LLM 采样的机制。这使得服务器能够在需要时获取 LLM 的生成能力,而不需要直接访问 LLM API。

复制代码
┌─────────────────────────────────────────────────────────────────┐
│                    Sampling 架构                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│   ┌─────────────┐         sampling/createMessage    ┌─────────┐│
│   │   Server    │ ─────────────────────────────────► │ Client  ││
│   │             │                                    │ (Host)  ││
│   │             │ ◄───────────────────────────────── │         ││
│   │             │         sampling/createMessage/result        ││
│   └─────────────┘                                    └────┬────┘│
│                                                           │     │
│                                                           ▼     │
│                                                    ┌──────────┐ │
│                                                    │   LLM    │ │
│                                                    │ Provider │ │
│                                                    └──────────┘ │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

5.2 Sampling 请求流程

LLM Provider MCP Client (Host) MCP Server LLM Provider MCP Client (Host) MCP Server 包含 messages, modelPreferences, systemPrompt alt [请求有效] [请求无效] sampling/createMessage 验证请求 检查权限/配额 转发采样请求 返回生成结果 sampling/createMessage/result 返回错误

5.3 TypeScript 实现:Sampling 请求

typescript 复制代码
// sampling/sampling-client.ts
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import {
  CreateMessageRequest,
  CreateMessageResult,
  SamplingMessage,
  ModelPreferences,
} from "@modelcontextprotocol/sdk/types.js";

interface SamplingOptions {
  messages: SamplingMessage[];
  modelPreferences?: ModelPreferences;
  systemPrompt?: string;
  maxTokens?: number;
  temperature?: number;
  stopSequences?: string[];
}

class SamplingClient {
  private client: Client;

  constructor(client: Client) {
    this.client = client;
  }

  /**
   * 发送采样请求
   */
  async createMessage(options: SamplingOptions): Promise<CreateMessageResult> {
    const request: CreateMessageRequest = {
      method: "sampling/createMessage",
      params: {
        messages: options.messages,
        modelPreferences: options.modelPreferences,
        systemPrompt: options.systemPrompt,
        maxTokens: options.maxTokens,
        temperature: options.temperature,
        stopSequences: options.stopSequences,
      },
    };

    console.log("[SamplingClient] Sending request:", JSON.stringify(request, null, 2));

    const result = await this.client.request(request, CreateMessageResultSchema);
    
    console.log("[SamplingClient] Received result:", JSON.stringify(result, null, 2));
    
    return result;
  }

  /**
   * 简单文本生成
   */
  async generateText(
    prompt: string,
    options: Omit<SamplingOptions, "messages"> = {}
  ): Promise<string> {
    const result = await this.createMessage({
      messages: [
        {
          role: "user",
          content: {
            type: "text",
            text: prompt,
          },
        },
      ],
      ...options,
    });

    if (result.content.type === "text") {
      return result.content.text;
    }
    
    throw new Error("Expected text response, got: " + result.content.type);
  }

  /**
   * 带上下文的对话
   */
  async chat(
    messages: SamplingMessage[],
    options: Omit<SamplingOptions, "messages"> = {}
  ): Promise<CreateMessageResult> {
    return await this.createMessage({
      messages,
      ...options,
    });
  }
}

// 使用示例
async function demonstrateSampling() {
  const client = new Client({ name: "example-client", version: "1.0.0" });
  const sampling = new SamplingClient(client);

  // 简单生成
  const response = await sampling.generateText(
    "解释什么是 Model Context Protocol",
    {
      maxTokens: 500,
      temperature: 0.7,
      modelPreferences: {
        hints: [{ name: "claude" }],
        intelligencePriority: 0.8,
        speedPriority: 0.3,
      },
    }
  );

  console.log("Generated:", response);
}

5.4 Python 实现:Sampling 请求

python 复制代码
# sampling/sampling_client.py
from typing import List, Optional, Dict, Any
from dataclasses import dataclass
from mcp.client import Client
from mcp.types import (
    SamplingMessage,
    ModelPreferences,
    CreateMessageRequest,
    CreateMessageResult,
    TextContent,
)

@dataclass
class SamplingOptions:
    """采样选项"""
    messages: List[SamplingMessage]
    model_preferences: Optional[ModelPreferences] = None
    system_prompt: Optional[str] = None
    max_tokens: Optional[int] = None
    temperature: Optional[float] = None
    stop_sequences: Optional[List[str]] = None

class SamplingClient:
    """采样客户端"""
    
    def __init__(self, client: Client):
        self._client = client
    
    async def create_message(self, options: SamplingOptions) -> CreateMessageResult:
        """发送采样请求"""
        request = CreateMessageRequest(
            method="sampling/createMessage",
            params={
                "messages": options.messages,
                "modelPreferences": options.model_preferences,
                "systemPrompt": options.system_prompt,
                "maxTokens": options.max_tokens,
                "temperature": options.temperature,
                "stopSequences": options.stop_sequences,
            }
        )
        
        print(f"[SamplingClient] Sending request: {request}")
        
        result = await self._client.request(request)
        
        print(f"[SamplingClient] Received result: {result}")
        
        return result
    
    async def generate_text(
        self,
        prompt: str,
        **kwargs
    ) -> str:
        """简单文本生成"""
        result = await self.create_message(SamplingOptions(
            messages=[
                SamplingMessage(
                    role="user",
                    content=TextContent(type="text", text=prompt)
                )
            ],
            **kwargs
        ))
        
        if result.content.type == "text":
            return result.content.text
        
        raise ValueError(f"Expected text response, got: {result.content.type}")
    
    async def chat(
        self,
        messages: List[SamplingMessage],
        **kwargs
    ) -> CreateMessageResult:
        """带上下文的对话"""
        return await self.create_message(SamplingOptions(
            messages=messages,
            **kwargs
        ))

# 使用示例
async def demonstrate_sampling():
    client = Client(name="example-client", version="1.0.0")
    sampling = SamplingClient(client)
    
    # 简单生成
    response = await sampling.generate_text(
        "解释什么是 Model Context Protocol",
        max_tokens=500,
        temperature=0.7,
        model_preferences=ModelPreferences(
            hints=[{"name": "claude"}],
            intelligencePriority=0.8,
            speedPriority=0.3,
        )
    )
    
    print(f"Generated: {response}")

6. 采样请求的生命周期管理

6.1 生命周期阶段

复制代码
┌─────────────────────────────────────────────────────────────────┐
│              Sampling 请求生命周期                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  ┌─────────┐   ┌─────────┐   ┌─────────┐   ┌─────────┐        │
│  │ Created │──►│Pending  │──►│Active   │──►│Completed│        │
│  │         │   │         │   │         │   │         │        │
│  └─────────┘   └─────────┘   └─────────┘   └─────────┘        │
│       │            │            │            │                 │
│       ▼            ▼            ▼            ▼                 │
│  请求创建      等待处理      正在采样      完成/失败            │
│                                                                 │
│  超时处理: Pending ──(timeout)──► Failed                        │
│  取消处理: Active  ──(cancel)───► Cancelled                     │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

6.2 TypeScript 实现:生命周期管理

typescript 复制代码
// sampling/lifecycle-manager.ts
import { EventEmitter } from "events";

type SamplingStatus = 
  | "created" 
  | "pending" 
  | "active" 
  | "completed" 
  | "failed" 
  | "cancelled";

interface SamplingRequest {
  id: string;
  status: SamplingStatus;
  createdAt: Date;
  startedAt?: Date;
  completedAt?: Date;
  error?: string;
  result?: any;
}

class SamplingLifecycleManager extends EventEmitter {
  private requests: Map<string, SamplingRequest> = new Map();
  private readonly DEFAULT_TIMEOUT_MS = 60000; // 60秒默认超时

  createRequest(id: string): SamplingRequest {
    const request: SamplingRequest = {
      id,
      status: "created",
      createdAt: new Date(),
    };
    
    this.requests.set(id, request);
    this.emit("request:created", request);
    
    return request;
  }

  transitionTo(id: string, status: SamplingStatus, data?: any): void {
    const request = this.requests.get(id);
    if (!request) {
      throw new Error(`Request not found: ${id}`);
    }

    const oldStatus = request.status;
    request.status = status;

    switch (status) {
      case "pending":
        // 启动超时计时器
        this.startTimeout(id);
        break;
      case "active":
        request.startedAt = new Date();
        break;
      case "completed":
        request.completedAt = new Date();
        request.result = data;
        break;
      case "failed":
      case "cancelled":
        request.completedAt = new Date();
        request.error = data;
        break;
    }

    this.emit(`request:${status}`, request);
    this.emit("request:transition", { request, oldStatus, newStatus: status });
  }

  private startTimeout(id: string): void {
    setTimeout(() => {
      const request = this.requests.get(id);
      if (request && request.status === "pending") {
        this.transitionTo(id, "failed", "Request timeout");
      }
    }, this.DEFAULT_TIMEOUT_MS);
  }

  cancelRequest(id: string): boolean {
    const request = this.requests.get(id);
    if (!request) return false;
    
    if (request.status === "pending" || request.status === "active") {
      this.transitionTo(id, "cancelled", "Cancelled by user");
      return true;
    }
    
    return false;
  }

  getRequest(id: string): SamplingRequest | undefined {
    return this.requests.get(id);
  }

  getActiveRequests(): SamplingRequest[] {
    return Array.from(this.requests.values()).filter(
      (r) => r.status === "pending" || r.status === "active"
    );
  }

  // 清理完成的请求
  cleanup(maxAgeMs: number = 3600000): void {
    const now = Date.now();
    for (const [id, request] of this.requests.entries()) {
      if (request.completedAt) {
        const age = now - request.completedAt.getTime();
        if (age > maxAgeMs) {
          this.requests.delete(id);
          this.emit("request:cleaned", { id, age });
        }
      }
    }
  }
}

// 使用示例
const lifecycle = new SamplingLifecycleManager();

// 监听状态变化
lifecycle.on("request:created", (req) => {
  console.log(`[Lifecycle] Request ${req.id} created`);
});

lifecycle.on("request:completed", (req) => {
  console.log(`[Lifecycle] Request ${req.id} completed`);
});

lifecycle.on("request:failed", (req) => {
  console.log(`[Lifecycle] Request ${req.id} failed: ${req.error}`);
});

6.3 Python 实现:生命周期管理

python 复制代码
# sampling/lifecycle_manager.py
from typing import Dict, Optional, List, Any, Callable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum, auto
import asyncio
import uuid

class SamplingStatus(Enum):
    """采样状态"""
    CREATED = auto()
    PENDING = auto()
    ACTIVE = auto()
    COMPLETED = auto()
    FAILED = auto()
    CANCELLED = auto()

@dataclass
class SamplingRequest:
    """采样请求"""
    id: str
    status: SamplingStatus = SamplingStatus.CREATED
    created_at: datetime = field(default_factory=datetime.now)
    started_at: Optional[datetime] = None
    completed_at: Optional[datetime] = None
    error: Optional[str] = None
    result: Optional[Any] = None

class SamplingLifecycleManager:
    """采样生命周期管理器"""
    
    DEFAULT_TIMEOUT_SECONDS = 60
    
    def __init__(self):
        self._requests: Dict[str, SamplingRequest] = {}
        self._handlers: Dict[str, List[Callable]] = {}
        self._timeout_tasks: Dict[str, asyncio.Task] = {}
    
    def create_request(self, request_id: Optional[str] = None) -> SamplingRequest:
        """创建请求"""
        req_id = request_id or str(uuid.uuid4())
        request = SamplingRequest(id=req_id)
        self._requests[req_id] = request
        self._emit("request:created", request)
        return request
    
    def transition_to(
        self, 
        request_id: str, 
        status: SamplingStatus, 
        data: Any = None
    ) -> None:
        """状态转换"""
        request = self._requests.get(request_id)
        if not request:
            raise ValueError(f"Request not found: {request_id}")
        
        old_status = request.status
        request.status = status
        
        if status == SamplingStatus.PENDING:
            self._start_timeout(request_id)
        elif status == SamplingStatus.ACTIVE:
            request.started_at = datetime.now()
            self._cancel_timeout(request_id)
        elif status in (SamplingStatus.COMPLETED, SamplingStatus.FAILED, SamplingStatus.CANCELLED):
            request.completed_at = datetime.now()
            if status != SamplingStatus.COMPLETED:
                request.error = str(data) if data else None
            else:
                request.result = data
            self._cancel_timeout(request_id)
        
        self._emit(f"request:{status.name.lower()}", request)
        self._emit("request:transition", {
            "request": request,
            "old_status": old_status,
            "new_status": status
        })
    
    def _start_timeout(self, request_id: str) -> None:
        """启动超时计时"""
        async def timeout_handler():
            await asyncio.sleep(self.DEFAULT_TIMEOUT_SECONDS)
            request = self._requests.get(request_id)
            if request and request.status == SamplingStatus.PENDING:
                self.transition_to(request_id, SamplingStatus.FAILED, "Request timeout")
        
        self._timeout_tasks[request_id] = asyncio.create_task(timeout_handler())
    
    def _cancel_timeout(self, request_id: str) -> None:
        """取消超时计时"""
        task = self._timeout_tasks.pop(request_id, None)
        if task:
            task.cancel()
    
    def cancel_request(self, request_id: str) -> bool:
        """取消请求"""
        request = self._requests.get(request_id)
        if not request:
            return False
        
        if request.status in (SamplingStatus.PENDING, SamplingStatus.ACTIVE):
            self.transition_to(request_id, SamplingStatus.CANCELLED, "Cancelled by user")
            return True
        
        return False
    
    def get_request(self, request_id: str) -> Optional[SamplingRequest]:
        """获取请求"""
        return self._requests.get(request_id)
    
    def get_active_requests(self) -> List[SamplingRequest]:
        """获取活跃请求"""
        return [
            r for r in self._requests.values()
            if r.status in (SamplingStatus.PENDING, SamplingStatus.ACTIVE)
        ]
    
    def on(self, event: str, handler: Callable):
        """注册事件处理器"""
        if event not in self._handlers:
            self._handlers[event] = []
        self._handlers[event].append(handler)
    
    def _emit(self, event: str, data: Any):
        """触发事件"""
        handlers = self._handlers.get(event, [])
        for handler in handlers:
            try:
                handler(data)
            except Exception as e:
                print(f"Error in event handler: {e}")
    
    def cleanup(self, max_age_seconds: int = 3600) -> None:
        """清理完成的请求"""
        now = datetime.now()
        to_remove = []
        
        for req_id, request in self._requests.items():
            if request.completed_at:
                age = (now - request.completed_at).total_seconds()
                if age > max_age_seconds:
                    to_remove.append(req_id)
        
        for req_id in to_remove:
            del self._requests[req_id]
            self._emit("request:cleaned", {"id": req_id})

7. 嵌套调用与递归防护

7.1 嵌套调用风险

当服务器使用 Sampling 请求 LLM,而 LLM 又调用同一服务器的 Tools 时,可能形成嵌套调用链,导致无限递归。

复制代码
危险场景:
┌─────────┐     ┌─────────┐     ┌─────────┐     ┌─────────┐
│  User   │────►│ Server  │────►│  Host   │────►│   LLM   │
└─────────┘     └────┬────┘     └─────────┘     └────┬────┘
                     │                                │
                     │◄───────────────────────────────┘
                     │        LLM 调用 tool
                     │                                
                     ▼                                
              ┌─────────────┐                         
              │ tool 执行中 │                         
              │ 需要采样    │────────────────────────►
              └─────────────┘    再次请求采样(递归!)

7.2 递归防护策略

策略 说明 实现复杂度
调用深度限制 限制嵌套调用层数
循环检测 检测重复调用模式
请求标记 标记采样请求来源
超时控制 限制单次调用时间
令牌桶限流 限制采样请求频率

7.3 TypeScript 实现:递归防护

typescript 复制代码
// guards/recursion-guard.ts
interface CallContext {
  requestId: string;
  parentId?: string;
  depth: number;
  timestamp: number;
  toolName?: string;
}

class RecursionGuard {
  private readonly MAX_DEPTH = 3;
  private readonly MAX_CALLS_PER_MINUTE = 10;
  
  private callStack: Map<string, CallContext> = new Map();
  private callHistory: Array<{ timestamp: number; requestId: string }> = [];

  /**
   * 开始新的调用上下文
   */
  beginContext(requestId: string, parentId?: string, toolName?: string): CallContext {
    const parent = parentId ? this.callStack.get(parentId) : undefined;
    const depth = parent ? parent.depth + 1 : 0;
    
    // 检查深度限制
    if (depth > this.MAX_DEPTH) {
      throw new Error(
        `Recursion depth exceeded: ${depth} > ${this.MAX_DEPTH}. ` +
        `Possible infinite recursion detected.`
      );
    }
    
    // 检查调用频率
    this.checkRateLimit();
    
    const context: CallContext = {
      requestId,
      parentId,
      depth,
      timestamp: Date.now(),
      toolName,
    };
    
    this.callStack.set(requestId, context);
    this.callHistory.push({ timestamp: Date.now(), requestId });
    
    console.log(`[RecursionGuard] Context started: ${requestId} (depth: ${depth})`);
    
    return context;
  }

  /**
   * 结束调用上下文
   */
  endContext(requestId: string): void {
    const context = this.callStack.get(requestId);
    if (context) {
      console.log(`[RecursionGuard] Context ended: ${requestId} (duration: ${Date.now() - context.timestamp}ms)`);
      this.callStack.delete(requestId);
    }
  }

  /**
   * 检查调用频率限制
   */
  private checkRateLimit(): void {
    const now = Date.now();
    const oneMinuteAgo = now - 60000;
    
    // 清理旧记录
    this.callHistory = this.callHistory.filter(c => c.timestamp > oneMinuteAgo);
    
    // 检查限制
    if (this.callHistory.length >= this.MAX_CALLS_PER_MINUTE) {
      throw new Error(
        `Rate limit exceeded: ${this.callHistory.length} calls in the last minute. ` +
        `Maximum allowed: ${this.MAX_CALLS_PER_MINUTE}`
      );
    }
  }

  /**
   * 检测循环调用模式
   */
  detectCycle(requestId: string, toolName: string): boolean {
    const context = this.callStack.get(requestId);
    if (!context) return false;
    
    // 检查父链中是否有相同 toolName
    let current: CallContext | undefined = context;
    while (current?.parentId) {
      const parent = this.callStack.get(current.parentId);
      if (parent?.toolName === toolName) {
        console.warn(`[RecursionGuard] Cycle detected: ${toolName} called recursively`);
        return true;
      }
      current = parent;
    }
    
    return false;
  }

  /**
   * 获取当前调用链
   */
  getCallChain(requestId: string): string[] {
    const chain: string[] = [];
    let current: CallContext | undefined = this.callStack.get(requestId);
    
    while (current) {
      chain.unshift(current.toolName || current.requestId);
      current = current.parentId ? this.callStack.get(current.parentId) : undefined;
    }
    
    return chain;
  }

  /**
   * 获取当前活跃上下文数量
   */
  getActiveContextCount(): number {
    return this.callStack.size;
  }
}

// 使用示例
const guard = new RecursionGuard();

async function handleToolCall(toolName: string, parentId?: string) {
  const requestId = crypto.randomUUID();
  
  try {
    // 开始上下文
    const context = guard.beginContext(requestId, parentId, toolName);
    
    // 检测循环
    if (guard.detectCycle(requestId, toolName)) {
      throw new Error(`Circular dependency detected for tool: ${toolName}`);
    }
    
    // 执行工具逻辑...
    console.log(`Call chain: ${guard.getCallChain(requestId).join(" -> ")}`);
    
    // 模拟嵌套调用
    if (toolName === "analyze" && context.depth < 2) {
      await handleToolCall("summarize", requestId);
    }
    
  } finally {
    // 结束上下文
    guard.endContext(requestId);
  }
}

7.4 Python 实现:递归防护

python 复制代码
# guards/recursion_guard.py
from typing import Dict, Optional, List
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import threading
import uuid

@dataclass
class CallContext:
    """调用上下文"""
    request_id: str
    parent_id: Optional[str] = None
    depth: int = 0
    timestamp: datetime = field(default_factory=datetime.now)
    tool_name: Optional[str] = None

class RecursionGuard:
    """递归防护器"""
    
    MAX_DEPTH = 3
    MAX_CALLS_PER_MINUTE = 10
    
    def __init__(self):
        self._call_stack: Dict[str, CallContext] = {}
        self._call_history: List[Dict] = []
        self._lock = threading.RLock()
    
    def begin_context(
        self,
        request_id: Optional[str] = None,
        parent_id: Optional[str] = None,
        tool_name: Optional[str] = None
    ) -> CallContext:
        """开始新的调用上下文"""
        req_id = request_id or str(uuid.uuid4())
        
        with self._lock:
            parent = self._call_stack.get(parent_id) if parent_id else None
            depth = parent.depth + 1 if parent else 0
            
            # 检查深度限制
            if depth > self.MAX_DEPTH:
                raise RecursionError(
                    f"Recursion depth exceeded: {depth} > {self.MAX_DEPTH}. "
                    f"Possible infinite recursion detected."
                )
            
            # 检查调用频率
            self._check_rate_limit()
            
            context = CallContext(
                request_id=req_id,
                parent_id=parent_id,
                depth=depth,
                tool_name=tool_name
            )
            
            self._call_stack[req_id] = context
            self._call_history.append({
                "timestamp": datetime.now(),
                "request_id": req_id
            })
            
            print(f"[RecursionGuard] Context started: {req_id} (depth: {depth})")
            
            return context
    
    def end_context(self, request_id: str) -> None:
        """结束调用上下文"""
        with self._lock:
            context = self._call_stack.get(request_id)
            if context:
                duration = (datetime.now() - context.timestamp).total_seconds() * 1000
                print(f"[RecursionGuard] Context ended: {request_id} (duration: {duration:.0f}ms)")
                del self._call_stack[request_id]
    
    def _check_rate_limit(self) -> None:
        """检查调用频率限制"""
        now = datetime.now()
        one_minute_ago = now - timedelta(minutes=1)
        
        # 清理旧记录
        self._call_history = [
            c for c in self._call_history
            if c["timestamp"] > one_minute_ago
        ]
        
        # 检查限制
        if len(self._call_history) >= self.MAX_CALLS_PER_MINUTE:
            raise RuntimeError(
                f"Rate limit exceeded: {len(self._call_history)} calls in the last minute. "
                f"Maximum allowed: {self.MAX_CALLS_PER_MINUTE}"
            )
    
    def detect_cycle(self, request_id: str, tool_name: str) -> bool:
        """检测循环调用模式"""
        with self._lock:
            context = self._call_stack.get(request_id)
            if not context:
                return False
            
            # 检查父链中是否有相同 tool_name
            current: Optional[CallContext] = context
            while current and current.parent_id:
                parent = self._call_stack.get(current.parent_id)
                if parent and parent.tool_name == tool_name:
                    print(f"[RecursionGuard] Cycle detected: {tool_name} called recursively")
                    return True
                current = parent
            
            return False
    
    def get_call_chain(self, request_id: str) -> List[str]:
        """获取当前调用链"""
        with self._lock:
            chain = []
            current: Optional[CallContext] = self._call_stack.get(request_id)
            
            while current:
                chain.insert(0, current.tool_name or current.request_id)
                current = self._call_stack.get(current.parent_id) if current.parent_id else None
            
            return chain
    
    def get_active_context_count(self) -> int:
        """获取当前活跃上下文数量"""
        with self._lock:
            return len(self._call_stack)

# 上下文管理器
class GuardedContext:
    """受保护的上下文管理器"""
    
    def __init__(self, guard: RecursionGuard, tool_name: str, parent_id: Optional[str] = None):
        self.guard = guard
        self.tool_name = tool_name
        self.parent_id = parent_id
        self.request_id: Optional[str] = None
        self.context: Optional[CallContext] = None
    
    def __enter__(self) -> CallContext:
        self.request_id = str(uuid.uuid4())
        self.context = self.guard.begin_context(
            self.request_id, 
            self.parent_id, 
            self.tool_name
        )
        
        if self.guard.detect_cycle(self.request_id, self.tool_name):
            raise RecursionError(f"Circular dependency detected for tool: {self.tool_name}")
        
        return self.context
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.request_id:
            self.guard.end_context(self.request_id)

# 使用示例
guard = RecursionGuard()

def handle_tool_call(tool_name: str, parent_id: Optional[str] = None):
    with GuardedContext(guard, tool_name, parent_id) as context:
        print(f"Call chain: {' -> '.join(guard.get_call_chain(context.request_id))}")
        
        # 模拟嵌套调用
        if tool_name == "analyze" and context.depth < 2:
            handle_tool_call("summarize", context.request_id)

8. 采样成本控制与优化

8.1 成本因素分析

Sampling 请求的成本主要由以下因素决定:

因素 影响 优化策略
输入 Token 数 直接影响成本 精简提示、使用摘要
输出 Token 数 直接影响成本 设置 maxTokens
模型选择 不同模型价格差异大 根据任务选择合适模型
采样温度 影响生成质量 平衡质量与成本
请求频率 累积成本 批处理、缓存

8.2 成本优化策略

复制代码
┌─────────────────────────────────────────────────────────────────┐
│              采样成本优化策略                                    │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  ┌─────────────┐    ┌─────────────┐    ┌─────────────┐         │
│  │   提示优化   │───►│   模型选择   │───►│   缓存策略   │         │
│  └─────────────┘    └─────────────┘    └─────────────┘         │
│         │                  │                  │                 │
│         ▼                  ▼                  ▼                 │
│  • 去除冗余内容      • 智能路由         • 结果缓存                │
│  • 使用模板变量      • 分层模型         • 相似查询合并            │
│  • 上下文压缩        • 成本预估         • 预热机制                │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

8.3 TypeScript 实现:成本控制器

typescript 复制代码
// cost/cost-controller.ts
interface CostConfig {
  maxInputTokens: number;
  maxOutputTokens: number;
  maxTotalCost: number; // 美元
  preferredModels: string[];
  budgetPeriod: "daily" | "weekly" | "monthly";
}

interface SamplingCost {
  inputTokens: number;
  outputTokens: number;
  estimatedCost: number;
  model: string;
  timestamp: Date;
}

class CostController {
  private config: CostConfig;
  private costs: SamplingCost[] = [];
  private cache: Map<string, { result: any; cost: SamplingCost }> = new Map();

  constructor(config: CostConfig) {
    this.config = config;
  }

  /**
   * 估算采样成本
   */
  estimateCost(inputText: string, outputTokens: number, model: string): SamplingCost {
    // 简化的 Token 估算(实际应使用 tokenizer)
    const inputTokens = Math.ceil(inputText.length / 4);
    
    // 模型价格表(每 1K tokens)
    const prices: Record<string, { input: number; output: number }> = {
      "gpt-4": { input: 0.03, output: 0.06 },
      "gpt-3.5-turbo": { input: 0.0015, output: 0.002 },
      "claude-3-opus": { input: 0.015, output: 0.075 },
      "claude-3-sonnet": { input: 0.003, output: 0.015 },
    };

    const price = prices[model] || prices["gpt-3.5-turbo"];
    const estimatedCost = 
      (inputTokens / 1000) * price.input + 
      (outputTokens / 1000) * price.output;

    return {
      inputTokens,
      outputTokens,
      estimatedCost,
      model,
      timestamp: new Date(),
    };
  }

  /**
   * 检查是否超出预算
   */
  checkBudget(): { allowed: boolean; remaining: number } {
    const periodCosts = this.getPeriodCosts();
    const totalCost = periodCosts.reduce((sum, c) => sum + c.estimatedCost, 0);
    const remaining = this.config.maxTotalCost - totalCost;
    
    return {
      allowed: remaining > 0,
      remaining: Math.max(0, remaining),
    };
  }

  /**
   * 选择最优模型
   */
  selectModel(complexity: "low" | "medium" | "high"): string {
    const modelTiers: Record<string, string[]> = {
      low: ["gpt-3.5-turbo", "claude-3-sonnet"],
      medium: ["claude-3-sonnet", "gpt-4"],
      high: ["claude-3-opus", "gpt-4"],
    };

    const candidates = modelTiers[complexity] || modelTiers.medium;
    
    // 优先选择配置中指定的模型
    for (const preferred of this.config.preferredModels) {
      if (candidates.includes(preferred)) {
        return preferred;
      }
    }
    
    return candidates[0];
  }

  /**
   * 检查缓存
   */
  checkCache(prompt: string): { result: any; cost: SamplingCost } | undefined {
    const hash = this.hashPrompt(prompt);
    return this.cache.get(hash);
  }

  /**
   * 存入缓存
   */
  setCache(prompt: string, result: any, cost: SamplingCost): void {
    const hash = this.hashPrompt(prompt);
    this.cache.set(hash, { result, cost });
    
    // 限制缓存大小
    if (this.cache.size > 1000) {
      const firstKey = this.cache.keys().next().value;
      this.cache.delete(firstKey);
    }
  }

  /**
   * 记录实际成本
   */
  recordCost(cost: SamplingCost): void {
    this.costs.push(cost);
    this.cleanupOldCosts();
  }

  /**
   * 获取周期内的成本
   */
  private getPeriodCosts(): SamplingCost[] {
    const now = new Date();
    const periodStart = new Date();
    
    switch (this.config.budgetPeriod) {
      case "daily":
        periodStart.setHours(0, 0, 0, 0);
        break;
      case "weekly":
        periodStart.setDate(now.getDate() - now.getDay());
        periodStart.setHours(0, 0, 0, 0);
        break;
      case "monthly":
        periodStart.setDate(1);
        periodStart.setHours(0, 0, 0, 0);
        break;
    }
    
    return this.costs.filter(c => c.timestamp >= periodStart);
  }

  /**
   * 清理旧成本记录
   */
  private cleanupOldCosts(): void {
    const oneMonthAgo = new Date();
    oneMonthAgo.setMonth(oneMonthAgo.getMonth() - 1);
    this.costs = this.costs.filter(c => c.timestamp >= oneMonthAgo);
  }

  /**
   * 生成提示哈希
   */
  private hashPrompt(prompt: string): string {
    // 简化实现,实际应使用 crypto 模块
    let hash = 0;
    for (let i = 0; i < prompt.length; i++) {
      const char = prompt.charCodeAt(i);
      hash = ((hash << 5) - hash) + char;
      hash = hash & hash;
    }
    return hash.toString(16);
  }

  /**
   * 获取成本报告
   */
  getCostReport(): {
    totalCost: number;
    totalCalls: number;
    averageCost: number;
    byModel: Record<string, { calls: number; cost: number }>;
  } {
    const periodCosts = this.getPeriodCosts();
    const totalCost = periodCosts.reduce((sum, c) => sum + c.estimatedCost, 0);
    
    const byModel: Record<string, { calls: number; cost: number }> = {};
    for (const cost of periodCosts) {
      if (!byModel[cost.model]) {
        byModel[cost.model] = { calls: 0, cost: 0 };
      }
      byModel[cost.model].calls++;
      byModel[cost.model].cost += cost.estimatedCost;
    }

    return {
      totalCost,
      totalCalls: periodCosts.length,
      averageCost: totalCost / periodCosts.length || 0,
      byModel,
    };
  }
}

// 使用示例
const costController = new CostController({
  maxInputTokens: 4000,
  maxOutputTokens: 1000,
  maxTotalCost: 10.0, // $10 per period
  preferredModels: ["claude-3-sonnet", "gpt-3.5-turbo"],
  budgetPeriod: "daily",
});

async function optimizedSampling(prompt: string) {
  // 检查缓存
  const cached = costController.checkCache(prompt);
  if (cached) {
    console.log("[CostController] Cache hit!");
    return cached.result;
  }

  // 检查预算
  const budget = costController.checkBudget();
  if (!budget.allowed) {
    throw new Error(`Budget exceeded. Remaining: $${budget.remaining.toFixed(2)}`);
  }

  // 选择模型
  const model = costController.selectModel("medium");
  
  // 估算成本
  const estimatedCost = costController.estimateCost(prompt, 500, model);
  console.log(`[CostController] Estimated cost: $${estimatedCost.estimatedCost.toFixed(4)}`);

  // 执行采样...
  const result = { /* sampling result */ };
  
  // 记录成本
  costController.recordCost(estimatedCost);
  
  // 存入缓存
  costController.setCache(prompt, result, estimatedCost);
  
  return result;
}

8.4 Python 实现:成本控制器

python 复制代码
# cost/cost_controller.py
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime, timedelta
import hashlib
from enum import Enum

class BudgetPeriod(Enum):
    DAILY = "daily"
    WEEKLY = "weekly"
    MONTHLY = "monthly"

@dataclass
class CostConfig:
    """成本配置"""
    max_input_tokens: int = 4000
    max_output_tokens: int = 1000
    max_total_cost: float = 10.0  # 美元
    preferred_models: List[str] = None
    budget_period: BudgetPeriod = BudgetPeriod.DAILY
    
    def __post_init__(self):
        if self.preferred_models is None:
            self.preferred_models = ["claude-3-sonnet", "gpt-3.5-turbo"]

@dataclass
class SamplingCost:
    """采样成本"""
    input_tokens: int
    output_tokens: int
    estimated_cost: float
    model: str
    timestamp: datetime

class CostController:
    """成本控制器"""
    
    # 模型价格表(每 1K tokens)
    PRICES = {
        "gpt-4": {"input": 0.03, "output": 0.06},
        "gpt-3.5-turbo": {"input": 0.0015, "output": 0.002},
        "claude-3-opus": {"input": 0.015, "output": 0.075},
        "claude-3-sonnet": {"input": 0.003, "output": 0.015},
    }
    
    def __init__(self, config: CostConfig):
        self.config = config
        self._costs: List[SamplingCost] = []
        self._cache: Dict[str, Dict] = {}
    
    def estimate_cost(self, input_text: str, output_tokens: int, model: str) -> SamplingCost:
        """估算采样成本"""
        # 简化的 Token 估算
        input_tokens = len(input_text) // 4
        
        price = self.PRICES.get(model, self.PRICES["gpt-3.5-turbo"])
        estimated_cost = (
            (input_tokens / 1000) * price["input"] +
            (output_tokens / 1000) * price["output"]
        )
        
        return SamplingCost(
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            estimated_cost=estimated_cost,
            model=model,
            timestamp=datetime.now()
        )
    
    def check_budget(self) -> Dict[str, any]:
        """检查是否超出预算"""
        period_costs = self._get_period_costs()
        total_cost = sum(c.estimated_cost for c in period_costs)
        remaining = self.config.max_total_cost - total_cost
        
        return {
            "allowed": remaining > 0,
            "remaining": max(0, remaining),
            "used": total_cost,
            "total": self.config.max_total_cost
        }
    
    def select_model(self, complexity: str = "medium") -> str:
        """选择最优模型"""
        model_tiers = {
            "low": ["gpt-3.5-turbo", "claude-3-sonnet"],
            "medium": ["claude-3-sonnet", "gpt-4"],
            "high": ["claude-3-opus", "gpt-4"],
        }
        
        candidates = model_tiers.get(complexity, model_tiers["medium"])
        
        # 优先选择配置中指定的模型
        for preferred in self.config.preferred_models:
            if preferred in candidates:
                return preferred
        
        return candidates[0]
    
    def check_cache(self, prompt: str) -> Optional[Dict]:
        """检查缓存"""
        hash_key = self._hash_prompt(prompt)
        return self._cache.get(hash_key)
    
    def set_cache(self, prompt: str, result: any, cost: SamplingCost) -> None:
        """存入缓存"""
        hash_key = self._hash_prompt(prompt)
        self._cache[hash_key] = {"result": result, "cost": cost}
        
        # 限制缓存大小
        if len(self._cache) > 1000:
            oldest_key = next(iter(self._cache))
            del self._cache[oldest_key]
    
    def record_cost(self, cost: SamplingCost) -> None:
        """记录实际成本"""
        self._costs.append(cost)
        self._cleanup_old_costs()
    
    def _get_period_costs(self) -> List[SamplingCost]:
        """获取周期内的成本"""
        now = datetime.now()
        
        if self.config.budget_period == BudgetPeriod.DAILY:
            period_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
        elif self.config.budget_period == BudgetPeriod.WEEKLY:
            period_start = now - timedelta(days=now.weekday())
            period_start = period_start.replace(hour=0, minute=0, second=0, microsecond=0)
        else:  # MONTHLY
            period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
        
        return [c for c in self._costs if c.timestamp >= period_start]
    
    def _cleanup_old_costs(self) -> None:
        """清理旧成本记录"""
        one_month_ago = datetime.now() - timedelta(days=30)
        self._costs = [c for c in self._costs if c.timestamp >= one_month_ago]
    
    def _hash_prompt(self, prompt: str) -> str:
        """生成提示哈希"""
        return hashlib.md5(prompt.encode()).hexdigest()
    
    def get_cost_report(self) -> Dict:
        """获取成本报告"""
        period_costs = self._get_period_costs()
        total_cost = sum(c.estimated_cost for c in period_costs)
        
        by_model: Dict[str, Dict] = {}
        for cost in period_costs:
            if cost.model not in by_model:
                by_model[cost.model] = {"calls": 0, "cost": 0.0}
            by_model[cost.model]["calls"] += 1
            by_model[cost.model]["cost"] += cost.estimated_cost
        
        return {
            "total_cost": total_cost,
            "total_calls": len(period_costs),
            "average_cost": total_cost / len(period_costs) if period_costs else 0,
            "by_model": by_model,
            "period": self.config.budget_period.value
        }

# 使用示例
def optimized_sampling_example():
    controller = CostController(CostConfig(
        max_input_tokens=4000,
        max_output_tokens=1000,
        max_total_cost=10.0,
        preferred_models=["claude-3-sonnet", "gpt-3.5-turbo"],
        budget_period=BudgetPeriod.DAILY
    ))
    
    prompt = "分析这段代码的性能问题"
    
    # 检查缓存
    cached = controller.check_cache(prompt)
    if cached:
        print("[CostController] Cache hit!")
        return cached["result"]
    
    # 检查预算
    budget = controller.check_budget()
    if not budget["allowed"]:
        raise RuntimeError(f"Budget exceeded. Remaining: ${budget['remaining']:.2f}")
    
    # 选择模型
    model = controller.select_model("medium")
    
    # 估算成本
    estimated = controller.estimate_cost(prompt, 500, model)
    print(f"[CostController] Estimated cost: ${estimated.estimated_cost:.4f}")
    
    # 执行采样...
    result = {"analysis": "code analysis result"}
    
    # 记录成本
    controller.record_cost(estimated)
    
    # 存入缓存
    controller.set_cache(prompt, result, estimated)
    
    return result

9. 提示模板与采样的安全考虑

9.1 安全风险分析

风险类型 描述 防护措施
提示注入 恶意输入操纵 LLM 行为 输入验证、参数转义
数据泄露 敏感信息通过采样泄露 数据脱敏、权限控制
资源耗尽 恶意请求消耗大量资源 限流、配额管理
递归攻击 利用嵌套调用造成 DoS 深度限制、超时控制
模型滥用 生成有害内容 内容过滤、审核机制

9.2 输入验证与净化

typescript 复制代码
// security/input-validator.ts
interface ValidationRule {
  type: "string" | "number" | "boolean" | "array" | "object";
  required?: boolean;
  minLength?: number;
  maxLength?: number;
  pattern?: RegExp;
  enum?: unknown[];
  sanitize?: (value: unknown) => unknown;
}

class InputValidator {
  private readonly DANGEROUS_PATTERNS = [
    /ignore\s+previous/i,
    /system\s+prompt/i,
    /override\s+instructions/i,
    /<script/i,
    /javascript:/i,
  ];

  /**
   * 验证并净化输入
   */
  validateAndSanitize(
    input: Record<string, unknown>,
    rules: Record<string, ValidationRule>
  ): { valid: boolean; sanitized: Record<string, unknown>; errors: string[] } {
    const sanitized: Record<string, unknown> = {};
    const errors: string[] = [];

    for (const [key, rule] of Object.entries(rules)) {
      const value = input[key];

      // 检查必填
      if (rule.required && (value === undefined || value === null)) {
        errors.push(`Missing required field: ${key}`);
        continue;
      }

      // 可选字段为空时跳过
      if (!rule.required && (value === undefined || value === null)) {
        continue;
      }

      // 类型检查
      if (!this.checkType(value, rule.type)) {
        errors.push(`Invalid type for ${key}: expected ${rule.type}`);
        continue;
      }

      // 字符串特定检查
      if (rule.type === "string" && typeof value === "string") {
        // 长度检查
        if (rule.minLength && value.length < rule.minLength) {
          errors.push(`${key} is too short (min ${rule.minLength})`);
          continue;
        }
        if (rule.maxLength && value.length > rule.maxLength) {
          errors.push(`${key} is too long (max ${rule.maxLength})`);
          continue;
        }

        // 模式检查
        if (rule.pattern && !rule.pattern.test(value)) {
          errors.push(`${key} does not match required pattern`);
          continue;
        }

        // 危险模式检查
        if (this.containsDangerousPatterns(value)) {
          errors.push(`${key} contains potentially dangerous content`);
          continue;
        }

        // 净化
        let sanitizedValue = this.sanitizeString(value);
        if (rule.sanitize) {
          sanitizedValue = rule.sanitize(sanitizedValue) as string;
        }
        sanitized[key] = sanitizedValue;
      } else {
        sanitized[key] = value;
      }
    }

    return {
      valid: errors.length === 0,
      sanitized,
      errors,
    };
  }

  /**
   * 检查类型
   */
  private checkType(value: unknown, type: string): boolean {
    switch (type) {
      case "string":
        return typeof value === "string";
      case "number":
        return typeof value === "number" && !isNaN(value);
      case "boolean":
        return typeof value === "boolean";
      case "array":
        return Array.isArray(value);
      case "object":
        return typeof value === "object" && value !== null && !Array.isArray(value);
      default:
        return false;
    }
  }

  /**
   * 检查危险模式
   */
  private containsDangerousPatterns(value: string): boolean {
    return this.DANGEROUS_PATTERNS.some((pattern) => pattern.test(value));
  }

  /**
   * 净化字符串
   */
  private sanitizeString(value: string): string {
    return value
      .replace(/[<>]/g, "") // 移除 HTML 标签
      .trim();
  }
}

// 使用示例
const validator = new InputValidator();

const rules = {
  code: {
    type: "string" as const,
    required: true,
    maxLength: 10000,
  },
  language: {
    type: "string" as const,
    required: true,
    enum: ["typescript", "python", "javascript", "java"],
  },
  focus: {
    type: "string" as const,
    required: false,
    pattern: /^[a-z]+$/,
  },
};

const result = validator.validateAndSanitize(
  {
    code: "function test() { return 1; }",
    language: "javascript",
    focus: "performance",
  },
  rules
);

if (!result.valid) {
  console.error("Validation failed:", result.errors);
}

9.3 数据脱敏处理

python 复制代码
# security/data_masking.py
import re
from typing import Dict, List, Pattern
from dataclasses import dataclass

@dataclass
class MaskingRule:
    """脱敏规则"""
    name: str
    pattern: Pattern
    replacement: str
    description: str

class DataMasker:
    """数据脱敏器"""
    
    DEFAULT_RULES = [
        MaskingRule(
            name="email",
            pattern=re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'),
            replacement="[EMAIL_MASKED]",
            description="邮箱地址"
        ),
        MaskingRule(
            name="phone",
            pattern=re.compile(r'\b1[3-9]\d{9}\b'),
            replacement="[PHONE_MASKED]",
            description="手机号码"
        ),
        MaskingRule(
            name="credit_card",
            pattern=re.compile(r'\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b'),
            replacement="[CARD_MASKED]",
            description="信用卡号"
        ),
        MaskingRule(
            name="api_key",
            pattern=re.compile(r'(?:api[_-]?key|token)["\']?\s*[:=]\s*["\']?[a-zA-Z0-9]{32,}["\']?', re.IGNORECASE),
            replacement="[API_KEY_MASKED]",
            description="API 密钥"
        ),
        MaskingRule(
            name="password",
            pattern=re.compile(r'(?:password|passwd|pwd)["\']?\s*[:=]\s*["\']?[^"\'\s]+["\']?', re.IGNORECASE),
            replacement="[PASSWORD_MASKED]",
            description="密码"
        ),
        MaskingRule(
            name="ssn",
            pattern=re.compile(r'\b\d{3}-\d{2}-\d{4}\b'),
            replacement="[SSN_MASKED]",
            description="社会安全号码"
        ),
    ]
    
    def __init__(self, rules: List[MaskingRule] = None):
        self.rules = rules or self.DEFAULT_RULES
    
    def mask(self, text: str, strict: bool = False) -> str:
        """
        对文本进行脱敏处理
        
        Args:
            text: 原始文本
            strict: 是否严格模式(发现敏感信息则拒绝)
        
        Returns:
            脱敏后的文本
        """
        masked_text = text
        found_sensitive = []
        
        for rule in self.rules:
            matches = rule.pattern.findall(masked_text)
            if matches:
                found_sensitive.append({
                    "rule": rule.name,
                    "description": rule.description,
                    "count": len(matches)
                })
                masked_text = rule.pattern.sub(rule.replacement, masked_text)
        
        if strict and found_sensitive:
            raise ValueError(
                f"Sensitive data detected: "
                f"{', '.join(f['description'] for f in found_sensitive)}"
            )
        
        return masked_text
    
    def scan(self, text: str) -> List[Dict]:
        """扫描文本中的敏感信息"""
        findings = []
        
        for rule in self.rules:
            matches = rule.pattern.findall(text)
            if matches:
                findings.append({
                    "rule": rule.name,
                    "description": rule.description,
                    "matches": matches[:5],  # 最多显示5个
                    "count": len(matches)
                })
        
        return findings
    
    def add_custom_rule(self, rule: MaskingRule) -> None:
        """添加自定义脱敏规则"""
        self.rules.append(rule)

# 使用示例
def secure_sampling_example():
    masker = DataMasker()
    
    # 包含敏感信息的代码
    code_with_secrets = """
    const config = {
        apiKey: "sk-1234567890abcdef1234567890abcdef",
        email: "admin@company.com",
        password: "super_secret_password123"
    };
    """
    
    # 扫描敏感信息
    findings = masker.scan(code_with_secrets)
    print("Found sensitive data:", findings)
    
    # 脱敏处理
    masked_code = masker.mask(code_with_secrets)
    print("Masked code:", masked_code)
    
    # 严格模式(发现敏感信息则抛出异常)
    try:
        masker.mask(code_with_secrets, strict=True)
    except ValueError as e:
        print(f"Security violation: {e}")

10. 实际应用场景与案例

10.1 场景一:智能代码审查系统

复制代码
┌─────────────────────────────────────────────────────────────────┐
│              智能代码审查系统架构                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  ┌─────────────┐    ┌─────────────┐    ┌─────────────┐         │
│  │   GitHub    │───►│ MCP Server  │───►│  LLM Host   │         │
│  │   Webhook   │    │ (Reviewer)  │    │  (Claude)   │         │
│  └─────────────┘    └──────┬──────┘    └─────────────┘         │
│                            │                                    │
│                            ▼                                    │
│                     ┌─────────────┐                            │
│                     │   Prompts   │                            │
│                     │ • security  │                            │
│                     │ • style     │                            │
│                     │ • perf      │                            │
│                     └─────────────┘                            │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

TypeScript 实现:

typescript 复制代码
// examples/code-review-system.ts
import { Server } from "@modelcontextprotocol/sdk/server/index.js";

interface PullRequest {
  number: number;
  title: string;
  files: Array<{
    filename: string;
    patch: string;
    status: "added" | "modified" | "removed";
  }>;
}

class CodeReviewServer {
  private server: Server;

  constructor() {
    this.server = new Server({
      name: "code-review-server",
      version: "1.0.0",
    });

    this.setupPrompts();
    this.setupTools();
  }

  private setupPrompts(): void {
    // 注册代码审查提示
    this.server.setRequestHandler(ListPromptsRequestSchema, async () => {
      return {
        prompts: [
          {
            name: "pr-security-review",
            description: "对 PR 进行安全审查",
            arguments: [
              { name: "prNumber", description: "PR 编号", required: true },
              { name: "focus", description: "审查重点", required: false },
            ],
          },
          {
            name: "pr-style-review",
            description: "对 PR 进行代码风格审查",
            arguments: [
              { name: "prNumber", description: "PR 编号", required: true },
            ],
          },
        ],
      };
    });

    // 处理提示请求
    this.server.setRequestHandler(GetPromptRequestSchema, async (request) => {
      const { name, arguments: args } = request.params;
      const pr = await this.fetchPR(args.prNumber as number);

      if (name === "pr-security-review") {
        return this.buildSecurityReviewPrompt(pr, args.focus as string);
      }

      if (name === "pr-style-review") {
        return this.buildStyleReviewPrompt(pr);
      }

      throw new Error(`Unknown prompt: ${name}`);
    });
  }

  private async fetchPR(prNumber: number): Promise<PullRequest> {
    // 从 GitHub API 获取 PR 信息
    // 实现省略...
    return {
      number: prNumber,
      title: "Example PR",
      files: [],
    };
  }

  private buildSecurityReviewPrompt(pr: PullRequest, focus?: string): {
    description: string;
    messages: any[];
  } {
    const codeToReview = pr.files
      .filter((f) => f.status !== "removed")
      .map((f) => `File: ${f.filename}\n\`\`\`\n${f.patch}\n\`\`\``)
      .join("\n\n");

    return {
      description: `Security Review for PR #${pr.number}`,
      messages: [
        {
          role: "system",
          content: {
            type: "text",
            text: `You are a security expert. Review the code for:
- SQL injection vulnerabilities
- XSS vulnerabilities  
- Authentication/authorization issues
- Sensitive data exposure
- ${focus || "All security concerns"}`,
          },
        },
        {
          role: "user",
          content: {
            type: "text",
            text: `Please review this PR for security issues:\n\n${codeToReview}`,
          },
        },
      ],
    };
  }

  private setupTools(): void {
    // 注册工具:提交审查评论
    this.server.setRequestHandler(CallToolRequestSchema, async (request) => {
      if (request.params.name === "submit-review") {
        const { prNumber, review } = request.params.arguments as {
          prNumber: number;
          review: string;
        };

        // 提交审查评论到 GitHub
        await this.submitReviewToGitHub(prNumber, review);

        return {
          content: [{ type: "text", text: "Review submitted successfully" }],
        };
      }

      throw new Error("Unknown tool");
    });
  }

  private async submitReviewToGitHub(prNumber: number, review: string): Promise<void> {
    // GitHub API 调用
    // 实现省略...
  }
}

10.2 场景二:数据分析助手

python 复制代码
# examples/data-analysis-assistant.py
from mcp.server import Server
from mcp.types import Prompt, TextContent
import pandas as pd
import json

class DataAnalysisServer:
    """数据分析助手 MCP 服务器"""
    
    def __init__(self):
        self.server = Server("data-analysis-server")
        self.data_cache: Dict[str, pd.DataFrame] = {}
        
    def setup_prompts(self):
        """设置提示模板"""
        
        @self.server.list_prompts()
        async def list_prompts():
            return [
                Prompt(
                    name="analyze-dataset",
                    description="分析数据集并提供洞察",
                    arguments=[
                        {"name": "dataset_id", "description": "数据集ID", "required": True},
                        {"name": "analysis_type", "description": "分析类型", "required": False},
                    ]
                ),
                Prompt(
                    name="generate-report",
                    description="生成数据分析报告",
                    arguments=[
                        {"name": "dataset_id", "description": "数据集ID", "required": True},
                        {"name": "report_format", "description": "报告格式", "required": False},
                    ]
                ),
            ]
        
        @self.server.get_prompt()
        async def get_prompt(name: str, arguments: dict):
            dataset_id = arguments.get("dataset_id")
            df = self.data_cache.get(dataset_id)
            
            if df is None:
                raise ValueError(f"Dataset not found: {dataset_id}")
            
            # 生成数据集摘要
            summary = self._generate_summary(df)
            
            if name == "analyze-dataset":
                analysis_type = arguments.get("analysis_type", "comprehensive")
                return self._build_analysis_prompt(summary, analysis_type)
            
            elif name == "generate-report":
                report_format = arguments.get("report_format", "markdown")
                return self._build_report_prompt(summary, report_format)
            
            raise ValueError(f"Unknown prompt: {name}")
    
    def _generate_summary(self, df: pd.DataFrame) -> dict:
        """生成数据集摘要"""
        return {
            "shape": df.shape,
            "columns": df.columns.tolist(),
            "dtypes": df.dtypes.to_dict(),
            "missing": df.isnull().sum().to_dict(),
            "sample": df.head(5).to_dict(),
            "describe": df.describe().to_dict(),
        }
    
    def _build_analysis_prompt(self, summary: dict, analysis_type: str) -> dict:
        """构建分析提示"""
        
        analysis_focus = {
            "comprehensive": "全面分析,包括统计摘要、趋势、异常值和相关性",
            "trend": "重点关注时间趋势和变化模式",
            "correlation": "重点关注变量间的相关性分析",
            "anomaly": "重点关注异常值和离群点检测",
        }
        
        return {
            "description": f"Dataset Analysis ({analysis_type})",
            "messages": [
                {
                    "role": "system",
                    "content": TextContent(
                        type="text",
                        text=f"""你是一位数据分析专家。请对提供的数据集进行专业分析。

分析重点:{analysis_focus.get(analysis_type, analysis_focus["comprehensive"])}

请提供:
1. 数据概览和关键指标
2. 主要发现和洞察
3. 数据质量问题(如有)
4. 可视化建议
5. 进一步分析的建议"""
                    )
                },
                {
                    "role": "user",
                    "content": TextContent(
                        type="text",
                        text=f"请分析以下数据集:\n\n```json\n{json.dumps(summary, indent=2, default=str)}\n```"
                    )
                }
            ]
        }
    
    def _build_report_prompt(self, summary: dict, report_format: str) -> dict:
        """构建报告提示"""
        
        format_instructions = {
            "markdown": "使用 Markdown 格式,包含标题、表格和列表",
            "html": "使用 HTML 格式,包含样式和表格",
            "json": "使用 JSON 格式,结构化输出",
        }
        
        return {
            "description": f"Analysis Report ({report_format})",
            "messages": [
                {
                    "role": "system",
                    "content": TextContent(
                        type="text",
                        text=f"""你是一位数据报告专家。请基于数据集生成专业报告。

格式要求:{format_instructions.get(report_format, format_instructions["markdown"])}

报告结构:
1. 执行摘要
2. 数据描述
3. 分析方法
4. 主要发现
5. 结论和建议"""
                    )
                },
                {
                    "role": "user",
                    "content": TextContent(
                        type="text",
                        text=f"请为以下数据集生成报告:\n\n```json\n{json.dumps(summary, indent=2, default=str)}\n```"
                    )
                }
            ]
        }

# 使用示例
async def run_analysis_example():
    server = DataAnalysisServer()
    
    # 加载示例数据
    server.data_cache["sales_2024"] = pd.DataFrame({
        "date": pd.date_range("2024-01-01", periods=100),
        "product": ["A", "B", "C"] * 33 + ["A"],
        "sales": [100 + i * 2 + (i % 10) * 5 for i in range(100)],
        "region": ["North", "South", "East", "West"] * 25,
    })
    
    server.setup_prompts()

10.3 场景三:智能客服系统

typescript 复制代码
// examples/customer-service-bot.ts
import { Server } from "@modelcontextprotocol/sdk/server/index.js";

interface Customer {
  id: string;
  name: string;
  tier: "basic" | "premium" | "enterprise";
  history: Array<{
    timestamp: Date;
    query: string;
    resolved: boolean;
  }>;
}

interface KnowledgeBase {
  articles: Array<{
    id: string;
    title: string;
    content: string;
    tags: string[];
  }>;
}

class CustomerServiceServer {
  private server: Server;
  private customers: Map<string, Customer> = new Map();
  private knowledgeBase: KnowledgeBase;

  constructor() {
    this.server = new Server({
      name: "customer-service-server",
      version: "1.0.0",
    });

    this.setupPrompts();
  }

  private setupPrompts(): void {
    // 注册客服提示模板
    this.server.setRequestHandler(ListPromptsRequestSchema, async () => {
      return {
        prompts: [
          {
            name: "handle-inquiry",
            description: "处理客户咨询",
            arguments: [
              { name: "customerId", description: "客户ID", required: true },
              { name: "query", description: "客户问题", required: true },
              { name: "category", description: "问题分类", required: false },
            ],
          },
          {
            name: "escalate-ticket",
            description: "升级工单到人工",
            arguments: [
              { name: "customerId", description: "客户ID", required: true },
              { name: "reason", description: "升级原因", required: true },
            ],
          },
        ],
      };
    });

    // 处理提示请求
    this.server.setRequestHandler(GetPromptRequestSchema, async (request) => {
      const { name, arguments: args } = request.params;
      const customer = this.customers.get(args.customerId as string);

      if (!customer) {
        throw new Error(`Customer not found: ${args.customerId}`);
      }

      if (name === "handle-inquiry") {
        return this.buildInquiryPrompt(customer, args.query as string, args.category as string);
      }

      if (name === "escalate-ticket") {
        return this.buildEscalationPrompt(customer, args.reason as string);
      }

      throw new Error(`Unknown prompt: ${name}`);
    });
  }

  private buildInquiryPrompt(
    customer: Customer,
    query: string,
    category?: string
  ): { description: string; messages: any[] } {
    // 检索相关知识库文章
    const relevantArticles = this.searchKnowledgeBase(query);
    
    // 构建客户上下文
    const customerContext = `
Customer: ${customer.name}
Tier: ${customer.tier}
Previous Interactions: ${customer.history.length}
Recent Issues: ${customer.history
  .slice(-3)
  .map((h) => h.query)
  .join("; ")}
    `;

    // 构建知识库上下文
    const kbContext = relevantArticles
      .map((a) => `Article: ${a.title}\n${a.content}`)
      .join("\n\n");

    return {
      description: `Handle inquiry from ${customer.name}`,
      messages: [
        {
          role: "system",
          content: {
            type: "text",
            text: `You are a helpful customer service representative. 
Use the provided knowledge base articles to answer the customer's question.
Be polite, professional, and concise.

Customer Context:
${customerContext}

Knowledge Base:
${kbContext}`,
          },
        },
        {
          role: "user",
          content: {
            type: "text",
            text: `Customer Query${category ? ` [${category}]` : ""}: ${query}`,
          },
        },
      ],
    };
  }

  private searchKnowledgeBase(query: string): typeof this.knowledgeBase.articles {
    // 简单的关键词匹配
    const keywords = query.toLowerCase().split(" ");
    return this.knowledgeBase.articles
      .filter((article) =>
        keywords.some(
          (k) =>
            article.title.toLowerCase().includes(k) ||
            article.content.toLowerCase().includes(k) ||
            article.tags.some((t) => t.toLowerCase().includes(k))
        )
      )
      .slice(0, 3); // 返回最相关的3篇
  }

  private buildEscalationPrompt(customer: Customer, reason: string) {
    return {
      description: `Escalate ticket for ${customer.name}`,
      messages: [
        {
          role: "system",
          content: {
            type: "text",
            text: `Generate a ticket escalation summary for the support team.`,
          },
        },
        {
          role: "user",
          content: {
            type: "text",
            text: `Escalate ticket for customer ${customer.name} (Tier: ${customer.tier})
Reason: ${reason}
History: ${JSON.stringify(customer.history.slice(-5))}`,
          },
        },
      ],
    };
  }
}

11. 常见问题 FAQ

Q1: Prompts 和 Tools 的主要区别是什么?

A: Prompts 用于提供结构化的提示模板,返回的是消息列表(给 LLM 的上下文);Tools 用于执行具体的功能操作,返回的是执行结果。Prompts 通常由用户选择使用,而 Tools 由 LLM 自动调用。

Q2: 如何实现动态提示模板?

A: 动态提示模板可以通过以下方式实现:

  1. 使用参数化模板,根据输入参数动态生成提示内容
  2. GetPromptRequest 处理器中根据参数动态构建消息列表
  3. 结合外部数据源(如数据库、API)动态获取提示内容
typescript 复制代码
// 示例:动态生成提示
server.setRequestHandler(GetPromptRequestSchema, async (request) => {
  const { name, arguments: args } = request.params;
  
  // 根据参数动态生成
  const dynamicContent = await fetchDynamicContent(args);
  
  return {
    description: `Dynamic prompt for ${name}`,
    messages: [
      { role: "system", content: { type: "text", text: dynamicContent } },
    ],
  };
});

Q3: Sampling 请求失败如何处理?

A: Sampling 请求失败时应采取以下策略:

  1. 重试机制:对临时错误进行指数退避重试
  2. 降级策略:切换到备用模型或简化请求
  3. 错误反馈:向用户清晰说明失败原因
  4. 日志记录:记录失败详情用于后续分析
typescript 复制代码
try {
  const result = await samplingClient.createMessage(options);
} catch (error) {
  if (error.code === "RATE_LIMITED") {
    // 等待后重试
    await delay(1000);
    return await retryWithBackoff(options);
  }
  if (error.code === "MODEL_UNAVAILABLE") {
    // 切换到备用模型
    return await fallbackToAlternativeModel(options);
  }
  throw error;
}

Q4: 如何防止提示注入攻击?

A: 防止提示注入的主要措施:

  1. 输入验证:使用正则表达式过滤危险模式
  2. 参数转义:对用户输入进行适当的转义处理
  3. 长度限制:限制输入内容的长度
  4. 内容审查:使用内容安全策略检测恶意输入
  5. 最小权限:限制 Prompts 可访问的资源和数据

Q5: 采样请求的上下文长度有限制吗?

A: 是的,采样请求的上下文长度受以下限制:

  • 模型限制:不同模型的最大上下文长度不同(如 GPT-4 支持 8K/32K tokens)
  • 配置限制 :可以通过 maxTokens 参数限制输出长度
  • 协议限制:MCP 协议本身没有硬性限制,但建议保持合理大小

建议实现上下文压缩策略,如只保留最近的 N 条消息。

Q6: 如何实现多轮对话的上下文保持?

A: 多轮对话上下文保持可以通过以下方式实现:

  1. 会话管理:使用 Session ID 跟踪对话状态
  2. 消息历史:在服务器端存储消息历史
  3. 上下文传递:在每次请求中包含历史消息
python 复制代码
class SessionManager:
    def __init__(self):
        self.sessions = {}
    
    def append_message(self, session_id, message):
        if session_id not in self.sessions:
            self.sessions[session_id] = []
        self.sessions[session_id].append(message)
    
    def get_context(self, session_id, max_history=10):
        return self.sessions.get(session_id, [])[-max_history:]

Q7: Sampling 的成本如何控制?

A: 成本控制策略包括:

  1. 模型选择:根据任务复杂度选择性价比合适的模型
  2. Token 限制 :设置合理的 maxTokens 限制
  3. 缓存机制:缓存常见查询的结果
  4. 预算管理:设置每日/每周采样预算上限
  5. 批处理:合并多个小请求为批量请求

Q8: 嵌套调用深度应该限制在多少层?

A: 建议将嵌套调用深度限制在 3 层以内。超过这个深度通常意味着:

  • 存在设计问题,应该重构逻辑
  • 可能出现无限递归风险
  • 性能会显著下降
typescript 复制代码
const MAX_NESTING_DEPTH = 3;

if (currentDepth > MAX_NESTING_DEPTH) {
  throw new Error(`Nesting depth exceeded: ${currentDepth}`);
}

Q9: Prompts 可以返回图片或多媒体内容吗?

A: MCP 协议支持返回多媒体内容。消息内容的 type 可以是:

  • "text":纯文本内容
  • "image":图片内容(需要提供 base64 编码)
typescript 复制代码
{
  role: "user",
  content: {
    type: "image",
    data: base64EncodedImage,
    mimeType: "image/png"
  }
}

Q10: 如何调试 Prompts 和 Sampling 问题?

A: 调试建议:

  1. 日志记录:记录所有 Prompts 请求和响应
  2. 采样追踪:记录 Sampling 请求的完整生命周期
  3. 本地测试:使用 MCP Inspector 工具进行本地调试
  4. 渐进测试:从简单场景开始,逐步增加复杂度
  5. 错误隔离:使用 try-catch 隔离问题范围

12. 参考文献

  1. MCP 官方文档

  2. MCP 协议规范

  3. MCP TypeScript SDK

  4. MCP Python SDK

  5. Anthropic MCP 博客

  6. OpenAI Function Calling 指南

  7. Prompt Injection 防护研究

  8. LLM Cost Optimization


附录:完整配置示例

TypeScript 服务器配置

typescript 复制代码
// config/server.config.ts
export const serverConfig = {
  name: "advanced-mcp-server",
  version: "1.0.0",
  
  prompts: {
    enabled: true,
    maxTemplates: 100,
    cacheEnabled: true,
  },
  
  sampling: {
    enabled: true,
    maxConcurrent: 5,
    defaultTimeout: 60000,
    maxDepth: 3,
  },
  
  security: {
    inputValidation: true,
    maxInputLength: 10000,
    dangerousPatterns: [
      /ignore\s+previous/i,
      /system\s+prompt/i,
    ],
  },
  
  cost: {
    budgetPeriod: "daily",
    maxDailyCost: 10.0,
    preferredModels: ["claude-3-sonnet", "gpt-3.5-turbo"],
  },
};

Python 服务器配置

python 复制代码
# config/server_config.py
from dataclasses import dataclass
from typing import List

@dataclass
class ServerConfig:
    name: str = "advanced-mcp-server"
    version: str = "1.0.0"
    
    # Prompts 配置
    prompts_enabled: bool = True
    max_templates: int = 100
    cache_enabled: bool = True
    
    # Sampling 配置
    sampling_enabled: bool = True
    max_concurrent: int = 5
    default_timeout: int = 60
    max_depth: int = 3
    
    # 安全配置
    input_validation: bool = True
    max_input_length: int = 10000
    
    # 成本配置
    budget_period: str = "daily"
    max_daily_cost: float = 10.0
    preferred_models: List[str] = None
    
    def __post_init__(self):
        if self.preferred_models is None:
            self.preferred_models = ["claude-3-sonnet", "gpt-3.5-turbo"]

# 全局配置实例
config = ServerConfig()

相关推荐
东离与糖宝2 小时前
小米MiMo-V2-Pro开放调用,Java后端快速接入全流程实战
java·人工智能
科技林总2 小时前
【系统分析师】12.2 软件架构风格
学习
●VON2 小时前
旗舰基座大模型 MiMo-V2-Pro 初体验与实战指南
学习·小米·模型·von·mimo-v2-pro
⑩-2 小时前
Kafka 架构和工作原理?Kafka 如何保证高可用?
java·分布式·架构·kafka
balmtv2 小时前
GPT-5.4镜像实测:gpt技术拆解——当AI学会操控电脑
人工智能·gpt·电脑
大傻^2 小时前
Spring AI 2.0 生产部署指南:从 1.x 迁移、性能调优与云原生实践
人工智能·spring·云原生·springai
不懒不懒2 小时前
【机器学习模型评估:8种算法对比实战(本篇文章先介绍6种)】
人工智能·机器学习
CET中电技术2 小时前
中压(公共连接点10kV及以上)分布式光伏项目,四可如何改造?
分布式
ejjdhdjdjdjdjjsl2 小时前
halcon算子
人工智能·算法·计算机视觉