以下为例:
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
# 这些样本是随机读取的,没有特定的顺序
random.shuffle(indices)
for i in range(0, num_examples, batch_size):
batch_indices = torch.tensor(
indices[i: min(i + batch_size, num_examples)])
yield features[batch_indices], labels[batch_indices]
在 Python 中,
yield
是一个关键字,使用yield
的函数是一个生成器函数生成器函数的基本概念
- 普通函数在执行时,遇到
return
语句就会终止函数执行,并返回相应的值。而生成器函数在执行过程中,遇到yield
语句时,会暂停函数的执行,保存当前的执行状态(包括局部变量的值等),并返回yield
后面表达式的值(如果有的话)。当下一次通过某种方式(比如在循环中迭代这个生成器)来请求生成器继续执行时,函数会从上次暂停的地方(也就是yield
语句处)继续往下执行,直到再次遇到yield
语句或者函数执行完毕(如果没有更多的yield
语句了)。在
data_iter
函数中的具体作用
- 在
data_iter
函数里,目的是将给定的数据集(features
和labels
)按照指定的batch_size
划分成一个个小批次(batch)数据来方便后续的批量训练等操作。- 当循环执行到
yield features[batch_indices], labels[batch_indices]
这一行时:
- 首先,它会基于当前批次对应的索引(
batch_indices
)从总的特征数据features
和标签数据labels
中取出相应的批次数据。- 然后,将取出的该批次的特征数据和标签数据作为一个元组返回,这个返回值可以被外部代码获取到(比如在循环中迭代这个生成器来依次获取每个批次的数据)。
- 执行完这次
yield
后,函数就暂停在这里了,等到下一次继续迭代这个生成器(比如下一次循环到这里来获取下一个批次的数据),函数会接着从这个yield
语句之后继续执行,重新去处理下一组索引范围,取出下一个批次的数据并返回,如此反复,直到整个数据集的样本都被划分成批次并返回完。总的来说,
yield
让data_iter
函数变成了一个生成器,能方便地按批次逐个生成数据,避免一次性把所有数据都处理好放入内存,节省内存空间并且符合按批次处理数据的常见深度学习训练流程需求。