SciKit-Learn 全面分析分类任务 breast_cancer 数据集

背景

乳腺癌数据集,569个样本,30个特征,2个类别(良性/恶性)

步骤

  1. 加载数据集
  2. 拆分训练集、测试集
  3. 数据预处理(标准化)
  4. 选择模型
  5. 模型训练(拟合)
  6. 测试模型效果
  7. 评估模型

分析方法

对数据集使用 7 种分类方法进行分析

  • K 近邻(K-NN)
  • 决策树
  • 支持向量机(SVM)
  • 逻辑回归
  • 随机森林
  • 朴素贝叶斯
  • 多层感知机(MLP)

模型分析结果

复制代码
--- 正在训练 K近邻 (K-NN) 模型 ---
K近邻 (K-NN) 模型的准确率: 0.9591
K近邻 (K-NN) 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.95      0.94      0.94        63
      benign       0.96      0.97      0.97       108

    accuracy                           0.96       171
   macro avg       0.96      0.95      0.96       171
weighted avg       0.96      0.96      0.96       171


--- 正在训练 决策树 模型 ---
决策树 模型的准确率: 0.9415
决策树 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.90      0.95      0.92        63
      benign       0.97      0.94      0.95       108

    accuracy                           0.94       171
   macro avg       0.93      0.94      0.94       171
weighted avg       0.94      0.94      0.94       171


--- 正在训练 支持向量机 (SVM) 模型 ---
支持向量机 (SVM) 模型的准确率: 0.9766
支持向量机 (SVM) 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.97      0.97      0.97        63
      benign       0.98      0.98      0.98       108

    accuracy                           0.98       171
   macro avg       0.97      0.97      0.97       171
weighted avg       0.98      0.98      0.98       171


--- 正在训练 逻辑回归 模型 ---
逻辑回归 模型的准确率: 0.9825
逻辑回归 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.97      0.98      0.98        63
      benign       0.99      0.98      0.99       108

    accuracy                           0.98       171
   macro avg       0.98      0.98      0.98       171
weighted avg       0.98      0.98      0.98       171


--- 正在训练 随机森林 模型 ---
随机森林 模型的准确率: 0.9708
随机森林 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.98      0.94      0.96        63
      benign       0.96      0.99      0.98       108

    accuracy                           0.97       171
   macro avg       0.97      0.96      0.97       171
weighted avg       0.97      0.97      0.97       171


--- 正在训练 朴素贝叶斯 模型 ---
朴素贝叶斯 模型的准确率: 0.9357
朴素贝叶斯 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.92      0.90      0.91        63
      benign       0.94      0.95      0.95       108

    accuracy                           0.94       171
   macro avg       0.93      0.93      0.93       171
weighted avg       0.94      0.94      0.94       171


--- 正在训练 多层感知器 (MLP) 模型 ---
多层感知器 (MLP) 模型的准确率: 0.9825
多层感知器 (MLP) 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.98      0.97      0.98        63
      benign       0.98      0.99      0.99       108

    accuracy                           0.98       171
   macro avg       0.98      0.98      0.98       171
weighted avg       0.98      0.98      0.98       171

代码

python 复制代码
from sklearn.datasets import load_breast_cancer

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neural_network import MLPClassifier

from sklearn.metrics import accuracy_score, classification_report, roc_curve, auc

import matplotlib.pyplot as plt
import numpy as np

# 设置 Matplotlib 字体以正确显示中文
plt.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Zen Hei', 'STHeiti', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False  

def perform_breast_cancer_analysis():
    """
    使用 scikit-learn 对乳腺癌数据集进行全面的分析。
    该函数包含数据加载、预处理、模型训练、评估和 ROC/AUC 可视化。
    """
    print("--- 正在加载乳腺癌数据集 ---")
    # 加载乳腺癌数据集
    cancer = load_breast_cancer()
    
    # 获取数据特征和目标标签
    X = cancer.data
    y = cancer.target
    target_names = cancer.target_names

    print("\n--- 数据集概览 ---")
    print(f"数据形状: {X.shape}")
    print(f"目标名称: {target_names}")

    # 将数据集划分为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

    print("\n--- 数据划分结果 ---")
    print(f"训练集形状: {X_train.shape}")
    print(f"测试集形状: {X_test.shape}")
    
    # 数据标准化
    print("\n--- 正在对数据进行标准化处理 ---")
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # 定义并训练多个分类器模型
    models = {
        "K近邻 (K-NN)": KNeighborsClassifier(n_neighbors=5),
        "决策树": DecisionTreeClassifier(random_state=42),
        "支持向量机 (SVM)": SVC(kernel='rbf', C=1.0, random_state=42, probability=True),
        "逻辑回归": LogisticRegression(random_state=42, max_iter=10000),
        "随机森林": RandomForestClassifier(random_state=42),
        "朴素贝叶斯": GaussianNB(),
        "多层感知器 (MLP)": MLPClassifier(random_state=42, max_iter=10000)
    }

    print("\n--- 模型训练与评估 ---")
    for name, model in models.items():
        print(f"\n--- 正在训练 {name} 模型 ---")
        model.fit(X_train_scaled, y_train)
        y_pred = model.predict(X_test_scaled)
        accuracy = accuracy_score(y_test, y_pred)
        report = classification_report(y_test, y_pred, target_names=target_names)
        print(f"{name} 模型的准确率: {accuracy:.4f}")
        print(f"{name} 模型的分类报告:\n{report}")

    print("\n--- ROC 曲线和 AUC 对比 ---")
    num_models = len(models)
    cols = 3
    rows = (num_models + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(18, 6 * rows))
    axes = axes.flatten()

    for i, (name, model) in enumerate(models.items()):
        ax = axes[i]
        
        # 获取预测概率
        if hasattr(model, "predict_proba"):
            y_score = model.predict_proba(X_test_scaled)[:, 1]
        else:
            y_score = model.decision_function(X_test_scaled)
        
        # 计算 ROC 曲线和 AUC
        fpr, tpr, _ = roc_curve(y_test, y_score)
        roc_auc = auc(fpr, tpr)
        
        # 绘制 ROC 曲线并填充
        ax.plot(fpr, tpr, label=f'AUC = {roc_auc:.2f}', alpha=0.7)
        ax.fill_between(fpr, tpr, alpha=0.1)
        
        # 绘制对角线 (随机猜测)
        ax.plot([0, 1], [0, 1], 'k--', lw=2)
        
        # 设置图表属性
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel('假正率 (FPR)')
        ax.set_ylabel('真正率 (TPR)')
        ax.set_title(f'{name} - ROC 曲线')
        ax.legend(loc="lower right", fontsize='small')
        ax.grid(True)
    
    for j in range(num_models, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    perform_breast_cancer_analysis()
相关推荐
hrhcode18 小时前
【LangChain】一.LangChain v1.0-快速上手(核心组件、工具、中间件)
python·ai·langchain·agent
硅谷秋水18 小时前
《自动驾驶系统开发》英文版《Autonomous Driving Hanbook》推荐
人工智能·深度学习·机器学习·计算机视觉·语言模型·自动驾驶
SunnyDays101118 小时前
Python Word 转 Excel 详解(含整个文档、特定页面或表格转换)
python·word 转 excel·docx 转 xlsx·word 表格导出 excel
m0_7411733318 小时前
CSS移动端实现卡片悬浮投影_利用box-shadow设置层次感
jvm·数据库·python
西洼工作室18 小时前
uniapp+vue3+python对接阿里云短信认证服务alibabacloud_dypnsapi20170525
python·阿里云·uni-app
chushiyunen18 小时前
pygame实现射击游戏
python·游戏·pygame
sinat_3834373618 小时前
如何在 Laravel 中筛选并格式化匹配预定义列表的产品数据
jvm·数据库·python
2401_8463395618 小时前
mysql如何用执行流程思维写好SQL_SQL优化方法总结
jvm·数据库·python
forEverPlume18 小时前
SQL如何统计分组内不重复值的数量_COUNT与DISTINCT结合应用
jvm·数据库·python
啦啦啦_999919 小时前
案例之 逻辑回归_癌症预测
算法·机器学习·逻辑回归