机器学习实战:6种数据集划分方法详解与代码实现

在机器学习项目中,合理划分数据集是模型开发的关键第一步。本文将全面介绍6种常见数据格式的划分方法,并附完整Python代码示例,帮助初学者掌握这一核心技能。

一、数据集划分基础函数

1. 核心函数:train_test_split

python 复制代码
from sklearn.model_selection import train_test_split

# 基本用法
X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.3,      # 测试集比例
    random_state=42,    # 随机种子
    stratify=y          # 保持类别比例(分类问题)
)

补充说明:

stratify=y参数的作用是保持划分后数据集的类别比例与原始数据集一致。换句话说:

  • 如果原始数据中类别A占30%,类别B占70%

  • 那么划分后的训练集和测试集中,类别A和B的比例也会保持30%和70%

二、6种数据格式划分实战

1. 列表(List)数据集划分

python 复制代码
data_list = [i for i in range(100)]  # 创建0-99的列表
labels = [0 if x < 70 else 1 for x in range(100)]  # 前70个为0,后30个为1

# 划分示例
train_data, test_data, train_labels, test_labels = train_test_split(
    data_list, labels, test_size=0.2, stratify=labels)

print(f"训练集大小: {len(train_data)}, 测试集大小: {len(test_data)}")

2. NumPy数组(ndarray)划分

python 复制代码
import numpy as np

# 创建100x5的随机数组
X = np.random.rand(100, 5)  
y = np.random.randint(0, 2, size=100)  # 二分类标签

# 划分示例
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=42)

3. 二维数组划分(特征+标签)

train_test_split只划分第一维度,第二维度保持不变

python 复制代码
from sklearn.model_selection import train_test_split
import numpy as np
data1 = np.arange(1, 16, 1)
data1.shape=(5,3)
print(data1)
a, b = train_test_split(data1,  test_size=0.4, random_state=42)
print("a=\n", a)
print("b=\n", b)

\[ 1 2 3

4 5 6

7 8 9

10 11 12

13 14 15\]

a=

\[10 11 12

1 2 3

13 14 15\]

b=

\[4 5 6

7 8 9\]

4. Pandas DataFrame划分

可以划分DataFrame, 划分后的两部分还是DataFrame

python 复制代码
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
data1 = np.arange(1, 16, 1)
data1.shape=(5,3)
data1 = pd.DataFrame(data1, index=[1,2,3,4,5], columns=["one","two","three"])
print(data1)

a, b = train_test_split(data1,  test_size=0.4, random_state=42)
print("\n", a)
print("\n", b)

one two three

1 1 2 3

2 4 5 6

3 7 8 9

4 10 11 12

5 13 14 15

one two three

4 10 11 12

1 1 2 3

5 13 14 15

one two three

2 4 5 6

3 7 8 9

5. 字典(Dict)格式数据划分

python 复制代码
data_dict = {
    'features': np.random.rand(100, 4),
    'labels': np.random.choice([0, 1], size=100),
    'ids': [f'ID_{i}' for i in range(100)]
}

# 转换为数组再划分
features = data_dict['features']
labels = data_dict['labels']

# 划分并保留索引
idx = np.arange(len(labels))
train_idx, test_idx = train_test_split(idx, test_size=0.15)

train_data = {k: v[train_idx] for k, v in data_dict.items()}
test_data = {k: v[test_idx] for k, v in data_dict.items()}

6. 经典鸢尾花数据集划分

python 复制代码
from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data
y = iris.target

# 专业划分方案:先分训练+临时集,再分验证+测试
X_train, X_temp, y_train, y_temp = train_test_split(
    X, y, test_size=0.4, stratify=y)  # 60%训练,40%临时

X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, stratify=y_temp)  # 各20%

print(f"训练集: {X_train.shape[0]} 验证集: {X_val.shape[0]} 测试集: {X_test.shape[0]}")

三、现实世界数据集划分技巧

1. 非平衡数据处理

python 复制代码
from sklearn.utils import resample

# 假设我们有不平衡数据
majority = df[df['target'] == 0]
minority = df[df['target'] == 1]

# 上采样少数类
minority_upsampled = resample(minority,
                             replace=True,     # 允许重复采样
                             n_samples=len(majority),  # 目标数量
                             random_state=42)

# 合并后划分
balanced_df = pd.concat([majority, minority_upsampled])

2. 时间序列数据划分

python 复制代码
# 按时间顺序划分(不能随机)
time_series = pd.read_csv('sales_data.csv', parse_dates=['date'])

split_point = int(len(time_series) * 0.8)
train = time_series.iloc[:split_point]
test = time_series.iloc[split_point:]

3. 交叉验证进阶用法

python 复制代码
from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

for train_index, test_index in skf.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    # 训练和评估...

四、作业与实践

实战作业:

  1. 使用UCI葡萄酒数据集完成以下划分:

    • 按7:2:1划分训练/验证/测试集
python 复制代码
from sklearn.datasets import load_wine
wine = load_wine()
# 你的代码...

eg:

python 复制代码
import numpy as np
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split

# 1. 加载葡萄酒数据集
wine = load_wine()
X = wine.data  # 特征数据
y = wine.target  # 标签数据
target_names = wine.target_names  # 类别名称

# 打印原始数据集信息
print("原始数据集信息:")
print(f"样本数量: {len(X)}")
print(f"特征数量: {X.shape[1]}")
print("类别分布:")
for i, name in enumerate(target_names):
    print(f"{name}: {sum(y == i)}个样本")

# 2. 按7:2:1比例划分数据集
# 首先分出训练集(70%)和临时集(30%)
X_train, X_temp, y_train, y_temp = train_test_split(
    X, y, 
    test_size=0.3,  # 临时集占30%
    random_state=42,  # 随机种子保证可重复性
    stratify=y  # 保持类别比例
)

# 然后将临时集按2:1比例分成验证集和测试集
# 验证集占临时集的2/3(即整体的20%),测试集占1/3(即整体的10%)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp,
    test_size=1/3,  # 测试集占临时集的1/3
    random_state=42,
    stratify=y_temp  # 继续保持类别比例
)

# 3. 打印划分结果
print("\n划分结果:")
print(f"训练集: {len(X_train)}个样本 ({len(X_train)/len(X):.1%})")
print(f"验证集: {len(X_val)}个样本 ({len(X_val)/len(X):.1%})")
print(f"测试集: {len(X_test)}个样本 ({len(X_test)/len(X):.1%})")

print("\n训练集类别分布:")
for i, name in enumerate(target_names):
    print(f"{name}: {sum(y_train == i)}个样本")

print("\n验证集类别分布:")
for i, name in enumerate(target_names):
    print(f"{name}: {sum(y_val == i)}个样本")

print("\n测试集类别分布:")
for i, name in enumerate(target_names):
    print(f"{name}: {sum(y_test == i)}个样本")

# 4. 验证比例是否正确
print("\n验证比例:")
total = len(X_train) + len(X_val) + len(X_test)
print(f"训练集比例: {len(X_train)/total:.2f} (应接近0.7)")
print(f"验证集比例: {len(X_val)/total:.2f} (应接近0.2)")
print(f"测试集比例: {len(X_test)/total:.2f} (应接近0.1)")

思考题:

  1. 当特征数据和标签存储在不同结构中时,如何确保划分后的一致性?

  2. 对于多输出问题(多个y值),应该如何调整划分方法?

  3. 如果数据集包含分组信息(如来自同一患者的多个样本),应该使用什么特殊划分方法?

五、总结

本文详细介绍了:

  • 6种数据格式的划分方法(列表、数组、DataFrame等)

  • 现实场景中的特殊处理技巧(不平衡数据、时间序列)

  • 专业级的划分策略(分层抽样、交叉验证)

关键要点:

  1. 分类问题务必使用stratify参数

  2. 不同数据格式需要适配不同的划分策略

  3. 时间序列数据不能随机划分

  4. 大数据集可使用简单划分,小数据集推荐交叉验证

记住:好的开始是成功的一半,合理的数据划分是构建优秀模型的基础!

相关推荐
wei_shuo1 小时前
OB Cloud 云数据库V4.3:SQL +AI全新体验
数据库·人工智能·sql
努力的搬砖人.1 小时前
AI生成视频推荐
人工智能
想要成为计算机高手2 小时前
Helix:一种用于通用人形控制的视觉语言行动模型
人工智能·计算机视觉·自然语言处理·大模型·vla
Mory_Herbert2 小时前
5.1 神经网络: 层和块
人工智能·深度学习·神经网络
Evand J3 小时前
MATLAB程序演示与编程思路,相对导航,四个小车的形式,使用集中式扩展卡尔曼滤波(fullyCN-EKF)
人工智能·算法
知来者逆4 小时前
在与大语言模型交互中的礼貌现象:技术影响、社会行为与文化意义的多维度探讨
人工智能·深度学习·语言模型·自然语言处理·llm
IT猿手6 小时前
基于 Q-learning 的城市场景无人机三维路径规划算法研究,可以自定义地图,提供完整MATLAB代码
深度学习·算法·matlab·无人机·强化学习·qlearning·无人机路径规划
xwz小王子7 小时前
Taccel:一个高性能的GPU加速视触觉机器人模拟平台
人工智能·机器人
深空数字孪生8 小时前
AI时代的数据可视化:未来已来
人工智能·信息可视化
Icoolkj8 小时前
探秘 Canva AI 图像生成器:重塑设计创作新范式
人工智能