random_split() 函数说明:
torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)
参数:
- dataset(Dataset) -要拆分的数据集
- lengths(序列) -要产生的分割长度
- generator(torch.Generator) -用于随机排列的生成器。
注:关于torch.Generator详见笔记:Pytorch:torch.Generator()
pytorch: random_split(),函数的具体定义如下:
python
def random_split(dataset, lengths):
r"""
Randomly split a dataset into non-overlapping new datasets of given lengths.
Arguments:
dataset (Dataset): Dataset to be split
lengths (sequence): lengths of splits to be produced
"""
if sum(lengths) != len(dataset):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
indices = randperm(sum(lengths)).tolist()
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
以U-Net代码(详见:U-Net代码复现)为例:
python
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
通过random_split()将数据分为训练集和验证集(随机)