
🎯 博主简介
CSDN 「新星创作者」 人工智能技术领域博主,码龄5年 ,累计发布
180+篇原创文章,博客总访问量24万+浏览!
🚀 持续更新AI前沿实战知识,专注于 AI 技术实战,核心方向包括:
- 🔥 生产级 RAG 系统 --- 架构设计、多路召回、混合检索、生产优化
- 🤖 多 Agent 协作系统 --- Swarm 架构、任务编排、状态管理
- 🔌 MCP 协议 --- 协议规范与深度实现
- ⚡ OpenClaw --- AI 助手框架进阶应用
同时也涉猎计算机视觉、Java 后端与 Spring 生态、Transformer 等深度学习技术。坚持从架构到代码、从原理到部署的实战风格。每篇文章配套代码与扩展资料,欢迎交流探讨。
📱 公众号:
Anyi研习社--- 每天一点AI:资料、笔记、工具、趋势,一起进步。🤝 商务合作 :请搜索关注微信公众号
「Anyi研习社」

系列导读:本文是《MCP 协议深度解析》系列的第 8 篇,深入探讨 MCP 协议中 Prompts 提示模板系统与 Sampling 采样机制的设计原理、实现细节与最佳实践。通过双语言代码示例(TypeScript + Python),帮助开发者掌握 MCP 的高级交互能力。
📑 目录
- [1. 引言:Prompt 提示模板系统设计理念](#1. 引言:Prompt 提示模板系统设计理念)
- [2. 参数化提示与动态生成](#2. 参数化提示与动态生成)
- [3. 提示模板注册与发现](#3. 提示模板注册与发现)
- [4. 多轮对话上下文管理](#4. 多轮对话上下文管理)
- [5. Sampling 采样机制:Server 请求 LLM](#5. Sampling 采样机制:Server 请求 LLM)
- [6. 采样请求的生命周期管理](#6. 采样请求的生命周期管理)
- [7. 嵌套调用与递归防护](#7. 嵌套调用与递归防护)
- [8. 采样成本控制与优化](#8. 采样成本控制与优化)
- [9. 提示模板与采样的安全考虑](#9. 提示模板与采样的安全考虑)
- [10. 实际应用场景与案例](#10. 实际应用场景与案例)
- [11. 常见问题 FAQ](#11. 常见问题 FAQ)
- [12. 参考文献](#12. 参考文献)
1. 引言:Prompt 提示模板系统设计理念
1.1 什么是 MCP Prompts
Model Context Protocol(MCP)的 Prompts 特性允许服务器向客户端暴露结构化的提示模板。这些模板可以被用户直接调用,或者由 LLM 根据上下文自动选择使用。Prompts 是 MCP 协议中实现"可复用 AI 工作流"的核心机制。
┌─────────────────────────────────────────────────────────────────┐
│ MCP Prompts 架构概览 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │
│ │ Client │◄───────►│ Server │◄───────►│ LLM │ │
│ │ (Host) │ MCP │ (Prompts │ Sampling│ Provider │ │
│ │ │ Protocol│ Provider) │ Protocol│ │ │
│ └─────────────┘ └─────────────┘ └──────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────────────────────────────┐ │
│ │ Prompt Templates │ │
│ │ • 代码审查模板 │ │
│ │ • 数据分析模板 │ │
│ │ • 文档生成模板 │ │
│ │ • 自定义工作流 │ │
│ └─────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
1.2 Prompts 与 Tools 的区别
| 特性 | Prompts | Tools |
|---|---|---|
| 用途 | 提供结构化提示模板 | 执行具体功能操作 |
| 返回值 | 消息列表(给 LLM 的上下文) | 工具执行结果 |
| 交互方式 | 用户选择或 LLM 建议 | LLM 自动调用 |
| 参数类型 | 描述性参数(用于生成提示) | 功能性参数(用于执行操作) |
| 典型场景 | 代码审查、分析模板 | 文件操作、API 调用 |
| 状态影响 | 通常无副作用 | 可能改变系统状态 |
1.3 设计哲学
MCP Prompts 的设计遵循以下核心原则:
- 可发现性(Discoverability):客户端可以动态发现服务器提供的所有提示模板
- 可组合性(Composability):提示模板可以嵌套组合,构建复杂工作流
- 类型安全(Type Safety):参数通过 JSON Schema 定义,确保类型正确
- 上下文感知(Context Awareness):提示可以访问 MCP 会话的完整上下文
2. 参数化提示与动态生成
2.1 参数化提示的基本概念
参数化提示允许开发者定义带有变量的模板,在运行时根据用户输入动态生成完整的提示内容。这是实现可复用提示模板的基础。
2.1.1 参数类型系统
MCP 支持以下参数类型:
| 类型 | 说明 | 示例 |
|---|---|---|
string |
文本字符串 | "code": "function hello() {}" |
number |
数值 | "temperature": 0.7 |
boolean |
布尔值 | "includeTests": true |
array |
数组 | "files": ["src/index.ts", "src/utils.ts"] |
object |
对象 | "config": {"style": "google"} |
enum |
枚举值 | "language": ["zh", "en", "ja"] |
2.2 TypeScript 实现:参数化提示定义
typescript
// prompts/code-review.ts
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
import {
ListPromptsRequestSchema,
GetPromptRequestSchema,
} from "@modelcontextprotocol/sdk/types.js";
// 定义提示模板接口
interface CodeReviewArgs {
code: string;
language: string;
focus?: "performance" | "security" | "readability" | "all";
includeExamples?: boolean;
}
// 注册提示模板
server.setRequestHandler(ListPromptsRequestSchema, async () => {
return {
prompts: [
{
name: "code-review",
description: "对代码进行专业审查,提供改进建议",
arguments: [
{
name: "code",
description: "需要审查的代码",
required: true,
},
{
name: "language",
description: "编程语言",
required: true,
},
{
name: "focus",
description: "审查重点",
required: false,
},
{
name: "includeExamples",
description: "是否包含改进示例",
required: false,
},
],
},
],
};
});
// 处理提示请求
server.setRequestHandler(GetPromptRequestSchema, async (request) => {
const { name, arguments: args } = request.params;
if (name !== "code-review") {
throw new Error(`Unknown prompt: ${name}`);
}
const { code, language, focus = "all", includeExamples = false } = args as CodeReviewArgs;
// 动态生成系统提示
const systemPrompt = generateReviewPrompt(focus, includeExamples);
// 构建消息列表
const messages = [
{
role: "system" as const,
content: {
type: "text" as const,
text: systemPrompt,
},
},
{
role: "user" as const,
content: {
type: "text" as const,
text: `请审查以下 ${language} 代码:\n\n\`\`\`${language}\n${code}\n\`\`\``,
},
},
];
return {
description: `代码审查 - ${language}`,
messages,
};
});
// 根据参数动态生成提示内容
function generateReviewPrompt(focus: string, includeExamples: boolean): string {
const focusAreas: Record<string, string> = {
performance: "性能优化、算法复杂度、资源使用",
security: "安全漏洞、输入验证、权限控制",
readability: "代码风格、命名规范、注释质量",
all: "性能、安全、可读性、可维护性",
};
let prompt = `你是一位资深 ${focusAreas[focus] || focusAreas.all} 专家。`;
prompt += `请对提供的代码进行专业审查,重点关注 ${focusAreas[focus] || focusAreas.all}。\n\n`;
prompt += `审查要求:\n`;
prompt += `1. 识别潜在问题和风险\n`;
prompt += `2. 提供具体的改进建议\n`;
prompt += `3. 说明问题的严重程度(高/中/低)\n`;
if (includeExamples) {
prompt += `4. 为每个问题提供改进后的代码示例\n`;
}
prompt += `\n请以结构化的方式输出审查结果。`;
return prompt;
}
2.3 Python 实现:参数化提示定义
python
# prompts/code_review.py
from mcp.server import Server
from mcp.types import (
ListPromptsRequest,
GetPromptRequest,
Prompt,
PromptArgument,
TextContent,
)
from typing import Optional
import json
# 创建服务器实例
server = Server("code-review-server")
# 定义提示模板配置
PROMPT_TEMPLATES = {
"code-review": {
"description": "对代码进行专业审查,提供改进建议",
"arguments": [
PromptArgument(
name="code",
description="需要审查的代码",
required=True,
),
PromptArgument(
name="language",
description="编程语言",
required=True,
),
PromptArgument(
name="focus",
description="审查重点 (performance/security/readability/all)",
required=False,
),
PromptArgument(
name="include_examples",
description="是否包含改进示例",
required=False,
),
],
},
"doc-generate": {
"description": "为代码生成文档",
"arguments": [
PromptArgument(
name="code",
description="需要生成文档的代码",
required=True,
),
PromptArgument(
name="style",
description="文档风格 (google/numpy/jsdoc)",
required=False,
),
],
},
}
# 提示内容生成器
class PromptGenerator:
"""动态提示内容生成器"""
FOCUS_AREAS = {
"performance": "性能优化、算法复杂度、资源使用",
"security": "安全漏洞、输入验证、权限控制",
"readability": "代码风格、命名规范、注释质量",
"all": "性能、安全、可读性、可维护性",
}
@classmethod
def generate_review_prompt(cls, focus: str, include_examples: bool) -> str:
focus_area = cls.FOCUS_AREAS.get(focus, cls.FOCUS_AREAS["all"])
prompt = f"你是一位资深 {focus_area} 专家。\n"
prompt += f"请对提供的代码进行专业审查,重点关注 {focus_area}。\n\n"
prompt += "审查要求:\n"
prompt += "1. 识别潜在问题和风险\n"
prompt += "2. 提供具体的改进建议\n"
prompt += "3. 说明问题的严重程度(高/中/低)\n"
if include_examples:
prompt += "4. 为每个问题提供改进后的代码示例\n"
prompt += "\n请以结构化的方式输出审查结果。"
return prompt
# 注册提示列表处理器
@server.list_prompts()
async def list_prompts() -> list[Prompt]:
return [
Prompt(
name=name,
description=config["description"],
arguments=config["arguments"],
)
for name, config in PROMPT_TEMPLATES.items()
]
# 注册提示获取处理器
@server.get_prompt()
async def get_prompt(name: str, arguments: dict) -> dict:
if name not in PROMPT_TEMPLATES:
raise ValueError(f"Unknown prompt: {name}")
if name == "code-review":
code = arguments.get("code", "")
language = arguments.get("language", "")
focus = arguments.get("focus", "all")
include_examples = arguments.get("include_examples", False)
system_prompt = PromptGenerator.generate_review_prompt(focus, include_examples)
return {
"description": f"代码审查 - {language}",
"messages": [
{
"role": "system",
"content": TextContent(type="text", text=system_prompt),
},
{
"role": "user",
"content": TextContent(
type="text",
text=f"请审查以下 {language} 代码:\n\n\`\`\`{language}\n{code}\n\`\`\`"
),
},
],
}
# 其他提示模板处理...
return {}
2.4 动态提示生成策略
在实际应用中,提示模板可以根据多种因素动态生成:
| 策略 | 说明 | 适用场景 |
|---|---|---|
| 条件分支 | 根据参数选择不同提示分支 | 多语言支持、不同审查级别 |
| 模板继承 | 基础模板 + 特定扩展 | 通用审查 + 专项审查 |
| 上下文注入 | 将运行时数据注入提示 | 当前项目信息、用户偏好 |
| 链式组合 | 多个提示模板组合 | 复杂多阶段分析 |
3. 提示模板注册与发现
3.1 注册机制
MCP 服务器通过实现 prompts/list 方法来暴露可用的提示模板。
3.1.1 注册时序图
Prompt Registry MCP Server MCP Client Prompt Registry MCP Server MCP Client prompts/list 请求 查询所有注册模板 返回模板列表 返回 prompts[] prompts/get (name, args) 查找指定模板 返回模板处理器 执行模板生成 返回 messages[]
3.2 TypeScript 实现:完整注册流程
typescript
// server/prompt-registry.ts
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
import {
ListPromptsRequestSchema,
GetPromptRequestSchema,
Prompt,
PromptMessage,
} from "@modelcontextprotocol/sdk/types.js";
// 提示模板定义接口
interface PromptTemplate {
name: string;
description: string;
arguments: PromptArgument[];
handler: (args: Record<string, unknown>) => Promise<{
description: string;
messages: PromptMessage[];
}>;
}
// 提示注册表
class PromptRegistry {
private templates: Map<string, PromptTemplate> = new Map();
register(template: PromptTemplate): void {
this.templates.set(template.name, template);
console.log(`[PromptRegistry] Registered: ${template.name}`);
}
unregister(name: string): void {
this.templates.delete(name);
console.log(`[PromptRegistry] Unregistered: ${name}`);
}
list(): Prompt[] {
return Array.from(this.templates.values()).map((t) => ({
name: t.name,
description: t.description,
arguments: t.arguments,
}));
}
async execute(name: string, args: Record<string, unknown>): Promise<{
description: string;
messages: PromptMessage[];
}> {
const template = this.templates.get(name);
if (!template) {
throw new Error(`Prompt not found: ${name}`);
}
return await template.handler(args);
}
}
// 创建注册表实例
const registry = new PromptRegistry();
// 注册代码审查模板
registry.register({
name: "code-review",
description: "代码审查提示模板",
arguments: [
{ name: "code", description: "代码内容", required: true },
{ name: "language", description: "编程语言", required: true },
],
handler: async (args) => {
const { code, language } = args as { code: string; language: string };
return {
description: `代码审查 (${language})`,
messages: [
{
role: "system",
content: {
type: "text",
text: "你是一位资深代码审查专家...",
},
},
{
role: "user",
content: {
type: "text",
text: `请审查以下代码:\n\n\`\`\`${language}\n${code}\n\`\`\``,
},
},
],
};
},
});
// 注册数据分析模板
registry.register({
name: "data-analysis",
description: "数据分析提示模板",
arguments: [
{ name: "data", description: "数据内容", required: true },
{ name: "format", description: "数据格式", required: false },
],
handler: async (args) => {
// 实现...
return { description: "数据分析", messages: [] };
},
});
// 设置 MCP 请求处理器
export function setupPromptHandlers(server: Server): void {
// 处理列表请求
server.setRequestHandler(ListPromptsRequestSchema, async () => {
return { prompts: registry.list() };
});
// 处理获取请求
server.setRequestHandler(GetPromptRequestSchema, async (request) => {
const { name, arguments: args } = request.params;
return await registry.execute(name, args || {});
});
}
3.3 Python 实现:完整注册流程
python
# server/prompt_registry.py
from typing import Callable, Dict, List, Any, Optional
from dataclasses import dataclass
from mcp.types import Prompt, PromptArgument, PromptMessage, TextContent
import asyncio
@dataclass
class PromptTemplate:
"""提示模板定义"""
name: str
description: str
arguments: List[PromptArgument]
handler: Callable[[Dict[str, Any]], asyncio.Future]
class PromptRegistry:
"""提示模板注册表"""
def __init__(self):
self._templates: Dict[str, PromptTemplate] = {}
def register(self, template: PromptTemplate) -> None:
"""注册提示模板"""
self._templates[template.name] = template
print(f"[PromptRegistry] Registered: {template.name}")
def unregister(self, name: str) -> None:
"""注销提示模板"""
if name in self._templates:
del self._templates[name]
print(f"[PromptRegistry] Unregistered: {name}")
def list(self) -> List[Prompt]:
"""列出所有提示模板"""
return [
Prompt(
name=t.name,
description=t.description,
arguments=t.arguments,
)
for t in self._templates.values()
]
async def execute(self, name: str, args: Dict[str, Any]) -> Dict[str, Any]:
"""执行提示模板"""
template = self._templates.get(name)
if not template:
raise ValueError(f"Prompt not found: {name}")
return await template.handler(args)
def get_template(self, name: str) -> Optional[PromptTemplate]:
"""获取指定模板"""
return self._templates.get(name)
# 创建全局注册表
registry = PromptRegistry()
# 装饰器方式注册
def prompt(
name: str,
description: str,
arguments: Optional[List[PromptArgument]] = None
):
"""提示模板装饰器"""
def decorator(func: Callable):
template = PromptTemplate(
name=name,
description=description,
arguments=arguments or [],
handler=func,
)
registry.register(template)
return func
return decorator
# 使用示例
@prompt(
name="code-review",
description="代码审查提示模板",
arguments=[
PromptArgument(name="code", description="代码内容", required=True),
PromptArgument(name="language", description="编程语言", required=True),
]
)
async def code_review_handler(args: Dict[str, Any]) -> Dict[str, Any]:
"""代码审查处理器"""
code = args.get("code", "")
language = args.get("language", "")
return {
"description": f"代码审查 ({language})",
"messages": [
{
"role": "system",
"content": TextContent(
type="text",
text="你是一位资深代码审查专家..."
),
},
{
"role": "user",
"content": TextContent(
type="text",
text=f"请审查以下代码:\n\n\`\`\`{language}\n{code}\n\`\`\`"
),
},
],
}
# 动态注册示例
def register_dynamic_prompt(name: str, config: Dict[str, Any]):
"""动态注册提示模板"""
async def handler(args: Dict[str, Any]) -> Dict[str, Any]:
# 根据配置动态生成提示
template_content = config.get("template", "")
return {
"description": config.get("description", name),
"messages": [
{
"role": "system",
"content": TextContent(type="text", text=template_content),
},
],
}
registry.register(PromptTemplate(
name=name,
description=config.get("description", ""),
arguments=[
PromptArgument(**arg)
for arg in config.get("arguments", [])
],
handler=handler,
))
4. 多轮对话上下文管理
4.1 上下文管理的重要性
在多轮对话场景中,Prompts 需要维护对话历史,确保 LLM 能够理解上下文并做出连贯的响应。
4.2 上下文传递机制
┌─────────────────────────────────────────────────────────────┐
│ 多轮对话上下文传递流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Round 1: │
│ ┌─────────┐ prompts/get ┌─────────┐ │
│ │ Client │ ─────────────────► │ Server │ │
│ │ │ ◄───────────────── │ │ │
│ │ │ messages[role] │ │ │
│ └────┬────┘ └─────────┘ │
│ │ │
│ ▼ 存储到会话上下文 │
│ ┌─────────────┐ │
│ │ SessionCtx │ {role: "assistant", content: "..."} │
│ └─────────────┘ │
│ │
│ Round 2: │
│ ┌─────────┐ prompts/get ┌─────────┐ │
│ │ Client │ ─────────────────► │ Server │ │
│ │ (含历史) │ ◄───────────────── │ │ │
│ └────┬────┘ messages[] └─────────┘ │
│ │ │
│ ▼ 追加到历史 │
│ ┌─────────────┐ │
│ │ SessionCtx │ [历史消息1, 历史消息2, ...] │
│ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
4.3 TypeScript 实现:上下文管理
typescript
// context/session-manager.ts
import { PromptMessage } from "@modelcontextprotocol/sdk/types.js";
interface SessionContext {
sessionId: string;
messages: PromptMessage[];
metadata: Record<string, unknown>;
createdAt: Date;
lastAccessedAt: Date;
}
class SessionManager {
private sessions: Map<string, SessionContext> = new Map();
private readonly TTL_MS = 30 * 60 * 1000; // 30分钟过期
createSession(): string {
const sessionId = crypto.randomUUID();
this.sessions.set(sessionId, {
sessionId,
messages: [],
metadata: {},
createdAt: new Date(),
lastAccessedAt: new Date(),
});
return sessionId;
}
getSession(sessionId: string): SessionContext | undefined {
const session = this.sessions.get(sessionId);
if (session) {
session.lastAccessedAt = new Date();
}
return session;
}
appendMessage(sessionId: string, message: PromptMessage): void {
const session = this.sessions.get(sessionId);
if (session) {
session.messages.push(message);
session.lastAccessedAt = new Date();
}
}
getContextMessages(sessionId: string, maxHistory: number = 10): PromptMessage[] {
const session = this.sessions.get(sessionId);
if (!session) return [];
// 返回最近的消息
return session.messages.slice(-maxHistory);
}
// 清理过期会话
cleanup(): void {
const now = Date.now();
for (const [id, session] of this.sessions.entries()) {
if (now - session.lastAccessedAt.getTime() > this.TTL_MS) {
this.sessions.delete(id);
console.log(`[SessionManager] Cleaned up session: ${id}`);
}
}
}
}
// 使用示例:构建带上下文的提示
async function buildContextualPrompt(
sessionManager: SessionManager,
sessionId: string,
newPrompt: PromptMessage[]
): Promise<PromptMessage[]> {
const history = sessionManager.getContextMessages(sessionId, 5);
// 组合历史消息和新提示
return [...history, ...newPrompt];
}
4.4 Python 实现:上下文管理
python
# context/session_manager.py
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import asyncio
import uuid
@dataclass
class SessionContext:
"""会话上下文"""
session_id: str
messages: List[Dict[str, Any]] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
created_at: datetime = field(default_factory=datetime.now)
last_accessed_at: datetime = field(default_factory=datetime.now)
class SessionManager:
"""会话管理器"""
TTL_SECONDS = 30 * 60 # 30分钟过期
def __init__(self):
self._sessions: Dict[str, SessionContext] = {}
self._cleanup_task: Optional[asyncio.Task] = None
def create_session(self) -> str:
"""创建新会话"""
session_id = str(uuid.uuid4())
self._sessions[session_id] = SessionContext(session_id=session_id)
print(f"[SessionManager] Created session: {session_id}")
return session_id
def get_session(self, session_id: str) -> Optional[SessionContext]:
"""获取会话"""
session = self._sessions.get(session_id)
if session:
session.last_accessed_at = datetime.now()
return session
def append_message(self, session_id: str, message: Dict[str, Any]) -> None:
"""追加消息到会话"""
session = self._sessions.get(session_id)
if session:
session.messages.append(message)
session.last_accessed_at = datetime.now()
def get_context_messages(
self,
session_id: str,
max_history: int = 10
) -> List[Dict[str, Any]]:
"""获取上下文消息"""
session = self._sessions.get(session_id)
if not session:
return []
return session.messages[-max_history:]
def cleanup(self) -> None:
"""清理过期会话"""
now = datetime.now()
expired = [
sid for sid, session in self._sessions.items()
if (now - session.last_accessed_at).total_seconds() > self.TTL_SECONDS
]
for sid in expired:
del self._sessions[sid]
print(f"[SessionManager] Cleaned up session: {sid}")
async def start_cleanup_scheduler(self):
"""启动定期清理任务"""
while True:
await asyncio.sleep(60) # 每分钟检查一次
self.cleanup()
# 使用示例
async def build_contextual_prompt(
session_manager: SessionManager,
session_id: str,
new_messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""构建带上下文的提示"""
history = session_manager.get_context_messages(session_id, max_history=5)
return history + new_messages
5. Sampling 采样机制:Server 请求 LLM
5.1 Sampling 概述
Sampling 是 MCP 协议中允许服务器向客户端(Host)请求 LLM 采样的机制。这使得服务器能够在需要时获取 LLM 的生成能力,而不需要直接访问 LLM API。
┌─────────────────────────────────────────────────────────────────┐
│ Sampling 架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ sampling/createMessage ┌─────────┐│
│ │ Server │ ─────────────────────────────────► │ Client ││
│ │ │ │ (Host) ││
│ │ │ ◄───────────────────────────────── │ ││
│ │ │ sampling/createMessage/result ││
│ └─────────────┘ └────┬────┘│
│ │ │
│ ▼ │
│ ┌──────────┐ │
│ │ LLM │ │
│ │ Provider │ │
│ └──────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
5.2 Sampling 请求流程
LLM Provider MCP Client (Host) MCP Server LLM Provider MCP Client (Host) MCP Server 包含 messages, modelPreferences, systemPrompt alt [请求有效] [请求无效] sampling/createMessage 验证请求 检查权限/配额 转发采样请求 返回生成结果 sampling/createMessage/result 返回错误
5.3 TypeScript 实现:Sampling 请求
typescript
// sampling/sampling-client.ts
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import {
CreateMessageRequest,
CreateMessageResult,
SamplingMessage,
ModelPreferences,
} from "@modelcontextprotocol/sdk/types.js";
interface SamplingOptions {
messages: SamplingMessage[];
modelPreferences?: ModelPreferences;
systemPrompt?: string;
maxTokens?: number;
temperature?: number;
stopSequences?: string[];
}
class SamplingClient {
private client: Client;
constructor(client: Client) {
this.client = client;
}
/**
* 发送采样请求
*/
async createMessage(options: SamplingOptions): Promise<CreateMessageResult> {
const request: CreateMessageRequest = {
method: "sampling/createMessage",
params: {
messages: options.messages,
modelPreferences: options.modelPreferences,
systemPrompt: options.systemPrompt,
maxTokens: options.maxTokens,
temperature: options.temperature,
stopSequences: options.stopSequences,
},
};
console.log("[SamplingClient] Sending request:", JSON.stringify(request, null, 2));
const result = await this.client.request(request, CreateMessageResultSchema);
console.log("[SamplingClient] Received result:", JSON.stringify(result, null, 2));
return result;
}
/**
* 简单文本生成
*/
async generateText(
prompt: string,
options: Omit<SamplingOptions, "messages"> = {}
): Promise<string> {
const result = await this.createMessage({
messages: [
{
role: "user",
content: {
type: "text",
text: prompt,
},
},
],
...options,
});
if (result.content.type === "text") {
return result.content.text;
}
throw new Error("Expected text response, got: " + result.content.type);
}
/**
* 带上下文的对话
*/
async chat(
messages: SamplingMessage[],
options: Omit<SamplingOptions, "messages"> = {}
): Promise<CreateMessageResult> {
return await this.createMessage({
messages,
...options,
});
}
}
// 使用示例
async function demonstrateSampling() {
const client = new Client({ name: "example-client", version: "1.0.0" });
const sampling = new SamplingClient(client);
// 简单生成
const response = await sampling.generateText(
"解释什么是 Model Context Protocol",
{
maxTokens: 500,
temperature: 0.7,
modelPreferences: {
hints: [{ name: "claude" }],
intelligencePriority: 0.8,
speedPriority: 0.3,
},
}
);
console.log("Generated:", response);
}
5.4 Python 实现:Sampling 请求
python
# sampling/sampling_client.py
from typing import List, Optional, Dict, Any
from dataclasses import dataclass
from mcp.client import Client
from mcp.types import (
SamplingMessage,
ModelPreferences,
CreateMessageRequest,
CreateMessageResult,
TextContent,
)
@dataclass
class SamplingOptions:
"""采样选项"""
messages: List[SamplingMessage]
model_preferences: Optional[ModelPreferences] = None
system_prompt: Optional[str] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
stop_sequences: Optional[List[str]] = None
class SamplingClient:
"""采样客户端"""
def __init__(self, client: Client):
self._client = client
async def create_message(self, options: SamplingOptions) -> CreateMessageResult:
"""发送采样请求"""
request = CreateMessageRequest(
method="sampling/createMessage",
params={
"messages": options.messages,
"modelPreferences": options.model_preferences,
"systemPrompt": options.system_prompt,
"maxTokens": options.max_tokens,
"temperature": options.temperature,
"stopSequences": options.stop_sequences,
}
)
print(f"[SamplingClient] Sending request: {request}")
result = await self._client.request(request)
print(f"[SamplingClient] Received result: {result}")
return result
async def generate_text(
self,
prompt: str,
**kwargs
) -> str:
"""简单文本生成"""
result = await self.create_message(SamplingOptions(
messages=[
SamplingMessage(
role="user",
content=TextContent(type="text", text=prompt)
)
],
**kwargs
))
if result.content.type == "text":
return result.content.text
raise ValueError(f"Expected text response, got: {result.content.type}")
async def chat(
self,
messages: List[SamplingMessage],
**kwargs
) -> CreateMessageResult:
"""带上下文的对话"""
return await self.create_message(SamplingOptions(
messages=messages,
**kwargs
))
# 使用示例
async def demonstrate_sampling():
client = Client(name="example-client", version="1.0.0")
sampling = SamplingClient(client)
# 简单生成
response = await sampling.generate_text(
"解释什么是 Model Context Protocol",
max_tokens=500,
temperature=0.7,
model_preferences=ModelPreferences(
hints=[{"name": "claude"}],
intelligencePriority=0.8,
speedPriority=0.3,
)
)
print(f"Generated: {response}")
6. 采样请求的生命周期管理
6.1 生命周期阶段
┌─────────────────────────────────────────────────────────────────┐
│ Sampling 请求生命周期 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ Created │──►│Pending │──►│Active │──►│Completed│ │
│ │ │ │ │ │ │ │ │ │
│ └─────────┘ └─────────┘ └─────────┘ └─────────┘ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ 请求创建 等待处理 正在采样 完成/失败 │
│ │
│ 超时处理: Pending ──(timeout)──► Failed │
│ 取消处理: Active ──(cancel)───► Cancelled │
│ │
└─────────────────────────────────────────────────────────────────┘
6.2 TypeScript 实现:生命周期管理
typescript
// sampling/lifecycle-manager.ts
import { EventEmitter } from "events";
type SamplingStatus =
| "created"
| "pending"
| "active"
| "completed"
| "failed"
| "cancelled";
interface SamplingRequest {
id: string;
status: SamplingStatus;
createdAt: Date;
startedAt?: Date;
completedAt?: Date;
error?: string;
result?: any;
}
class SamplingLifecycleManager extends EventEmitter {
private requests: Map<string, SamplingRequest> = new Map();
private readonly DEFAULT_TIMEOUT_MS = 60000; // 60秒默认超时
createRequest(id: string): SamplingRequest {
const request: SamplingRequest = {
id,
status: "created",
createdAt: new Date(),
};
this.requests.set(id, request);
this.emit("request:created", request);
return request;
}
transitionTo(id: string, status: SamplingStatus, data?: any): void {
const request = this.requests.get(id);
if (!request) {
throw new Error(`Request not found: ${id}`);
}
const oldStatus = request.status;
request.status = status;
switch (status) {
case "pending":
// 启动超时计时器
this.startTimeout(id);
break;
case "active":
request.startedAt = new Date();
break;
case "completed":
request.completedAt = new Date();
request.result = data;
break;
case "failed":
case "cancelled":
request.completedAt = new Date();
request.error = data;
break;
}
this.emit(`request:${status}`, request);
this.emit("request:transition", { request, oldStatus, newStatus: status });
}
private startTimeout(id: string): void {
setTimeout(() => {
const request = this.requests.get(id);
if (request && request.status === "pending") {
this.transitionTo(id, "failed", "Request timeout");
}
}, this.DEFAULT_TIMEOUT_MS);
}
cancelRequest(id: string): boolean {
const request = this.requests.get(id);
if (!request) return false;
if (request.status === "pending" || request.status === "active") {
this.transitionTo(id, "cancelled", "Cancelled by user");
return true;
}
return false;
}
getRequest(id: string): SamplingRequest | undefined {
return this.requests.get(id);
}
getActiveRequests(): SamplingRequest[] {
return Array.from(this.requests.values()).filter(
(r) => r.status === "pending" || r.status === "active"
);
}
// 清理完成的请求
cleanup(maxAgeMs: number = 3600000): void {
const now = Date.now();
for (const [id, request] of this.requests.entries()) {
if (request.completedAt) {
const age = now - request.completedAt.getTime();
if (age > maxAgeMs) {
this.requests.delete(id);
this.emit("request:cleaned", { id, age });
}
}
}
}
}
// 使用示例
const lifecycle = new SamplingLifecycleManager();
// 监听状态变化
lifecycle.on("request:created", (req) => {
console.log(`[Lifecycle] Request ${req.id} created`);
});
lifecycle.on("request:completed", (req) => {
console.log(`[Lifecycle] Request ${req.id} completed`);
});
lifecycle.on("request:failed", (req) => {
console.log(`[Lifecycle] Request ${req.id} failed: ${req.error}`);
});
6.3 Python 实现:生命周期管理
python
# sampling/lifecycle_manager.py
from typing import Dict, Optional, List, Any, Callable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum, auto
import asyncio
import uuid
class SamplingStatus(Enum):
"""采样状态"""
CREATED = auto()
PENDING = auto()
ACTIVE = auto()
COMPLETED = auto()
FAILED = auto()
CANCELLED = auto()
@dataclass
class SamplingRequest:
"""采样请求"""
id: str
status: SamplingStatus = SamplingStatus.CREATED
created_at: datetime = field(default_factory=datetime.now)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
error: Optional[str] = None
result: Optional[Any] = None
class SamplingLifecycleManager:
"""采样生命周期管理器"""
DEFAULT_TIMEOUT_SECONDS = 60
def __init__(self):
self._requests: Dict[str, SamplingRequest] = {}
self._handlers: Dict[str, List[Callable]] = {}
self._timeout_tasks: Dict[str, asyncio.Task] = {}
def create_request(self, request_id: Optional[str] = None) -> SamplingRequest:
"""创建请求"""
req_id = request_id or str(uuid.uuid4())
request = SamplingRequest(id=req_id)
self._requests[req_id] = request
self._emit("request:created", request)
return request
def transition_to(
self,
request_id: str,
status: SamplingStatus,
data: Any = None
) -> None:
"""状态转换"""
request = self._requests.get(request_id)
if not request:
raise ValueError(f"Request not found: {request_id}")
old_status = request.status
request.status = status
if status == SamplingStatus.PENDING:
self._start_timeout(request_id)
elif status == SamplingStatus.ACTIVE:
request.started_at = datetime.now()
self._cancel_timeout(request_id)
elif status in (SamplingStatus.COMPLETED, SamplingStatus.FAILED, SamplingStatus.CANCELLED):
request.completed_at = datetime.now()
if status != SamplingStatus.COMPLETED:
request.error = str(data) if data else None
else:
request.result = data
self._cancel_timeout(request_id)
self._emit(f"request:{status.name.lower()}", request)
self._emit("request:transition", {
"request": request,
"old_status": old_status,
"new_status": status
})
def _start_timeout(self, request_id: str) -> None:
"""启动超时计时"""
async def timeout_handler():
await asyncio.sleep(self.DEFAULT_TIMEOUT_SECONDS)
request = self._requests.get(request_id)
if request and request.status == SamplingStatus.PENDING:
self.transition_to(request_id, SamplingStatus.FAILED, "Request timeout")
self._timeout_tasks[request_id] = asyncio.create_task(timeout_handler())
def _cancel_timeout(self, request_id: str) -> None:
"""取消超时计时"""
task = self._timeout_tasks.pop(request_id, None)
if task:
task.cancel()
def cancel_request(self, request_id: str) -> bool:
"""取消请求"""
request = self._requests.get(request_id)
if not request:
return False
if request.status in (SamplingStatus.PENDING, SamplingStatus.ACTIVE):
self.transition_to(request_id, SamplingStatus.CANCELLED, "Cancelled by user")
return True
return False
def get_request(self, request_id: str) -> Optional[SamplingRequest]:
"""获取请求"""
return self._requests.get(request_id)
def get_active_requests(self) -> List[SamplingRequest]:
"""获取活跃请求"""
return [
r for r in self._requests.values()
if r.status in (SamplingStatus.PENDING, SamplingStatus.ACTIVE)
]
def on(self, event: str, handler: Callable):
"""注册事件处理器"""
if event not in self._handlers:
self._handlers[event] = []
self._handlers[event].append(handler)
def _emit(self, event: str, data: Any):
"""触发事件"""
handlers = self._handlers.get(event, [])
for handler in handlers:
try:
handler(data)
except Exception as e:
print(f"Error in event handler: {e}")
def cleanup(self, max_age_seconds: int = 3600) -> None:
"""清理完成的请求"""
now = datetime.now()
to_remove = []
for req_id, request in self._requests.items():
if request.completed_at:
age = (now - request.completed_at).total_seconds()
if age > max_age_seconds:
to_remove.append(req_id)
for req_id in to_remove:
del self._requests[req_id]
self._emit("request:cleaned", {"id": req_id})
7. 嵌套调用与递归防护
7.1 嵌套调用风险
当服务器使用 Sampling 请求 LLM,而 LLM 又调用同一服务器的 Tools 时,可能形成嵌套调用链,导致无限递归。
危险场景:
┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐
│ User │────►│ Server │────►│ Host │────►│ LLM │
└─────────┘ └────┬────┘ └─────────┘ └────┬────┘
│ │
│◄───────────────────────────────┘
│ LLM 调用 tool
│
▼
┌─────────────┐
│ tool 执行中 │
│ 需要采样 │────────────────────────►
└─────────────┘ 再次请求采样(递归!)
7.2 递归防护策略
| 策略 | 说明 | 实现复杂度 |
|---|---|---|
| 调用深度限制 | 限制嵌套调用层数 | 低 |
| 循环检测 | 检测重复调用模式 | 中 |
| 请求标记 | 标记采样请求来源 | 低 |
| 超时控制 | 限制单次调用时间 | 低 |
| 令牌桶限流 | 限制采样请求频率 | 中 |
7.3 TypeScript 实现:递归防护
typescript
// guards/recursion-guard.ts
interface CallContext {
requestId: string;
parentId?: string;
depth: number;
timestamp: number;
toolName?: string;
}
class RecursionGuard {
private readonly MAX_DEPTH = 3;
private readonly MAX_CALLS_PER_MINUTE = 10;
private callStack: Map<string, CallContext> = new Map();
private callHistory: Array<{ timestamp: number; requestId: string }> = [];
/**
* 开始新的调用上下文
*/
beginContext(requestId: string, parentId?: string, toolName?: string): CallContext {
const parent = parentId ? this.callStack.get(parentId) : undefined;
const depth = parent ? parent.depth + 1 : 0;
// 检查深度限制
if (depth > this.MAX_DEPTH) {
throw new Error(
`Recursion depth exceeded: ${depth} > ${this.MAX_DEPTH}. ` +
`Possible infinite recursion detected.`
);
}
// 检查调用频率
this.checkRateLimit();
const context: CallContext = {
requestId,
parentId,
depth,
timestamp: Date.now(),
toolName,
};
this.callStack.set(requestId, context);
this.callHistory.push({ timestamp: Date.now(), requestId });
console.log(`[RecursionGuard] Context started: ${requestId} (depth: ${depth})`);
return context;
}
/**
* 结束调用上下文
*/
endContext(requestId: string): void {
const context = this.callStack.get(requestId);
if (context) {
console.log(`[RecursionGuard] Context ended: ${requestId} (duration: ${Date.now() - context.timestamp}ms)`);
this.callStack.delete(requestId);
}
}
/**
* 检查调用频率限制
*/
private checkRateLimit(): void {
const now = Date.now();
const oneMinuteAgo = now - 60000;
// 清理旧记录
this.callHistory = this.callHistory.filter(c => c.timestamp > oneMinuteAgo);
// 检查限制
if (this.callHistory.length >= this.MAX_CALLS_PER_MINUTE) {
throw new Error(
`Rate limit exceeded: ${this.callHistory.length} calls in the last minute. ` +
`Maximum allowed: ${this.MAX_CALLS_PER_MINUTE}`
);
}
}
/**
* 检测循环调用模式
*/
detectCycle(requestId: string, toolName: string): boolean {
const context = this.callStack.get(requestId);
if (!context) return false;
// 检查父链中是否有相同 toolName
let current: CallContext | undefined = context;
while (current?.parentId) {
const parent = this.callStack.get(current.parentId);
if (parent?.toolName === toolName) {
console.warn(`[RecursionGuard] Cycle detected: ${toolName} called recursively`);
return true;
}
current = parent;
}
return false;
}
/**
* 获取当前调用链
*/
getCallChain(requestId: string): string[] {
const chain: string[] = [];
let current: CallContext | undefined = this.callStack.get(requestId);
while (current) {
chain.unshift(current.toolName || current.requestId);
current = current.parentId ? this.callStack.get(current.parentId) : undefined;
}
return chain;
}
/**
* 获取当前活跃上下文数量
*/
getActiveContextCount(): number {
return this.callStack.size;
}
}
// 使用示例
const guard = new RecursionGuard();
async function handleToolCall(toolName: string, parentId?: string) {
const requestId = crypto.randomUUID();
try {
// 开始上下文
const context = guard.beginContext(requestId, parentId, toolName);
// 检测循环
if (guard.detectCycle(requestId, toolName)) {
throw new Error(`Circular dependency detected for tool: ${toolName}`);
}
// 执行工具逻辑...
console.log(`Call chain: ${guard.getCallChain(requestId).join(" -> ")}`);
// 模拟嵌套调用
if (toolName === "analyze" && context.depth < 2) {
await handleToolCall("summarize", requestId);
}
} finally {
// 结束上下文
guard.endContext(requestId);
}
}
7.4 Python 实现:递归防护
python
# guards/recursion_guard.py
from typing import Dict, Optional, List
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import threading
import uuid
@dataclass
class CallContext:
"""调用上下文"""
request_id: str
parent_id: Optional[str] = None
depth: int = 0
timestamp: datetime = field(default_factory=datetime.now)
tool_name: Optional[str] = None
class RecursionGuard:
"""递归防护器"""
MAX_DEPTH = 3
MAX_CALLS_PER_MINUTE = 10
def __init__(self):
self._call_stack: Dict[str, CallContext] = {}
self._call_history: List[Dict] = []
self._lock = threading.RLock()
def begin_context(
self,
request_id: Optional[str] = None,
parent_id: Optional[str] = None,
tool_name: Optional[str] = None
) -> CallContext:
"""开始新的调用上下文"""
req_id = request_id or str(uuid.uuid4())
with self._lock:
parent = self._call_stack.get(parent_id) if parent_id else None
depth = parent.depth + 1 if parent else 0
# 检查深度限制
if depth > self.MAX_DEPTH:
raise RecursionError(
f"Recursion depth exceeded: {depth} > {self.MAX_DEPTH}. "
f"Possible infinite recursion detected."
)
# 检查调用频率
self._check_rate_limit()
context = CallContext(
request_id=req_id,
parent_id=parent_id,
depth=depth,
tool_name=tool_name
)
self._call_stack[req_id] = context
self._call_history.append({
"timestamp": datetime.now(),
"request_id": req_id
})
print(f"[RecursionGuard] Context started: {req_id} (depth: {depth})")
return context
def end_context(self, request_id: str) -> None:
"""结束调用上下文"""
with self._lock:
context = self._call_stack.get(request_id)
if context:
duration = (datetime.now() - context.timestamp).total_seconds() * 1000
print(f"[RecursionGuard] Context ended: {request_id} (duration: {duration:.0f}ms)")
del self._call_stack[request_id]
def _check_rate_limit(self) -> None:
"""检查调用频率限制"""
now = datetime.now()
one_minute_ago = now - timedelta(minutes=1)
# 清理旧记录
self._call_history = [
c for c in self._call_history
if c["timestamp"] > one_minute_ago
]
# 检查限制
if len(self._call_history) >= self.MAX_CALLS_PER_MINUTE:
raise RuntimeError(
f"Rate limit exceeded: {len(self._call_history)} calls in the last minute. "
f"Maximum allowed: {self.MAX_CALLS_PER_MINUTE}"
)
def detect_cycle(self, request_id: str, tool_name: str) -> bool:
"""检测循环调用模式"""
with self._lock:
context = self._call_stack.get(request_id)
if not context:
return False
# 检查父链中是否有相同 tool_name
current: Optional[CallContext] = context
while current and current.parent_id:
parent = self._call_stack.get(current.parent_id)
if parent and parent.tool_name == tool_name:
print(f"[RecursionGuard] Cycle detected: {tool_name} called recursively")
return True
current = parent
return False
def get_call_chain(self, request_id: str) -> List[str]:
"""获取当前调用链"""
with self._lock:
chain = []
current: Optional[CallContext] = self._call_stack.get(request_id)
while current:
chain.insert(0, current.tool_name or current.request_id)
current = self._call_stack.get(current.parent_id) if current.parent_id else None
return chain
def get_active_context_count(self) -> int:
"""获取当前活跃上下文数量"""
with self._lock:
return len(self._call_stack)
# 上下文管理器
class GuardedContext:
"""受保护的上下文管理器"""
def __init__(self, guard: RecursionGuard, tool_name: str, parent_id: Optional[str] = None):
self.guard = guard
self.tool_name = tool_name
self.parent_id = parent_id
self.request_id: Optional[str] = None
self.context: Optional[CallContext] = None
def __enter__(self) -> CallContext:
self.request_id = str(uuid.uuid4())
self.context = self.guard.begin_context(
self.request_id,
self.parent_id,
self.tool_name
)
if self.guard.detect_cycle(self.request_id, self.tool_name):
raise RecursionError(f"Circular dependency detected for tool: {self.tool_name}")
return self.context
def __exit__(self, exc_type, exc_val, exc_tb):
if self.request_id:
self.guard.end_context(self.request_id)
# 使用示例
guard = RecursionGuard()
def handle_tool_call(tool_name: str, parent_id: Optional[str] = None):
with GuardedContext(guard, tool_name, parent_id) as context:
print(f"Call chain: {' -> '.join(guard.get_call_chain(context.request_id))}")
# 模拟嵌套调用
if tool_name == "analyze" and context.depth < 2:
handle_tool_call("summarize", context.request_id)
8. 采样成本控制与优化
8.1 成本因素分析
Sampling 请求的成本主要由以下因素决定:
| 因素 | 影响 | 优化策略 |
|---|---|---|
| 输入 Token 数 | 直接影响成本 | 精简提示、使用摘要 |
| 输出 Token 数 | 直接影响成本 | 设置 maxTokens |
| 模型选择 | 不同模型价格差异大 | 根据任务选择合适模型 |
| 采样温度 | 影响生成质量 | 平衡质量与成本 |
| 请求频率 | 累积成本 | 批处理、缓存 |
8.2 成本优化策略
┌─────────────────────────────────────────────────────────────────┐
│ 采样成本优化策略 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 提示优化 │───►│ 模型选择 │───►│ 缓存策略 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ • 去除冗余内容 • 智能路由 • 结果缓存 │
│ • 使用模板变量 • 分层模型 • 相似查询合并 │
│ • 上下文压缩 • 成本预估 • 预热机制 │
│ │
└─────────────────────────────────────────────────────────────────┘
8.3 TypeScript 实现:成本控制器
typescript
// cost/cost-controller.ts
interface CostConfig {
maxInputTokens: number;
maxOutputTokens: number;
maxTotalCost: number; // 美元
preferredModels: string[];
budgetPeriod: "daily" | "weekly" | "monthly";
}
interface SamplingCost {
inputTokens: number;
outputTokens: number;
estimatedCost: number;
model: string;
timestamp: Date;
}
class CostController {
private config: CostConfig;
private costs: SamplingCost[] = [];
private cache: Map<string, { result: any; cost: SamplingCost }> = new Map();
constructor(config: CostConfig) {
this.config = config;
}
/**
* 估算采样成本
*/
estimateCost(inputText: string, outputTokens: number, model: string): SamplingCost {
// 简化的 Token 估算(实际应使用 tokenizer)
const inputTokens = Math.ceil(inputText.length / 4);
// 模型价格表(每 1K tokens)
const prices: Record<string, { input: number; output: number }> = {
"gpt-4": { input: 0.03, output: 0.06 },
"gpt-3.5-turbo": { input: 0.0015, output: 0.002 },
"claude-3-opus": { input: 0.015, output: 0.075 },
"claude-3-sonnet": { input: 0.003, output: 0.015 },
};
const price = prices[model] || prices["gpt-3.5-turbo"];
const estimatedCost =
(inputTokens / 1000) * price.input +
(outputTokens / 1000) * price.output;
return {
inputTokens,
outputTokens,
estimatedCost,
model,
timestamp: new Date(),
};
}
/**
* 检查是否超出预算
*/
checkBudget(): { allowed: boolean; remaining: number } {
const periodCosts = this.getPeriodCosts();
const totalCost = periodCosts.reduce((sum, c) => sum + c.estimatedCost, 0);
const remaining = this.config.maxTotalCost - totalCost;
return {
allowed: remaining > 0,
remaining: Math.max(0, remaining),
};
}
/**
* 选择最优模型
*/
selectModel(complexity: "low" | "medium" | "high"): string {
const modelTiers: Record<string, string[]> = {
low: ["gpt-3.5-turbo", "claude-3-sonnet"],
medium: ["claude-3-sonnet", "gpt-4"],
high: ["claude-3-opus", "gpt-4"],
};
const candidates = modelTiers[complexity] || modelTiers.medium;
// 优先选择配置中指定的模型
for (const preferred of this.config.preferredModels) {
if (candidates.includes(preferred)) {
return preferred;
}
}
return candidates[0];
}
/**
* 检查缓存
*/
checkCache(prompt: string): { result: any; cost: SamplingCost } | undefined {
const hash = this.hashPrompt(prompt);
return this.cache.get(hash);
}
/**
* 存入缓存
*/
setCache(prompt: string, result: any, cost: SamplingCost): void {
const hash = this.hashPrompt(prompt);
this.cache.set(hash, { result, cost });
// 限制缓存大小
if (this.cache.size > 1000) {
const firstKey = this.cache.keys().next().value;
this.cache.delete(firstKey);
}
}
/**
* 记录实际成本
*/
recordCost(cost: SamplingCost): void {
this.costs.push(cost);
this.cleanupOldCosts();
}
/**
* 获取周期内的成本
*/
private getPeriodCosts(): SamplingCost[] {
const now = new Date();
const periodStart = new Date();
switch (this.config.budgetPeriod) {
case "daily":
periodStart.setHours(0, 0, 0, 0);
break;
case "weekly":
periodStart.setDate(now.getDate() - now.getDay());
periodStart.setHours(0, 0, 0, 0);
break;
case "monthly":
periodStart.setDate(1);
periodStart.setHours(0, 0, 0, 0);
break;
}
return this.costs.filter(c => c.timestamp >= periodStart);
}
/**
* 清理旧成本记录
*/
private cleanupOldCosts(): void {
const oneMonthAgo = new Date();
oneMonthAgo.setMonth(oneMonthAgo.getMonth() - 1);
this.costs = this.costs.filter(c => c.timestamp >= oneMonthAgo);
}
/**
* 生成提示哈希
*/
private hashPrompt(prompt: string): string {
// 简化实现,实际应使用 crypto 模块
let hash = 0;
for (let i = 0; i < prompt.length; i++) {
const char = prompt.charCodeAt(i);
hash = ((hash << 5) - hash) + char;
hash = hash & hash;
}
return hash.toString(16);
}
/**
* 获取成本报告
*/
getCostReport(): {
totalCost: number;
totalCalls: number;
averageCost: number;
byModel: Record<string, { calls: number; cost: number }>;
} {
const periodCosts = this.getPeriodCosts();
const totalCost = periodCosts.reduce((sum, c) => sum + c.estimatedCost, 0);
const byModel: Record<string, { calls: number; cost: number }> = {};
for (const cost of periodCosts) {
if (!byModel[cost.model]) {
byModel[cost.model] = { calls: 0, cost: 0 };
}
byModel[cost.model].calls++;
byModel[cost.model].cost += cost.estimatedCost;
}
return {
totalCost,
totalCalls: periodCosts.length,
averageCost: totalCost / periodCosts.length || 0,
byModel,
};
}
}
// 使用示例
const costController = new CostController({
maxInputTokens: 4000,
maxOutputTokens: 1000,
maxTotalCost: 10.0, // $10 per period
preferredModels: ["claude-3-sonnet", "gpt-3.5-turbo"],
budgetPeriod: "daily",
});
async function optimizedSampling(prompt: string) {
// 检查缓存
const cached = costController.checkCache(prompt);
if (cached) {
console.log("[CostController] Cache hit!");
return cached.result;
}
// 检查预算
const budget = costController.checkBudget();
if (!budget.allowed) {
throw new Error(`Budget exceeded. Remaining: $${budget.remaining.toFixed(2)}`);
}
// 选择模型
const model = costController.selectModel("medium");
// 估算成本
const estimatedCost = costController.estimateCost(prompt, 500, model);
console.log(`[CostController] Estimated cost: $${estimatedCost.estimatedCost.toFixed(4)}`);
// 执行采样...
const result = { /* sampling result */ };
// 记录成本
costController.recordCost(estimatedCost);
// 存入缓存
costController.setCache(prompt, result, estimatedCost);
return result;
}
8.4 Python 实现:成本控制器
python
# cost/cost_controller.py
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime, timedelta
import hashlib
from enum import Enum
class BudgetPeriod(Enum):
DAILY = "daily"
WEEKLY = "weekly"
MONTHLY = "monthly"
@dataclass
class CostConfig:
"""成本配置"""
max_input_tokens: int = 4000
max_output_tokens: int = 1000
max_total_cost: float = 10.0 # 美元
preferred_models: List[str] = None
budget_period: BudgetPeriod = BudgetPeriod.DAILY
def __post_init__(self):
if self.preferred_models is None:
self.preferred_models = ["claude-3-sonnet", "gpt-3.5-turbo"]
@dataclass
class SamplingCost:
"""采样成本"""
input_tokens: int
output_tokens: int
estimated_cost: float
model: str
timestamp: datetime
class CostController:
"""成本控制器"""
# 模型价格表(每 1K tokens)
PRICES = {
"gpt-4": {"input": 0.03, "output": 0.06},
"gpt-3.5-turbo": {"input": 0.0015, "output": 0.002},
"claude-3-opus": {"input": 0.015, "output": 0.075},
"claude-3-sonnet": {"input": 0.003, "output": 0.015},
}
def __init__(self, config: CostConfig):
self.config = config
self._costs: List[SamplingCost] = []
self._cache: Dict[str, Dict] = {}
def estimate_cost(self, input_text: str, output_tokens: int, model: str) -> SamplingCost:
"""估算采样成本"""
# 简化的 Token 估算
input_tokens = len(input_text) // 4
price = self.PRICES.get(model, self.PRICES["gpt-3.5-turbo"])
estimated_cost = (
(input_tokens / 1000) * price["input"] +
(output_tokens / 1000) * price["output"]
)
return SamplingCost(
input_tokens=input_tokens,
output_tokens=output_tokens,
estimated_cost=estimated_cost,
model=model,
timestamp=datetime.now()
)
def check_budget(self) -> Dict[str, any]:
"""检查是否超出预算"""
period_costs = self._get_period_costs()
total_cost = sum(c.estimated_cost for c in period_costs)
remaining = self.config.max_total_cost - total_cost
return {
"allowed": remaining > 0,
"remaining": max(0, remaining),
"used": total_cost,
"total": self.config.max_total_cost
}
def select_model(self, complexity: str = "medium") -> str:
"""选择最优模型"""
model_tiers = {
"low": ["gpt-3.5-turbo", "claude-3-sonnet"],
"medium": ["claude-3-sonnet", "gpt-4"],
"high": ["claude-3-opus", "gpt-4"],
}
candidates = model_tiers.get(complexity, model_tiers["medium"])
# 优先选择配置中指定的模型
for preferred in self.config.preferred_models:
if preferred in candidates:
return preferred
return candidates[0]
def check_cache(self, prompt: str) -> Optional[Dict]:
"""检查缓存"""
hash_key = self._hash_prompt(prompt)
return self._cache.get(hash_key)
def set_cache(self, prompt: str, result: any, cost: SamplingCost) -> None:
"""存入缓存"""
hash_key = self._hash_prompt(prompt)
self._cache[hash_key] = {"result": result, "cost": cost}
# 限制缓存大小
if len(self._cache) > 1000:
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]
def record_cost(self, cost: SamplingCost) -> None:
"""记录实际成本"""
self._costs.append(cost)
self._cleanup_old_costs()
def _get_period_costs(self) -> List[SamplingCost]:
"""获取周期内的成本"""
now = datetime.now()
if self.config.budget_period == BudgetPeriod.DAILY:
period_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
elif self.config.budget_period == BudgetPeriod.WEEKLY:
period_start = now - timedelta(days=now.weekday())
period_start = period_start.replace(hour=0, minute=0, second=0, microsecond=0)
else: # MONTHLY
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return [c for c in self._costs if c.timestamp >= period_start]
def _cleanup_old_costs(self) -> None:
"""清理旧成本记录"""
one_month_ago = datetime.now() - timedelta(days=30)
self._costs = [c for c in self._costs if c.timestamp >= one_month_ago]
def _hash_prompt(self, prompt: str) -> str:
"""生成提示哈希"""
return hashlib.md5(prompt.encode()).hexdigest()
def get_cost_report(self) -> Dict:
"""获取成本报告"""
period_costs = self._get_period_costs()
total_cost = sum(c.estimated_cost for c in period_costs)
by_model: Dict[str, Dict] = {}
for cost in period_costs:
if cost.model not in by_model:
by_model[cost.model] = {"calls": 0, "cost": 0.0}
by_model[cost.model]["calls"] += 1
by_model[cost.model]["cost"] += cost.estimated_cost
return {
"total_cost": total_cost,
"total_calls": len(period_costs),
"average_cost": total_cost / len(period_costs) if period_costs else 0,
"by_model": by_model,
"period": self.config.budget_period.value
}
# 使用示例
def optimized_sampling_example():
controller = CostController(CostConfig(
max_input_tokens=4000,
max_output_tokens=1000,
max_total_cost=10.0,
preferred_models=["claude-3-sonnet", "gpt-3.5-turbo"],
budget_period=BudgetPeriod.DAILY
))
prompt = "分析这段代码的性能问题"
# 检查缓存
cached = controller.check_cache(prompt)
if cached:
print("[CostController] Cache hit!")
return cached["result"]
# 检查预算
budget = controller.check_budget()
if not budget["allowed"]:
raise RuntimeError(f"Budget exceeded. Remaining: ${budget['remaining']:.2f}")
# 选择模型
model = controller.select_model("medium")
# 估算成本
estimated = controller.estimate_cost(prompt, 500, model)
print(f"[CostController] Estimated cost: ${estimated.estimated_cost:.4f}")
# 执行采样...
result = {"analysis": "code analysis result"}
# 记录成本
controller.record_cost(estimated)
# 存入缓存
controller.set_cache(prompt, result, estimated)
return result
9. 提示模板与采样的安全考虑
9.1 安全风险分析
| 风险类型 | 描述 | 防护措施 |
|---|---|---|
| 提示注入 | 恶意输入操纵 LLM 行为 | 输入验证、参数转义 |
| 数据泄露 | 敏感信息通过采样泄露 | 数据脱敏、权限控制 |
| 资源耗尽 | 恶意请求消耗大量资源 | 限流、配额管理 |
| 递归攻击 | 利用嵌套调用造成 DoS | 深度限制、超时控制 |
| 模型滥用 | 生成有害内容 | 内容过滤、审核机制 |
9.2 输入验证与净化
typescript
// security/input-validator.ts
interface ValidationRule {
type: "string" | "number" | "boolean" | "array" | "object";
required?: boolean;
minLength?: number;
maxLength?: number;
pattern?: RegExp;
enum?: unknown[];
sanitize?: (value: unknown) => unknown;
}
class InputValidator {
private readonly DANGEROUS_PATTERNS = [
/ignore\s+previous/i,
/system\s+prompt/i,
/override\s+instructions/i,
/<script/i,
/javascript:/i,
];
/**
* 验证并净化输入
*/
validateAndSanitize(
input: Record<string, unknown>,
rules: Record<string, ValidationRule>
): { valid: boolean; sanitized: Record<string, unknown>; errors: string[] } {
const sanitized: Record<string, unknown> = {};
const errors: string[] = [];
for (const [key, rule] of Object.entries(rules)) {
const value = input[key];
// 检查必填
if (rule.required && (value === undefined || value === null)) {
errors.push(`Missing required field: ${key}`);
continue;
}
// 可选字段为空时跳过
if (!rule.required && (value === undefined || value === null)) {
continue;
}
// 类型检查
if (!this.checkType(value, rule.type)) {
errors.push(`Invalid type for ${key}: expected ${rule.type}`);
continue;
}
// 字符串特定检查
if (rule.type === "string" && typeof value === "string") {
// 长度检查
if (rule.minLength && value.length < rule.minLength) {
errors.push(`${key} is too short (min ${rule.minLength})`);
continue;
}
if (rule.maxLength && value.length > rule.maxLength) {
errors.push(`${key} is too long (max ${rule.maxLength})`);
continue;
}
// 模式检查
if (rule.pattern && !rule.pattern.test(value)) {
errors.push(`${key} does not match required pattern`);
continue;
}
// 危险模式检查
if (this.containsDangerousPatterns(value)) {
errors.push(`${key} contains potentially dangerous content`);
continue;
}
// 净化
let sanitizedValue = this.sanitizeString(value);
if (rule.sanitize) {
sanitizedValue = rule.sanitize(sanitizedValue) as string;
}
sanitized[key] = sanitizedValue;
} else {
sanitized[key] = value;
}
}
return {
valid: errors.length === 0,
sanitized,
errors,
};
}
/**
* 检查类型
*/
private checkType(value: unknown, type: string): boolean {
switch (type) {
case "string":
return typeof value === "string";
case "number":
return typeof value === "number" && !isNaN(value);
case "boolean":
return typeof value === "boolean";
case "array":
return Array.isArray(value);
case "object":
return typeof value === "object" && value !== null && !Array.isArray(value);
default:
return false;
}
}
/**
* 检查危险模式
*/
private containsDangerousPatterns(value: string): boolean {
return this.DANGEROUS_PATTERNS.some((pattern) => pattern.test(value));
}
/**
* 净化字符串
*/
private sanitizeString(value: string): string {
return value
.replace(/[<>]/g, "") // 移除 HTML 标签
.trim();
}
}
// 使用示例
const validator = new InputValidator();
const rules = {
code: {
type: "string" as const,
required: true,
maxLength: 10000,
},
language: {
type: "string" as const,
required: true,
enum: ["typescript", "python", "javascript", "java"],
},
focus: {
type: "string" as const,
required: false,
pattern: /^[a-z]+$/,
},
};
const result = validator.validateAndSanitize(
{
code: "function test() { return 1; }",
language: "javascript",
focus: "performance",
},
rules
);
if (!result.valid) {
console.error("Validation failed:", result.errors);
}
9.3 数据脱敏处理
python
# security/data_masking.py
import re
from typing import Dict, List, Pattern
from dataclasses import dataclass
@dataclass
class MaskingRule:
"""脱敏规则"""
name: str
pattern: Pattern
replacement: str
description: str
class DataMasker:
"""数据脱敏器"""
DEFAULT_RULES = [
MaskingRule(
name="email",
pattern=re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'),
replacement="[EMAIL_MASKED]",
description="邮箱地址"
),
MaskingRule(
name="phone",
pattern=re.compile(r'\b1[3-9]\d{9}\b'),
replacement="[PHONE_MASKED]",
description="手机号码"
),
MaskingRule(
name="credit_card",
pattern=re.compile(r'\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b'),
replacement="[CARD_MASKED]",
description="信用卡号"
),
MaskingRule(
name="api_key",
pattern=re.compile(r'(?:api[_-]?key|token)["\']?\s*[:=]\s*["\']?[a-zA-Z0-9]{32,}["\']?', re.IGNORECASE),
replacement="[API_KEY_MASKED]",
description="API 密钥"
),
MaskingRule(
name="password",
pattern=re.compile(r'(?:password|passwd|pwd)["\']?\s*[:=]\s*["\']?[^"\'\s]+["\']?', re.IGNORECASE),
replacement="[PASSWORD_MASKED]",
description="密码"
),
MaskingRule(
name="ssn",
pattern=re.compile(r'\b\d{3}-\d{2}-\d{4}\b'),
replacement="[SSN_MASKED]",
description="社会安全号码"
),
]
def __init__(self, rules: List[MaskingRule] = None):
self.rules = rules or self.DEFAULT_RULES
def mask(self, text: str, strict: bool = False) -> str:
"""
对文本进行脱敏处理
Args:
text: 原始文本
strict: 是否严格模式(发现敏感信息则拒绝)
Returns:
脱敏后的文本
"""
masked_text = text
found_sensitive = []
for rule in self.rules:
matches = rule.pattern.findall(masked_text)
if matches:
found_sensitive.append({
"rule": rule.name,
"description": rule.description,
"count": len(matches)
})
masked_text = rule.pattern.sub(rule.replacement, masked_text)
if strict and found_sensitive:
raise ValueError(
f"Sensitive data detected: "
f"{', '.join(f['description'] for f in found_sensitive)}"
)
return masked_text
def scan(self, text: str) -> List[Dict]:
"""扫描文本中的敏感信息"""
findings = []
for rule in self.rules:
matches = rule.pattern.findall(text)
if matches:
findings.append({
"rule": rule.name,
"description": rule.description,
"matches": matches[:5], # 最多显示5个
"count": len(matches)
})
return findings
def add_custom_rule(self, rule: MaskingRule) -> None:
"""添加自定义脱敏规则"""
self.rules.append(rule)
# 使用示例
def secure_sampling_example():
masker = DataMasker()
# 包含敏感信息的代码
code_with_secrets = """
const config = {
apiKey: "sk-1234567890abcdef1234567890abcdef",
email: "admin@company.com",
password: "super_secret_password123"
};
"""
# 扫描敏感信息
findings = masker.scan(code_with_secrets)
print("Found sensitive data:", findings)
# 脱敏处理
masked_code = masker.mask(code_with_secrets)
print("Masked code:", masked_code)
# 严格模式(发现敏感信息则抛出异常)
try:
masker.mask(code_with_secrets, strict=True)
except ValueError as e:
print(f"Security violation: {e}")
10. 实际应用场景与案例
10.1 场景一:智能代码审查系统
┌─────────────────────────────────────────────────────────────────┐
│ 智能代码审查系统架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ GitHub │───►│ MCP Server │───►│ LLM Host │ │
│ │ Webhook │ │ (Reviewer) │ │ (Claude) │ │
│ └─────────────┘ └──────┬──────┘ └─────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Prompts │ │
│ │ • security │ │
│ │ • style │ │
│ │ • perf │ │
│ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
TypeScript 实现:
typescript
// examples/code-review-system.ts
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
interface PullRequest {
number: number;
title: string;
files: Array<{
filename: string;
patch: string;
status: "added" | "modified" | "removed";
}>;
}
class CodeReviewServer {
private server: Server;
constructor() {
this.server = new Server({
name: "code-review-server",
version: "1.0.0",
});
this.setupPrompts();
this.setupTools();
}
private setupPrompts(): void {
// 注册代码审查提示
this.server.setRequestHandler(ListPromptsRequestSchema, async () => {
return {
prompts: [
{
name: "pr-security-review",
description: "对 PR 进行安全审查",
arguments: [
{ name: "prNumber", description: "PR 编号", required: true },
{ name: "focus", description: "审查重点", required: false },
],
},
{
name: "pr-style-review",
description: "对 PR 进行代码风格审查",
arguments: [
{ name: "prNumber", description: "PR 编号", required: true },
],
},
],
};
});
// 处理提示请求
this.server.setRequestHandler(GetPromptRequestSchema, async (request) => {
const { name, arguments: args } = request.params;
const pr = await this.fetchPR(args.prNumber as number);
if (name === "pr-security-review") {
return this.buildSecurityReviewPrompt(pr, args.focus as string);
}
if (name === "pr-style-review") {
return this.buildStyleReviewPrompt(pr);
}
throw new Error(`Unknown prompt: ${name}`);
});
}
private async fetchPR(prNumber: number): Promise<PullRequest> {
// 从 GitHub API 获取 PR 信息
// 实现省略...
return {
number: prNumber,
title: "Example PR",
files: [],
};
}
private buildSecurityReviewPrompt(pr: PullRequest, focus?: string): {
description: string;
messages: any[];
} {
const codeToReview = pr.files
.filter((f) => f.status !== "removed")
.map((f) => `File: ${f.filename}\n\`\`\`\n${f.patch}\n\`\`\``)
.join("\n\n");
return {
description: `Security Review for PR #${pr.number}`,
messages: [
{
role: "system",
content: {
type: "text",
text: `You are a security expert. Review the code for:
- SQL injection vulnerabilities
- XSS vulnerabilities
- Authentication/authorization issues
- Sensitive data exposure
- ${focus || "All security concerns"}`,
},
},
{
role: "user",
content: {
type: "text",
text: `Please review this PR for security issues:\n\n${codeToReview}`,
},
},
],
};
}
private setupTools(): void {
// 注册工具:提交审查评论
this.server.setRequestHandler(CallToolRequestSchema, async (request) => {
if (request.params.name === "submit-review") {
const { prNumber, review } = request.params.arguments as {
prNumber: number;
review: string;
};
// 提交审查评论到 GitHub
await this.submitReviewToGitHub(prNumber, review);
return {
content: [{ type: "text", text: "Review submitted successfully" }],
};
}
throw new Error("Unknown tool");
});
}
private async submitReviewToGitHub(prNumber: number, review: string): Promise<void> {
// GitHub API 调用
// 实现省略...
}
}
10.2 场景二:数据分析助手
python
# examples/data-analysis-assistant.py
from mcp.server import Server
from mcp.types import Prompt, TextContent
import pandas as pd
import json
class DataAnalysisServer:
"""数据分析助手 MCP 服务器"""
def __init__(self):
self.server = Server("data-analysis-server")
self.data_cache: Dict[str, pd.DataFrame] = {}
def setup_prompts(self):
"""设置提示模板"""
@self.server.list_prompts()
async def list_prompts():
return [
Prompt(
name="analyze-dataset",
description="分析数据集并提供洞察",
arguments=[
{"name": "dataset_id", "description": "数据集ID", "required": True},
{"name": "analysis_type", "description": "分析类型", "required": False},
]
),
Prompt(
name="generate-report",
description="生成数据分析报告",
arguments=[
{"name": "dataset_id", "description": "数据集ID", "required": True},
{"name": "report_format", "description": "报告格式", "required": False},
]
),
]
@self.server.get_prompt()
async def get_prompt(name: str, arguments: dict):
dataset_id = arguments.get("dataset_id")
df = self.data_cache.get(dataset_id)
if df is None:
raise ValueError(f"Dataset not found: {dataset_id}")
# 生成数据集摘要
summary = self._generate_summary(df)
if name == "analyze-dataset":
analysis_type = arguments.get("analysis_type", "comprehensive")
return self._build_analysis_prompt(summary, analysis_type)
elif name == "generate-report":
report_format = arguments.get("report_format", "markdown")
return self._build_report_prompt(summary, report_format)
raise ValueError(f"Unknown prompt: {name}")
def _generate_summary(self, df: pd.DataFrame) -> dict:
"""生成数据集摘要"""
return {
"shape": df.shape,
"columns": df.columns.tolist(),
"dtypes": df.dtypes.to_dict(),
"missing": df.isnull().sum().to_dict(),
"sample": df.head(5).to_dict(),
"describe": df.describe().to_dict(),
}
def _build_analysis_prompt(self, summary: dict, analysis_type: str) -> dict:
"""构建分析提示"""
analysis_focus = {
"comprehensive": "全面分析,包括统计摘要、趋势、异常值和相关性",
"trend": "重点关注时间趋势和变化模式",
"correlation": "重点关注变量间的相关性分析",
"anomaly": "重点关注异常值和离群点检测",
}
return {
"description": f"Dataset Analysis ({analysis_type})",
"messages": [
{
"role": "system",
"content": TextContent(
type="text",
text=f"""你是一位数据分析专家。请对提供的数据集进行专业分析。
分析重点:{analysis_focus.get(analysis_type, analysis_focus["comprehensive"])}
请提供:
1. 数据概览和关键指标
2. 主要发现和洞察
3. 数据质量问题(如有)
4. 可视化建议
5. 进一步分析的建议"""
)
},
{
"role": "user",
"content": TextContent(
type="text",
text=f"请分析以下数据集:\n\n```json\n{json.dumps(summary, indent=2, default=str)}\n```"
)
}
]
}
def _build_report_prompt(self, summary: dict, report_format: str) -> dict:
"""构建报告提示"""
format_instructions = {
"markdown": "使用 Markdown 格式,包含标题、表格和列表",
"html": "使用 HTML 格式,包含样式和表格",
"json": "使用 JSON 格式,结构化输出",
}
return {
"description": f"Analysis Report ({report_format})",
"messages": [
{
"role": "system",
"content": TextContent(
type="text",
text=f"""你是一位数据报告专家。请基于数据集生成专业报告。
格式要求:{format_instructions.get(report_format, format_instructions["markdown"])}
报告结构:
1. 执行摘要
2. 数据描述
3. 分析方法
4. 主要发现
5. 结论和建议"""
)
},
{
"role": "user",
"content": TextContent(
type="text",
text=f"请为以下数据集生成报告:\n\n```json\n{json.dumps(summary, indent=2, default=str)}\n```"
)
}
]
}
# 使用示例
async def run_analysis_example():
server = DataAnalysisServer()
# 加载示例数据
server.data_cache["sales_2024"] = pd.DataFrame({
"date": pd.date_range("2024-01-01", periods=100),
"product": ["A", "B", "C"] * 33 + ["A"],
"sales": [100 + i * 2 + (i % 10) * 5 for i in range(100)],
"region": ["North", "South", "East", "West"] * 25,
})
server.setup_prompts()
10.3 场景三:智能客服系统
typescript
// examples/customer-service-bot.ts
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
interface Customer {
id: string;
name: string;
tier: "basic" | "premium" | "enterprise";
history: Array<{
timestamp: Date;
query: string;
resolved: boolean;
}>;
}
interface KnowledgeBase {
articles: Array<{
id: string;
title: string;
content: string;
tags: string[];
}>;
}
class CustomerServiceServer {
private server: Server;
private customers: Map<string, Customer> = new Map();
private knowledgeBase: KnowledgeBase;
constructor() {
this.server = new Server({
name: "customer-service-server",
version: "1.0.0",
});
this.setupPrompts();
}
private setupPrompts(): void {
// 注册客服提示模板
this.server.setRequestHandler(ListPromptsRequestSchema, async () => {
return {
prompts: [
{
name: "handle-inquiry",
description: "处理客户咨询",
arguments: [
{ name: "customerId", description: "客户ID", required: true },
{ name: "query", description: "客户问题", required: true },
{ name: "category", description: "问题分类", required: false },
],
},
{
name: "escalate-ticket",
description: "升级工单到人工",
arguments: [
{ name: "customerId", description: "客户ID", required: true },
{ name: "reason", description: "升级原因", required: true },
],
},
],
};
});
// 处理提示请求
this.server.setRequestHandler(GetPromptRequestSchema, async (request) => {
const { name, arguments: args } = request.params;
const customer = this.customers.get(args.customerId as string);
if (!customer) {
throw new Error(`Customer not found: ${args.customerId}`);
}
if (name === "handle-inquiry") {
return this.buildInquiryPrompt(customer, args.query as string, args.category as string);
}
if (name === "escalate-ticket") {
return this.buildEscalationPrompt(customer, args.reason as string);
}
throw new Error(`Unknown prompt: ${name}`);
});
}
private buildInquiryPrompt(
customer: Customer,
query: string,
category?: string
): { description: string; messages: any[] } {
// 检索相关知识库文章
const relevantArticles = this.searchKnowledgeBase(query);
// 构建客户上下文
const customerContext = `
Customer: ${customer.name}
Tier: ${customer.tier}
Previous Interactions: ${customer.history.length}
Recent Issues: ${customer.history
.slice(-3)
.map((h) => h.query)
.join("; ")}
`;
// 构建知识库上下文
const kbContext = relevantArticles
.map((a) => `Article: ${a.title}\n${a.content}`)
.join("\n\n");
return {
description: `Handle inquiry from ${customer.name}`,
messages: [
{
role: "system",
content: {
type: "text",
text: `You are a helpful customer service representative.
Use the provided knowledge base articles to answer the customer's question.
Be polite, professional, and concise.
Customer Context:
${customerContext}
Knowledge Base:
${kbContext}`,
},
},
{
role: "user",
content: {
type: "text",
text: `Customer Query${category ? ` [${category}]` : ""}: ${query}`,
},
},
],
};
}
private searchKnowledgeBase(query: string): typeof this.knowledgeBase.articles {
// 简单的关键词匹配
const keywords = query.toLowerCase().split(" ");
return this.knowledgeBase.articles
.filter((article) =>
keywords.some(
(k) =>
article.title.toLowerCase().includes(k) ||
article.content.toLowerCase().includes(k) ||
article.tags.some((t) => t.toLowerCase().includes(k))
)
)
.slice(0, 3); // 返回最相关的3篇
}
private buildEscalationPrompt(customer: Customer, reason: string) {
return {
description: `Escalate ticket for ${customer.name}`,
messages: [
{
role: "system",
content: {
type: "text",
text: `Generate a ticket escalation summary for the support team.`,
},
},
{
role: "user",
content: {
type: "text",
text: `Escalate ticket for customer ${customer.name} (Tier: ${customer.tier})
Reason: ${reason}
History: ${JSON.stringify(customer.history.slice(-5))}`,
},
},
],
};
}
}
11. 常见问题 FAQ
Q1: Prompts 和 Tools 的主要区别是什么?
A: Prompts 用于提供结构化的提示模板,返回的是消息列表(给 LLM 的上下文);Tools 用于执行具体的功能操作,返回的是执行结果。Prompts 通常由用户选择使用,而 Tools 由 LLM 自动调用。
Q2: 如何实现动态提示模板?
A: 动态提示模板可以通过以下方式实现:
- 使用参数化模板,根据输入参数动态生成提示内容
- 在
GetPromptRequest处理器中根据参数动态构建消息列表 - 结合外部数据源(如数据库、API)动态获取提示内容
typescript
// 示例:动态生成提示
server.setRequestHandler(GetPromptRequestSchema, async (request) => {
const { name, arguments: args } = request.params;
// 根据参数动态生成
const dynamicContent = await fetchDynamicContent(args);
return {
description: `Dynamic prompt for ${name}`,
messages: [
{ role: "system", content: { type: "text", text: dynamicContent } },
],
};
});
Q3: Sampling 请求失败如何处理?
A: Sampling 请求失败时应采取以下策略:
- 重试机制:对临时错误进行指数退避重试
- 降级策略:切换到备用模型或简化请求
- 错误反馈:向用户清晰说明失败原因
- 日志记录:记录失败详情用于后续分析
typescript
try {
const result = await samplingClient.createMessage(options);
} catch (error) {
if (error.code === "RATE_LIMITED") {
// 等待后重试
await delay(1000);
return await retryWithBackoff(options);
}
if (error.code === "MODEL_UNAVAILABLE") {
// 切换到备用模型
return await fallbackToAlternativeModel(options);
}
throw error;
}
Q4: 如何防止提示注入攻击?
A: 防止提示注入的主要措施:
- 输入验证:使用正则表达式过滤危险模式
- 参数转义:对用户输入进行适当的转义处理
- 长度限制:限制输入内容的长度
- 内容审查:使用内容安全策略检测恶意输入
- 最小权限:限制 Prompts 可访问的资源和数据
Q5: 采样请求的上下文长度有限制吗?
A: 是的,采样请求的上下文长度受以下限制:
- 模型限制:不同模型的最大上下文长度不同(如 GPT-4 支持 8K/32K tokens)
- 配置限制 :可以通过
maxTokens参数限制输出长度 - 协议限制:MCP 协议本身没有硬性限制,但建议保持合理大小
建议实现上下文压缩策略,如只保留最近的 N 条消息。
Q6: 如何实现多轮对话的上下文保持?
A: 多轮对话上下文保持可以通过以下方式实现:
- 会话管理:使用 Session ID 跟踪对话状态
- 消息历史:在服务器端存储消息历史
- 上下文传递:在每次请求中包含历史消息
python
class SessionManager:
def __init__(self):
self.sessions = {}
def append_message(self, session_id, message):
if session_id not in self.sessions:
self.sessions[session_id] = []
self.sessions[session_id].append(message)
def get_context(self, session_id, max_history=10):
return self.sessions.get(session_id, [])[-max_history:]
Q7: Sampling 的成本如何控制?
A: 成本控制策略包括:
- 模型选择:根据任务复杂度选择性价比合适的模型
- Token 限制 :设置合理的
maxTokens限制 - 缓存机制:缓存常见查询的结果
- 预算管理:设置每日/每周采样预算上限
- 批处理:合并多个小请求为批量请求
Q8: 嵌套调用深度应该限制在多少层?
A: 建议将嵌套调用深度限制在 3 层以内。超过这个深度通常意味着:
- 存在设计问题,应该重构逻辑
- 可能出现无限递归风险
- 性能会显著下降
typescript
const MAX_NESTING_DEPTH = 3;
if (currentDepth > MAX_NESTING_DEPTH) {
throw new Error(`Nesting depth exceeded: ${currentDepth}`);
}
Q9: Prompts 可以返回图片或多媒体内容吗?
A: MCP 协议支持返回多媒体内容。消息内容的 type 可以是:
"text":纯文本内容"image":图片内容(需要提供 base64 编码)
typescript
{
role: "user",
content: {
type: "image",
data: base64EncodedImage,
mimeType: "image/png"
}
}
Q10: 如何调试 Prompts 和 Sampling 问题?
A: 调试建议:
- 日志记录:记录所有 Prompts 请求和响应
- 采样追踪:记录 Sampling 请求的完整生命周期
- 本地测试:使用 MCP Inspector 工具进行本地调试
- 渐进测试:从简单场景开始,逐步增加复杂度
- 错误隔离:使用 try-catch 隔离问题范围
12. 参考文献
-
MCP 官方文档
- URL: https://modelcontextprotocol.io/
- 描述:Model Context Protocol 官方文档,包含完整的协议规范和 API 参考
-
MCP 协议规范
- URL: https://spec.modelcontextprotocol.io/
- 描述:详细的协议规范文档,包含所有消息类型和交互流程
-
MCP TypeScript SDK
- URL: https://github.com/modelcontextprotocol/typescript-sdk
- 描述:官方 TypeScript SDK,提供类型定义和客户端/服务器实现
-
MCP Python SDK
- URL: https://github.com/modelcontextprotocol/python-sdk
- 描述:官方 Python SDK,支持 Python 3.10+
-
Anthropic MCP 博客
- URL: https://www.anthropic.com/news/model-context-protocol
- 描述:MCP 协议的官方介绍和设计理念
-
OpenAI Function Calling 指南
- URL: https://platform.openai.com/docs/guides/function-calling
- 描述:函数调用最佳实践,与 MCP Tools 设计相关
-
Prompt Injection 防护研究
- URL: https://owasp.org/www-project-top-10-for-large-language-model-applications/
- 描述:LLM 应用安全风险与防护措施
-
LLM Cost Optimization
- URL: https://platform.openai.com/docs/guides/production-best-practices
- 描述:LLM 生产环境最佳实践和成本优化策略
附录:完整配置示例
TypeScript 服务器配置
typescript
// config/server.config.ts
export const serverConfig = {
name: "advanced-mcp-server",
version: "1.0.0",
prompts: {
enabled: true,
maxTemplates: 100,
cacheEnabled: true,
},
sampling: {
enabled: true,
maxConcurrent: 5,
defaultTimeout: 60000,
maxDepth: 3,
},
security: {
inputValidation: true,
maxInputLength: 10000,
dangerousPatterns: [
/ignore\s+previous/i,
/system\s+prompt/i,
],
},
cost: {
budgetPeriod: "daily",
maxDailyCost: 10.0,
preferredModels: ["claude-3-sonnet", "gpt-3.5-turbo"],
},
};
Python 服务器配置
python
# config/server_config.py
from dataclasses import dataclass
from typing import List
@dataclass
class ServerConfig:
name: str = "advanced-mcp-server"
version: str = "1.0.0"
# Prompts 配置
prompts_enabled: bool = True
max_templates: int = 100
cache_enabled: bool = True
# Sampling 配置
sampling_enabled: bool = True
max_concurrent: int = 5
default_timeout: int = 60
max_depth: int = 3
# 安全配置
input_validation: bool = True
max_input_length: int = 10000
# 成本配置
budget_period: str = "daily"
max_daily_cost: float = 10.0
preferred_models: List[str] = None
def __post_init__(self):
if self.preferred_models is None:
self.preferred_models = ["claude-3-sonnet", "gpt-3.5-turbo"]
# 全局配置实例
config = ServerConfig()