关于load_data_fashion_mnist函数运行原理以及运行速度慢解决方案

需要了解具体解决方案的可以直接跳至第三块内容

前言

在 PyTorch 深度学习入门实践中,Fashion-MNIST 时尚服饰数据集是新手必练的经典数据集,而load_data_fashion_mnist函数几乎是加载该数据集的标配工具。

但很多小伙伴在使用时都会遇到两个痛点:

  1. 函数运行超慢,等待时间长
  2. 不清楚函数底层原理,只会照搬调用

今天这篇文章就从函数简介、运行原理、卡顿原因 + 极致优化方案三个维度,把这个函数彻底讲透,新手也能轻松掌握!


一、load_data_fashion_mnist 函数核心简介

1. 函数定位

load_data_fashion_mnist 不是 Python 原生函数 ,也不是 PyTorch 官方内置函数,而是《动手学深度学习》等经典教材、深度学习项目中高频使用的自定义封装函数

它的核心价值:一行代码完成 Fashion-MNIST 数据集的下载、预处理、迭代器生成,极大简化新手数据加载流程。

2. 标准完整代码(可直接复制使用)

python 复制代码
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def load_data_fashion_mnist(batch_size, resize=None):
    """
    加载Fashion-MNIST数据集,返回训练集和测试集的DataLoader
    :param batch_size: 每个批次的样本数量
    :param resize: 可选,将图片resize到指定尺寸(如resize=64)
    :return: train_iter训练集迭代器, test_iter测试集迭代器
    """
    # 1. 定义数据预处理操作:转张量 + 标准化
    trans = [transforms.ToTensor()]  # 转成PyTorch张量
    if resize:
        trans.insert(0, transforms.Resize(resize))  # 可选resize
    trans = transforms.Compose(trans)  # 组合预处理
    
    # 2. 加载数据集(自动下载,root为存储路径)
    mnist_train = datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True
    )
    mnist_test = datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True
    )
    
    # 3. 生成数据迭代器(训练集打乱,测试集不打乱)
    train_iter = DataLoader(
        mnist_train, batch_size=batch_size, shuffle=True, num_workers=0
    )
    test_iter = DataLoader(
        mnist_test, batch_size=batch_size, shuffle=False, num_workers=0
    )
    return train_iter, test_iter

二、load_data_fashion_mnist 运行原理解析

这个函数本质是三步流水线作业,清晰易懂:

1. 数据集自动下载

函数会检测指定路径(../data)下是否存在 Fashion-MNIST 数据集,不存在则自动从官方服务器下载,包含 60000 张训练图 + 10000 张测试图。

2. 数据标准化与预处理

将原始图片转换为 PyTorch 模型可识别的张量格式,支持自定义图片缩放,完成数据归一化预处理。

3. 生成训练 / 测试迭代器

最终返回两个DataLoader迭代器:

  • 训练集迭代器:打乱数据,防止模型学习顺序特征
  • 测试集迭代器:不打乱数据 ,保证评估结果稳定迭代器会按batch_size分批输出数据,直接对接模型训练。

三、关于load_data_fashion_mnist为什么会慢和如何提高效率

这是大家最关心的核心问题。函数运行慢主要分两种场景,对应不同解决方案:

场景 1:首次运行极慢 → 数据集在线下载导致

问题原因

函数默认开启download=True,首次运行时会从国外服务器下载数据集,网速受限导致等待时间极长。

解决方案:本地手动下载数据集
  1. 提前下载 Fashion-MNIST 数据集压缩包(4 个文件)
  2. 在项目根目录创建data文件夹,再新建FashionMNIST\raw子文件夹
  3. 将下载的 4 个文件直接放入raw文件夹中
  4. 再次运行函数,会直接读取本地文件

场景 2:本地已有数据集,依旧读取缓慢 → CPU 单核读取导致

问题原因

原生函数默认num_workers=0,当数据集在本地,但依旧读取速度慢,还有一种原因,那就是load_data_fashion_mnist函数本身默认的读取是采用cpu单核读取的,而大部分代码常常只有一个batch_size参数,而没有num_workers参数,而这个参数影响有多少个cpu核心参与本次读取工作,因此需要将num_workers设置一下,推荐是自己cpu核心数量的1/3~2/3。 *window操作系统下不支持修改num_workers 最后,Linux和Mac系统下,num_workers设置如果出现报错,那么可能是你使用的是d2l下的load_data_fashion_mnist,早期版本的d2l是不支持设置load_data_fashion_mnist的num_workers,而最新的部分版本是支持的。更新指令为:

python 复制代码
pip install --upgrade d2l
相关推荐
m0_617493941 天前
PyTorch CUDA设备不可用错误解决方案
人工智能·pytorch·python
Soari1 天前
告别玩具级 Demo!深度拆解 agents-towards-production,用硬核工程把 AI Agent 推向工业级生产线
人工智能·软件工程·llmops·架构优化·genai·aiagent·生产级部署
minhuan1 天前
RTX 4090显存终极优化:模型分层加载、CPU Offload显存和内存动态置换实践.179
人工智能·大模型应用·rtx 4090显存优化·模型分层加载·cpu offload优化
小郑加油1 天前
python学习Day15:综合训练——数据清洗与缺失值补充
开发语言·python·学习
完成大叔1 天前
Agent入门:用本地模型从零搭建
开发语言·python·langchain
2601_958548481 天前
电镀整流机源头厂家:企业采购选型策略深度解析
人工智能
光锥智能1 天前
智元WITA成为全国首例完成大模型备案的具身智能交互模型
人工智能
墨神谕1 天前
人工智能(一)—AI的起源和发展
人工智能
科技云报道1 天前
当攻击开始“自主决策”,安全体系如何应战?
人工智能
一切皆是因缘际会1 天前
AI低代码开发实战:轻量化部署与多场景落地
人工智能·深度学习·低代码·机器学习·ai·架构