Pytorch:torch.utils.data.random_split()

random_split() 函数说明:

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)

参数:

  • dataset(Dataset) -要拆分的数据集
  • lengths(序列) -要产生的分割长度
  • generator(torch.Generator) -用于随机排列的生成器。

注:关于torch.Generator详见笔记:Pytorch:torch.Generator()

pytorch: random_split(),函数的具体定义如下:

python 复制代码
def random_split(dataset, lengths):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths)).tolist()
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

以U-Net代码(详见:U-Net代码复现)为例:

python 复制代码
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

通过random_split()将数据分为训练集和验证集(随机)

相关推荐
咕白m6257 小时前
Python 将 Excel 转换为图片:实现数据可视化
后端·python
一人の梅雨7 小时前
1688 拍立淘接口深度开发:从图像识别到供应链匹配的技术实现
人工智能·算法·计算机视觉
递归不收敛7 小时前
吴恩达机器学习课程(PyTorch适配)学习笔记:1.4 模型评估与问题解决
pytorch·学习·机器学习
深蓝电商API7 小时前
不止是 Python:聊聊 Node.js/Puppeteer 在爬虫领域的应用
爬虫·python·node.js
Autumn72997 小时前
【材料学python入门】conda、 jupyter、cpu、GPAW、wsl、ubuntu
python·jupyter·conda
dundunmm7 小时前
【数据集】WebQuestions
人工智能·llm·数据集·知识库问答·知识库
却道天凉_好个秋8 小时前
OpenCV(五):鼠标控制
人工智能·opencv·鼠标控制
K2I-8 小时前
UCI中Steel Plates Faults不平衡数据集处理
python
蓑笠翁0018 小时前
Django REST Framework 全面指南:从模型到完整API接口开发
后端·python·django
IT_陈寒8 小时前
Redis性能优化:5个被低估的配置项让你的QPS提升50%
前端·人工智能·后端