在 PyTorch 中,Dataset
和 DataLoader
是处理数据的核心工具。它们的作用是将数据高效地加载到模型中,支持批量处理、多线程加速和数据增强等功能。
一、Dataset:数据集的抽象
Dataset
是一个抽象类,用于表示数据集的接口。你需要继承 torch.utils.data.Dataset
并实现以下两个方法:
__len__()
: 返回数据集的总样本数。__getitem__(idx)
: 根据索引idx
返回一个样本(数据和标签)。
示例:自定义 Dataset
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels, transform=None):
self.data = data
self.labels = labels
self.transform = transform # 数据预处理/增强函数
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = {
"data": self.data[idx],
"label": self.labels[idx]
}
if self.transform:
sample = self.transform(sample)
return sample
使用场景
- 加载图像、文本、表格数据等。
- 支持数据预处理(如归一化、裁剪)和数据增强(如随机翻转)。
二、 DataLoader:高效加载数据
DataLoader
负责将 Dataset
包装成一个可迭代对象,支持批量加载、多线程加速和数据打乱。
基本用法
from torch.utils.data import DataLoader
# 假设 dataset 是你的 CustomDataset 实例
data_loader = DataLoader(
dataset,
batch_size=32, # 批量大小
shuffle=True, # 是否打乱数据(训练时建议开启)
num_workers=4, # 多线程加载数据的进程数
drop_last=False # 是否丢弃最后不足一个 batch 的数据
)
遍历 DataLoader
for batch in data_loader:
data = batch["data"] # 形状:[batch_size, ...]
labels = batch["label"] # 形状:[batch_size]
# 将数据送入模型训练...
三**、pytorch内置数据集**
PyTorch 提供了一系列内置数据集,这些数据集可以直接用于训练模型。这些数据集涵盖了多种领域,如图像、文本、音频等。以下是一些常用的PyTorch内置数据集:
图像数据集
-
MNIST: 手写数字数据集,包含0到9的手写数字图片。
from torchvision import datasets mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
-
CIFAR10/CIFAR100: 包含彩色图片的数据集,CIFAR10有60000张32x32的彩色图片,分为10个类别;CIFAR100类似但有100个类别。
cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
-
ImageNet: 包含超过1400万张图片的非常庞大的数据集,常用于图像识别和分类任务。
import torchvision.datasets as datasets imagenet_train = datasets.ImageNet(root='./data', split='train', download=True)
-
STL10: 一个用于计算机视觉研究的小型图像数据集,包含96x96的彩色图片。
stl10_train = datasets.STL10(root='./data', split='train', download=True)
-
SVHN: 包含数字图片的数据集,与MNIST类似但包含更多实际场景的图片。
svhn_train = datasets.SVHN(root='./data', split='train', download=True, transform=transform)
文本数据集
1.Text8: 一个用于自然语言处理的小型文本数据集。
from torchtext.datasets import Text8
text8_train = Text8(split=('train',))
-
AG_NEWS: 包含新闻文章的文本数据集,分为4个类别。
from torchtext.datasets import AG_NEWS
ag_news_train = AG_NEWS(split=('train',))
音频数据集
-
Speech Commands: 一个用于语音识别的数据集,包含约65,000个单词发音的音频文件。
from torchaudio.datasets import SPEECHCOMMANDS
speech_commands = SPEECHCOMMANDS(root="./data", download=True)
使用方法
要使用这些数据集,首先需要导入torchvision
(对于图像数据集)、torchtext
(对于文本数据集)或torchaudio
(对于音频数据集),然后使用其提供的类来加载数据。通常还包括一些数据预处理步骤,例如转换(transforms)。
import torchvision.transforms as transforms
from torchvision import datasets
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
四、完整代码示例
步骤 1:创建数据集
import numpy as np
from torch.utils.data import Dataset, DataLoader
# 生成示例数据(假设是 10 个样本,每个样本是长度为 5 的向量)
data = np.random.randn(10, 5)
labels = np.random.randint(0, 2, size=(10,)) # 二分类标签
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = torch.tensor(data, dtype=torch.float32)
self.labels = torch.tensor(labels, dtype=torch.long)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return {
"data": self.data[idx],
"label": self.labels[idx]
}
dataset = MyDataset(data, labels)
步骤 2:创建 DataLoader
data_loader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
num_workers=2
)
步骤 3:使用 DataLoader 训练模型
model = ... # 你的模型
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(10):
for batch in data_loader:
x = batch["data"]
y = batch["label"]
# 前向传播
outputs = model(x)
loss = loss_fn(outputs, y)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
五、常见问题解决
**(1)数据格式不匹配**
- 问题 :
DataLoader
返回的数据形状与模型输入不匹配。 - 解决 :检查
Dataset
的__getitem__
返回的数据类型和形状,确保与模型输入一致。
**(2)多线程加载卡顿**
- 问题 :设置
num_workers>0
时程序卡死或报错。 - 解决 :在 Windows 系统中,多线程可能需要将代码放在
if __name__ == "__main__":
块中运行。
**(3)数据增强**
-
使用
torchvision.transforms
中的工具(如RandomCrop
、RandomHorizontalFlip
)对图像数据进行增强:from torchvision import transforms transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]), ])
**(4)内存不足**
- 对于大型数据集,使用
torch.utils.data.DataLoader
的persistent_workers=True
(PyTorch 1.7+)或优化数据加载逻辑。
六、高级功能
- 分布式训练 :使用
torch.utils.data.distributed.DistributedSampler
配合多 GPU。 - 预加载数据 :使用
torch.utils.data.TensorDataset
直接加载 Tensor 数据。 - 自定义采样器 :通过
sampler
参数控制数据采样顺序(如平衡类别采样)。