一、划分方式速查表(核心选型)
| 划分器 | 适用场景 | 核心特点 | 推荐度 |
|---|---|---|---|
| KFold | 回归任务、类别均衡分类 | 均分K份,无分层,可打乱 | ⭐⭐⭐ |
| StratifiedKFold | 所有分类任务(含类别不平衡) | 分层保类别比例,标准K折交叉验证 | ⭐⭐⭐⭐⭐ |
| ShuffleSplit | 通用数据集,自定义训练/测试占比 | 随机抽样划分,无分层 | ⭐⭐⭐ |
| StratifiedShuffleSplit | 分类任务 + 自定义测试集比例 | 分层+随机打乱,单划分等价分层切集 | ⭐⭐⭐⭐⭐ |
| GroupKFold / LeaveOneGroupOut | 带分组数据(用户/样本/设备分组) | 同组数据不跨训练/测试集,防数据泄露 | ⭐⭐⭐⭐ |
| TimeSeriesSplit | 时序数据(股价、流量、时序预测) | 禁止打乱,严格按时间顺序划分 | ⭐⭐⭐⭐ |
二、场景快速判断口诀
- 做回归 → 用
KFold - 做分类 + 交叉验证(K折) → 首选
StratifiedKFold - 做分类 + 单次划分/自定义测试集比例 → 首选
StratifiedShuffleSplit - 数据有分组标识 (同一个体/设备多条样本)→ 用
GroupKFold - 时间序列数据 → 专用
TimeSeriesSplit
三、全套最简可运行代码模板
统一导入
python
import numpy as np
from sklearn.model_selection import (
KFold, StratifiedKFold,
ShuffleSplit, StratifiedShuffleSplit,
GroupKFold, TimeSeriesSplit
)
from sklearn.utils.validation import check_random_state
# 统一随机种子(sklearn 标准用法)
seed = 42
rng = check_random_state(seed)
1. KFold(回归/均衡分类)
python
kf = KFold(n_splits=5, shuffle=True, random_state=rng)
for train_idx, test_idx in kf.split(X):
X_train, X_test = X[train_idx], X[test_idx]
2. StratifiedKFold(分类K折,最常用)
python
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=rng)
for train_idx, test_idx in skf.split(X, y):
X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y[train_idx], y[test_idx]
3. ShuffleSplit(通用随机划分,无分层)
python
ss = ShuffleSplit(n_splits=5, test_size=0.2, random_state=rng)
for train_idx, test_idx in ss.split(X):
X_train, X_test = X[train_idx], X[test_idx]
4. StratifiedShuffleSplit(分类分层随机划分)
python
# n_splits=1 单次划分,等价分层train_test_split
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=rng)
for train_idx, test_idx in sss.split(X, y):
X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y[train_idx], y[test_idx]
5. GroupKFold(分组数据)
python
# groups 为每组对应的标签数组
gkf = GroupKFold(n_splits=5)
for train_idx, test_idx in gkf.split(X, y, groups=groups):
X_train, X_test = X[train_idx], X[test_idx]
6. TimeSeriesSplit(时序数据,不可shuffle)
python
tscv = TimeSeriesSplit(n_splits=5)
for train_idx, test_idx in tscv.split(X):
X_train, X_test = X[train_idx], X[test_idx]
四、补充关键提醒
- 分层类划分器(
Stratified*)必须传入标签y,否则失效; - 时序划分不能开启打乱,会破坏时间逻辑;
- 需结果可复现时,统一用
check_random_state管理随机种子,适配所有 sklearn 组件; - 日常单次划分训练/测试集:优先
StratifiedShuffleSplit(n_splits=1)或train_test_split(stratify=y)。