需要了解具体解决方案的可以直接跳至第三块内容
前言
在 PyTorch 深度学习入门实践中,Fashion-MNIST 时尚服饰数据集是新手必练的经典数据集,而load_data_fashion_mnist函数几乎是加载该数据集的标配工具。
但很多小伙伴在使用时都会遇到两个痛点:
- 函数运行超慢,等待时间长
- 不清楚函数底层原理,只会照搬调用
今天这篇文章就从函数简介、运行原理、卡顿原因 + 极致优化方案三个维度,把这个函数彻底讲透,新手也能轻松掌握!
一、load_data_fashion_mnist 函数核心简介
1. 函数定位
load_data_fashion_mnist 不是 Python 原生函数 ,也不是 PyTorch 官方内置函数,而是《动手学深度学习》等经典教材、深度学习项目中高频使用的自定义封装函数。
它的核心价值:一行代码完成 Fashion-MNIST 数据集的下载、预处理、迭代器生成,极大简化新手数据加载流程。
2. 标准完整代码(可直接复制使用)
python
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
def load_data_fashion_mnist(batch_size, resize=None):
"""
加载Fashion-MNIST数据集,返回训练集和测试集的DataLoader
:param batch_size: 每个批次的样本数量
:param resize: 可选,将图片resize到指定尺寸(如resize=64)
:return: train_iter训练集迭代器, test_iter测试集迭代器
"""
# 1. 定义数据预处理操作:转张量 + 标准化
trans = [transforms.ToTensor()] # 转成PyTorch张量
if resize:
trans.insert(0, transforms.Resize(resize)) # 可选resize
trans = transforms.Compose(trans) # 组合预处理
# 2. 加载数据集(自动下载,root为存储路径)
mnist_train = datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True
)
mnist_test = datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True
)
# 3. 生成数据迭代器(训练集打乱,测试集不打乱)
train_iter = DataLoader(
mnist_train, batch_size=batch_size, shuffle=True, num_workers=0
)
test_iter = DataLoader(
mnist_test, batch_size=batch_size, shuffle=False, num_workers=0
)
return train_iter, test_iter
二、load_data_fashion_mnist 运行原理解析
这个函数本质是三步流水线作业,清晰易懂:
1. 数据集自动下载
函数会检测指定路径(../data)下是否存在 Fashion-MNIST 数据集,不存在则自动从官方服务器下载,包含 60000 张训练图 + 10000 张测试图。
2. 数据标准化与预处理
将原始图片转换为 PyTorch 模型可识别的张量格式,支持自定义图片缩放,完成数据归一化预处理。
3. 生成训练 / 测试迭代器
最终返回两个DataLoader迭代器:
- 训练集迭代器:打乱数据,防止模型学习顺序特征
- 测试集迭代器:不打乱数据 ,保证评估结果稳定迭代器会按
batch_size分批输出数据,直接对接模型训练。
三、关于load_data_fashion_mnist为什么会慢和如何提高效率
这是大家最关心的核心问题。函数运行慢主要分两种场景,对应不同解决方案:
场景 1:首次运行极慢 → 数据集在线下载导致
问题原因
函数默认开启download=True,首次运行时会从国外服务器下载数据集,网速受限导致等待时间极长。
解决方案:本地手动下载数据集
- 提前下载 Fashion-MNIST 数据集压缩包(4 个文件)
- 在项目根目录创建
data文件夹,再新建FashionMNIST\raw子文件夹 - 将下载的 4 个文件直接放入
raw文件夹中 - 再次运行函数,会直接读取本地文件
场景 2:本地已有数据集,依旧读取缓慢 → CPU 单核读取导致
问题原因
原生函数默认num_workers=0,当数据集在本地,但依旧读取速度慢,还有一种原因,那就是load_data_fashion_mnist函数本身默认的读取是采用cpu单核读取的,而大部分代码常常只有一个batch_size参数,而没有num_workers参数,而这个参数影响有多少个cpu核心参与本次读取工作,因此需要将num_workers设置一下,推荐是自己cpu核心数量的1/3~2/3。 *window操作系统下不支持修改num_workers 最后,Linux和Mac系统下,num_workers设置如果出现报错,那么可能是你使用的是d2l下的load_data_fashion_mnist,早期版本的d2l是不支持设置load_data_fashion_mnist的num_workers,而最新的部分版本是支持的。更新指令为:
python
pip install --upgrade d2l