有的时候训练需要对样本重复抽样为一个batch,可以按如下格式定义:
bash
class TrainLoader(Dataset):
def __init__(self, fns, repeat=1):
super(TrainLoader, self).__init__()
self.length = len(fns) # 数据数量
self.repeat = repeat # 数据重复次数
def __getitem__(self, idx):
idx = idx % self.length
def __len__(self):
return self.length * self.repeat