PyTorch中的Dataset与DataLoader详解
1. Dataset基础
Dataset是PyTorch中表示数据集的抽象类,我们需要继承它并实现两个关键方法:
python
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
"""初始化方法,加载数据"""
self.data = data
self.labels = labels
def __len__(self):
"""返回数据集的大小"""
return len(self.data)
def __getitem__(self, idx):
"""根据索引获取单个样本"""
sample = self.data[idx]
label = self.labels[idx]
return sample, label
使用示例
python
# 假设我们有一些简单的数据
data = [[1, 2], [3, 4], [5, 6], [7, 8]]
labels = [0, 1, 0, 1]
# 创建数据集实例
dataset = CustomDataset(data, labels)
# 测试数据集
print(f"数据集大小: {len(dataset)}") # 输出: 4
print(dataset[0]) # 输出: ([1, 2], 0)
2. DataLoader功能
DataLoader负责从Dataset中加载数据,并提供批处理、打乱顺序和多线程加载等功能。
python
from torch.utils.data import DataLoader
# 创建DataLoader
dataloader = DataLoader(
dataset, # 数据集对象
batch_size=2, # 每批数据大小
shuffle=True, # 是否打乱数据
num_workers=2 # 使用多少子进程加载数据
)
# 遍历数据
for batch_idx, (batch_data, batch_labels) in enumerate(dataloader):
print(f"批次 {batch_idx}:")
print("数据:", batch_data)
print("标签:", batch_labels)
3. 实际应用示例
图像数据集示例
python
import os
from PIL import Image
class ImageDataset(Dataset):
def __init__(self, img_dir, transform=None):
self.img_dir = img_dir
self.transform = transform
self.img_names = os.listdir(img_dir)
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_names[idx])
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
# 假设文件名格式为 "label_image.jpg"
label = int(self.img_names[idx].split('_')[0])
return image, label
使用数据增强
python
from torchvision import transforms
# 定义数据转换
transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 创建数据集
dataset = ImageDataset("path/to/images", transform=transform)
# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
4. 高级功能
自定义批处理
python
from torch.utils.data.dataloader import default_collate
def custom_collate(batch):
# 过滤掉None样本
batch = [item for item in batch if item is not None]
if len(batch) == 0:
return None
return default_collate(batch)
dataloader = DataLoader(dataset, batch_size=4, collate_fn=custom_collate)
使用Subset划分数据集
python
from torch.utils.data import random_split
# 假设我们有一个大的数据集
full_dataset = CustomDataset(data, labels)
# 划分训练集和测试集
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
# 创建对应的DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
5. 性能优化技巧
- num_workers设置:根据CPU核心数设置合理的num_workers值(通常2-4)
- pin_memory:在GPU训练时设置pin_memory=True可以加速数据传输
- 预取数据:使用prefetch_factor参数(PyTorch 1.7+)
python
dataloader = DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True,
prefetch_factor=2
)
6. 常见问题解决
- 内存不足:减小batch_size或使用IterableDataset
- 数据加载慢:确保数据存储在SSD上,使用更快的文件格式(如HDF5)
- 数据不平衡:使用WeightedRandomSampler
python
from torch.utils.data import WeightedRandomSampler
# 假设我们有不平衡的数据集
weights = [1.0 if label == 0 else 0.1 for _, label in dataset]
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
balanced_loader = DataLoader(dataset, batch_size=32, sampler=sampler)
通过合理使用Dataset和DataLoader,可以高效地管理和加载大规模数据集,为深度学习模型训练提供稳定、高效的数据管道。