Pytorch torch.utils.data.DataLoader 用法详细介绍

文章目录

  • [1. 介绍](#1. 介绍)
  • [2. 参数详解](#2. 参数详解)
  • [3. 用法](#3. 用法)
  • [4. 参考](#4. 参考)

1. 介绍

torch.utils.data.DataLoader 是 PyTorch 提供的一个用于数据加载的工具类,用于批量加载数据并为模型提供输入。它可以将数据集包装成一个可迭代的对象,方便地进行数据加载和批处理操作。Pytorch DataLoader 的详细官方介绍看这里

2. 参数详解

  • dataset (Dataset) -- 加载的数据集

  • batch_size (int, optional) -- 每一次处理加载多少数据

  • shuffle (bool, optional) -- True 表示每次 epoch 都要重新打乱数据,默认 False

  • sampler (Sampler or Iterable, optional) -- 定义采样的策略。如果定义了此参数,那么 shuffle 参数必须为 False

  • batch_sampler (Sampler or Iterable, optional) -- 同 sample 一样,但每次返回数据的索引。与 batch_sizeshufflesampledrop_last 参数互斥

  • num_workers (int, optional) -- 指定用于数据加载的子进程数,可以加快数据加载速度。默认0,表示用主进程加载

  • collate_fn (Callable, optional) -- 批处理函数,用于将多个样本合并成一个批次,例如将多个张量拼接在一起,构建 mini-batch。当使用 map-style 数据集进行批量加载时使用。

  • pin_memory (bool, optional) -- True 表示在返回张量之前将张量复制到 CUDA 固定的内存中,加快 GPU 传输速度

  • drop_last (bool, optional) -- True 表示可删除最后一个不完整的批次。默认 False,如果数据集的大小不能被批次大小整除,则最后一个批次会更小。

  • timeout (numeric, optional) -- 非负数,worker 收集批次数据的超时时间,默认0

  • worker_init_fn (Callable, optional) -- 如果非None,则在种子设定之后和数据加载之前,将以worker id([0,num_workers-1]中的int)作为输入对每个 worker 子进程调用此函数。(默认值:None)

  • multiprocessing_context (str or multiprocessing.context.BaseContext, optional) -- 如果为None,则将使用操作系统的默认多处理上下文。(默认值:None)

  • generator (torch.Generator, optional) -- 如果非None,则RandomSampler 将使用此RNG来生成随机索引,并进行多进程处理以为 workers 生成 base_seed。(默认值:None)

  • prefetch_factor (int, optional, keyword-only arg) -- 每个 worker 预先装载的批次数。2 表示在所有工作线程中总共预取2*num_workers批次。(默认值取决于为num_workers设置的值。如果num_workers=0的值,则默认为None。否则,如果num_workers>0的值,默认为2)

  • persistent_workers (bool, optional) -- True 表示不会在数据集使用一次后关闭工作进程。这允许保持 worker 实例处于活动状态。(默认值:False)

  • pin_memory_device (str, optional) -- 如果 pin_memory 为 True,该参数表示 pin_memory 所指向的设备

3. 用法

使用 DataLoader 进行迭代

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader
# 假设有自定义数据集类 MyDataset
class MyDataset(Dataset):
    # 实现 __init__, __len__, 和 __getitem__ 方法...

# 实例化数据集
dataset = MyDataset(data_source)

# 创建 DataLoader
dataloader = DataLoader(dataset,
                       batch_size=64,  # 设置批次大小
                       shuffle=True,   # 是否随机打乱数据
                       num_workers=4,  # 启用4个工作进程加载数据
                       drop_last=True  # 丢弃最后一个不足批次大小的数据
                      )

# 迭代数据加载器进行训练
for epoch in range(num_epochs):
    for inputs, labels in dataloader:
        # 训练模型...
        outputs = model(inputs)
        loss = compute_loss(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

在迭代过程中,loader 会自动从数据集中加载数据,并将其组织成批次。每次迭代返回一个批次的数据,其中 batch_data 是一个包含输入数据和标签的元组或列表。

4. 参考

https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

相关推荐
skywalk81632 分钟前
Trae 是一款由 AI 驱动的 IDE,让编程更加愉悦和高效。国际版集成了 GPT-4 和 Claude 3.5,国内版集成了DeepSeek-r1
人工智能·trae
WenGyyyL8 分钟前
使用OpenCV和MediaPipe库——驼背检测(姿态监控)
人工智能·python·opencv·算法·计算机视觉·numpy
梓羽玩Python21 分钟前
开源版Manus来了!14.7k标星的OpenManus,让AI替你全自动执行任务!
人工智能·github
蹦蹦跳跳真可爱58921 分钟前
Python----数据分析(Matplotlib四:Figure的用法,创建Figure对象,常用的Figure对象的方法)
python·数据分析·matplotlib
广拓科技21 分钟前
中国视频生成 AI 开源潮:腾讯阿里掀技术普惠革命,重塑内容创作格局
人工智能·开源
dr李四维31 分钟前
Java在小米SU7 Ultra汽车中的技术赋能
java·人工智能·安卓·智能驾驶·互联·小米su7ultra·hdfs架构
guanshiyishi32 分钟前
ABeam 德硕 | 中国汽车市场(1)——正在推进电动化的中国汽车市场
人工智能·物联网·汽车
思茂信息33 分钟前
CST直角反射器 --- 距离多普勒(RD图), 毫米波汽车雷达ADAS
前端·人工智能·5g·汽车·无人机·软件工程
瑞瑞大大1 小时前
简单介绍下Manus功能
人工智能
小杨4041 小时前
python入门系列六(文件操作)
人工智能·python·pycharm