DataLoader 是 PyTorch 中处理数据的核心组件,它提供了高效的数据加载、批处理和并行处理功能。下面是一个全面的 DataLoader 实战指南,包含代码示例和最佳实践。
基础用法:简单数据加载
import torch
from torch.utils.data import Dataset, DataLoader
# 1. 创建自定义数据集
class SimpleDataset(Dataset):
def __init__(self, size=1000):
self.data = torch.randn(size, 3, 32, 32) # 模拟图像数据
self.labels = torch.randint(0, 10, (size,)) # 0-9的标签
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# 2. 创建DataLoader
dataset = SimpleDataset(1000)
dataloader = DataLoader(
dataset,
batch_size=64, # 批大小
shuffle=True, # 是否打乱数据
num_workers=4, # 使用4个进程加载数据
pin_memory=True # 使用固定内存(加速GPU传输)
)
# 3. 使用DataLoader
for epoch in range(3):
print(f"Epoch {epoch+1}")
for batch_idx, (data, targets) in enumerate(dataloader):
# 数据自动分批:data.shape = [64, 3, 32, 32], targets.shape = [64]
if batch_idx % 10 == 0:
print(f" Batch {batch_idx}: {data.shape}, {targets.shape}")
print("Epoch completed\n")
高级功能:自定义数据集与转换
图像数据集示例
import os
from PIL import Image
from torchvision import transforms
class CustomImageDataset(Dataset):
def __init__(self, img_dir, transform=None):
self.img_dir = img_dir
self.transform = transform
self.img_names = [f for f in os.listdir(img_dir) if f.endswith('.jpg')]
# 假设文件名格式为 "label_imageid.jpg",例如 "3_001.jpg"
self.labels = [int(f.split('_')[0]) for f in self.img_names]
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_names[idx])
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
# 定义数据转换
transform = transforms.Compose([
transforms.Resize((256, 256)), # 调整大小
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(15), # 随机旋转 ±15度
transforms.ToTensor(), # 转为Tensor [0,1]
transforms.Normalize( # 标准化
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# 创建数据集和DataLoader
dataset = CustomImageDataset('/path/to/images', transform=transform)
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4,
collate_fn=lambda batch: tuple(zip(*batch)) # 自定义批处理函数
)
文本数据集示例
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
class TextDataset(Dataset):
def __init__(self, file_path, max_len=100):
self.max_len = max_len
self.tokenizer = get_tokenizer('basic_english')
# 读取文本数据和标签
self.texts = []
self.labels = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
label, text = line.split('\t')
self.labels.append(int(label))
self.texts.append(text.strip())
# 构建词汇表
self.vocab = build_vocab_from_iterator(
(self.tokenizer(text) for text in self.texts),
specials=['<unk>', '<pad>']
)
self.vocab.set_default_index(self.vocab['<unk>'])
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
tokens = self.tokenizer(text)
# 将token转换为索引
indices = [self.vocab[token] for token in tokens]
# 截断或填充序列
if len(indices) > self.max_len:
indices = indices[:self.max_len]
else:
indices = indices + [self.vocab['<pad>']] * (self.max_len - len(indices))
return torch.tensor(indices), self.labels[idx]
# 自定义批处理函数(处理变长序列)
def collate_fn(batch):
texts, labels = zip(*batch)
# 找到批次中最长序列的长度
max_len = max(len(t) for t in texts)
# 填充所有序列到相同长度
padded_texts = []
for text in texts:
padding = torch.zeros(max_len - len(text), dtype=torch.long)
padded_texts.append(torch.cat((text, padding)))
return torch.stack(padded_texts), torch.tensor(labels)
# 创建DataLoader
text_dataset = TextDataset('/path/to/text_data.txt', max_len=100)
text_dataloader = DataLoader(
text_dataset,
batch_size=32,
shuffle=True,
num_workers=2,
collate_fn=collate_fn # 使用自定义批处理函数
)
性能优化技巧
1. 使用并行加载
# 根据CPU核心数设置num_workers
import os
num_workers = min(4, os.cpu_count()) # 使用不超过4个或CPU核心数的worker
dataloader = DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=num_workers,
pin_memory=True, # 对于GPU训练非常重要
persistent_workers=True # 保持worker进程活动(PyTorch 1.7+)
)
2. 数据预取
from torch.utils.data import DataLoader, PrefetchGenerator
# 使用预取生成器(PyTorch 1.7+)
dataloader = DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=4,
prefetch_factor=2 # 每个worker预取的批次数
)
# 或者使用自定义预取
class PrefetchLoader:
def __init__(self, loader, device):
self.loader = loader
self.device = device
self.stream = torch.cuda.Stream() if device.type == 'cuda' else None
def __iter__(self):
first = True
for batch in self.loader:
if self.stream is not None:
with torch.cuda.stream(self.stream):
batch = self._preprocess(batch)
else:
batch = self._preprocess(batch)
if not first and self.stream is not None:
torch.cuda.current_stream().wait_stream(self.stream)
first = False
yield batch
def _preprocess(self, batch):
data, target = batch
return data.to(self.device, non_blocking=True), target.to(self.device, non_blocking=True)
# 使用自定义预取
device = torch.device('cuda')
prefetch_dataloader = PrefetchLoader(dataloader, device)
3. 内存映射文件处理大文件
import numpy as np
import torch
from torch.utils.data import Dataset
class MmapDataset(Dataset):
def __init__(self, file_path, shape, dtype=np.float32):
self.data = np.memmap(file_path, dtype=dtype, mode='r', shape=shape)
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
return torch.from_numpy(np.array(self.data[idx]))
分布式数据加载
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
# 初始化分布式环境
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()
# 创建分布式采样器
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
seed=42
)
# 创建分布式DataLoader
dist_dataloader = DataLoader(
dataset,
batch_size=64,
sampler=sampler,
num_workers=4,
pin_memory=True,
drop_last=True # 丢弃最后不完整的批次
)
# 在每个进程中
for epoch in range(10):
# 设置epoch确保所有进程的shuffle一致
dist_dataloader.sampler.set_epoch(epoch)
for batch in dist_dataloader:
# 处理批次数据
pass
数据增强策略
图像增强
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
# 使用torchvision
torchvision_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 使用Albumentations(更丰富的增强)
albumentations_transform = A.Compose([
A.RandomResizedCrop(224, 224),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.2),
A.Rotate(limit=30),
A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.9),
A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
# 在数据集类中使用
def __getitem__(self, idx):
img_path = self.img_paths[idx]
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform:
augmented = self.transform(image=image)
image = augmented['image']
return image, self.labels[idx]
文本增强
import nlpaug.augmenter.word as naw
# 创建文本增强器
augmenter = naw.ContextualWordEmbsAug(
model_path='bert-base-uncased',
action="substitute", # 替换、插入等
aug_p=0.1 # 增强比例
)
# 在数据集中使用
def __getitem__(self, idx):
text = self.texts[idx]
if self.augment and random.random() < 0.5: # 50%概率增强
text = augmenter.augment(text)
# 后续处理...
数据可视化与调试
import matplotlib.pyplot as plt
import numpy as np
def show_batch(dataloader, n=4):
"""显示一批图像及其标签"""
dataiter = iter(dataloader)
images, labels = next(dataiter)
fig, axes = plt.subplots(1, n, figsize=(15, 4))
for i in range(n):
img = images[i].permute(1, 2, 0).numpy() # CHW -> HWC
img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) # 反归一化
img = np.clip(img, 0, 1)
axes[i].imshow(img)
axes[i].set_title(f"Label: {labels[i].item()}")
axes[i].axis('off')
plt.show()
# 使用
show_batch(dataloader, n=8)
常见问题解决方案
1. 内存不足
# 解决方案1:使用更小的批大小
dataloader = DataLoader(dataset, batch_size=16)
# 解决方案2:使用内存映射文件
# 如前文的MmapDataset示例
# 解决方案3:使用IterableDataset
from torch.utils.data import IterableDataset
class LargeIterableDataset(IterableDataset):
def __init__(self, file_path, chunk_size=1000):
self.file_path = file_path
self.chunk_size = chunk_size
def __iter__(self):
with open(self.file_path, 'r') as f:
chunk = []
for line in f:
chunk.append(process_line(line)) # 自定义处理函数
if len(chunk) == self.chunk_size:
yield from chunk
chunk = []
if chunk:
yield from chunk
# 使用
dataset = LargeIterableDataset('large_file.txt')
dataloader = DataLoader(dataset, batch_size=64)
2. Windows多进程问题
# 解决方案:将主代码放入if __name__ == '__main__'块中
if __name__ == '__main__':
# 在这里创建DataLoader
dataloader = DataLoader(dataset, num_workers=4)
# 训练代码...
3. 数据加载成为瓶颈
# 解决方案1:增加num_workers
dataloader = DataLoader(dataset, num_workers=os.cpu_count())
# 解决方案2:使用预取
# 如前文的PrefetchLoader示例
# 解决方案3:使用更快的存储(如SSD代替HDD)
# 解决方案4:使用更高效的数据格式(如HDF5、LMDB)
最佳实践总结
-
批大小选择:根据GPU内存选择最大可用批大小
-
Worker数量:设置为CPU核心数的1-2倍
-
固定内存 :GPU训练时始终设置
pin_memory=True
-
数据增强:在CPU上执行,避免占用GPU资源
-
分布式训练 :使用
DistributedSampler
确保数据正确分区 -
内存优化:对大文件使用内存映射或IterableDataset
-
预取策略 :使用内置
prefetch_factor
或自定义预取 -
数据验证:定期可视化批次数据确保数据增强有效
-
资源监控:监控CPU/GPU利用率,识别瓶颈
-
格式优化:使用高效数据格式(如TFRecord、LMDB)加速IO
通过合理配置DataLoader,你可以显著提高模型训练效率,充分利用硬件资源,加速模型迭代过程。