103_PyTorch 快速上手:官方 torchvision 数据集加载与应用

在学习深度学习时,手动编写 Dataset 类虽然灵活,但对于经典的基准数据集(如 CIFAR-10、MNIST),PyTorch 提供了更高效的"一键加载"方案。本文将解析如何通过 torchvision.datasets 快速获取数据并配合 DataLoader 进行批量管理。

1. 为什么使用 torchvision.datasets?

  • 官方维护:无需担心数据格式兼容性问题。
  • 自动下载:支持一行代码自动从云端下载并解压数据。
  • 无缝集成 :原生支持 transforms 预处理逻辑。

2. 核心实战:加载 CIFAR-10 数据集

CIFAR-10 是计算机视觉最常用的入门数据集,包含 10 个类别的 60,000 张 32 \\times 32 彩色图像。

① 定义预处理流水线

在加载数据集之前,我们通常先定义好 transforms,以便在数据读取时直接进行转换(例如将 PIL 图片转为 Tensor)。

Python

复制代码
import torchvision
from torchvision import transforms

# 定义转换逻辑:转为张量
dataset_transform = transforms.Compose([
    transforms.ToTensor()
])

② 一键下载并加载

使用 torchvision.datasets.CIFAR10 接口,我们可以轻松区分训练集和测试集。

Python

复制代码
# 加载训练集
train_set = torchvision.datasets.CIFAR10(
    root="./dataset",       # 数据存储路径
    train=True,             # 是否为训练集
    transform=dataset_transform, # 应用预处理
    download=True           # 若本地无数据则自动下载
)

# 加载测试集
test_set = torchvision.datasets.CIFAR10(
    root="./dataset", 
    train=False, 
    transform=dataset_transform, 
    download=True
)

3. 数据集的探索与验证

加载完成后,我们可以像操作列表一样访问数据集。

  • 查看单条数据img, target = test_set[0] 会返回处理后的图像张量及其对应的类别索引。
  • 类别映射 :通过 test_set.classes 可以查看数字索引对应的真实名称(如 airplane, dog 等)。

4. 迈向模型训练:DataLoader 的配合

虽然 Dataset 解决了"取数据"的问题,但模型训练需要"成批次"的数据。这时我们需要配合 DataLoader

Python

复制代码
from torch.utils.data import DataLoader

# 创建数据加载器
test_loader = DataLoader(
    dataset=test_set, 
    batch_size=4,     # 每次取出 4 张图片
    shuffle=True,     # 打乱顺序
    num_workers=0,    # 多进程加载设置
    drop_last=False   # 是否舍弃最后不满足一个 batch 的数据
)

5. 总结:标准流水线逻辑

通过分析该文件,我们可以总结出 PyTorch 加载官方数据集的标准流程:

  1. 设置 Transforms:规划好数据进入模型前需要进行的数学变换。
  2. 实例化 Dataset :利用 torchvision.datasets 接口指定路径、类型及转换工具。
  3. 封装 DataLoader:设定 Batch 大小,实现自动化的批量读取。

这种方式极大地简化了实验准备阶段的代码量,让你能将更多精力投入到模型结构的设计中。

相关推荐
weixin_156241575764 分钟前
基于YOLO深度学习的动物检测与识别系统
人工智能·深度学习·yolo
叶舟15 分钟前
LYT-NET:一个超级轻量的低光照图像增强Transformer网络
人工智能·深度学习·transformer·llie·低光照图像增强
管二狗赶快去工作!44 分钟前
体系结构论文(九十八):NPUEval: Optimizing NPU Kernels with LLMs and Open Source Compilers
人工智能·深度学习·自然语言处理·体系结构
LaughingZhu1 小时前
Product Hunt 每日热榜 | 2026-04-10
人工智能·经验分享·深度学习·神经网络·产品运营
nap-joker2 小时前
FT-Mamba:一种高效的表回归的新深度学习模型
人工智能·深度学习·ftmamba
m0_372257022 小时前
bert和LLM训练的时候输入输出的格式是什么有什么区别
人工智能·深度学习·bert
杨夏同学3 小时前
AI入门——如何计算神经网络的参数
人工智能·深度学习·神经网络
龙文浩_3 小时前
AI中NLP的注意力机制的计算公式解析
人工智能·pytorch·深度学习·神经网络·自然语言处理
赵药师4 小时前
YOLO中task.py改复杂的模块
python·深度学习·yolo
Pelb4 小时前
求导 z = (x + y)^2
人工智能·深度学习·数学建模