Pytorch:torch.utils.data.random_split()

random_split() 函数说明:

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)

参数:

  • dataset(Dataset) -要拆分的数据集
  • lengths(序列) -要产生的分割长度
  • generator(torch.Generator) -用于随机排列的生成器。

注:关于torch.Generator详见笔记:Pytorch:torch.Generator()

pytorch: random_split(),函数的具体定义如下:

python 复制代码
def random_split(dataset, lengths):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths)).tolist()
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

以U-Net代码(详见:U-Net代码复现)为例:

python 复制代码
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

通过random_split()将数据分为训练集和验证集(随机)

相关推荐
扫地的小何尚8 分钟前
全新NVIDIA Llama Nemotron Nano视觉语言模型在OCR基准测试中准确率夺冠
c++·人工智能·语言模型·机器人·ocr·llama·gpu
xiaohanbao0911 分钟前
day54 python对抗生成网络
网络·python·深度学习·学习
爬虫程序猿14 分钟前
利用 Python 爬虫按关键字搜索 1688 商品
开发语言·爬虫·python
英杰.王25 分钟前
深入 Java 泛型:基础应用与实战技巧
java·windows·python
安替-AnTi30 分钟前
基于Django的购物系统
python·sql·django·毕设·购物系统
树叶@33 分钟前
Python 数据分析10
python·数据分析
岁月如歌,青春不败40 分钟前
Python-PLAXIS自动化建模技术与典型岩土工程
python·自动化·岩土工程·公路·地球科学·铁路·地质工程
m0_575470881 小时前
n8n实战:自动化生成AI日报并发布
人工智能·ai·自动化·ai自动写作
软件开发技术深度爱好者1 小时前
python类成员概要
开发语言·python