【教程】DataLoader中各个参数的解释

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]

如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~

目录

参数介绍

dataset

[batch_size: int = 1](#batch_size: int = 1)

[shuffle: bool = False](#shuffle: bool = False)

[sampler: Sampler | Iterable = None](#sampler: Sampler | Iterable = None)

[batch_sampler: Sampler | Iterable = None](#batch_sampler: Sampler | Iterable = None)

[num_workers: int = 0](#num_workers: int = 0)

[collate_fn: Callable = None](#collate_fn: Callable = None)

[pin_memory: bool = False](#pin_memory: bool = False)

[drop_last: bool = False](#drop_last: bool = False)

[timeout: float = 0](#timeout: float = 0)

[worker_init_fn: Callable = None](#worker_init_fn: Callable = None)

[multiprocessing_context: str](#multiprocessing_context: str)

[generator: torch.Generator = None](#generator: torch.Generator = None)

[prefetch_factor: int | None](#prefetch_factor: int | None)

[persistent_workers: bool = False](#persistent_workers: bool = False)

[pin_memory_device: str = ""](#pin_memory_device: str = "")

一些警告


torch.utils.data --- PyTorch 2.8 documentation

数据加载器(DataLoader) 将数据集(dataset)与采样器(sampler)结合起来,并提供一个可迭代对象,用于遍历给定的数据集。torch.utils.data.DataLoader 支持 map-style(映射式)iterable-style(可迭代式) 两种数据集,支持单进程或多进程加载,可以自定义加载顺序,并可选择是否自动分批(collation)以及是否将数据固定到内存(pinning)。

参数介绍

torch.utils.data.DataLoader的所有参数:

bash 复制代码
class DataLoader(
    # 要加载的数据集。
    dataset: Dataset,
    # 每个批次加载多少样本(默认:1)。   
    batch_size: int | None = 1,
    # 若为 True,则在每个 epoch 开始时打乱数据(默认:False)。
    shuffle: bool | None = None,
    # 定义从数据集中抽取样本的策略。可以是任何实现了 __len__ 的 Iterable。如果指定了 sampler,则不能再设置 shuffle。
    sampler: Sampler | Iterable | None = None,
    # 与 sampler 类似,但一次返回一批索引。与 batch_size、shuffle、sampler 和 drop_last 互斥。
    batch_sampler: Sampler[List] | Iterable[List] | None = None,
    # 用于数据加载的子进程数量。设为 0 时,数据将在主进程中加载(默认:0)。
    num_workers: int = 0,
    # 将一个样本列表合并成一个小批量(mini-batch)张量。通常在从 map-style 数据集进行批量加载时使用。
    collate_fn: _collate_fn_t | None = None,
    # 若为 True,则 DataLoader 会在返回前把张量复制到设备/CUDA 的 锁页内存(pinned memory)。如果你的数据元素是自定义类型,或者 collate_fn 返回的是自定义类型。
    pin_memory: bool = False,
    # 若为 True,当数据集大小不能被 batch_size 整除时,丢弃最后一个不完整的批次。若为 False,则保留最后一个较小的批次(默认:False)。
    drop_last: bool = False,
    # 若大于 0,表示从 worker 获取一个批次的超时时间(秒)。必须为非负数(默认:0)。
    timeout: float = 0,
    # 若不为 None,则在每个 worker 子进程启动时调用,输入参数为 worker id(范围 [0, num_workers - 1])。它会在设置随机种子之后、开始加载数据之前被调用(默认:None)。
    worker_init_fn: _worker_init_fn_t | None = None,
    # 若为 None,则使用操作系统的默认多进程上下文(multiprocessing context)(默认:None)。
    multiprocessing_context: Any | None = None,
    # 若不为 None,则该随机数生成器将被 RandomSampler 用于生成随机索引,并被多进程机制用于生成 worker 的基础随机种子(默认:None)。
    generator: Any | None = None,
    # 每个 worker 预先加载的批次数。2 表示总共有 2 * num_workers 个批次被预取。默认值取决于 num_workers 的设置。如果 num_workers=0,则默认是 None;否则默认是 2。
    prefetch_factor: int | None = None,
    # 若为 True,则在数据集被消耗完一次后,worker 进程不会关闭。这允许在多个 epoch 之间保持 worker 中的 Dataset 实例常驻(默认:False)。
    persistent_workers: bool = False,
    # 当 pin_memory=True 时,指定锁页内存绑定到的设备。
    pin_memory_device: str = ""
)

dataset

数据来源对象,告诉 DataLoader 去哪里拿样本。pytorch提供的torch.utils.data.Dataset类是一个抽象基类,供用户继承,编写自己的dataset,实现对数据的读取。允许两种格式的dataset:

  • map-style :一个map格式的数据集必须要重写getitem(self, index) 和**len(self)**两个内建方法,用来表示从索引到样本的映射。这样一个数据集dataset,当使用dataset[idx]命令时,可以读取你的数据集中第idx张图片以及其标签;len(dataset)则会返回这个数据集的容量。

  • iterable-style :一个Iterable格式的数据集是抽象类data.IterableDataset 的子类,并且覆写了iter方法成为一个迭代器。这种数据集主要用于数据大小未知,或者以流的形式的输入,本地文件不固定的情况,需要以迭代的方式来获取样本索引。

注意

  • iterable-style 不支持 shuffle/sampler/batch_sampler(很多会被忽略或报错)。

  • 多进程(num_workers>0)会在子进程 里各自构造 dataset,所以 dataset 内的对象要可被子进程安全创建/序列化

  • 重 IO/解码别放在 __init__,放到 __getitem__/__iter__ 里。

示例(map-style)

python 复制代码
from torch.utils.data import Dataset

class ToySet(Dataset):
    def __init__(self, xs, ys):
        self.xs, self.ys = xs, ys
    def __len__(self):
        return len(self.xs)
    def __getitem__(self, idx):
        return self.xs[idx], self.ys[idx]

示例(iterable-style,含多 worker 分片)

python 复制代码
from torch.utils.data import IterableDataset, get_worker_info

class StreamSet(IterableDataset):
    def __iter__(self):
        info = get_worker_info()
        wid, wnum = (info.id, info.num_workers) if info else (0, 1)
        for i, sample in enumerate(infinite_source()):  # 你自己的流
            if i % wnum == wid:     # 简单分片,避免重复
                yield sample

batch_size: int = 1

每个批次里包含的样本数。控制一次前向/反向要处理多少样本,影响吞吐、显存占用与稳定性。默认 1,表示 不做批划分,每次返回单条样本。

注意

  • 对显存是线性增长的,建议先从 32/64 开始实验。

  • batch_sampler 互斥。指定了 batch_sampler不要 再传 batch_size

  • 可变长数据 若直接堆叠会报错(stack expects each tensor to be equal size),需要自定义 collate_fn 做 padding/对齐。

  • 最后一批不满时是否丢弃由 drop_last 决定。

  • 太大可能 OOM;太小可能训练不稳定(如 BatchNorm)。

示例

python 复制代码
from torch.utils.data import DataLoader
loader = DataLoader(ToySet(xs, ys), batch_size=32)
for xb, yb in loader:
    ...

shuffle: bool = False

每个 epoch 是否打乱样本顺序。可以避免固定顺序带来的偏差,提高泛化。

注意

  • samplerbatch_sampler 互斥。如果你已经自定义 sampler(如 RandomSampler),shuffle 必须设为 False,否则会报错。

  • 分布式训练(DDP)通常用 DistributedSampler(shuffle=True) 接管打乱与分片,不再单独传 shuffle=True

  • 训练集 → shuffle=True(防止模型记忆顺序)

  • 验证/测试集 → shuffle=False(保持顺序方便对应标签)

示例

python 复制代码
loader = DataLoader(ToySet(xs, ys), batch_size=64, shuffle=True)

sampler: Sampler | Iterable = None

返回单条样本的索引。可以自定义"取样本"的策略,实现加权采样、不放回采样、分布式分片等。

注意

  • 指定 sampler不要 再传 shuffle

  • 常用的内置 sampler:RandomSampler, SequentialSampler, WeightedRandomSampler.

  • 分布式用 DistributedSampler 时,记得 sampler.set_epoch(epoch),确保每轮乱序不同。

示例(类别不均衡的权重采样)

python 复制代码
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler

weights = torch.tensor([1, 5, 1, 2, ...], dtype=torch.double)  # 每样本权重
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
loader = DataLoader(ToySet(xs, ys), batch_size=32, sampler=sampler)

batch_sampler: Sampler | Iterable = None

直接返回每个 batch的索引列表。可以把"打乱+分批"的逻辑完全自定义(如按长度分桶、同类打包)。

注意

  • batch_sizeshufflesamplerdrop_last 互斥(这些都交给它控制)。

  • 返回的是"索引列表"的迭代器。

  • 适合"按类别分组采样""长度分桶"等复杂 batching 策略。

示例(用内置 BatchSampler 包一层)

python 复制代码
from torch.utils.data import DataLoader, RandomSampler, BatchSampler

base = RandomSampler(range(len(xs)))      # 先有一个样本级采样器
batches = BatchSampler(base, batch_size=16, drop_last=False)  # 再打成批
loader = DataLoader(ToySet(xs, ys), batch_sampler=batches)

num_workers: int = 0

用于数据加载的子进程个数;0 表示在主进程中加载(无并行)。I/O 或解码较重时,可以提高吞吐(读取与训练并行)。

注意

  • 每个 worker 是独立进程,有各自随机状态。

  • Windows/macOS 默认 spawn:传入的对象(datasetcollate_fn 等)必须可被 pickle ;避免 lambda、内嵌函数。

  • 在 Linux/macOS 上,num_workers = min(4, os.cpu_count()) 是一个不错的起点。

  • 在 Windows(spawn 启动方式)上,子进程启动开销更大,推荐 num_workers=0~2,或者使用 torch.multiprocessing.set_start_method('spawn') 明确启动方式。

  • 调优常用路径:从 2→4→8 逐步尝试;结合 prefetch_factorpersistent_workers

  • **死锁:**在 Windows + num_workers>0 + fork(不支持)+ DataLoader 与 multiprocessing.Queue 同时使用时容易出现 deadlock。

  • **异常捕获:**子进程异常会在下一个 batch 报错,堆栈信息在 torch.utils.data.DataLoader 中会被包装,建议使用 torch.utils.data.DataLoader(..., timeout=... ) 防止无限等待。

示例

python 复制代码
loader = DataLoader(ToySet(xs, ys), batch_size=64, num_workers=4)

collate_fn: Callable = None

把"样本列表"合并成"一个 batch"的函数。可以实现对齐变长序列(padding)、拼接字典/列表、返回自定义 Batch 对象等。

注意

  • 默认的 default_collate 已能处理 Tensor/数值/np/dict/list 等"等长"数据。

  • 处理变长 数据(NLP 序列、检测框)时,通常需要自定义 collate_fn 完成对齐。

  • 尽量不要在这里做重增强,重活放到 __getitem__

示例(序列 padding)

python 复制代码
import torch
from torch.nn.utils.rnn import pad_sequence

def pad_collate(batch):
    xs, ys = zip(*batch)
    xs = [torch.as_tensor(x) for x in xs]
    ys = torch.as_tensor(ys)
    return pad_sequence(xs, batch_first=True), ys

loader = DataLoader(ToySet(xs, ys), batch_size=32, collate_fn=pad_collate)

pin_memory: bool = False

把从 CPU 读取的数据直接拷贝到 页锁定(pinned)内存当随后调用 tensor.cuda(non_blocking=True) 时,GPU 能通过 DMA 直接读取,显著提升传输带宽。

注意

  • 只在GPU 训练 时打开。

  • 会增加主机内存占用;超大batch、内存紧张时谨慎开启。

  • 如果 collate_fn 返回的是自定义对象,需要实现 .pin_memory() 或在 collate_fn 内手动对每个张量 pin_memory()

  • 小 batchCPU‑only 训练,开启与否差别不大,甚至会稍微增加 CPU 内存占用。

示例(配合非阻塞拷贝)

python 复制代码
device = "cuda"
loader = DataLoader(ToySet(xs, ys), batch_size=64, pin_memory=True)
for xb, yb in loader:
    xb = xb.to(device, non_blocking=True)
    yb = yb.to(device, non_blocking=True)

drop_last: bool = False

当数据量不是 batch_size 的整数倍时,是否丢弃最后那个"不满"的批。可以在需要固定 batch 大小的场景(如 BatchNorm、固定步长的梯度累积)保持一致性。

注意:训练可能设 True;验证/测试一般 False。

示例

python 复制代码
loader = DataLoader(ToySet(xs, ys), batch_size=64, drop_last=True)

timeout: float = 0

从 workers 收集一个 batch 的超时秒数(非负)。可以排查卡死/死锁(到点抛错)。

注意

  • num_workers>0 才有意义;设太小会误报。
  • 当数据读取涉及 网络 I/O(如远程文件、WebDataset)时,防止卡死。
  • 在多机器分布式训练中,某些节点慢导致整体阻塞。

示例

python 复制代码
loader = DataLoader(ToySet(xs, ys), batch_size=64, num_workers=4, timeout=30)

worker_init_fn: Callable = None

每个 worker 进程 启动后、正式取数据前会被调用一次的函数,入参是 worker_id。可以设置 numpy/random 等外部库的种子、全局变量初始化、初始化三方资源。

注意

  • 为每个 worker 设置不同的 np.random.seed / random.seed,避免所有进程产生相同的随机数。

  • 初始化自定义的全局缓存(如打开一次数据库连接)。

  • spawn 启动方式(Win/macOS 常见)下,必须是可 pickle 的顶层函数 ,不能用 lambda

  • 想要可复现,常与 generator 一起用。

示例(可复现设置)

python 复制代码
import random, numpy as np, torch

def seed_worker(worker_id):
    seed = torch.initial_seed() % 2**32
    random.seed(seed); np.random.seed(seed)

g = torch.Generator().manual_seed(42)
loader = DataLoader(
    ToySet(xs, ys), shuffle=True, num_workers=4,
    worker_init_fn=seed_worker, generator=g
)

multiprocessing_context: str

指定 Python 多进程的"启动方式":"fork" | "spawn" | "forkserver" 或对应 context。当某些库在 fork 下不安全(线程、显存句柄)时,改用 spawn 更稳。

注意

  • Windows 只支持 spawn;Linux 默认 fork

  • spawn 更通用但启动更慢。

  • 当你的代码中有 CUDA 初始化OpenMP 线程池,fork 可能导致子进程继承父进程的 GPU 上下文,产生 CUDA 错误。此时改为 'spawn' 更安全。

示例

python 复制代码
ctx = torch.multiprocessing.get_context("spawn")
loader = DataLoader(ToySet(xs, ys), num_workers=4, multiprocessing_context=ctx)

generator: torch.Generator = None

DataLoader 用到的随机数发生器。可以控制乱序/随机采样,同时为 workers 生成基础种子。

注意

  • 强可复现:generator.manual_seed(...) + 在 worker_init_fn 同步 numpy/random 的种子。

示例

python 复制代码
g = torch.Generator().manual_seed(2024)
loader = DataLoader(ToySet(xs, ys), shuffle=True, generator=g)

prefetch_factor: int | None

每个 worker 会 预取 prefetch_factor * batch_size 条样本放进内部队列。可以让读取/解码与训练形成流水线,减少"等数据"。
注意

  • 默认:num_workers==0None(无预取);num_workers>02

  • 调大更顺滑,但会多占内存;I/O 慢时可适度调大(如 3~4)。

  • 当 I/O 非常慢(比如磁盘读取大图片),可以调大到 4~8,但会消耗更多内存。

示例

python 复制代码
loader = DataLoader(ToySet(xs, ys), batch_size=64, num_workers=4, prefetch_factor=3)

persistent_workers: bool = False

True 时,跨 epoch 不销毁 worker 进程(复用已创建的 Dataset 实例)。长时间训练更高效,避免反复建/毁进程(特别是 num_workers 较大时)。

注意

  • 只能在 torch >= 1.(1.7+)使用。

  • num_workers>0 生效。

  • 如果你的 Dataset 每轮都需要重建或怀疑有资源泄漏,谨慎开启或定期重启。

  • 如果 Dataset 在每个 epoch 会 改变内部状态(比如随机采样器自行改变种子),需要自行在 worker_init_fn 中重置,否则子进程会复用上一次的状态。

示例

python 复制代码
loader = DataLoader(ToySet(xs, ys), num_workers=4, persistent_workers=True)

pin_memory_device: str = ""

pin_memory=True 时,指定 pinned memory 绑定到哪个设备(比如 "cuda""cuda:0")。

更细粒度地控制固定内存的目标设备;多数场景留默认即可。
注意

  • 通常只需 pin_memory=True;只有在多设备或特定优化需求时才设置它。

示例

python 复制代码
loader = DataLoader(
    ToySet(xs, ys),
    batch_size=64,
    pin_memory=True,
    pin_memory_device="cuda"   # 或 "cuda:0"
)

一些警告

  1. 关于 spawn 启动方式 :如果使用 spawn 启动多进程,worker_init_fn 不能是不可序列化对象(例如 lambda 函数)。详情见 multiprocessing 最佳实践

  2. 关于 len(dataloader) 的计算

    • 当使用普通数据集(map-style dataset)时,len(dataloader) 基于所用的 sampler 的长度。

    • 当使用 IterableDataset 时,len(dataloader) 返回的估计值为 len(dataset) / batch_size,并根据 drop_last 进行四舍五入,与是否多进程无关。

    • 这是 PyTorch 能做出的最佳猜测 ,因为它假定用户编写的 dataset 能正确处理多进程加载以避免数据重复。

    • 但如果数据被分片(sharding)后导致多个 worker 有不完整的尾批,这个估计值仍可能不准确,因为:

      1. 一个完整批次可能被拆成多个不完整的批次。

      2. 当设置了 drop_last=True 时,可能会丢弃超过一个批次的数据。

    • PyTorch 无法检测到这类情况。

  3. 关于随机性和复现 :参见文档中的 reproducibilitydataloader-workers-random-seeddata-loading-randomness 部分,了解随机种子相关问题。