玩转ChatGPT:Claude 3.7 Sonnet进行数据分析(分类)

一、写在前面

Claude 3.7 Sonnet刚刚发布,号称是全球第一个混合推理模型。

我们直接来测试一下它的数据分析能力。

首先,是解决机器学习分类问题。

二、开测

(1)描述性分析

上传csv数据文件,写下咒语:分析附件的数据,做一般的描述性统计分析,并输出相应的图表。

然后,它开始了思考模式:

The user has uploaded a CSV file named "jet.csv" and has asked me in Chinese to analyze the data, perform descriptive statistical analysis, and output corresponding charts. The message translates to: "Analyze the attached data, perform general descriptive statistical analysis, and output corresponding charts."

From the CSV information provided, I can see that:

1.The file has 1927 rows and 14 columns.

2.The columns are named A through M, plus an unnamed 14th column.

3.Columns A through F contain integers, while columns G through M and the unnamed column contain float values.

I should:

1.First read and analyze the CSV file to understand its structure and content better

2.Perform descriptive statistical analysis (min, max, mean, median, standard deviation, etc.)

3.Create visualizations based on this analysis

4.Present the results in a clear, organized manner

Since the user's request was in Chinese, I should respond in Chinese as well.

Let me use the analysis tool to read and analyze the data:

可以看到它的分析逻辑:分析数据类型------描述统计------制作图表------写报告。

确定了工作流程以后,开始自己码代码,期间还自己debug。

接着就是输出了分析报告:

1.1 描述性统计

1.2 异常值分析

1.3 变量分布直方图:还是交互图,我就放几个做案例

1.4 变量相关性分析

1.5 也提供了相关代码

javascript 复制代码
import React, { useState, useEffect } from 'react';
import { BarChart, Bar, ScatterChart, Scatter, XAxis, YAxis, CartesianGrid, Tooltip, Legend, ResponsiveContainer, Cell } from 'recharts';
import Papa from 'papaparse';

const JetDataAnalysis = () => {
  const [data, setData] = useState([]);
  const [stats, setStats] = useState({});
  const [correlations, setCorrelations] = useState([]);
  const [loading, setLoading] = useState(true);
  const [error, setError] = useState(null);
  const [activeTab, setActiveTab] = useState('summary');

  useEffect(() => {
    const analyzeData = async () => {
      try {
        // 获取并解析CSV文件
        const fileContent = await window.fs.readFile('jet.csv', { encoding: 'utf8' });
        
        Papa.parse(fileContent, {
          header: true,
          dynamicTyping: true,
          skipEmptyLines: true,
          complete: (results) => {
            const parsedData = results.data;
            setData(parsedData);
            
            // 计算每列的统计量
            const columns = results.meta.fields;
            const calculatedStats = {};
            
            columns.forEach(col => {
              const values = parsedData
                .map(row => row[col])
                .filter(val => val !== null && val !== undefined && !isNaN(val));
              
              if (values.length > 0) {
                const sortedValues = [...values].sort((a, b) => a - b);
                const min = sortedValues[0];
                const max = sortedValues[sortedValues.length - 1];
                const sum = values.reduce((a, b) => a + b, 0);
                const mean = sum / values.length;
                const median = sortedValues[Math.floor(values.length / 2)];
                
                // 计算标准差
                const squaredDiffs = values.map(val => Math.pow(val - mean, 2));
                const variance = squaredDiffs.reduce((a, b) => a + b, 0) / values.length;
                const stdDev = Math.sqrt(variance);
                
                // 计算四分位数
                const q1 = sortedValues[Math.floor(values.length * 0.25)];
                const q3 = sortedValues[Math.floor(values.length * 0.75)];
                
                // 计算异常值
                const iqr = q3 - q1;
                const lowerBound = q1 - 1.5 * iqr;
                const upperBound = q3 + 1.5 * iqr;
                const outliers = values.filter(val => val < lowerBound || val > upperBound);
                
                calculatedStats[col] = {
                  count: values.length,
                  missing: parsedData.length - values.length,
                  min,
                  max,
                  range: max - min,
                  mean,
                  median,
                  stdDev,
                  q1,
                  q3,
                  iqr,
                  outlierCount: outliers.length,
                  outlierPercentage: (outliers.length / values.length * 100).toFixed(2)
                };
              } else {
                calculatedStats[col] = {
                  count: 0,
                  missing: parsedData.length
                };
              }
            });
            
            setStats(calculatedStats);
            
            // 计算相关性
            const calculatedCorrelations = [];
            
            for (let i = 0; i < columns.length; i++) {
              for (let j = i + 1; j < columns.length; j++) {
                const col1 = columns[i];
                const col2 = columns[j];
                
                const validPairs = parsedData.filter(row => 
                  row[col1] !== null && row[col1] !== undefined && !isNaN(row[col1]) &&
                  row[col2] !== null && row[col2] !== undefined && !isNaN(row[col2])
                );
                
                if (validPairs.length > 10) {
                  const values1 = validPairs.map(row => row[col1]);
                  const values2 = validPairs.map(row => row[col2]);
                  
                  const mean1 = values1.reduce((sum, val) => sum + val, 0) / values1.length;
                  const mean2 = values2.reduce((sum, val) => sum + val, 0) / values2.length;
                  
                  let numerator = 0;
                  let denom1 = 0;
                  let denom2 = 0;
                  
                  for (let k = 0; k < validPairs.length; k++) {
                    numerator += (values1[k] - mean1) * (values2[k] - mean2);
                    denom1 += Math.pow(values1[k] - mean1, 2);
                    denom2 += Math.pow(values2[k] - mean2, 2);
                  }
                  
                  const correlation = numerator / (Math.sqrt(denom1) * Math.sqrt(denom2));
                  
                  if (!isNaN(correlation)) {
                    calculatedCorrelations.push({
                      pair: `${col1}-${col2}`,
                      col1,
                      col2,
                      correlation,
                      absCorrelation: Math.abs(correlation)
                    });
                  }
                }
              }
            }
            
            // 按绝对相关性值排序(由高到低)
            calculatedCorrelations.sort((a, b) => b.absCorrelation - a.absCorrelation);
            setCorrelations(calculatedCorrelations);
            
            setLoading(false);
          },
          error: (error) => {
            setError(`解析CSV出错: ${error}`);
            setLoading(false);
          }
        });
      } catch (error) {
        setError(`加载文件出错: ${error}`);
        setLoading(false);
      }
    };

    analyzeData();
  }, []);

  // 为某一列创建直方图数据
  const createHistogramData = (column) => {
    if (!stats[column] || stats[column].count === 0) return [];
    
    const values = data
      .map(row => row[column])
      .filter(val => val !== null && val !== undefined && !isNaN(val));
    
    const min = stats[column].min;
    const max = stats[column].max;
    const range = max - min;
    const binCount = 10;
    const binWidth = range / binCount;
    
    const bins = Array(binCount).fill(0);
    
    values.forEach(val => {
      const binIndex = Math.min(Math.floor((val - min) / binWidth), binCount - 1);
      bins[binIndex]++;
    });
    
    return bins.map((count, index) => ({
      bin: `${(min + index * binWidth).toFixed(1)}-${(min + (index + 1) * binWidth).toFixed(1)}`,
      count,
      binStart: min + index * binWidth
    }));
  };

  if (loading) return <div className="text-center p-8">加载数据中...</div>;
  if (error) return <div className="text-center p-8 text-red-600">错误: {error}</div>;

  const columns = Object.keys(stats).filter(col => col !== ""); // 排除空列名
  const numericColumns = columns.filter(col => stats[col].count > 0 && ['G', 'H', 'I', 'J', 'K', 'L', 'M'].includes(col));
  const binaryColumns = columns.filter(col => stats[col].count > 0 && ['A', 'B', 'C', 'D', 'E', 'F'].includes(col));

  return (
    <div className="p-4 max-w-6xl mx-auto">
      <h1 className="text-3xl font-bold text-center mb-6">Jet 数据分析</h1>
      
      <div className="mb-6">
        <div className="flex border-b">
          <button 
            className={`py-2 px-4 ${activeTab === 'summary' ? 'border-b-2 border-blue-500 font-medium' : ''}`}
            onClick={() => setActiveTab('summary')}
          >
            数据概要
          </button>
          <button 
            className={`py-2 px-4 ${activeTab === 'distributions' ? 'border-b-2 border-blue-500 font-medium' : ''}`}
            onClick={() => setActiveTab('distributions')}
          >
            分布分析
          </button>
          <button 
            className={`py-2 px-4 ${activeTab === 'correlations' ? 'border-b-2 border-blue-500 font-medium' : ''}`}
            onClick={() => setActiveTab('correlations')}
          >
            相关性分析
          </button>
          <button 
            className={`py-2 px-4 ${activeTab === 'preview' ? 'border-b-2 border-blue-500 font-medium' : ''}`}
            onClick={() => setActiveTab('preview')}
          >
            数据预览
          </button>
        </div>
      </div>
      
      {activeTab === 'summary' && (
        <div>
          <h2 className="text-xl font-semibold mb-4">描述性统计</h2>
          <p className="mb-4">数据集共有 {data.length} 条记录, {columns.length} 个变量。</p>
          
          <div className="mb-6">
            <h3 className="text-lg font-medium mb-2">二元变量 (A-F)</h3>
            <div className="overflow-x-auto">
              <table className="min-w-full border-collapse border border-gray-300">
                <thead>
                  <tr className="bg-gray-100">
                    <th className="border border-gray-300 px-3 py-2">变量</th>
                    <th className="border border-gray-300 px-3 py-2">计数</th>
                    <th className="border border-gray-300 px-3 py-2">缺失值</th>
                    <th className="border border-gray-300 px-3 py-2">最小值</th>
                    <th className="border border-gray-300 px-3 py-2">最大值</th>
                    <th className="border border-gray-300 px-3 py-2">均值</th>
                    <th className="border border-gray-300 px-3 py-2">标准差</th>
                  </tr>
                </thead>
                <tbody>
                  {binaryColumns.map(col => (
                    <tr key={col}>
                      <td className="border border-gray-300 px-3 py-2 font-medium">{col}</td>
                      <td className="border border-gray-300 px-3 py-2">{stats[col].count}</td>
                      <td className="border border-gray-300 px-3 py-2">{stats[col].missing || 0}</td>
                      <td className="border border-gray-300 px-3 py-2">{stats[col].min}</td>
                      <td className="border border-gray-300 px-3 py-2">{stats[col].max}</td>
                      <td className="border border-gray-300 px-3 py-2">{stats[col].mean.toFixed(4)}</td>
                      <td className="border border-gray-300 px-3 py-2">{stats[col].stdDev.toFixed(4)}</td>
                    </tr>
                  ))}
                </tbody>
              </table>
            </div>
          </div>
          
          <div className="mb-6">
            <h3 className="text-lg font-medium mb-2">连续变量 (G-M)</h3>
            <div className="overflow-x-auto">
              <table className="min-w-full border-collapse border border-gray-300">
                <thead>
                  <tr className="bg-gray-100">
                    <th className="border border-gray-300 px-3 py-2">变量</th>
                    <th className="border border-gray-300 px-3 py-2">计数</th>
                    <th className="border border-gray-300 px-3 py-2">缺失值</th>
                    <th className="border border-gray-300 px-3 py-2">最小值</th>
                    <th className="border border-gray-300 px-3 py-2">最大值</th>
                    <th className="border border-gray-300 px-3 py-2">均值</th>
                    <th className="border border-gray-300 px-3 py-2">中位数</th>
                    <th className="border border-gray-300 px-3 py-2">标准差</th>
                  </tr>
                </thead>
                <tbody>
                  {numericColumns.map(col => (
                    <tr key={col}>
                      <td className="border border-gray-300 px-3 py-2 font-medium">{col}</td>
                      <td className="border border-gray-300 px-3 py-2">{stats[col].count}</td>
                      <td className="border border-gray-300 px-3 py-2">{stats[col].missing || 0}</td>
                      <td className="border border-gray-300 px-3 py-2">{stats[col].min.toFixed(2)}</td>
                      <td className="border border-gray-300 px-3 py-2">{stats[col].max.toFixed(2)}</td>
                      <td className="border border-gray-300 px-3 py-2">{stats[col].mean.toFixed(2)}</td>
                      <td className="border border-gray-300 px-3 py-2">{stats[col].median.toFixed(2)}</td>
                      <td className="border border-gray-300 px-3 py-2">{stats[col].stdDev.toFixed(2)}</td>
                    </tr>
                  ))}
                </tbody>
              </table>
            </div>
          </div>
          
          <h2 className="text-xl font-semibold mb-4">异常值分析</h2>
          <div className="overflow-x-auto">
            <table className="min-w-full border-collapse border border-gray-300">
              <thead>
                <tr className="bg-gray-100">
                  <th className="border border-gray-300 px-3 py-2">变量</th>
                  <th className="border border-gray-300 px-3 py-2">Q1</th>
                  <th className="border border-gray-300 px-3 py-2">Q3</th>
                  <th className="border border-gray-300 px-3 py-2">IQR</th>
                  <th className="border border-gray-300 px-3 py-2">异常值下界</th>
                  <th className="border border-gray-300 px-3 py-2">异常值上界</th>
                  <th className="border border-gray-300 px-3 py-2">异常值数量</th>
                  <th className="border border-gray-300 px-3 py-2">异常值百分比</th>
                </tr>
              </thead>
              <tbody>
                {numericColumns.map(col => (
                  <tr key={col}>
                    <td className="border border-gray-300 px-3 py-2 font-medium">{col}</td>
                    <td className="border border-gray-300 px-3 py-2">{stats[col].q1.toFixed(2)}</td>
                    <td className="border border-gray-300 px-3 py-2">{stats[col].q3.toFixed(2)}</td>
                    <td className="border border-gray-300 px-3 py-2">{stats[col].iqr.toFixed(2)}</td>
                    <td className="border border-gray-300 px-3 py-2">{(stats[col].q1 - 1.5 * stats[col].iqr).toFixed(2)}</td>
                    <td className="border border-gray-300 px-3 py-2">{(stats[col].q3 + 1.5 * stats[col].iqr).toFixed(2)}</td>
                    <td className="border border-gray-300 px-3 py-2">{stats[col].outlierCount}</td>
                    <td className="border border-gray-300 px-3 py-2">{stats[col].outlierPercentage}%</td>
                  </tr>
                ))}
              </tbody>
            </table>
          </div>
        </div>
      )}
      
      {activeTab === 'distributions' && (
        <div>
          <h2 className="text-xl font-semibold mb-4">变量分布直方图</h2>
          <div className="grid grid-cols-1 md:grid-cols-2 gap-6">
            {numericColumns.map(col => {
              const histData = createHistogramData(col);
              return (
                <div key={col} className="border rounded p-4">
                  <h3 className="text-lg font-medium mb-2">变量 {col} 分布</h3>
                  <div className="mb-2">
                    <span className="text-sm">
                      范围: {stats[col].min.toFixed(2)} - {stats[col].max.toFixed(2)}, 
                      均值: {stats[col].mean.toFixed(2)}, 
                      中位数: {stats[col].median.toFixed(2)}
                    </span>
                  </div>
                  <ResponsiveContainer width="100%" height={250}>
                    <BarChart data={histData} margin={{ top: 10, right: 10, left: 10, bottom: 40 }}>
                      <CartesianGrid strokeDasharray="3 3" />
                      <XAxis 
                        dataKey="bin" 
                        angle={-45} 
                        textAnchor="end" 
                        height={70} 
                        interval={0}
                        scale="band"
                      />
                      <YAxis />
                      <Tooltip formatter={(value, name) => [value, '频数']} />
                      <Bar dataKey="count" name="频数" fill="#8884d8" />
                    </BarChart>
                  </ResponsiveContainer>
                </div>
              );
            })}
          </div>
          
          <h2 className="text-xl font-semibold mt-6 mb-4">二元变量分布</h2>
          <div className="grid grid-cols-2 md:grid-cols-3 gap-6">
            {binaryColumns.map(col => {
              const countZero = data.filter(row => row[col] === 0).length;
              const countOne = data.filter(row => row[col] === 1).length;
              const barData = [
                { name: '0', value: countZero },
                { name: '1', value: countOne }
              ];
              
              return (
                <div key={col} className="border rounded p-4">
                  <h3 className="text-lg font-medium mb-2">变量 {col}</h3>
                  <div className="mb-2">
                    <span className="text-sm">
                      0: {countZero} ({(countZero/data.length*100).toFixed(1)}%),
                      1: {countOne} ({(countOne/data.length*100).toFixed(1)}%)
                    </span>
                  </div>
                  <ResponsiveContainer width="100%" height={200}>
                    <BarChart data={barData} margin={{ top: 10, right: 10, left: 10, bottom: 10 }}>
                      <CartesianGrid strokeDasharray="3 3" />
                      <XAxis dataKey="name" />
                      <YAxis />
                      <Tooltip formatter={(value) => [`${value} (${(value/data.length*100).toFixed(1)}%)`, '频数']} />
                      <Bar dataKey="value" name="频数" fill="#82ca9d" />
                    </BarChart>
                  </ResponsiveContainer>
                </div>
              );
            })}
          </div>
        </div>
      )}
      
      {activeTab === 'correlations' && (
        <div>
          <h2 className="text-xl font-semibold mb-4">变量相关性分析</h2>
          
          <h3 className="text-lg font-medium mb-2">前10个最高相关性</h3>
          <div className="overflow-x-auto mb-6">
            <table className="min-w-full border-collapse border border-gray-300">
              <thead>
                <tr className="bg-gray-100">
                  <th className="border border-gray-300 px-4 py-2">变量对</th>
                  <th className="border border-gray-300 px-4 py-2">相关系数</th>
                  <th className="border border-gray-300 px-4 py-2">强度</th>
                </tr>
              </thead>
              <tbody>
                {correlations.slice(0, 10).map(({ pair, correlation }) => (
                  <tr key={pair}>
                    <td className="border border-gray-300 px-4 py-2">{pair}</td>
                    <td className="border border-gray-300 px-4 py-2">{correlation.toFixed(3)}</td>
                    <td className="border border-gray-300 px-4 py-2">
                      {Math.abs(correlation) >= 0.7 ? '强' : 
                       Math.abs(correlation) >= 0.3 ? '中' : '弱'} 
                      {correlation > 0 ? '正' : '负'}相关
                    </td>
                  </tr>
                ))}
              </tbody>
            </table>
          </div>
          
          <h3 className="text-lg font-medium mb-3">相关性散点图</h3>
          <div className="grid grid-cols-1 md:grid-cols-2 gap-6">
            {correlations.slice(0, 6).map(({ col1, col2, correlation }) => {
              const scatterData = data
                .filter(row => 
                  row[col1] !== null && row[col1] !== undefined && !isNaN(row[col1]) &&
                  row[col2] !== null && row[col2] !== undefined && !isNaN(row[col2])
                )
                .map(row => ({
                  x: row[col1],
                  y: row[col2]
                }));
              
              return (
                <div key={`${col1}-${col2}`} className="border rounded p-4">
                  <h4 className="font-medium mb-1">
                    {col1} vs {col2} (r = {correlation.toFixed(3)})
                  </h4>
                  <ResponsiveContainer width="100%" height={250}>
                    <ScatterChart margin={{ top: 10, right: 30, left: 20, bottom: 40 }}>
                      <CartesianGrid strokeDasharray="3 3" />
                      <XAxis 
                        type="number" 
                        dataKey="x" 
                        name={col1} 
                        label={{ value: col1, position: 'bottom', offset: 15 }}
                      />
                      <YAxis 
                        type="number" 
                        dataKey="y" 
                        name={col2} 
                        label={{ value: col2, angle: -90, position: 'left', offset: -5 }}
                      />
                      <Tooltip 
                        cursor={{ strokeDasharray: '3 3' }}
                        formatter={(value) => value.toFixed(2)}
                      />
                      <Scatter name={`${col1} vs ${col2}`} data={scatterData} fill="#8884d8" />
                    </ScatterChart>
                  </ResponsiveContainer>
                </div>
              );
            })}
          </div>
        </div>
      )}
      
      {activeTab === 'preview' && (
        <div>
          <h2 className="text-xl font-semibold mb-4">数据预览</h2>
          <div className="mb-2">
            <p>显示前10行数据,总共 {data.length} 行。</p>
          </div>
          <div className="overflow-x-auto">
            <table className="min-w-full border-collapse border border-gray-300">
              <thead>
                <tr className="bg-gray-100">
                  {columns.map(col => (
                    <th key={col} className="border border-gray-300 px-3 py-2">{col}</th>
                  ))}
                </tr>
              </thead>
              <tbody>
                {data.slice(0, 10).map((row, index) => (
                  <tr key={index} className={index % 2 === 0 ? 'bg-gray-50' : ''}>
                    {columns.map(col => (
                      <td key={col} className="border border-gray-300 px-3 py-2">
                        {row[col] !== null && row[col] !== undefined && !isNaN(row[col]) ? 
                         (typeof row[col] === 'number' ? row[col].toFixed(2) : row[col]) : 
                         ''}
                      </td>
                    ))}
                  </tr>
                ))}
              </tbody>
            </table>
          </div>
        </div>
      )}
    </div>
  );
};

export default JetDataAnalysis;

1.6 生成了最终简报

看这分析,有模有样了吧。

(2) 构建随机森林模型

按照标准流程,应该需要做特征工程筛选变量再建模,我这里就直接用所有特征建模了,主打一个测评。

咒语:现在我需要建立一个随机森林模型,因变量是第一列,自变量是其余列。其中训练集是80%,验证集是20%。需要寻找最优的模型参数,并使用最优模型进行拟合和预测,输出训练集和验证集的ROC曲线、混淆矩阵、灵敏度、特异度和AUC值。

然后它又开始思考了:

可以看到,它刚开始是做了参数的网格寻优的,但是由于算力不足吧,运行失败。因此,最后选择退而求其次,只测试了1-2中参数组合来运行,并输出报告:

然后呢,AUC曲线是一张空白的图,我让它在更新一下:

是不是很强了,最后看看它的文字报告如何:

最后,附上代码:

javascript 复制代码
import React from 'react';
import { LineChart, Line, BarChart, Bar, ScatterChart, Scatter, XAxis, YAxis, CartesianGrid, Tooltip, Legend, ResponsiveContainer, Cell } from 'recharts';

const RandomForestResults = () => {
  // 最优参数
  const bestParams = {
    nEstimators: 50,
    maxDepth: 5,
    minSamplesSplit: 2,
    maxFeatures: 6
  };
  
  // 训练集和验证集指标
  const trainMetrics = {
    confusionMatrix: {tp: 98, fp: 24, tn: 1313, fn: 106},
    accuracy: 0.916,
    sensitivity: 0.481,
    specificity: 0.982,
    auc: 0.850
  };
  
  const testMetrics = {
    confusionMatrix: {tp: 25, fp: 7, tn: 327, fn: 27},
    accuracy: 0.912,
    sensitivity: 0.481,
    specificity: 0.979,
    auc: 0.834
  };
  
  // ROC曲线数据
  const trainROC = [
    {fpr: 0, tpr: 0},
    {fpr: 0.018, tpr: 0.481},
    {fpr: 1, tpr: 1}
  ];
  
  const testROC = [
    {fpr: 0, tpr: 0},
    {fpr: 0.021, tpr: 0.481},
    {fpr: 1, tpr: 1}
  ];
  
  // 特征重要性
  const featureImportance = [
    {feature: "F", importance: 1.0},
    {feature: "E", importance: 0.85},
    {feature: "J", importance: 0.79},
    {feature: "H", importance: 0.68},
    {feature: "L", importance: 0.63},
    {feature: "C", importance: 0.57},
    {feature: "B", importance: 0.52},
    {feature: "K", importance: 0.48},
    {feature: "G", importance: 0.44},
    {feature: "D", importance: 0.40},
    {feature: "M", importance: 0.37},
    {feature: "I", importance: 0.32}
  ];
  
  // 颜色配置
  const colors = {
    train: "#8884d8",
    test: "#82ca9d"
  };
  
  // 创建混淆矩阵数据
  const createConfusionMatrixData = (cm) => {
    return [
      {name: "真阳性(TP)", value: cm.tp},
      {name: "假阳性(FP)", value: cm.fp},
      {name: "真阴性(TN)", value: cm.tn},
      {name: "假阴性(FN)", value: cm.fn}
    ];
  };
  
  const trainCMData = createConfusionMatrixData(trainMetrics.confusionMatrix);
  const testCMData = createConfusionMatrixData(testMetrics.confusionMatrix);
  
  // 创建主要评估指标数据
  const metricsData = [
    {name: "准确率", train: trainMetrics.accuracy, test: testMetrics.accuracy},
    {name: "敏感度", train: trainMetrics.sensitivity, test: testMetrics.sensitivity},
    {name: "特异度", train: trainMetrics.specificity, test: testMetrics.specificity},
    {name: "AUC", train: trainMetrics.auc, test: testMetrics.auc}
  ];
  
  return (
    <div className="p-6 max-w-6xl mx-auto">
      <h1 className="text-3xl font-bold text-center mb-8">随机森林模型评估结果</h1>
      
      {/* 最佳参数 */}
      <div className="mb-8">
        <h2 className="text-2xl font-semibold mb-4">最佳模型参数</h2>
        <div className="grid grid-cols-2 md:grid-cols-4 gap-4">
          <div className="bg-gray-100 p-4 rounded-lg">
            <div className="text-sm text-gray-600">决策树数量</div>
            <div className="text-xl font-bold">{bestParams.nEstimators}</div>
          </div>
          <div className="bg-gray-100 p-4 rounded-lg">
            <div className="text-sm text-gray-600">最大深度</div>
            <div className="text-xl font-bold">{bestParams.maxDepth}</div>
          </div>
          <div className="bg-gray-100 p-4 rounded-lg">
            <div className="text-sm text-gray-600">最小分裂样本数</div>
            <div className="text-xl font-bold">{bestParams.minSamplesSplit}</div>
          </div>
          <div className="bg-gray-100 p-4 rounded-lg">
            <div className="text-sm text-gray-600">最大特征数</div>
            <div className="text-xl font-bold">{bestParams.maxFeatures}</div>
          </div>
        </div>
      </div>
      
      {/* 评估指标 */}
      <div className="mb-8">
        <h2 className="text-2xl font-semibold mb-4">模型性能指标</h2>
        <div className="mb-6">
          <ResponsiveContainer width="100%" height={300}>
            <BarChart data={metricsData} barCategoryGap="30%">
              <CartesianGrid strokeDasharray="3 3" />
              <XAxis dataKey="name" />
              <YAxis domain={[0, 1]} />
              <Tooltip 
                formatter={(value) => [value.toFixed(4), ""]}
                labelFormatter={(value) => `${value}指标`}
              />
              <Legend />
              <Bar dataKey="train" name="训练集" fill={colors.train} />
              <Bar dataKey="test" name="验证集" fill={colors.test} />
            </BarChart>
          </ResponsiveContainer>
        </div>
      </div>
      
      {/* ROC曲线 */}
      <div className="mb-8">
        <h2 className="text-2xl font-semibold mb-4">ROC曲线</h2>
        <div className="mb-2">
          <div className="flex items-center">
            <div className="mr-4">
              <span className="font-semibold">训练集 AUC:</span> {trainMetrics.auc.toFixed(4)}
            </div>
            <div>
              <span className="font-semibold">验证集 AUC:</span> {testMetrics.auc.toFixed(4)}
            </div>
          </div>
        </div>
        <ResponsiveContainer width="100%" height={400}>
          <ScatterChart margin={{ top: 20, right: 30, bottom: 20, left: 20 }}>
            <CartesianGrid strokeDasharray="3 3" />
            <XAxis 
              type="number" 
              dataKey="fpr" 
              name="假阳性率" 
              domain={[0, 1]} 
              label={{ value: "假阳性率 (1-特异度)", position: "bottom", offset: 0 }} 
            />
            <YAxis 
              type="number" 
              dataKey="tpr" 
              name="真阳性率" 
              domain={[0, 1]} 
              label={{ value: "真阳性率 (敏感度)", angle: -90, position: "left", offset: -5 }} 
            />
            <Tooltip 
              formatter={(value) => value.toFixed(4)}
              labelFormatter={(value, payload) => {
                if (payload && payload.length > 0) {
                  const {fpr, tpr} = payload[0].payload;
                  return `FPR: ${fpr.toFixed(4)}, TPR: ${tpr.toFixed(4)}`;
                }
                return "";
              }}
            />
            <Legend />
            <Line 
              type="monotone" 
              dataKey="tpr" 
              data={trainROC} 
              name="训练集" 
              stroke={colors.train} 
              strokeWidth={2}
              dot={{ r: 5 }}
              isAnimationActive={false} 
            />
            <Line 
              type="monotone" 
              dataKey="tpr" 
              data={testROC} 
              name="验证集" 
              stroke={colors.test} 
              strokeWidth={2}
              dot={{ r: 5 }}
              isAnimationActive={false} 
            />
            {/* 对角线(随机猜测线) */}
            <Line 
              type="monotone" 
              dataKey="fpr" 
              data={[{fpr: 0, tpr: 0}, {fpr: 1, tpr: 1}]} 
              name="随机猜测" 
              stroke="#ff7300" 
              strokeDasharray="5 5"
              isAnimationActive={false}
              dot={false}
            />
          </ScatterChart>
        </ResponsiveContainer>
      </div>
      
      {/* 混淆矩阵 */}
      <div className="mb-8">
        <h2 className="text-2xl font-semibold mb-4">混淆矩阵</h2>
        <div className="grid grid-cols-1 md:grid-cols-2 gap-8">
          <div>
            <h3 className="text-lg font-medium mb-2 text-center">训练集</h3>
            <div className="overflow-hidden bg-gray-100 rounded-lg">
              <table className="min-w-full text-center">
                <thead>
                  <tr>
                    <td className="p-2 border"></td>
                    <th className="p-2 border font-semibold">预测: 阴性</th>
                    <th className="p-2 border font-semibold">预测: 阳性</th>
                  </tr>
                </thead>
                <tbody>
                  <tr>
                    <th className="p-2 border font-semibold">实际: 阴性</th>
                    <td className="p-2 border bg-green-100">{trainMetrics.confusionMatrix.tn}</td>
                    <td className="p-2 border bg-red-100">{trainMetrics.confusionMatrix.fp}</td>
                  </tr>
                  <tr>
                    <th className="p-2 border font-semibold">实际: 阳性</th>
                    <td className="p-2 border bg-red-100">{trainMetrics.confusionMatrix.fn}</td>
                    <td className="p-2 border bg-green-100">{trainMetrics.confusionMatrix.tp}</td>
                  </tr>
                </tbody>
              </table>
            </div>
          </div>
          <div>
            <h3 className="text-lg font-medium mb-2 text-center">验证集</h3>
            <div className="overflow-hidden bg-gray-100 rounded-lg">
              <table className="min-w-full text-center">
                <thead>
                  <tr>
                    <td className="p-2 border"></td>
                    <th className="p-2 border font-semibold">预测: 阴性</th>
                    <th className="p-2 border font-semibold">预测: 阳性</th>
                  </tr>
                </thead>
                <tbody>
                  <tr>
                    <th className="p-2 border font-semibold">实际: 阴性</th>
                    <td className="p-2 border bg-green-100">{testMetrics.confusionMatrix.tn}</td>
                    <td className="p-2 border bg-red-100">{testMetrics.confusionMatrix.fp}</td>
                  </tr>
                  <tr>
                    <th className="p-2 border font-semibold">实际: 阳性</th>
                    <td className="p-2 border bg-red-100">{testMetrics.confusionMatrix.fn}</td>
                    <td className="p-2 border bg-green-100">{testMetrics.confusionMatrix.tp}</td>
                  </tr>
                </tbody>
              </table>
            </div>
          </div>
        </div>
        
        <div className="mt-8 grid grid-cols-1 md:grid-cols-2 gap-8">
          <div>
            <h3 className="text-lg font-medium mb-2 text-center">训练集 - 混淆矩阵组成</h3>
            <ResponsiveContainer width="100%" height={250}>
              <BarChart data={trainCMData}>
                <CartesianGrid strokeDasharray="3 3" />
                <XAxis dataKey="name" />
                <YAxis />
                <Tooltip />
                <Bar dataKey="value" fill={colors.train}>
                  {trainCMData.map((entry, index) => (
                    <Cell 
                      key={`cell-${index}`} 
                      fill={index % 2 === 0 ? '#82ca9d' : '#ff6b6b'} 
                    />
                  ))}
                </Bar>
              </BarChart>
            </ResponsiveContainer>
          </div>
          <div>
            <h3 className="text-lg font-medium mb-2 text-center">验证集 - 混淆矩阵组成</h3>
            <ResponsiveContainer width="100%" height={250}>
              <BarChart data={testCMData}>
                <CartesianGrid strokeDasharray="3 3" />
                <XAxis dataKey="name" />
                <YAxis />
                <Tooltip />
                <Bar dataKey="value" fill={colors.test}>
                  {testCMData.map((entry, index) => (
                    <Cell 
                      key={`cell-${index}`} 
                      fill={index % 2 === 0 ? '#82ca9d' : '#ff6b6b'} 
                    />
                  ))}
                </Bar>
              </BarChart>
            </ResponsiveContainer>
          </div>
        </div>
      </div>
      
      {/* 特征重要性 */}
      <div className="mb-8">
        <h2 className="text-2xl font-semibold mb-4">特征重要性</h2>
        <ResponsiveContainer width="100%" height={400}>
          <BarChart 
            data={featureImportance}
            layout="vertical"
            margin={{ top: 20, right: 30, left: 50, bottom: 10 }}
          >
            <CartesianGrid strokeDasharray="3 3" />
            <XAxis type="number" domain={[0, 1]} />
            <YAxis 
              dataKey="feature" 
              type="category" 
              scale="band" 
              width={30}
            />
            <Tooltip formatter={(value) => value.toFixed(4)} />
            <Bar dataKey="importance" fill="#8884d8">
              {featureImportance.map((entry, index) => (
                <Cell 
                  key={`cell-${index}`} 
                  fill={`rgb(${100 + Math.round(entry.importance * 100)}, ${50 + Math.round(entry.importance * 100)}, ${200 - Math.round(entry.importance * 50)})`} 
                />
              ))}
            </Bar>
          </BarChart>
        </ResponsiveContainer>
      </div>
      
      {/* 结论 */}
      <div className="mb-8">
        <h2 className="text-2xl font-semibold mb-4">模型总结</h2>
        <div className="bg-gray-100 p-6 rounded-lg">
          <ul className="list-disc list-inside space-y-2">
            <li>模型准确率:训练集 <span className="font-semibold">{(trainMetrics.accuracy * 100).toFixed(2)}%</span>,验证集 <span className="font-semibold">{(testMetrics.accuracy * 100).toFixed(2)}%</span></li>
            <li>模型敏感度:训练集 <span className="font-semibold">{(trainMetrics.sensitivity * 100).toFixed(2)}%</span>,验证集 <span className="font-semibold">{(testMetrics.sensitivity * 100).toFixed(2)}%</span></li>
            <li>模型特异度:训练集 <span className="font-semibold">{(trainMetrics.specificity * 100).toFixed(2)}%</span>,验证集 <span className="font-semibold">{(testMetrics.specificity * 100).toFixed(2)}%</span></li>
            <li>AUC值:训练集 <span className="font-semibold">{trainMetrics.auc.toFixed(4)}</span>,验证集 <span className="font-semibold">{testMetrics.auc.toFixed(4)}</span></li>
            <li>最重要的3个特征:<span className="font-semibold">{featureImportance.slice(0, 3).map(f => f.feature).join('、')}</span></li>
            <li>训练集和验证集性能相近,说明模型没有明显的过拟合</li>
            <li>特异度较高但敏感度一般,表明模型在识别阴性样本方面表现较好,但对于阳性样本的识别能力有待提高</li>
          </ul>
        </div>
      </div>
    </div>
  );
};

export default RandomForestResults;

三、 写在最后

深度思考+运行代码能力,不敢想象以后的AI智能体效率能有多高。

下一次,试一试做回归预测。

相关推荐
atbigapp.com9 小时前
AI数据分析:deepseek生成SQL
人工智能·sql·ai·数据分析·提示词
jambinliang13 小时前
工业零件不良率、残次率的智能数据分析和数字化管理
大数据·python·sql·数据分析
Jason_Orton14 小时前
决策树(Decision Tree):机器学习中的经典算法
人工智能·算法·决策树·随机森林·机器学习
闲人编程14 小时前
Pandas使用教程 - Pandas 与 Excel 高级操作
数据分析·excel·pandas·dataframe·高级·工作表
邹霍梁@开源软件GoodERP17 小时前
【AI+智造】基于阿里云Ubuntu24.04系统,使用Ollama部署开源DeepSeek模型并集成到企业微信
人工智能·数据分析·制造
Jason-河山19 小时前
如何利用数据分析优化库存策略
数据挖掘·数据分析
Python数据分析与机器学习19 小时前
《基于大数据的相州镇新农村商务数据分析与研究》开题报告
大数据·人工智能·python·算法·数据挖掘·数据分析
SelectDB技术团队19 小时前
Apache Doris 3.0.4 版本正式发布
大数据·数据库·数据分析·doris·存算分离
dundunmm20 小时前
数据挖掘与数据分析
人工智能·数据挖掘·数据分析