Pytorch:torch.utils.data.random_split()

random_split() 函数说明:

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)

参数:

  • dataset(Dataset) -要拆分的数据集
  • lengths(序列) -要产生的分割长度
  • generator(torch.Generator) -用于随机排列的生成器。

注:关于torch.Generator详见笔记:Pytorch:torch.Generator()

pytorch: random_split(),函数的具体定义如下:

python 复制代码
def random_split(dataset, lengths):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths)).tolist()
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

以U-Net代码(详见:U-Net代码复现)为例:

python 复制代码
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

通过random_split()将数据分为训练集和验证集(随机)

相关推荐
腾飞开源几秒前
09_Spring AI 干货笔记之多模态
图像处理·人工智能·spring ai·多模态大语言模型·多模态api·媒体输入·文本响应
CM莫问13 分钟前
详解机器学习经典模型(原理及应用)——岭回归
人工智能·python·算法·机器学习·回归
七牛云行业应用13 分钟前
告别RLHF?DeepSeek过程奖励(PRM)架构解析与推理数据流设计
人工智能·强化学习·大模型架构·deepseek
SunnyRivers14 分钟前
Python打包指南:编写你的pyproject.toml
python·打包·toml
xcLeigh14 分钟前
AI的提示词专栏:Prompt 与传统机器学习特征工程的异同
人工智能·机器学习·ai·prompt·提示词
DuHz14 分钟前
论文阅读——Edge Impulse:面向微型机器学习的MLOps平台
论文阅读·人工智能·物联网·算法·机器学习·edge·边缘计算
诚丞成16 分钟前
机器学习——生成对抗网络(GANs):原理、进展与应用前景分析
人工智能·机器学习·生成对抗网络
盼小辉丶16 分钟前
图机器学习(7)——图神经网络 (Graph Neural Network, GNN)
人工智能·神经网络·图神经网络·图机器学习
码字的字节16 分钟前
机器学习中的可解释性:深入理解SHAP值及其应用
人工智能·shap
爱数学的程序猿20 分钟前
机器学习“捷径”:自动特征工程全面解析
人工智能·机器学习