PyTorch 的 Dataset 和 DataLoader

在 PyTorch 中,数据加载通常使用 torch.utils.data.Datasettorch.utils.data.DataLoader 进行高效管理和批量处理。我们来详细了解如何加载数据,包括 Dataset、DataLoader、数据预处理和自定义数据集


1. PyTorch 的 Dataset 和 DataLoader 🚀

(1) Dataset

PyTorch 提供了 torch.utils.data.Dataset 作为所有数据集的基类,所有数据集需要继承这个类,并实现:

  • __len__() → 返回数据集的大小
  • __getitem__(index) → 通过索引获取一个样本

(2) DataLoader

DataLoader 用于批量加载数据,同时支持 多线程(num_workers) 加速数据读取。


2. 加载内置数据集(MNIST 数据集)

PyTorch 提供了一些 内置数据集 (如 MNIST、CIFAR-10、ImageNet),可以直接使用 torchvision.datasets 进行下载和加载。

加载 MNIST 数据集

python 复制代码
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 1. 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 归一化:均值0.5,标准差0.5
])

# 2. 加载数据集
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

# 3. 创建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 4. 读取一个 batch 数据
data_iter = iter(train_loader)
images, labels = next(data_iter)

print(f"Batch shape: {images.shape}")  # 输出 (64, 1, 28, 28)
print(f"Labels: {labels[:10]}")  # 前 10 个标签

3. 自定义数据集(自定义 Dataset)

有时候,我们的数据并不在 PyTorch 提供的数据集里,比如 自定义图片数据、CSV 数据或文本数据 ,这时就需要自己继承 torch.utils.data.Dataset 来创建自己的数据集。

自定义 CSV 数据集

python 复制代码
import torch
from torch.utils.data import Dataset
import pandas as pd

class MyDataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)  # 读取 CSV 数据
       
    def __len__(self):
        return len(self.data)  # 数据集大小

    def __getitem__(self, idx):
        row = self.data.iloc[idx]  # 获取一行数据
        features = torch.tensor(row[:-1].values, dtype=torch.float32)  # 特征数据
        label = torch.tensor(row[-1], dtype=torch.long)  # 标签数据
        return features, label

# 创建数据集
dataset = MyDataset("data.csv")

# 创建 DataLoader
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 读取数据
for batch in data_loader:
    features, labels = batch
    print(features.shape, labels.shape)
    break

4. 加载图片数据(ImageFolder) 🖼️

如果你的数据存放在文件夹里,例如:

复制代码
data/
│── train/
│   ├── cats/ (类别 0)
│   │   ├── cat1.jpg
│   │   ├── cat2.jpg
│   │   └── ...
│   ├── dogs/ (类别 1)
│   │   ├── dog1.jpg
│   │   ├── dog2.jpg
│   │   └── ...
│── test/

可以使用 datasets.ImageFolder 自动加载图片数据集:

python 复制代码
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader

# 图片预处理
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # 调整大小
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 归一化
])

# 加载训练数据
train_dataset = ImageFolder(root="data/train", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 获取类别映射
print(train_dataset.class_to_idx)  # 输出:{'cats': 0, 'dogs': 1}

# 读取一个 batch
images, labels = next(iter(train_loader))
print(images.shape)  # 输出: (32, 3, 128, 128) (32 张 128x128 RGB 图片)

5. 训练数据集迭代

在训练过程中,我们会使用 DataLoader 进行批量数据加载,示例代码如下:

python 复制代码
for epoch in range(5):  # 训练 5 轮
    for images, labels in train_loader:
        images, labels = images.to("cuda"), labels.to("cuda")

        # 训练代码(前向传播、损失计算、反向传播等)
        
        print(f"Epoch: {epoch+1}, Batch Size: {images.shape[0]}")

6. 结语

PyTorch 的 DatasetDataLoader 让数据加载变得 高效、灵活、易扩展 ,无论是 标准数据集 还是 自定义数据集,都可以轻松处理。🔥 🚀

总结:

  • torchvision.datasets 直接加载常见数据集
  • Dataset 继承自定义数据集(CSV、图片等)
  • ImageFolder 适用于结构化图片数据集
  • DataLoader 高效批量处理数据

希望这个教程对你有帮助!😃

相关推荐
聆风吟º19 分钟前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys27 分钟前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_567827 分钟前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子30 分钟前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
ValhallaCoder33 分钟前
hot100-二叉树I
数据结构·python·算法·二叉树
智驱力人工智能1 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
qq_160144871 小时前
亲测!2026年零基础学AI的入门干货,新手照做就能上手
人工智能
Howie Zphile1 小时前
全面预算管理难以落地的核心真相:“完美模型幻觉”的认知误区
人工智能·全面预算
人工不智能5771 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
盟接之桥1 小时前
盟接之桥说制造:引流品 × 利润品,全球电商平台高效产品组合策略(供讨论)
大数据·linux·服务器·网络·人工智能·制造