一、DataLoader核心概念
什么是DataLoader?
Dataloader是PyTorch中用于批量加载数据的工具,它将Dataset包装成一个可迭代对象,支持:
- 批量处理(batching)
- 数据打乱(shuffling)
- 多进程加载(multiprocessing)
二、基本使用参数
python
from torch.utils.data import DataLoader, Dataset
dataloader = DataLoader(
dataset=dataset, # 数据集对象
batch_size=32, # 批量大小
shuffle=True, # 是否打乱
num_workers=4, # 子进程数,0:主进程加载,调试用,>0:多进程并行加载
pin_memory=True # 锁页内存,加速CPU到GPU的数据传输
drop_last=False # 是否丢弃最后一个不完整的batch
)
for batch in dataloader:
inputs, labels = batch
# 训练代码...
三、工作流程
1、数据加载流程:
python
Dataset → Sampler → BatchSampler → DataLoader → 模型
- 初始化阶段:创建worker进程;
- 索引生成:Sampler生成样本索引;
- 批量组织:BatchSampler将索引分组;
- 数据加载:Worker进程读取数据;
- 数据合并:collate_fn整理batch数据;
- 返回结果:将batch数据返回给主进程
2、Sampler和BatchSampler
(1)Sampler-控制样本顺序
python
from torch.utils.data import RandomSampler, SequentialSampler
# 随机采样(等价于shuffle=True)
sampler = RandomSampler(dataset)
# 顺序采样(等价于shuffle=False)
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
(2) 自定义Sampler
python
class ImbalancedSampler(Sampler):
def __init__(self, dataset, minority_class_indices):
self.dataset = dataset
self.minority_indices = minority_class_indices
self.majority_indices = [i for i in range(len(dataset))
if i not in minority_class_indices]
def __iter__(self):
# 过采样少数类,保持类别平衡
indices = []
# 每个少数类样本采样2次
for idx in self.minority_indices:
indices.extend([idx] * 2)
indices.extend(self.majority_indices)
random.shuffle(indices)
return iter(indices)
(3)BatchSampler-控制batch组织方式
python
from torch.utils.data import BatchSampler
sampler = RandomSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size=32, drop_last=False)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)
3、collate_fn函数
作用:将多个样本整理成一个batch
python
def custom_collate_fn(batch):
"""
batch: list of (image, label, metadata) tuples
返回: 整理后的batch数据
"""
images, labels, metadatas = zip(*batch)
# 处理图像
images = torch.stack(images, dim=0)
# 处理标签
labels = torch.tensor(labels)
# 处理变长数据
metadatas = list(metadatas) # 保持列表形式
return images, labels, metadatas
dataloader = DataLoader(dataset, collate_fn=custom_collate_fn, batch_size=32)
处理变长序列的collate_fn
python
def pad_collate_fn(batch):
# batch: [(sequence, label), ...]
sequences, labels = zip(*batch)
# 计算最大长度
max_len = max(len(seq) for seq in sequences)
# 填充序列
padded_sequences = []
for seq in sequences:
pad_size = max_len - len(seq)
padded_seq = torch.cat([seq, torch.zeros(pad_size)])
padded_sequences.append(padded_seq)
return torch.stack(padded_sequences), torch.tensor(labels)
4、多进程数据加载
(1)Worker进程工程流程
python
# 主进程
dataloader = DataLoader(dataset, num_workers=4)
# 内部工作:
# 1. 主进程创建4个worker进程
# 2. 每个worker有自己的数据集副本
# 3. Worker预加载多个batch到队列
# 4. 主进程从队列获取batch
(2)常见多进程问题及解决
问题1:内存泄漏
python
# 错误:每个epoch都重新创建DataLoader
for epoch in range(epochs):
dataloader = DataLoader(dataset, num_workers=4) # 内存泄漏!
# 正确:只创建一次
dataloader = DataLoader(dataset, num_workers=4)
for epoch in range(epochs):
for batch in dataloader:
# ...
问题2:CUDA在多进程中的使用
python
# 错误:在worker进程中初始化CUDA
def bad_worker_init(worker_id):
torch.cuda.set_device(0) # 错误!
# 正确:只在主进程使用CUDA
# Worker进程只负责数据加载,不涉及GPU操作
5、性能优化技巧
(1)选择合适的num_workers
python
# 测试最佳worker数量
import time
def find_optimal_workers(dataset, max_workers=8):
for n in range(max_workers + 1):
loader = DataLoader(dataset, num_workers=n, batch_size=32)
start = time.time()
for batch in loader:
pass
duration = time.time() - start
print(f"Workers: {n}, Time: {duration:.2f}s")
(2)使用pin_memory加速
python
dataloader = DataLoader(
dataset,
num_workers=4,
pin_memory=True, # GPU训练时开启
batch_size=32
)
(3)预加载策略
python
# 使用prefetch_factor(PyTorch 1.7+)
dataloader = DataLoader(
dataset,
num_workers=4,
prefetch_factor=2, # 每个worker预加载2个batch
batch_size=32
)
6、其他问题
(1)DataLoader中的num_workers设置多少合适?
通常设置为CPU核心数的2-4倍,但需要实际测试,设置过多会增加进程切换开销。
(2)shuffle=True是如何工作的?
在每个epoch开始时,RandomSampler会生成新的随机索引序列,实现数据打乱。
(3)多进程数据加载时,如何避免重复数据?
每个worker进程通过sampler获取不同的索引范围,使用不同的随机种子确保数据不重复。
(4)pin_memory为什么能加速?
锁页内存允许DMA直接传输到GPU,避免了额外的内存复制。
(5)如何处理变长数据?
使用自定义的collate_fn进行填充(padding)或打包(packing)
7、实际代码示例
python
class AdvancedDataLoader:
def __init__(self, dataset, config):
self.dataset = dataset
self.config = config
def create_dataloader(self, is_training=True):
# 选择sampler
if is_training:
sampler = RandomSampler(self.dataset)
else:
sampler = SequentialSampler(self.dataset)
# 创建DataLoader
dataloader = DataLoader(
dataset=self.dataset,
batch_size=self.config.batch_size,
sampler=sampler,
num_workers=self.config.num_workers,
pin_memory=self.config.use_gpu,
drop_last=is_training, # 训练时丢弃不完整batch
collate_fn=self.pad_collate_fn if self.config.pad_sequences else None,
persistent_workers=self.config.persistent_workers
)
return dataloader