学校对一个年级进行分班时,需要班级编号以及各同学的位号,你对数据集进行K折时,也面临这一问题:样本位于第几折第几条?有趣的是:你让为很复杂的一件事,却一句话就可以解决:
enumerate(folds.split(....))
一、两个迭代器
sklearn.model_selection.KFold 时,有两个用于"编号"的迭代器:folds.split(...) 和 enumerate(...),其中,folds.split(...) 是原始迭代器 (生成K折的训练/验证索引对),enumerate(...) 是包装迭代器(给原始迭代器的每个元素添加计数索引)。
1. folds.split(train, train['label']) → 原始迭代器
sklearn.model_selection.KFold 的 split 方法返回的是一个迭代器(更准确说是生成器),满足迭代器的核心特性:
- 惰性生成 :不会一次性生成所有折的索引对,而是遍历到某一折时才生成该折的
(trn_idx, val_idx); - 仅可遍历一次:迭代器的元素取完后就"耗尽",再次遍历会无结果;
- 可迭代 :支持
for ... in ...遍历,也可通过next()手动取元素。
验证示例:
python
import pandas as pd
from sklearn.model_selection import KFold
train = pd.DataFrame({"label": [1,0,1,1,0]})
folds = KFold(n_splits=2, shuffle=True, random_state=2020)
# 获取split返回的迭代器
split_iter = folds.split(train, train['label'])
# 验证1:是迭代器(可调用next())
print("第一次next():", next(split_iter)) # 输出第一折的索引对
print("第二次next():", next(split_iter)) # 输出第二折的索引对
# print("第三次next():", next(split_iter)) # 无元素,抛StopIteration异常
# 验证2:迭代器耗尽后无法再遍历
split_iter2 = folds.split(train, train['label'])
list(split_iter2) # 转列表,耗尽迭代器
for elem in split_iter2:
print(elem) # 无输出,因为迭代器已空
2. enumerate(...) → 包装迭代器
enumerate() 是Python内置函数,接收任意可迭代对象(包括迭代器) ,返回一个 enumerate 类型的迭代器------它不会改变原迭代器的元素内容,只是给每个元素"绑定一个递增的计数索引"。
核心特性(和原迭代器一致):
- 依然是惰性 :只有遍历到元素时,才会生成
(fold_, 原元素); - 仅可遍历一次:原迭代器耗尽,包装后的enumerate迭代器也会耗尽;
- 元素是"索引+原迭代器元素":原split迭代器的元素是
(trn_idx, val_idx),enumerate包装后变成(fold_, (trn_idx, val_idx))。
验证示例:
python
# 包装split返回的迭代器
enum_iter = enumerate(folds.split(train, train['label']), start=0)
# 验证:enumerate迭代器可遍历,元素是(折数, 索引对)
for fold_, (trn_idx, val_idx) in enum_iter:
print(f"折数:{fold_},索引对:{trn_idx, val_idx}")
# 再次遍历enum_iter,无输出(迭代器已耗尽)
for elem in enum_iter:
print(elem)
二、迭代器链条:split → enumerate 的关系
用一张图能清晰看到两者的关联:
原始数据(train)
↓
folds.split(...) → 迭代器A:生成 (trn_idx, val_idx)
↓
enumerate(迭代器A) → 迭代器B:生成 (fold_, (trn_idx, val_idx))
↓
for循环遍历迭代器B → 逐折处理数据
简单来说:
- 迭代器A(split返回)是"原料",负责生产K折的核心索引对;
- 迭代器B(enumerate返回)是"加工后的原料",给每个索引对加了"折数标签",方便代码中跟踪当前处理的是第几折。
三、迭代器的共性
- 惰性计算:不提前生成所有元素,遍历到才生成,节省内存(尤其适合大规模数据集);
- 一次性遍历 :迭代器的元素只能取一次,取完即空(所以代码中如果需要多次遍历,要重新调用
folds.split(...)生成新迭代器); - 支持for循环:这是迭代器最核心的使用方式,也是K折代码中遍历的基础;
- 不可索引 :迭代器不能用
[0][1]这样的索引取值(比如split_iter[0]会报错),只能通过next()或for循环逐个取。
四、附:enumerate(...)详解
enumerate()是Python内置的迭代器函数 ,核心作用是为可迭代对象(如列表、元组、迭代器)的每个元素"绑定一个递增的计数索引",返回一个enumerate类型的迭代器。
1、核心定义与语法
1. 基本作用
给任意可迭代对象(比如列表、迭代器)的每一个元素,添加一个从指定值开始的整数索引 ,遍历结果为(索引, 元素)的元组,解决"遍历可迭代对象时需要同时获取索引和值"的需求。
2. 语法
python
enumerate(iterable, start=0)
| 参数 | 说明 |
|---|---|
iterable |
任意可迭代对象(列表、元组、字符串、迭代器、生成器等) |
start |
计数起始值,默认0(可自定义为1、10等) |
| 返回值 | enumerate对象(迭代器),需遍历/转列表才能查看所有元素 |
2、基础用法示例(先理解核心逻辑)
先从简单场景入手,理解enumerate的基本行为:
1. 遍历列表(默认start=0)
python
# 可迭代对象:普通列表
fruits = ["苹果", "香蕉", "橙子"]
# 用enumerate包装,遍历结果
for idx, fruit in enumerate(fruits):
print(f"索引:{idx},元素:{fruit}")
输出:
索引:0,元素:苹果
索引:1,元素:香蕉
索引:2,元素:橙子
2. 自定义start(起始索引=1)
python
for idx, fruit in enumerate(fruits, start=1):
print(f"序号:{idx},水果:{fruit}")
输出:
序号:1,水果:苹果
序号:2,水果:香蕉
序号:3,水果:橙子
3. 查看enumerate迭代器的完整内容
enumerate返回的是惰性迭代器(不会一次性生成所有元素),可转成列表查看全部:
python
enum_obj = enumerate(fruits)
print("enumerate对象转列表:", list(enum_obj))
# 输出:[(0, '苹果'), (1, '香蕉'), (2, '橙子')]
3、结合K折场景的核心示例(呼应之前的代码)
回到你之前关注的enumerate(folds.split(train, train['label'])),我们拆解其执行逻辑:
1. 前置回顾
folds.split(...)返回一个迭代器 ,每个元素是(训练集索引数组, 验证集索引数组)(比如5折交叉验证时,这个迭代器有5个元素);enumerate的作用是给这个迭代器的每个元素,添加一个"折数索引"(从0开始)。
2. 完整示例(复刻之前的KFold场景)
python
import pandas as pd
from sklearn.model_selection import KFold
# 构造10行训练集
train = pd.DataFrame({"label": [1,0,1,1,0,1,1,0,1,0]})
# 初始化5折交叉验证
folds = KFold(n_splits=5, shuffle=True, random_state=2020)
# 第一步:先看folds.split(...)的迭代器内容(转列表)
split_iter = folds.split(train, train['label'])
print("folds.split(...)的迭代器内容:")
print(list(split_iter)) # 输出5个(训练索引, 验证索引)元组
print("-"*60)
# 第二步:用enumerate包装,遍历结果
split_iter2 = folds.split(train, train['label']) # 重新生成迭代器(迭代器只能遍历一次)
for fold_, (trn_idx, val_idx) in enumerate(split_iter2):
print(f"第{fold_}折:")
print(f" 训练集索引:{trn_idx}")
print(f" 验证集索引:{val_idx}")
3. 输出结果(核心部分)
folds.split(...)的迭代器内容:
[(array([0, 1, 2, 3, 4, 5, 6, 7]), array([8, 9])),
(array([0, 1, 2, 3, 4, 7, 8, 9]), array([5, 6])),
(array([0, 1, 4, 5, 6, 7, 8, 9]), array([2, 3])),
(array([2, 3, 4, 5, 6, 7, 8, 9]), array([0, 1])),
(array([0, 1, 2, 3, 5, 6, 8, 9]), array([4, 7]))]
------------------------------------------------------------
第0折:
训练集索引:[0 1 2 3 4 5 6 7]
验证集索引:[8 9]
第1折:
训练集索引:[0 1 2 3 4 7 8 9]
验证集索引:[5 6]
...(后3折略)