MCP案例 - 数据可视化工具服务器

python 复制代码
#!/usr/bin/env python3
"""
MCP数据可视化服务器示例
演示如何将Resources和Prompts转换为Tools,让大模型自主调用
"""

import json
import asyncio
from typing import Any, Dict, List, Optional
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from datetime import datetime, timedelta
import io
import base64

# MCP相关导入(假设使用mcp库)
from mcp.server import Server
from mcp.server.models import InitializationOptions
from mcp.server.stdio import stdio_server
from mcp.types import (
    CallToolRequestSchema,
    ListToolsRequestSchema,
    Tool,
    TextContent,
    ImageContent,
    EmbeddedResource
)

class DataVisualizationServer:
    """数据可视化MCP服务器"""
    
    def __init__(self):
        self.server = Server("data-visualization")
        self.datasets = self._create_sample_datasets()
        self._setup_tools()
    
    def _create_sample_datasets(self) -> Dict[str, pd.DataFrame]:
        """创建示例数据集"""
        # 销售数据
        dates = pd.date_range('2023-01-01', '2024-12-31', freq='D')
        sales_data = pd.DataFrame({
            'date': dates,
            'sales': np.random.normal(1000, 200, len(dates)) + 
                    np.sin(np.arange(len(dates)) * 2 * np.pi / 365) * 100,
            'product': np.random.choice(['A', 'B', 'C'], len(dates)),
            'region': np.random.choice(['North', 'South', 'East', 'West'], len(dates))
        })
        
        # 用户数据
        user_data = pd.DataFrame({
            'age_group': ['18-25', '26-35', '36-45', '46-55', '55+'],
            'count': [1200, 2500, 1800, 1100, 800],
            'revenue': [45000, 95000, 78000, 55000, 32000]
        })
        
        # 股票数据
        stock_dates = pd.date_range('2024-01-01', '2024-12-31', freq='D')
        stock_data = pd.DataFrame({
            'date': stock_dates,
            'price': 100 + np.cumsum(np.random.normal(0, 2, len(stock_dates))),
            'volume': np.random.exponential(1000, len(stock_dates))
        })
        
        return {
            'sales': sales_data,
            'users': user_data,
            'stocks': stock_data
        }
    
    def _setup_tools(self):
        """设置所有工具"""
        
        @self.server.list_tools()
        async def handle_list_tools() -> List[Tool]:
            """列出所有可用工具"""
            return [
                Tool(
                    name="list_available_datasets",
                    description="列出所有可用的数据集",
                    inputSchema={
                        "type": "object",
                        "properties": {},
                        "required": []
                    }
                ),
                Tool(
                    name="get_dataset",
                    description="获取指定数据集的内容",
                    inputSchema={
                        "type": "object",
                        "properties": {
                            "dataset_name": {
                                "type": "string",
                                "description": "数据集名称 (sales, users, stocks)"
                            },
                            "limit": {
                                "type": "integer",
                                "description": "返回行数限制,默认为所有行",
                                "default": None
                            }
                        },
                        "required": ["dataset_name"]
                    }
                ),
                Tool(
                    name="analyze_data_structure",
                    description="分析数据集的结构和特征",
                    inputSchema={
                        "type": "object",
                        "properties": {
                            "dataset_name": {
                                "type": "string",
                                "description": "要分析的数据集名称"
                            }
                        },
                        "required": ["dataset_name"]
                    }
                ),
                Tool(
                    name="suggest_visualization",
                    description="根据数据特征建议最佳可视化方式",
                    inputSchema={
                        "type": "object",
                        "properties": {
                            "dataset_name": {
                                "type": "string",
                                "description": "数据集名称"
                            },
                            "analysis_goal": {
                                "type": "string",
                                "description": "分析目标 (trend, distribution, comparison, correlation)"
                            }
                        },
                        "required": ["dataset_name", "analysis_goal"]
                    }
                ),
                Tool(
                    name="create_chart",
                    description="创建数据可视化图表",
                    inputSchema={
                        "type": "object",
                        "properties": {
                            "dataset_name": {
                                "type": "string",
                                "description": "数据集名称"
                            },
                            "chart_type": {
                                "type": "string",
                                "description": "图表类型 (line, bar, scatter, pie, heatmap, box)"
                            },
                            "x_column": {
                                "type": "string",
                                "description": "X轴列名"
                            },
                            "y_column": {
                                "type": "string",
                                "description": "Y轴列名"
                            },
                            "title": {
                                "type": "string",
                                "description": "图表标题",
                                "default": ""
                            },
                            "group_by": {
                                "type": "string",
                                "description": "分组列名(可选)",
                                "default": None
                            }
                        },
                        "required": ["dataset_name", "chart_type"]
                    }
                ),
                Tool(
                    name="get_data_insights",
                    description="获取数据洞察和统计摘要",
                    inputSchema={
                        "type": "object",
                        "properties": {
                            "dataset_name": {
                                "type": "string",
                                "description": "数据集名称"
                            }
                        },
                        "required": ["dataset_name"]
                    }
                )
            ]
        
        @self.server.call_tool()
        async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent | ImageContent]:
            """处理工具调用"""
            
            if name == "list_available_datasets":
                return await self._list_datasets()
            
            elif name == "get_dataset":
                return await self._get_dataset(
                    arguments["dataset_name"],
                    arguments.get("limit")
                )
            
            elif name == "analyze_data_structure":
                return await self._analyze_structure(arguments["dataset_name"])
            
            elif name == "suggest_visualization":
                return await self._suggest_visualization(
                    arguments["dataset_name"],
                    arguments["analysis_goal"]
                )
            
            elif name == "create_chart":
                return await self._create_chart(**arguments)
            
            elif name == "get_data_insights":
                return await self._get_insights(arguments["dataset_name"])
            
            else:
                raise ValueError(f"Unknown tool: {name}")
    
    async def _list_datasets(self) -> List[TextContent]:
        """列出所有可用数据集"""
        dataset_info = {
            "available_datasets": list(self.datasets.keys()),
            "descriptions": {
                "sales": "销售数据 - 包含日期、销售额、产品和地区信息",
                "users": "用户数据 - 包含年龄组分布和收入信息",
                "stocks": "股票数据 - 包含日期、价格和成交量信息"
            },
            "total_datasets": len(self.datasets)
        }
        
        return [TextContent(
            type="text",
            text=json.dumps(dataset_info, ensure_ascii=False, indent=2)
        )]
    
    async def _get_dataset(self, dataset_name: str, limit: Optional[int] = None) -> List[TextContent]:
        """获取数据集内容"""
        if dataset_name not in self.datasets:
            return [TextContent(
                type="text",
                text=f"错误: 数据集 '{dataset_name}' 不存在"
            )]
        
        df = self.datasets[dataset_name]
        if limit:
            df = df.head(limit)
        
        # 转换为JSON格式
        data_json = df.to_json(orient='records', date_format='iso', ensure_ascii=False)
        
        result = {
            "dataset_name": dataset_name,
            "shape": df.shape,
            "columns": list(df.columns),
            "data": json.loads(data_json)
        }
        
        return [TextContent(
            type="text",
            text=json.dumps(result, ensure_ascii=False, indent=2)
        )]
    
    async def _analyze_structure(self, dataset_name: str) -> List[TextContent]:
        """分析数据结构"""
        if dataset_name not in self.datasets:
            return [TextContent(
                type="text",
                text=f"错误: 数据集 '{dataset_name}' 不存在"
            )]
        
        df = self.datasets[dataset_name]
        
        # 数据类型分析
        numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
        categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
        datetime_cols = df.select_dtypes(include=['datetime64']).columns.tolist()
        
        analysis = {
            "basic_info": {
                "shape": df.shape,
                "columns": list(df.columns),
                "memory_usage": f"{df.memory_usage(deep=True).sum() / 1024:.2f} KB"
            },
            "column_types": {
                "numeric": numeric_cols,
                "categorical": categorical_cols,
                "datetime": datetime_cols
            },
            "data_quality": {
                "missing_values": df.isnull().sum().to_dict(),
                "duplicate_rows": df.duplicated().sum()
            },
            "numeric_summary": df.describe().to_dict() if numeric_cols else {},
            "recommendations": self._get_analysis_recommendations(df)
        }
        
        return [TextContent(
            type="text",
            text=json.dumps(analysis, ensure_ascii=False, indent=2, default=str)
        )]
    
    def _get_analysis_recommendations(self, df: pd.DataFrame) -> List[str]:
        """获取数据分析建议"""
        recommendations = []
        
        numeric_cols = df.select_dtypes(include=[np.number]).columns
        datetime_cols = df.select_dtypes(include=['datetime64']).columns
        categorical_cols = df.select_dtypes(include=['object']).columns
        
        if len(datetime_cols) > 0 and len(numeric_cols) > 0:
            recommendations.append("适合时间序列分析,建议使用线图展示趋势")
        
        if len(categorical_cols) > 0 and len(numeric_cols) > 0:
            recommendations.append("适合分组分析,建议使用条形图或箱线图比较不同类别")
        
        if len(numeric_cols) >= 2:
            recommendations.append("适合相关性分析,建议使用散点图或热力图")
        
        if len(categorical_cols) > 0:
            recommendations.append("适合分布分析,建议使用饼图或条形图")
        
        return recommendations
    
    async def _suggest_visualization(self, dataset_name: str, analysis_goal: str) -> List[TextContent]:
        """建议可视化方式"""
        if dataset_name not in self.datasets:
            return [TextContent(
                type="text",
                text=f"错误: 数据集 '{dataset_name}' 不存在"
            )]
        
        df = self.datasets[dataset_name]
        suggestions = self._generate_viz_suggestions(df, analysis_goal)
        
        return [TextContent(
            type="text",
            text=json.dumps(suggestions, ensure_ascii=False, indent=2)
        )]
    
    def _generate_viz_suggestions(self, df: pd.DataFrame, goal: str) -> Dict[str, Any]:
        """生成可视化建议"""
        numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
        datetime_cols = df.select_dtypes(include=['datetime64']).columns.tolist()
        categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
        
        suggestions = {
            "analysis_goal": goal,
            "recommended_charts": [],
            "column_mappings": {}
        }
        
        if goal == "trend":
            if datetime_cols and numeric_cols:
                suggestions["recommended_charts"] = ["line", "area"]
                suggestions["column_mappings"] = {
                    "x_axis": datetime_cols[0],
                    "y_axis": numeric_cols[0]
                }
        
        elif goal == "distribution":
            if numeric_cols:
                suggestions["recommended_charts"] = ["histogram", "box", "violin"]
                suggestions["column_mappings"] = {"value": numeric_cols[0]}
            if categorical_cols:
                suggestions["recommended_charts"].extend(["pie", "bar"])
        
        elif goal == "comparison":
            if categorical_cols and numeric_cols:
                suggestions["recommended_charts"] = ["bar", "box"]
                suggestions["column_mappings"] = {
                    "category": categorical_cols[0],
                    "value": numeric_cols[0]
                }
        
        elif goal == "correlation":
            if len(numeric_cols) >= 2:
                suggestions["recommended_charts"] = ["scatter", "heatmap"]
                suggestions["column_mappings"] = {
                    "x_axis": numeric_cols[0],
                    "y_axis": numeric_cols[1]
                }
        
        return suggestions
    
    async def _create_chart(self, dataset_name: str, chart_type: str, 
                          x_column: str = None, y_column: str = None,
                          title: str = "", group_by: str = None) -> List[TextContent | ImageContent]:
        """创建图表"""
        if dataset_name not in self.datasets:
            return [TextContent(
                type="text",
                text=f"错误: 数据集 '{dataset_name}' 不存在"
            )]
        
        df = self.datasets[dataset_name]
        
        try:
            # 设置图表样式
            plt.style.use('default')
            fig, ax = plt.subplots(figsize=(12, 8))
            
            # 根据图表类型创建图表
            if chart_type == "line":
                if x_column and y_column:
                    if group_by:
                        for group in df[group_by].unique():
                            group_data = df[df[group_by] == group]
                            ax.plot(group_data[x_column], group_data[y_column], 
                                   label=group, marker='o')
                        ax.legend()
                    else:
                        ax.plot(df[x_column], df[y_column], marker='o')
                    ax.set_xlabel(x_column)
                    ax.set_ylabel(y_column)
            
            elif chart_type == "bar":
                if x_column and y_column:
                    if group_by:
                        df.groupby([x_column, group_by])[y_column].mean().unstack().plot(
                            kind='bar', ax=ax)
                    else:
                        ax.bar(df[x_column], df[y_column])
                    ax.set_xlabel(x_column)
                    ax.set_ylabel(y_column)
            
            elif chart_type == "scatter":
                if x_column and y_column:
                    if group_by:
                        for group in df[group_by].unique():
                            group_data = df[df[group_by] == group]
                            ax.scatter(group_data[x_column], group_data[y_column], 
                                     label=group, alpha=0.7)
                        ax.legend()
                    else:
                        ax.scatter(df[x_column], df[y_column], alpha=0.7)
                    ax.set_xlabel(x_column)
                    ax.set_ylabel(y_column)
            
            elif chart_type == "pie":
                if x_column:
                    pie_data = df[x_column].value_counts()
                    ax.pie(pie_data.values, labels=pie_data.index, autopct='%1.1f%%')
            
            elif chart_type == "box":
                if y_column:
                    if x_column:
                        df.boxplot(column=y_column, by=x_column, ax=ax)
                    else:
                        ax.boxplot(df[y_column])
                    ax.set_ylabel(y_column)
            
            elif chart_type == "heatmap":
                numeric_df = df.select_dtypes(include=[np.number])
                correlation_matrix = numeric_df.corr()
                sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', ax=ax)
            
            # 设置标题
            if title:
                ax.set_title(title, fontsize=16, fontweight='bold')
            else:
                ax.set_title(f'{dataset_name.title()} - {chart_type.title()} Chart', 
                           fontsize=16, fontweight='bold')
            
            plt.tight_layout()
            
            # 将图表转换为base64图像
            buffer = io.BytesIO()
            plt.savefig(buffer, format='png', dpi=300, bbox_inches='tight')
            buffer.seek(0)
            image_base64 = base64.b64encode(buffer.getvalue()).decode()
            plt.close()
            
            return [
                TextContent(
                    type="text",
                    text=f"成功创建 {chart_type} 图表,数据集: {dataset_name}"
                ),
                ImageContent(
                    type="image",
                    data=image_base64,
                    mimeType="image/png"
                )
            ]
            
        except Exception as e:
            return [TextContent(
                type="text",
                text=f"创建图表时出错: {str(e)}"
            )]
    
    async def _get_insights(self, dataset_name: str) -> List[TextContent]:
        """获取数据洞察"""
        if dataset_name not in self.datasets:
            return [TextContent(
                type="text",
                text=f"错误: 数据集 '{dataset_name}' 不存在"
            )]
        
        df = self.datasets[dataset_name]
        insights = self._generate_insights(df, dataset_name)
        
        return [TextContent(
            type="text",
            text=json.dumps(insights, ensure_ascii=False, indent=2, default=str)
        )]
    
    def _generate_insights(self, df: pd.DataFrame, dataset_name: str) -> Dict[str, Any]:
        """生成数据洞察"""
        insights = {
            "dataset": dataset_name,
            "key_findings": [],
            "statistics": {},
            "recommendations": []
        }
        
        numeric_cols = df.select_dtypes(include=[np.number]).columns
        
        # 基础统计
        for col in numeric_cols:
            stats = {
                "mean": df[col].mean(),
                "median": df[col].median(),
                "std": df[col].std(),
                "min": df[col].min(),
                "max": df[col].max()
            }
            insights["statistics"][col] = stats
            
            # 生成洞察
            if stats["std"] / stats["mean"] > 0.5:  # 高变异性
                insights["key_findings"].append(f"{col} 显示高变异性,标准差与均值比为 {stats['std']/stats['mean']:.2f}")
            
            if abs(stats["mean"] - stats["median"]) / stats["std"] > 0.5:  # 偏斜分布
                insights["key_findings"].append(f"{col} 分布可能存在偏斜")
        
        # 数据质量洞察
        missing_pct = (df.isnull().sum() / len(df) * 100)
        for col, pct in missing_pct.items():
            if pct > 5:
                insights["key_findings"].append(f"{col} 有 {pct:.1f}% 的缺失值")
        
        # 生成建议
        if len(numeric_cols) > 1:
            insights["recommendations"].append("建议进行相关性分析")
        
        if 'date' in df.columns or any('date' in col.lower() for col in df.columns):
            insights["recommendations"].append("建议进行时间序列分析")
        
        return insights

async def main():
    """主函数"""
    # 创建服务器实例
    viz_server = DataVisualizationServer()
    
    # 运行服务器
    async with stdio_server() as (read_stream, write_stream):
        await viz_server.server.run(
            read_stream,
            write_stream,
            InitializationOptions(
                server_name="data-visualization",
                server_version="1.0.0",
                capabilities=viz_server.server.get_capabilities(
                    notification_options=None,
                    experimental_capabilities=None,
                )
            )
        )

if __name__ == "__main__":
    asyncio.run(main())
相关推荐
虹科网络安全1 分钟前
艾体宝洞察 | 利用“隐形字符”的钓鱼邮件:传统防御为何失效,AI安全意识培训如何补上最后一道防线
运维·网络·安全
石像鬼₧魂石20 分钟前
Kali Linux 网络端口深度扫描
linux·运维·网络
alengan23 分钟前
linux上面写python3日志服务器
linux·运维·服务器
yBmZlQzJ1 小时前
免费内网穿透-端口转发配置介绍
运维·经验分享·docker·容器·1024程序员节
JH30731 小时前
docker 新手入门:10分钟搞定基础使用
运维·docker·容器
小卒过河01042 小时前
使用apache nifi 从数据库文件表路径拉取远程文件至远程服务器目的地址
运维·服务器·数据库
土星云SaturnCloud2 小时前
液冷“内卷”:在局部优化与系统重构之间,寻找第三条路
服务器·人工智能·ai·计算机外设
Empty_7772 小时前
DevOps理念
运维·devops
叶之香2 小时前
CentOS/RHEL 7、8安装exfat和ntfs文件系统
linux·运维·centos