Pytorch:torch.utils.data.random_split()

random_split() 函数说明:

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)

参数:

  • dataset(Dataset) -要拆分的数据集
  • lengths(序列) -要产生的分割长度
  • generator(torch.Generator) -用于随机排列的生成器。

注:关于torch.Generator详见笔记:Pytorch:torch.Generator()

pytorch: random_split(),函数的具体定义如下:

python 复制代码
def random_split(dataset, lengths):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths)).tolist()
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

以U-Net代码(详见:U-Net代码复现)为例:

python 复制代码
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

通过random_split()将数据分为训练集和验证集(随机)

相关推荐
圣殿骑士-Khtangc几秒前
【论文精读】《Scaling Laws for Neural Language Models》解读:大模型缩放定律的核心秘密
人工智能·语言模型·自然语言处理·神经网络模型
与虾牵手9 分钟前
Rust 入门:一个写了 6 年 Python 的人,被编译器骂了三天
python
2401_8578652310 分钟前
Python日志记录(Logging)最佳实践
jvm·数据库·python
慧知AI10 分钟前
用OpenClaw的架构思路,我给公司搭了一套内部AI Agent系统,替代了3个外包岗位
人工智能
balmtv12 分钟前
2026年Gemini 3 Pro技术拆解:深度推理、空间智能与Agentic系统的架构革命
人工智能·gpt·架构
AsDuang13 分钟前
Python 3.12 MagicMethods - 54 - __rrshift__
开发语言·python
Shining059615 分钟前
前沿模型系列(四)《大模型前沿架构》
人工智能·学习·其他·ai·架构·大模型·infinitensor
进击的野人23 分钟前
Prompt工程入门指南:写给AI学习新手的提示词秘籍
人工智能·aigc·ai编程
Bert.Cai24 分钟前
Python字符串详解
开发语言·python
Ekehlaft26 分钟前
同题画图大考,AiPy 适配性拉满,OpenClaw 全程 “哑火”
人工智能·ai·openclaw·aipy