SciKit-Learn 全面分析分类任务 breast_cancer 数据集

背景

乳腺癌数据集,569个样本,30个特征,2个类别(良性/恶性)

步骤

  1. 加载数据集
  2. 拆分训练集、测试集
  3. 数据预处理(标准化)
  4. 选择模型
  5. 模型训练(拟合)
  6. 测试模型效果
  7. 评估模型

分析方法

对数据集使用 7 种分类方法进行分析

  • K 近邻(K-NN)
  • 决策树
  • 支持向量机(SVM)
  • 逻辑回归
  • 随机森林
  • 朴素贝叶斯
  • 多层感知机(MLP)

模型分析结果

复制代码
--- 正在训练 K近邻 (K-NN) 模型 ---
K近邻 (K-NN) 模型的准确率: 0.9591
K近邻 (K-NN) 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.95      0.94      0.94        63
      benign       0.96      0.97      0.97       108

    accuracy                           0.96       171
   macro avg       0.96      0.95      0.96       171
weighted avg       0.96      0.96      0.96       171


--- 正在训练 决策树 模型 ---
决策树 模型的准确率: 0.9415
决策树 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.90      0.95      0.92        63
      benign       0.97      0.94      0.95       108

    accuracy                           0.94       171
   macro avg       0.93      0.94      0.94       171
weighted avg       0.94      0.94      0.94       171


--- 正在训练 支持向量机 (SVM) 模型 ---
支持向量机 (SVM) 模型的准确率: 0.9766
支持向量机 (SVM) 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.97      0.97      0.97        63
      benign       0.98      0.98      0.98       108

    accuracy                           0.98       171
   macro avg       0.97      0.97      0.97       171
weighted avg       0.98      0.98      0.98       171


--- 正在训练 逻辑回归 模型 ---
逻辑回归 模型的准确率: 0.9825
逻辑回归 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.97      0.98      0.98        63
      benign       0.99      0.98      0.99       108

    accuracy                           0.98       171
   macro avg       0.98      0.98      0.98       171
weighted avg       0.98      0.98      0.98       171


--- 正在训练 随机森林 模型 ---
随机森林 模型的准确率: 0.9708
随机森林 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.98      0.94      0.96        63
      benign       0.96      0.99      0.98       108

    accuracy                           0.97       171
   macro avg       0.97      0.96      0.97       171
weighted avg       0.97      0.97      0.97       171


--- 正在训练 朴素贝叶斯 模型 ---
朴素贝叶斯 模型的准确率: 0.9357
朴素贝叶斯 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.92      0.90      0.91        63
      benign       0.94      0.95      0.95       108

    accuracy                           0.94       171
   macro avg       0.93      0.93      0.93       171
weighted avg       0.94      0.94      0.94       171


--- 正在训练 多层感知器 (MLP) 模型 ---
多层感知器 (MLP) 模型的准确率: 0.9825
多层感知器 (MLP) 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.98      0.97      0.98        63
      benign       0.98      0.99      0.99       108

    accuracy                           0.98       171
   macro avg       0.98      0.98      0.98       171
weighted avg       0.98      0.98      0.98       171

代码

python 复制代码
from sklearn.datasets import load_breast_cancer

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neural_network import MLPClassifier

from sklearn.metrics import accuracy_score, classification_report, roc_curve, auc

import matplotlib.pyplot as plt
import numpy as np

# 设置 Matplotlib 字体以正确显示中文
plt.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Zen Hei', 'STHeiti', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False  

def perform_breast_cancer_analysis():
    """
    使用 scikit-learn 对乳腺癌数据集进行全面的分析。
    该函数包含数据加载、预处理、模型训练、评估和 ROC/AUC 可视化。
    """
    print("--- 正在加载乳腺癌数据集 ---")
    # 加载乳腺癌数据集
    cancer = load_breast_cancer()
    
    # 获取数据特征和目标标签
    X = cancer.data
    y = cancer.target
    target_names = cancer.target_names

    print("\n--- 数据集概览 ---")
    print(f"数据形状: {X.shape}")
    print(f"目标名称: {target_names}")

    # 将数据集划分为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

    print("\n--- 数据划分结果 ---")
    print(f"训练集形状: {X_train.shape}")
    print(f"测试集形状: {X_test.shape}")
    
    # 数据标准化
    print("\n--- 正在对数据进行标准化处理 ---")
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # 定义并训练多个分类器模型
    models = {
        "K近邻 (K-NN)": KNeighborsClassifier(n_neighbors=5),
        "决策树": DecisionTreeClassifier(random_state=42),
        "支持向量机 (SVM)": SVC(kernel='rbf', C=1.0, random_state=42, probability=True),
        "逻辑回归": LogisticRegression(random_state=42, max_iter=10000),
        "随机森林": RandomForestClassifier(random_state=42),
        "朴素贝叶斯": GaussianNB(),
        "多层感知器 (MLP)": MLPClassifier(random_state=42, max_iter=10000)
    }

    print("\n--- 模型训练与评估 ---")
    for name, model in models.items():
        print(f"\n--- 正在训练 {name} 模型 ---")
        model.fit(X_train_scaled, y_train)
        y_pred = model.predict(X_test_scaled)
        accuracy = accuracy_score(y_test, y_pred)
        report = classification_report(y_test, y_pred, target_names=target_names)
        print(f"{name} 模型的准确率: {accuracy:.4f}")
        print(f"{name} 模型的分类报告:\n{report}")

    print("\n--- ROC 曲线和 AUC 对比 ---")
    num_models = len(models)
    cols = 3
    rows = (num_models + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(18, 6 * rows))
    axes = axes.flatten()

    for i, (name, model) in enumerate(models.items()):
        ax = axes[i]
        
        # 获取预测概率
        if hasattr(model, "predict_proba"):
            y_score = model.predict_proba(X_test_scaled)[:, 1]
        else:
            y_score = model.decision_function(X_test_scaled)
        
        # 计算 ROC 曲线和 AUC
        fpr, tpr, _ = roc_curve(y_test, y_score)
        roc_auc = auc(fpr, tpr)
        
        # 绘制 ROC 曲线并填充
        ax.plot(fpr, tpr, label=f'AUC = {roc_auc:.2f}', alpha=0.7)
        ax.fill_between(fpr, tpr, alpha=0.1)
        
        # 绘制对角线 (随机猜测)
        ax.plot([0, 1], [0, 1], 'k--', lw=2)
        
        # 设置图表属性
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel('假正率 (FPR)')
        ax.set_ylabel('真正率 (TPR)')
        ax.set_title(f'{name} - ROC 曲线')
        ax.legend(loc="lower right", fontsize='small')
        ax.grid(True)
    
    for j in range(num_models, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    perform_breast_cancer_analysis()
相关推荐
财富自由且长命百岁1 天前
移动端老兵转型端侧 AI:第一周,我跑通了 ResNet50 推理
机器学习
Csvn1 天前
🌟 LangChain 30 天保姆级教程 · Day 13|OutputParser 进阶!让 AI 输出自动转为结构化对象,并支持自动重试!
python·langchain
cch89181 天前
Python主流框架全解析
开发语言·python
sg_knight1 天前
设计模式实战:状态模式(State)
python·ui·设计模式·状态模式·state
好运的阿财1 天前
process 工具与子agent管理机制详解
网络·人工智能·python·程序人生·ai编程
张張4081 天前
(域格)环境搭建和编译
c语言·开发语言·python·ai
weixin_423533991 天前
【Windows11离线安装anaconda、python、vscode】
开发语言·vscode·python
Ricky111zzz1 天前
leetcode学python记录1
python·算法·leetcode·职场和发展
小白学大数据1 天前
Selenium+Python 爬虫:动态加载头条问答爬取
爬虫·python·selenium
Hui Baby1 天前
springboot读取配置文件
后端·python·flask