转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]
如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~
目录
[batch_size: int = 1](#batch_size: int = 1)
[shuffle: bool = False](#shuffle: bool = False)
[sampler: Sampler | Iterable = None](#sampler: Sampler | Iterable = None)
[batch_sampler: Sampler | Iterable = None](#batch_sampler: Sampler | Iterable = None)
[num_workers: int = 0](#num_workers: int = 0)
[collate_fn: Callable = None](#collate_fn: Callable = None)
[pin_memory: bool = False](#pin_memory: bool = False)
[drop_last: bool = False](#drop_last: bool = False)
[timeout: float = 0](#timeout: float = 0)
[worker_init_fn: Callable = None](#worker_init_fn: Callable = None)
[multiprocessing_context: str](#multiprocessing_context: str)
[generator: torch.Generator = None](#generator: torch.Generator = None)
[prefetch_factor: int | None](#prefetch_factor: int | None)
[persistent_workers: bool = False](#persistent_workers: bool = False)
[pin_memory_device: str = ""](#pin_memory_device: str = "")
数据加载器(DataLoader) 将数据集(dataset)与采样器(sampler)结合起来,并提供一个可迭代对象,用于遍历给定的数据集。torch.utils.data.DataLoader
支持 map-style(映射式) 和 iterable-style(可迭代式) 两种数据集,支持单进程或多进程加载,可以自定义加载顺序,并可选择是否自动分批(collation)以及是否将数据固定到内存(pinning)。

参数介绍
torch.utils.data.DataLoader的所有参数:
bash
class DataLoader(
# 要加载的数据集。
dataset: Dataset,
# 每个批次加载多少样本(默认:1)。
batch_size: int | None = 1,
# 若为 True,则在每个 epoch 开始时打乱数据(默认:False)。
shuffle: bool | None = None,
# 定义从数据集中抽取样本的策略。可以是任何实现了 __len__ 的 Iterable。如果指定了 sampler,则不能再设置 shuffle。
sampler: Sampler | Iterable | None = None,
# 与 sampler 类似,但一次返回一批索引。与 batch_size、shuffle、sampler 和 drop_last 互斥。
batch_sampler: Sampler[List] | Iterable[List] | None = None,
# 用于数据加载的子进程数量。设为 0 时,数据将在主进程中加载(默认:0)。
num_workers: int = 0,
# 将一个样本列表合并成一个小批量(mini-batch)张量。通常在从 map-style 数据集进行批量加载时使用。
collate_fn: _collate_fn_t | None = None,
# 若为 True,则 DataLoader 会在返回前把张量复制到设备/CUDA 的 锁页内存(pinned memory)。如果你的数据元素是自定义类型,或者 collate_fn 返回的是自定义类型。
pin_memory: bool = False,
# 若为 True,当数据集大小不能被 batch_size 整除时,丢弃最后一个不完整的批次。若为 False,则保留最后一个较小的批次(默认:False)。
drop_last: bool = False,
# 若大于 0,表示从 worker 获取一个批次的超时时间(秒)。必须为非负数(默认:0)。
timeout: float = 0,
# 若不为 None,则在每个 worker 子进程启动时调用,输入参数为 worker id(范围 [0, num_workers - 1])。它会在设置随机种子之后、开始加载数据之前被调用(默认:None)。
worker_init_fn: _worker_init_fn_t | None = None,
# 若为 None,则使用操作系统的默认多进程上下文(multiprocessing context)(默认:None)。
multiprocessing_context: Any | None = None,
# 若不为 None,则该随机数生成器将被 RandomSampler 用于生成随机索引,并被多进程机制用于生成 worker 的基础随机种子(默认:None)。
generator: Any | None = None,
# 每个 worker 预先加载的批次数。2 表示总共有 2 * num_workers 个批次被预取。默认值取决于 num_workers 的设置。如果 num_workers=0,则默认是 None;否则默认是 2。
prefetch_factor: int | None = None,
# 若为 True,则在数据集被消耗完一次后,worker 进程不会关闭。这允许在多个 epoch 之间保持 worker 中的 Dataset 实例常驻(默认:False)。
persistent_workers: bool = False,
# 当 pin_memory=True 时,指定锁页内存绑定到的设备。
pin_memory_device: str = ""
)
dataset
数据来源对象,告诉 DataLoader 去哪里拿样本。pytorch提供的torch.utils.data.Dataset类是一个抽象基类,供用户继承,编写自己的dataset,实现对数据的读取。允许两种格式的dataset:
-
map-style :一个map格式的数据集必须要重写getitem(self, index) 和**len(self)**两个内建方法,用来表示从索引到样本的映射。这样一个数据集dataset,当使用dataset[idx]命令时,可以读取你的数据集中第idx张图片以及其标签;len(dataset)则会返回这个数据集的容量。
-
iterable-style :一个Iterable格式的数据集是抽象类data.IterableDataset 的子类,并且覆写了iter方法成为一个迭代器。这种数据集主要用于数据大小未知,或者以流的形式的输入,本地文件不固定的情况,需要以迭代的方式来获取样本索引。
注意:
-
iterable-style 不支持
shuffle/sampler/batch_sampler
(很多会被忽略或报错)。 -
多进程(
num_workers>0
)会在子进程 里各自构造dataset
,所以dataset
内的对象要可被子进程安全创建/序列化。 -
重 IO/解码别放在
__init__
,放到__getitem__
/__iter__
里。
示例(map-style):
python
from torch.utils.data import Dataset
class ToySet(Dataset):
def __init__(self, xs, ys):
self.xs, self.ys = xs, ys
def __len__(self):
return len(self.xs)
def __getitem__(self, idx):
return self.xs[idx], self.ys[idx]
示例(iterable-style,含多 worker 分片):
python
from torch.utils.data import IterableDataset, get_worker_info
class StreamSet(IterableDataset):
def __iter__(self):
info = get_worker_info()
wid, wnum = (info.id, info.num_workers) if info else (0, 1)
for i, sample in enumerate(infinite_source()): # 你自己的流
if i % wnum == wid: # 简单分片,避免重复
yield sample
batch_size: int = 1
每个批次里包含的样本数。控制一次前向/反向要处理多少样本,影响吞吐、显存占用与稳定性。默认 1,
表示 不做批划分,每次返回单条样本。
注意:
-
对显存是线性增长的,建议先从 32/64 开始实验。
-
和
batch_sampler
互斥。指定了batch_sampler
时不要 再传batch_size
。 -
可变长数据 若直接堆叠会报错(
stack expects each tensor to be equal size
),需要自定义collate_fn
做 padding/对齐。 -
最后一批不满时是否丢弃由
drop_last
决定。 -
太大可能 OOM;太小可能训练不稳定(如 BatchNorm)。
示例:
python
from torch.utils.data import DataLoader
loader = DataLoader(ToySet(xs, ys), batch_size=32)
for xb, yb in loader:
...
shuffle: bool = False
每个 epoch 是否打乱样本顺序。可以避免固定顺序带来的偏差,提高泛化。
注意:
-
与
sampler
、batch_sampler
互斥。如果你已经自定义 sampler(如 RandomSampler),shuffle 必须设为 False,否则会报错。 -
分布式训练(DDP)通常用
DistributedSampler(shuffle=True)
接管打乱与分片,不再单独传shuffle=True
。 -
训练集 → shuffle=True(防止模型记忆顺序)
-
验证/测试集 → shuffle=False(保持顺序方便对应标签)
示例:
python
loader = DataLoader(ToySet(xs, ys), batch_size=64, shuffle=True)
sampler: Sampler | Iterable = None
返回单条样本的索引。可以自定义"取样本"的策略,实现加权采样、不放回采样、分布式分片等。
注意:
-
指定
sampler
后不要 再传shuffle
。 -
常用的内置 sampler:RandomSampler, SequentialSampler, WeightedRandomSampler.
-
分布式用
DistributedSampler
时,记得sampler.set_epoch(epoch)
,确保每轮乱序不同。
示例(类别不均衡的权重采样):
python
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
weights = torch.tensor([1, 5, 1, 2, ...], dtype=torch.double) # 每样本权重
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
loader = DataLoader(ToySet(xs, ys), batch_size=32, sampler=sampler)
batch_sampler: Sampler | Iterable = None
直接返回每个 batch的索引列表。可以把"打乱+分批"的逻辑完全自定义(如按长度分桶、同类打包)。
注意:
-
与
batch_size
、shuffle
、sampler
、drop_last
互斥(这些都交给它控制)。 -
返回的是"索引列表"的迭代器。
-
适合"按类别分组采样""长度分桶"等复杂 batching 策略。
示例(用内置 BatchSampler
包一层):
python
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
base = RandomSampler(range(len(xs))) # 先有一个样本级采样器
batches = BatchSampler(base, batch_size=16, drop_last=False) # 再打成批
loader = DataLoader(ToySet(xs, ys), batch_sampler=batches)
num_workers: int = 0
用于数据加载的子进程个数;0 表示在主进程中加载(无并行)。I/O 或解码较重时,可以提高吞吐(读取与训练并行)。
注意:
-
每个 worker 是独立进程,有各自随机状态。
-
Windows/macOS 默认
spawn
:传入的对象(dataset
、collate_fn
等)必须可被 pickle ;避免lambda
、内嵌函数。 -
在 Linux/macOS 上,num_workers = min(4, os.cpu_count()) 是一个不错的起点。
-
在 Windows(spawn 启动方式)上,子进程启动开销更大,推荐 num_workers=0~2,或者使用 torch.multiprocessing.set_start_method('spawn') 明确启动方式。
-
调优常用路径:从 2→4→8 逐步尝试;结合
prefetch_factor
、persistent_workers
。 -
**死锁:**在 Windows + num_workers>0 + fork(不支持)+ DataLoader 与 multiprocessing.Queue 同时使用时容易出现 deadlock。
-
**异常捕获:**子进程异常会在下一个 batch 报错,堆栈信息在 torch.utils.data.DataLoader 中会被包装,建议使用 torch.utils.data.DataLoader(..., timeout=... ) 防止无限等待。
示例:
python
loader = DataLoader(ToySet(xs, ys), batch_size=64, num_workers=4)
collate_fn: Callable = None
把"样本列表"合并成"一个 batch"的函数。可以实现对齐变长序列(padding)、拼接字典/列表、返回自定义 Batch 对象等。
注意:
-
默认的
default_collate
已能处理 Tensor/数值/np/dict/list 等"等长"数据。 -
处理变长 数据(NLP 序列、检测框)时,通常需要自定义
collate_fn
完成对齐。 -
尽量不要在这里做重增强,重活放到
__getitem__
。
示例(序列 padding):
python
import torch
from torch.nn.utils.rnn import pad_sequence
def pad_collate(batch):
xs, ys = zip(*batch)
xs = [torch.as_tensor(x) for x in xs]
ys = torch.as_tensor(ys)
return pad_sequence(xs, batch_first=True), ys
loader = DataLoader(ToySet(xs, ys), batch_size=32, collate_fn=pad_collate)
pin_memory: bool = False
把从 CPU 读取的数据直接拷贝到 页锁定(pinned)内存 ,当随后调用 tensor.cuda(non_blocking=True) 时,GPU 能通过 DMA 直接读取,显著提升传输带宽。
注意:
-
只在GPU 训练 时打开。
-
会增加主机内存占用;超大batch、内存紧张时谨慎开启。
-
如果
collate_fn
返回的是自定义对象,需要实现.pin_memory()
或在collate_fn
内手动对每个张量pin_memory()
。 -
对 小 batch 或CPU‑only 训练,开启与否差别不大,甚至会稍微增加 CPU 内存占用。
示例(配合非阻塞拷贝):
python
device = "cuda"
loader = DataLoader(ToySet(xs, ys), batch_size=64, pin_memory=True)
for xb, yb in loader:
xb = xb.to(device, non_blocking=True)
yb = yb.to(device, non_blocking=True)
drop_last: bool = False
当数据量不是 batch_size
的整数倍时,是否丢弃最后那个"不满"的批。可以在需要固定 batch 大小的场景(如 BatchNorm、固定步长的梯度累积)保持一致性。
注意:训练可能设 True;验证/测试一般 False。
示例:
python
loader = DataLoader(ToySet(xs, ys), batch_size=64, drop_last=True)
timeout: float = 0
从 workers 收集一个 batch 的超时秒数(非负)。可以排查卡死/死锁(到点抛错)。
注意:
- 对
num_workers>0
才有意义;设太小会误报。 - 当数据读取涉及 网络 I/O(如远程文件、WebDataset)时,防止卡死。
- 在多机器分布式训练中,某些节点慢导致整体阻塞。
示例:
python
loader = DataLoader(ToySet(xs, ys), batch_size=64, num_workers=4, timeout=30)
worker_init_fn: Callable = None
每个 worker 进程 启动后、正式取数据前会被调用一次的函数,入参是 worker_id
。可以设置 numpy
/random
等外部库的种子、全局变量初始化、初始化三方资源。
注意:
-
为每个 worker 设置不同的 np.random.seed / random.seed,避免所有进程产生相同的随机数。
-
初始化自定义的全局缓存(如打开一次数据库连接)。
-
在
spawn
启动方式(Win/macOS 常见)下,必须是可 pickle 的顶层函数 ,不能用 lambda
。 -
想要可复现,常与
generator
一起用。
示例(可复现设置):
python
import random, numpy as np, torch
def seed_worker(worker_id):
seed = torch.initial_seed() % 2**32
random.seed(seed); np.random.seed(seed)
g = torch.Generator().manual_seed(42)
loader = DataLoader(
ToySet(xs, ys), shuffle=True, num_workers=4,
worker_init_fn=seed_worker, generator=g
)
multiprocessing_context: str
指定 Python 多进程的"启动方式":"fork" | "spawn" | "forkserver"
或对应 context。当某些库在 fork
下不安全(线程、显存句柄)时,改用 spawn
更稳。
注意:
-
Windows 只支持
spawn
;Linux 默认fork
。 -
spawn
更通用但启动更慢。 -
当你的代码中有 CUDA 初始化 或 OpenMP 线程池,fork 可能导致子进程继承父进程的 GPU 上下文,产生 CUDA 错误。此时改为 'spawn' 更安全。
示例:
python
ctx = torch.multiprocessing.get_context("spawn")
loader = DataLoader(ToySet(xs, ys), num_workers=4, multiprocessing_context=ctx)
generator: torch.Generator = None
DataLoader 用到的随机数发生器。可以控制乱序/随机采样,同时为 workers 生成基础种子。
注意:
- 强可复现:
generator.manual_seed(...)
+ 在worker_init_fn
同步numpy/random
的种子。
示例:
python
g = torch.Generator().manual_seed(2024)
loader = DataLoader(ToySet(xs, ys), shuffle=True, generator=g)
prefetch_factor: int | None
每个 worker 会 预取 prefetch_factor * batch_size 条样本放进内部队列。可以让读取/解码与训练形成流水线,减少"等数据"。
注意:
-
默认:
num_workers==0
→None
(无预取);num_workers>0
→2
。 -
调大更顺滑,但会多占内存;I/O 慢时可适度调大(如 3~4)。
-
当 I/O 非常慢(比如磁盘读取大图片),可以调大到 4~8,但会消耗更多内存。
示例:
python
loader = DataLoader(ToySet(xs, ys), batch_size=64, num_workers=4, prefetch_factor=3)
persistent_workers: bool = False
True 时,跨 epoch 不销毁 worker 进程(复用已创建的 Dataset 实例)。长时间训练更高效,避免反复建/毁进程(特别是 num_workers 较大时)。
注意:
-
只能在 torch >= 1.(1.7+)使用。
-
仅
num_workers>0
生效。 -
如果你的 Dataset 每轮都需要重建或怀疑有资源泄漏,谨慎开启或定期重启。
-
如果 Dataset 在每个 epoch 会 改变内部状态(比如随机采样器自行改变种子),需要自行在 worker_init_fn 中重置,否则子进程会复用上一次的状态。
示例:
python
loader = DataLoader(ToySet(xs, ys), num_workers=4, persistent_workers=True)
pin_memory_device: str = ""
当 pin_memory=True
时,指定 pinned memory 绑定到哪个设备(比如 "cuda"
或 "cuda:0"
)。
更细粒度地控制固定内存的目标设备;多数场景留默认即可。
注意:
- 通常只需
pin_memory=True
;只有在多设备或特定优化需求时才设置它。
示例:
python
loader = DataLoader(
ToySet(xs, ys),
batch_size=64,
pin_memory=True,
pin_memory_device="cuda" # 或 "cuda:0"
)
一些警告
-
关于
spawn
启动方式 :如果使用spawn
启动多进程,worker_init_fn
不能是不可序列化对象(例如 lambda 函数)。详情见 multiprocessing 最佳实践。 -
关于
len(dataloader)
的计算:-
当使用普通数据集(map-style dataset)时,
len(dataloader)
基于所用的sampler
的长度。 -
当使用
IterableDataset
时,len(dataloader)
返回的估计值为len(dataset) / batch_size
,并根据drop_last
进行四舍五入,与是否多进程无关。 -
这是 PyTorch 能做出的最佳猜测 ,因为它假定用户编写的
dataset
能正确处理多进程加载以避免数据重复。 -
但如果数据被分片(sharding)后导致多个 worker 有不完整的尾批,这个估计值仍可能不准确,因为:
-
一个完整批次可能被拆成多个不完整的批次。
-
当设置了
drop_last=True
时,可能会丢弃超过一个批次的数据。
-
-
PyTorch 无法检测到这类情况。
-
-
关于随机性和复现 :参见文档中的 reproducibility 、dataloader-workers-random-seed 、data-loading-randomness 部分,了解随机种子相关问题。