PyTorch处理数据--Dataset和DataLoader

在 PyTorch 中,DatasetDataLoader 是处理数据的核心工具。它们的作用是将数据高效地加载到模型中,支持批量处理、多线程加速和数据增强等功能。

一、Dataset:数据集的抽象

Dataset 是一个抽象类,用于表示数据集的接口。你需要继承 torch.utils.data.Dataset 并实现以下两个方法:

  • __len__(): 返回数据集的总样本数。
  • __getitem__(idx): 根据索引 idx 返回一个样本(数据和标签)。
示例:自定义 Dataset
复制代码
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform  # 数据预处理/增强函数

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = {
            "data": self.data[idx], 
            "label": self.labels[idx]
        }
        if self.transform:
            sample = self.transform(sample)
        return sample
使用场景
  • 加载图像、文本、表格数据等。
  • 支持数据预处理(如归一化、裁剪)和数据增强(如随机翻转)。

二、 DataLoader:高效加载数据

DataLoader 负责将 Dataset 包装成一个可迭代对象,支持批量加载、多线程加速和数据打乱。

基本用法
复制代码
from torch.utils.data import DataLoader

# 假设 dataset 是你的 CustomDataset 实例
data_loader = DataLoader(
    dataset,
    batch_size=32,       # 批量大小
    shuffle=True,        # 是否打乱数据(训练时建议开启)
    num_workers=4,       # 多线程加载数据的进程数
    drop_last=False      # 是否丢弃最后不足一个 batch 的数据
)

遍历 DataLoader

复制代码
for batch in data_loader:
    data = batch["data"]    # 形状:[batch_size, ...]
    labels = batch["label"] # 形状:[batch_size]
    # 将数据送入模型训练...

三**、pytorch内置数据集**

PyTorch 提供了一系列内置数据集,这些数据集可以直接用于训练模型。这些数据集涵盖了多种领域,如图像、文本、音频等。以下是一些常用的PyTorch内置数据集:

图像数据集
  1. MNIST: 手写数字数据集,包含0到9的手写数字图片。

    复制代码
    from torchvision import datasets
    mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  2. CIFAR10/CIFAR100: 包含彩色图片的数据集,CIFAR10有60000张32x32的彩色图片,分为10个类别;CIFAR100类似但有100个类别。

    复制代码
    cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
  3. ImageNet: 包含超过1400万张图片的非常庞大的数据集,常用于图像识别和分类任务。

    复制代码
    import torchvision.datasets as datasets
    imagenet_train = datasets.ImageNet(root='./data', split='train', download=True)
  4. STL10: 一个用于计算机视觉研究的小型图像数据集,包含96x96的彩色图片。

    复制代码
    stl10_train = datasets.STL10(root='./data', split='train', download=True)
  5. SVHN: 包含数字图片的数据集,与MNIST类似但包含更多实际场景的图片。

    复制代码
    svhn_train = datasets.SVHN(root='./data', split='train', download=True, transform=transform)
文本数据集

1.Text8: 一个用于自然语言处理的小型文本数据集。

复制代码
from torchtext.datasets import Text8
text8_train = Text8(split=('train',))
  1. AG_NEWS: 包含新闻文章的文本数据集,分为4个类别。

    from torchtext.datasets import AG_NEWS
    ag_news_train = AG_NEWS(split=('train',))

音频数据集
  1. Speech Commands: 一个用于语音识别的数据集,包含约65,000个单词发音的音频文件。

    from torchaudio.datasets import SPEECHCOMMANDS
    speech_commands = SPEECHCOMMANDS(root="./data", download=True)

使用方法

要使用这些数据集,首先需要导入torchvision(对于图像数据集)、torchtext(对于文本数据集)或torchaudio(对于音频数据集),然后使用其提供的类来加载数据。通常还包括一些数据预处理步骤,例如转换(transforms)。

复制代码
import torchvision.transforms as transforms
from torchvision import datasets
 
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

四、完整代码示例

步骤 1:创建数据集
复制代码
import numpy as np
from torch.utils.data import Dataset, DataLoader

# 生成示例数据(假设是 10 个样本,每个样本是长度为 5 的向量)
data = np.random.randn(10, 5)
labels = np.random.randint(0, 2, size=(10,))  # 二分类标签

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {
            "data": self.data[idx],
            "label": self.labels[idx]
        }

dataset = MyDataset(data, labels)
步骤 2:创建 DataLoader
复制代码
data_loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    num_workers=2
)

步骤 3:使用 DataLoader 训练模型

复制代码
model = ...  # 你的模型
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()

for epoch in range(10):
    for batch in data_loader:
        x = batch["data"]
        y = batch["label"]
        
        # 前向传播
        outputs = model(x)
        loss = loss_fn(outputs, y)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

五、常见问题解决

‌**(1)数据格式不匹配**‌
  • 问题 ‌:DataLoader 返回的数据形状与模型输入不匹配。
  • 解决 ‌:检查 Dataset__getitem__ 返回的数据类型和形状,确保与模型输入一致。
‌**(2)多线程加载卡顿**‌
  • 问题 ‌:设置 num_workers>0 时程序卡死或报错。
  • 解决 ‌:在 Windows 系统中,多线程可能需要将代码放在 if __name__ == "__main__": 块中运行。
‌**(3)数据增强**‌
  • 使用 torchvision.transforms 中的工具(如 RandomCropRandomHorizontalFlip)对图像数据进行增强:

    复制代码
    from torchvision import transforms
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ])
‌**(4)内存不足**‌
  • 对于大型数据集,使用 torch.utils.data.DataLoaderpersistent_workers=True(PyTorch 1.7+)或优化数据加载逻辑。

六、高级功能

  • 分布式训练 ‌:使用 torch.utils.data.distributed.DistributedSampler 配合多 GPU。
  • 预加载数据 ‌:使用 torch.utils.data.TensorDataset 直接加载 Tensor 数据。
  • 自定义采样器 ‌:通过 sampler 参数控制数据采样顺序(如平衡类别采样)。
相关推荐
林泽毅1 小时前
SwanLab硬件监控:英伟达、昇腾、寒武纪
python·深度学习·昇腾·英伟达·swanlab·寒武纪·训练实战
Watermelo6171 小时前
Manus使用的MCP协议是什么?人工智能知识分享的“万能插头”
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理·数据挖掘
这就是编程2 小时前
自回归模型的新浪潮?GPT-4o图像生成技术解析与未来展望
人工智能·算法·机器学习·数据挖掘·回归
自由鬼2 小时前
GPT Workspace体验
人工智能·gpt
星际码仔7 小时前
AutoGLM沉思,仍然没有摆脱DeepResearch产品的通病
人工智能·ai编程·chatglm (智谱)
喝拿铁写前端7 小时前
前端与 AI 结合的 10 个可能路径图谱
前端·人工智能
城电科技8 小时前
城电科技|零碳园区光伏太阳花绽放零碳绿色未来
人工智能·科技·能源
HyperAI超神经8 小时前
Stable Virtual Camera 重新定义3D内容生成,解锁图像新维度;BatteryLife助力更精准预测电池寿命
图像处理·人工智能·3d·数学推理·视频生成·对话语音生成·蛋白质突变
Chaos_Wang_8 小时前
NLP高频面试题(二十三)对抗训练的发展脉络,原理,演化路径
人工智能·自然语言处理
Yeats_Liao9 小时前
华为开源自研AI框架昇思MindSpore应用案例:基于MindSpore框架实现PWCNet光流估计
人工智能·华为