Pytorch:torch.utils.data.random_split()

random_split() 函数说明:

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)

参数:

  • dataset(Dataset) -要拆分的数据集
  • lengths(序列) -要产生的分割长度
  • generator(torch.Generator) -用于随机排列的生成器。

注:关于torch.Generator详见笔记:Pytorch:torch.Generator()

pytorch: random_split(),函数的具体定义如下:

python 复制代码
def random_split(dataset, lengths):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths)).tolist()
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

以U-Net代码(详见:U-Net代码复现)为例:

python 复制代码
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

通过random_split()将数据分为训练集和验证集(随机)

相关推荐
邵宇然1 分钟前
Rust Unsafe 安全规范:从避免未定义行为到构建安全抽象的工程实践
人工智能
伊布拉西莫6 分钟前
LangChain LCEL源码深度剖析
python·langchain
用心_承载未来8 分钟前
从“复制链接→打开APP“到“一键解析“:我做了个短视频去水印工具
python·去水印·短视频去水印
TYUT_xiaoming10 分钟前
yolo模型训练
人工智能·python·yolo
2301_7807896615 分钟前
零信任架构中,身份感知防火墙(IAFW)的部署要点与最佳实践
linux·运维·服务器·人工智能·tcp/ip·架构
MicroTech202516 分钟前
业绩披露|微算法科技(MLGO)2025年净利润1.27亿元
大数据·人工智能·科技
百度Geek说16 分钟前
Superpowers:给 Claude Code 装上“工程大脑”
人工智能
AGIPlayer17 分钟前
没有生态的大模型不算前沿
大数据·人工智能·物联网
lulu121654407822 分钟前
OpenRouter Fusion 多模型融合架构深度拆解:预算级模型组团打平 Fable 5,多模型协作才是 AGI 的正确打开方式?
java·人工智能·架构·ai编程·agi
恋猫de小郭27 分钟前
Redis 作者反驳「中国模型之所以强,是因为通过 API 蒸馏了美国模型」
前端·人工智能·ai编程