在 PyTorch 中,数据加载通常使用 torch.utils.data.Dataset
和 torch.utils.data.DataLoader
进行高效管理和批量处理。我们来详细了解如何加载数据,包括 Dataset、DataLoader、数据预处理和自定义数据集。
1. PyTorch 的 Dataset 和 DataLoader 🚀
(1) Dataset
PyTorch 提供了 torch.utils.data.Dataset
作为所有数据集的基类,所有数据集需要继承这个类,并实现:
__len__()
→ 返回数据集的大小__getitem__(index)
→ 通过索引获取一个样本
(2) DataLoader
DataLoader
用于批量加载数据,同时支持 多线程(num_workers) 加速数据读取。
2. 加载内置数据集(MNIST 数据集)
PyTorch 提供了一些 内置数据集 (如 MNIST、CIFAR-10、ImageNet),可以直接使用 torchvision.datasets
进行下载和加载。
加载 MNIST 数据集
python
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 1. 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量
transforms.Normalize((0.5,), (0.5,)) # 归一化:均值0.5,标准差0.5
])
# 2. 加载数据集
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
# 3. 创建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 4. 读取一个 batch 数据
data_iter = iter(train_loader)
images, labels = next(data_iter)
print(f"Batch shape: {images.shape}") # 输出 (64, 1, 28, 28)
print(f"Labels: {labels[:10]}") # 前 10 个标签
3. 自定义数据集(自定义 Dataset)
有时候,我们的数据并不在 PyTorch 提供的数据集里,比如 自定义图片数据、CSV 数据或文本数据 ,这时就需要自己继承 torch.utils.data.Dataset
来创建自己的数据集。
自定义 CSV 数据集
python
import torch
from torch.utils.data import Dataset
import pandas as pd
class MyDataset(Dataset):
def __init__(self, csv_file):
self.data = pd.read_csv(csv_file) # 读取 CSV 数据
def __len__(self):
return len(self.data) # 数据集大小
def __getitem__(self, idx):
row = self.data.iloc[idx] # 获取一行数据
features = torch.tensor(row[:-1].values, dtype=torch.float32) # 特征数据
label = torch.tensor(row[-1], dtype=torch.long) # 标签数据
return features, label
# 创建数据集
dataset = MyDataset("data.csv")
# 创建 DataLoader
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# 读取数据
for batch in data_loader:
features, labels = batch
print(features.shape, labels.shape)
break
4. 加载图片数据(ImageFolder) 🖼️
如果你的数据存放在文件夹里,例如:
data/
│── train/
│ ├── cats/ (类别 0)
│ │ ├── cat1.jpg
│ │ ├── cat2.jpg
│ │ └── ...
│ ├── dogs/ (类别 1)
│ │ ├── dog1.jpg
│ │ ├── dog2.jpg
│ │ └── ...
│── test/
可以使用 datasets.ImageFolder
自动加载图片数据集:
python
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
# 图片预处理
transform = transforms.Compose([
transforms.Resize((128, 128)), # 调整大小
transforms.ToTensor(), # 转换为张量
transforms.Normalize((0.5,), (0.5,)) # 归一化
])
# 加载训练数据
train_dataset = ImageFolder(root="data/train", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# 获取类别映射
print(train_dataset.class_to_idx) # 输出:{'cats': 0, 'dogs': 1}
# 读取一个 batch
images, labels = next(iter(train_loader))
print(images.shape) # 输出: (32, 3, 128, 128) (32 张 128x128 RGB 图片)
5. 训练数据集迭代
在训练过程中,我们会使用 DataLoader
进行批量数据加载,示例代码如下:
python
for epoch in range(5): # 训练 5 轮
for images, labels in train_loader:
images, labels = images.to("cuda"), labels.to("cuda")
# 训练代码(前向传播、损失计算、反向传播等)
print(f"Epoch: {epoch+1}, Batch Size: {images.shape[0]}")
6. 结语
PyTorch 的 Dataset
和 DataLoader
让数据加载变得 高效、灵活、易扩展 ,无论是 标准数据集 还是 自定义数据集,都可以轻松处理。🔥 🚀
✅ 总结:
torchvision.datasets
直接加载常见数据集Dataset
继承自定义数据集(CSV、图片等)ImageFolder
适用于结构化图片数据集DataLoader
高效批量处理数据
希望这个教程对你有帮助!😃