Pytorch:torch.utils.data.random_split()

random_split() 函数说明:

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)

参数:

  • dataset(Dataset) -要拆分的数据集
  • lengths(序列) -要产生的分割长度
  • generator(torch.Generator) -用于随机排列的生成器。

注:关于torch.Generator详见笔记:Pytorch:torch.Generator()

pytorch: random_split(),函数的具体定义如下:

python 复制代码
def random_split(dataset, lengths):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths)).tolist()
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

以U-Net代码(详见:U-Net代码复现)为例:

python 复制代码
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

通过random_split()将数据分为训练集和验证集(随机)

相关推荐
kszlgy1 小时前
Day 52 神经网络调参指南
python
wrj的博客2 小时前
python环境安装
python·学习·环境配置
康康的AI博客2 小时前
腾讯王炸:CodeMoment - 全球首个产设研一体 AI IDE
ide·人工智能
中达瑞和-高光谱·多光谱2 小时前
中达瑞和LCTF:精准调控光谱,赋能显微成像新突破
人工智能
mahtengdbb13 小时前
【目标检测实战】基于YOLOv8-DynamicHGNetV2的猪面部检测系统搭建与优化
人工智能·yolo·目标检测
Pyeako3 小时前
深度学习--BP神经网络&梯度下降&损失函数
人工智能·python·深度学习·bp神经网络·损失函数·梯度下降·正则化惩罚
清 澜3 小时前
大模型面试400问第一部分第一章
人工智能·大模型·大模型面试
不大姐姐AI智能体3 小时前
搭了个小红书笔记自动生产线,一句话生成图文,一键发布,支持手机端、电脑端发布
人工智能·经验分享·笔记·矩阵·aigc
摘星编程4 小时前
OpenHarmony环境下React Native:Geolocation地理围栏
python