Pytorch:torch.utils.data.random_split()

random_split() 函数说明:

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)

参数:

  • dataset(Dataset) -要拆分的数据集
  • lengths(序列) -要产生的分割长度
  • generator(torch.Generator) -用于随机排列的生成器。

注:关于torch.Generator详见笔记:Pytorch:torch.Generator()

pytorch: random_split(),函数的具体定义如下:

python 复制代码
def random_split(dataset, lengths):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths)).tolist()
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

以U-Net代码(详见:U-Net代码复现)为例:

python 复制代码
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

通过random_split()将数据分为训练集和验证集(随机)

相关推荐
岁月宁静4 分钟前
当 AI 越来越“聪明”,人类真正的护城河是什么:智商、意识与认知主权
人工智能
CareyWYR7 分钟前
每周AI论文速递(260105-260109)
人工智能
智能相对论18 分钟前
CES深度观察丨智能清洁的四大关键词:变形、出户、体验以及生态协同
大数据·人工智能
AI小怪兽22 分钟前
基于YOLOv13的汽车零件分割系统(Python源码+数据集+Pyside6界面)
开发语言·python·yolo·无人机
齐齐大魔王22 分钟前
Pascal VOC 数据集
人工智能·深度学习·数据集·voc
程途拾光15831 分钟前
幻觉抑制:检索增强生成(RAG)的优化方向
人工智能
野豹商业评论33 分钟前
千问发力:“AI家教”开始抢教培生意?
人工智能
wszy180933 分钟前
新文章标签:让用户一眼发现最新内容
java·python·harmonyos
Eric.Lee202144 分钟前
python实现 mp4转gif文件
开发语言·python·手势识别·手势交互·手势建模·xr混合现实
程序员佳佳1 小时前
【万字硬核】从零构建企业级AI中台:基于Vector Engine整合GPT-5.2、Sora2与Veo3的落地实践指南
人工智能·gpt·chatgpt·ai作画·aigc·api·ai编程