机器学习数据预处理:数据拆分(超通俗完整版)
数据拆分是把数据集分成训练集、验证集、测试集 ,让模型"学、调、考"分开,是评估模型真实能力的必做步骤,本科/研究生入门必看、面试常考。
一、什么是数据拆分?为什么必须拆?
1. 一句话理解
把全部数据分成互不重叠 的几部分,分别用来训练模型、调参、最终打分,防止模型"作弊、死记硬背"。
2. 不拆分的后果
- 模型在训练集上考100分(过拟合)
- 遇到新数据一塌糊涂(泛化能力差)
- 实验结果不可信,论文/项目不被认可
3. 三大集合职责
- 训练集 Train:教模型知识(学)
- 验证集 Val:调参数、选模型(练)
- 测试集 Test:最终考试,不许偷看(考)
二、最常用拆分比例
- 简单场景:训练 : 测试 = 8 : 2
- 标准场景:训练 : 验证 : 测试 = 6 : 2 : 2
- 小样本:用 K折交叉验证
三、3种核心拆分方法(通俗+原理)
1. Hold-Out 随机拆分(最简单)
直接按比例随机分。
- 优点:快、简单
- 缺点:运气差时分布不均匀
2. 分层拆分 Stratified Split(最常用!)
按类别比例拆分,保证每类数据分布一致。
- 优点:类别不平衡数据必用
- 缺点:仅用于分类任务
3. K折交叉验证 K-Fold CV(最稳)
把数据分成 K 份,轮流当验证集。
- 优点:结果稳定、充分利用数据
- 缺点:速度慢
四、核心概念(论文/面试)
- 经验误差:模型在训练集上的误差
- 泛化误差:模型在新数据上的真实误差
- 数据独立同分布:训练/测试集要来自同一分布
- 信息泄露:严禁用测试集信息参与训练
五、完整可运行代码(乳腺癌数据集)
包含:随机拆分、分层拆分、K折交叉验证、PCA可视化、模型打分。
python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import (
train_test_split,
StratifiedKFold
)
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.decomposition import PCA
# ======================
# 1. 加载数据
# ======================
data = load_breast_cancer()
X = data.data
y = data.target
print("数据形状:", X.shape)
# ======================
# 2. 随机拆分 6:2:2
# ======================
X_temp, X_test, y_temp, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
X_train, X_val, y_train, y_val = train_test_split(
X_temp, y_temp, test_size=0.25, random_state=42
)
print(f"训练集:{X_train.shape}")
print(f"验证集:{X_val.shape}")
print(f"测试集:{X_test.shape}")
# ======================
# 3. PCA 可视化分布
# ======================
pca = PCA(n_components=2, random_state=42)
X_pca = pca.fit_transform(X)
plt.figure(figsize=(12,5))
plt.subplot(121)
plt.scatter(X_pca[y==0,0], X_pca[y==0,1], label='恶', alpha=0.7)
plt.scatter(X_pca[y==1,0], X_pca[y==1,1], label='良', alpha=0.7)
plt.title('全体数据分布')
plt.legend()
# 训练/验证/测试
X_train_pca = pca.transform(X_train)
X_val_pca = pca.transform(X_val)
X_test_pca = pca.transform(X_test)
plt.subplot(122)
plt.scatter(X_train_pca[:,0], X_train_pca[:,1], label='Train', alpha=0.7)
plt.scatter(X_val_pca[:,0], X_val_pca[:,1], label='Val', alpha=0.7)
plt.scatter(X_test_pca[:,0], X_test_pca[:,1], label='Test', alpha=0.7)
plt.title('Train/Val/Test 拆分')
plt.legend()
plt.show()
# ======================
# 4. 训练与评估
# ======================
model = LogisticRegression(max_iter=10000, random_state=42)
model.fit(X_train, y_train)
acc_val = accuracy_score(y_val, model.predict(X_val))
acc_test = accuracy_score(y_test, model.predict(X_test))
print(f"验证集准确率:{acc_val:.4f}")
print(f"测试集准确率:{acc_test:.4f}")
# ======================
# 5. 分层拆分(推荐)
# ======================
X_temp_s, X_test_s, y_temp_s, y_test_s = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=42
)
X_train_s, X_val_s, y_train_s, y_val_s = train_test_split(
X_temp_s, y_temp_s, test_size=0.25, stratify=y_temp_s, random_state=42
)
# 训练
model_s = LogisticRegression(max_iter=10000, random_state=42)
model_s.fit(X_train_s, y_train_s)
acc_val_s = accuracy_score(y_val_s, model_s.predict(X_val_s))
acc_test_s = accuracy_score(y_test_s, model_s.predict(X_test_s))
print("\n【分层拆分】")
print(f"验证集:{acc_val_s:.4f} 测试集:{acc_test_s:.4f}")
# ======================
# 6. K折交叉验证
# ======================
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
acc_list = []
for train_idx, val_idx in skf.split(X, y):
X_tr, X_vl = X[train_idx], X[val_idx]
y_tr, y_vl = y[train_idx], y[val_idx]
clf = LogisticRegression(max_iter=10000)
clf.fit(X_tr, y_tr)
acc_list.append(accuracy_score(y_vl, clf.predict(X_vl)))
print("\n【5折交叉验证】")
print(f"平均准确率:{np.mean(acc_list):.4f}")
print(f"标准差:{np.std(acc_list):.4f}")
plt.boxplot(acc_list)
plt.title('5折交叉验证准确率')
plt.show()
六、各种拆分方法对比(一张表看懂)
| 方法 | 保持类别比例 | 稳定性 | 速度 | 适用场景 |
|---|---|---|---|---|
| 随机拆分 | ❌ | 中 | 最快 | 大数据量 |
| 分层拆分 | ✅ | 中 | 快 | 分类/不平衡数据 |
| K折CV | ❌ | 高 | 慢 | 小样本 |
| 分层K折CV | ✅ | 最高 | 较慢 | 小样本+不平衡 |
| 时间序列拆分 | ❌ | 高 | 中 | 时序数据 |
七、使用建议(工业界标准)
- 分类任务一律用分层拆分
- 数据量小 → 用分层K折
- 数据量大 → 用 8:2 或 6:2:2
- 时序数据不能随机拆,必须按时间拆
- 测试集只用一次,严禁反复调参
八、总结(面试速背)
- 数据拆分 = 训练+验证+测试,防止过拟合
- 分类优先分层拆分
- 小样本用K折,大数据直接分
- 测试集是最终标准,绝不参与训练
- 好的拆分比复杂模型更重要