笔记/sklearn中的数据划分方法

文章目录

    • 一、前言
    • 二、数据划分方法
      • [1. 留出法(Hold-out)](#1. 留出法(Hold-out))
      • [2. K折交叉验证(K-Fold)](#2. K折交叉验证(K-Fold))
      • [3. 留一法(Leave-One-Out)](#3. 留一法(Leave-One-Out))
    • 三、总结

一、前言

简要介绍数据划分在机器学习中的作用。

二、数据划分方法

1. 留出法(Hold-out)

  • 使用 train_test_split 将数据分为训练集和测试集。
  • 代码片段:
python 复制代码
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
print('Train obs: ', len(X_train))
print('Test obs: ', len(X_test))

2. K折交叉验证(K-Fold)

  • 用 KFold 将数据分为多折,循环训练和测试。
  • 代码片段:
python 复制代码
from sklearn.model_selection import KFold
X = np.random.randn(20, 1)
# 创建一个KFold对象,将数据分为5份,shuffle=True表示在分割前会先打乱数据
# 设置一个random state保证每次打乱的结果一致
kf = KFold(n_splits=5, shuffle=True, random_state=10)
#kf.get_n_splits(X)
for train_index, test_index in kf.split(X):
    print(train_index, test_index)
# 创建一个KFold对象,将数据分为5份,不打乱数据
kf = KFold(n_splits=5, shuffle=False)
#kf.get_n_splits(X)
for train_index, test_index in kf.split(X):
    print(train_index, test_index)    


Note:假设总共有N个样本,K折交叉验证会将数据平均分成K份。每一折中,test_index的数量大约是 N/K(如果N不能被K整除,有的折会多一个或少一个),其余的样本作为训练集,train_index的数量就是N- test_index 的数量。在本例中,test_index的数量是20/5=4。

3. 留一法(Leave-One-Out)

  • 每次留一个样本做测试,其余做训练。
  • 代码片段:
python 复制代码
from sklearn.model_selection import LeaveOneOut
loo = LeaveOneOut()
loo.get_n_splits(X)
for train_index, test_index in loo.split(X):
    print(train_index, test_index)

三、总结

方法名称 主要思想 sklearn实现 训练集数量 测试集数量 适用场景与特点
留出法 随机划分一部分做训练,其余做测试 train_test_split 设定比例(如60%) 设定比例(如40%) 简单高效,适合大数据集
K折交叉验证 将数据均分为K份,轮流做测试 KFold N-N/K N/K 评估更稳定,适合中小数据集
留一法 每次留一个样本做测试,其余训练 LeaveOneOut N-1 1 适合样本量较小的情况

说明:

  • 训练集数量和测试集数量均为占总样本数的比例或数量。
  • K折法和留一法属于交叉验证,能更全面评估模型性能。
  • 留出法实现简单,适合数据量较大时快速实验。

参考:https://scikit-learn.org/stable/api/sklearn.model_selection.html

博客内容如有错误欢迎指正~

相关推荐
大千AI助手15 分钟前
加权分位数直方图:提升机器学习效能的关键技术
人工智能·机器学习·xgboost·直方图·加权直方图·特征分裂
明月56632 分钟前
github开源笔记应用程序项目推荐-Joplin
笔记·开源·joplin·跨平台笔记应用
YuCaiH1 小时前
网络编程的基础知识
linux·笔记·嵌入式·网络通信
AI数据皮皮侠1 小时前
中国博物馆数据
大数据·人工智能·python·深度学习·机器学习
强哥之神1 小时前
从零理解 KV Cache:大语言模型推理加速的核心机制
人工智能·深度学习·机器学习·语言模型·llm·kvcache
m0_689618281 小时前
突破亚微米光电子器件制造瓶颈!配体交换辅助打印技术实现全打印红外探测器
笔记·制造
Q26433650232 小时前
【有源码】基于Python与Spark的火锅店数据可视化分析系统-基于机器学习的火锅店综合竞争力评估与可视化分析-基于用户画像聚类的火锅店市场细分与可视化研究
大数据·hadoop·python·机器学习·数据分析·spark·毕业设计
W_chuanqi3 小时前
RDEx:一种效果驱动的混合单目标优化器,自适应选择与融合多种算子与策略
人工智能·算法·机器学习·性能优化
chenzhou__3 小时前
MYSQL学习笔记(个人)(第十五天)
linux·数据库·笔记·学习·mysql
rechol4 小时前
C++ 继承笔记
java·c++·笔记