PyTorch数据集与数据集加载

PyTorch中的Dataset与DataLoader详解

1. Dataset基础

Dataset是PyTorch中表示数据集的抽象类,我们需要继承它并实现两个关键方法:

python 复制代码
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        """初始化方法,加载数据"""
        self.data = data
        self.labels = labels
    
    def __len__(self):
        """返回数据集的大小"""
        return len(self.data)
    
    def __getitem__(self, idx):
        """根据索引获取单个样本"""
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

使用示例

python 复制代码
# 假设我们有一些简单的数据
data = [[1, 2], [3, 4], [5, 6], [7, 8]]
labels = [0, 1, 0, 1]

# 创建数据集实例
dataset = CustomDataset(data, labels)

# 测试数据集
print(f"数据集大小: {len(dataset)}")  # 输出: 4
print(dataset[0])  # 输出: ([1, 2], 0)

2. DataLoader功能

DataLoader负责从Dataset中加载数据,并提供批处理、打乱顺序和多线程加载等功能。

python 复制代码
from torch.utils.data import DataLoader

# 创建DataLoader
dataloader = DataLoader(
    dataset,          # 数据集对象
    batch_size=2,     # 每批数据大小
    shuffle=True,     # 是否打乱数据
    num_workers=2     # 使用多少子进程加载数据
)

# 遍历数据
for batch_idx, (batch_data, batch_labels) in enumerate(dataloader):
    print(f"批次 {batch_idx}:")
    print("数据:", batch_data)
    print("标签:", batch_labels)

3. 实际应用示例

图像数据集示例

python 复制代码
import os
from PIL import Image

class ImageDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_names = os.listdir(img_dir)
    
    def __len__(self):
        return len(self.img_names)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        # 假设文件名格式为 "label_image.jpg"
        label = int(self.img_names[idx].split('_')[0])
        
        return image, label

使用数据增强

python 复制代码
from torchvision import transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

# 创建数据集
dataset = ImageDataset("path/to/images", transform=transform)

# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

4. 高级功能

自定义批处理

python 复制代码
from torch.utils.data.dataloader import default_collate

def custom_collate(batch):
    # 过滤掉None样本
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None
    return default_collate(batch)

dataloader = DataLoader(dataset, batch_size=4, collate_fn=custom_collate)

使用Subset划分数据集

python 复制代码
from torch.utils.data import random_split

# 假设我们有一个大的数据集
full_dataset = CustomDataset(data, labels)

# 划分训练集和测试集
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

# 创建对应的DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

5. 性能优化技巧

  1. num_workers设置:根据CPU核心数设置合理的num_workers值(通常2-4)
  2. pin_memory:在GPU训练时设置pin_memory=True可以加速数据传输
  3. 预取数据:使用prefetch_factor参数(PyTorch 1.7+)
python 复制代码
dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=2
)

6. 常见问题解决

  1. 内存不足:减小batch_size或使用IterableDataset
  2. 数据加载慢:确保数据存储在SSD上,使用更快的文件格式(如HDF5)
  3. 数据不平衡:使用WeightedRandomSampler
python 复制代码
from torch.utils.data import WeightedRandomSampler

# 假设我们有不平衡的数据集
weights = [1.0 if label == 0 else 0.1 for _, label in dataset]
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

balanced_loader = DataLoader(dataset, batch_size=32, sampler=sampler)

通过合理使用Dataset和DataLoader,可以高效地管理和加载大规模数据集,为深度学习模型训练提供稳定、高效的数据管道。

相关推荐
Fuliy9611 分钟前
【自然语言处理】——基于与训练模型的方法【复习篇1】
人工智能·自然语言处理
项目管理打工人12 分钟前
高端装备制造企业如何选择适配的项目管理系统提升项目执行效率?附选型案例
大数据·人工智能·驱动开发·科技·硬件工程·团队开发·制造
江苏泊苏系统集成有限公司14 分钟前
集成电路制造设备防震基座选型指南:为稳定护航-江苏泊苏系统集成有限公司
人工智能·深度学习·目标检测·机器学习·制造·材料工程·精益工程
吹风看太阳15 分钟前
机器学习03-色彩空间:RGB、HSV、HLS
人工智能·机器学习
lichuangcsdn44 分钟前
利用python工具you-get下载网页的视频文件
python
Ronin-Lotus1 小时前
深度学习篇---Pytorch框架下OC-SORT实现
人工智能·pytorch·python·深度学习·oc-sort
雾迟sec1 小时前
TensorFlow 的基本概念和使用场景
人工智能·python·tensorflow
Blossom.1181 小时前
人工智能在智能健康监测中的创新应用与未来趋势
java·人工智能·深度学习·机器学习·语音识别
烛阴1 小时前
从零打造属于你的Python容器类型:全流程图解+实战案例
前端·python
GIS小天1 小时前
AI预测3D新模型百十个定位预测+胆码预测+去和尾2025年5月31日第94弹
人工智能·算法·机器学习·彩票