关于load_data_fashion_mnist函数运行原理以及运行速度慢解决方案

需要了解具体解决方案的可以直接跳至第三块内容

前言

在 PyTorch 深度学习入门实践中,Fashion-MNIST 时尚服饰数据集是新手必练的经典数据集,而load_data_fashion_mnist函数几乎是加载该数据集的标配工具。

但很多小伙伴在使用时都会遇到两个痛点:

  1. 函数运行超慢,等待时间长
  2. 不清楚函数底层原理,只会照搬调用

今天这篇文章就从函数简介、运行原理、卡顿原因 + 极致优化方案三个维度,把这个函数彻底讲透,新手也能轻松掌握!


一、load_data_fashion_mnist 函数核心简介

1. 函数定位

load_data_fashion_mnist 不是 Python 原生函数 ,也不是 PyTorch 官方内置函数,而是《动手学深度学习》等经典教材、深度学习项目中高频使用的自定义封装函数

它的核心价值:一行代码完成 Fashion-MNIST 数据集的下载、预处理、迭代器生成,极大简化新手数据加载流程。

2. 标准完整代码(可直接复制使用)

python 复制代码
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def load_data_fashion_mnist(batch_size, resize=None):
    """
    加载Fashion-MNIST数据集,返回训练集和测试集的DataLoader
    :param batch_size: 每个批次的样本数量
    :param resize: 可选,将图片resize到指定尺寸(如resize=64)
    :return: train_iter训练集迭代器, test_iter测试集迭代器
    """
    # 1. 定义数据预处理操作:转张量 + 标准化
    trans = [transforms.ToTensor()]  # 转成PyTorch张量
    if resize:
        trans.insert(0, transforms.Resize(resize))  # 可选resize
    trans = transforms.Compose(trans)  # 组合预处理
    
    # 2. 加载数据集(自动下载,root为存储路径)
    mnist_train = datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True
    )
    mnist_test = datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True
    )
    
    # 3. 生成数据迭代器(训练集打乱,测试集不打乱)
    train_iter = DataLoader(
        mnist_train, batch_size=batch_size, shuffle=True, num_workers=0
    )
    test_iter = DataLoader(
        mnist_test, batch_size=batch_size, shuffle=False, num_workers=0
    )
    return train_iter, test_iter

二、load_data_fashion_mnist 运行原理解析

这个函数本质是三步流水线作业,清晰易懂:

1. 数据集自动下载

函数会检测指定路径(../data)下是否存在 Fashion-MNIST 数据集,不存在则自动从官方服务器下载,包含 60000 张训练图 + 10000 张测试图。

2. 数据标准化与预处理

将原始图片转换为 PyTorch 模型可识别的张量格式,支持自定义图片缩放,完成数据归一化预处理。

3. 生成训练 / 测试迭代器

最终返回两个DataLoader迭代器:

  • 训练集迭代器:打乱数据,防止模型学习顺序特征
  • 测试集迭代器:不打乱数据 ,保证评估结果稳定迭代器会按batch_size分批输出数据,直接对接模型训练。

三、关于load_data_fashion_mnist为什么会慢和如何提高效率

这是大家最关心的核心问题。函数运行慢主要分两种场景,对应不同解决方案:

场景 1:首次运行极慢 → 数据集在线下载导致

问题原因

函数默认开启download=True,首次运行时会从国外服务器下载数据集,网速受限导致等待时间极长。

解决方案:本地手动下载数据集
  1. 提前下载 Fashion-MNIST 数据集压缩包(4 个文件)
  2. 在项目根目录创建data文件夹,再新建FashionMNIST\raw子文件夹
  3. 将下载的 4 个文件直接放入raw文件夹中
  4. 再次运行函数,会直接读取本地文件

场景 2:本地已有数据集,依旧读取缓慢 → CPU 单核读取导致

问题原因

原生函数默认num_workers=0,当数据集在本地,但依旧读取速度慢,还有一种原因,那就是load_data_fashion_mnist函数本身默认的读取是采用cpu单核读取的,而大部分代码常常只有一个batch_size参数,而没有num_workers参数,而这个参数影响有多少个cpu核心参与本次读取工作,因此需要将num_workers设置一下,推荐是自己cpu核心数量的1/3~2/3。 *window操作系统下不支持修改num_workers 最后,Linux和Mac系统下,num_workers设置如果出现报错,那么可能是你使用的是d2l下的load_data_fashion_mnist,早期版本的d2l是不支持设置load_data_fashion_mnist的num_workers,而最新的部分版本是支持的。更新指令为:

python 复制代码
pip install --upgrade d2l
相关推荐
东离与糖宝2 小时前
2026 Java AI框架选型:Spring AI/LangChain4j企业级对比
java·人工智能
林姜泽樾2 小时前
python入门第六课,其他字符串格式化和input
开发语言·python·pycharm
yunpeng.zhou2 小时前
深度理解agent与llm之间的关系、及mcp与skill的区别
人工智能·python·ai
CoderJia程序员甲2 小时前
GitHub 热榜项目 - 日榜(2026-04-03)
人工智能·ai·大模型·github·ai教程
TDengine (老段)2 小时前
TDengine IDMP 可视化 —— 趋势图
大数据·数据库·人工智能·物联网·时序数据库·tdengine·涛思数据
东离与糖宝2 小时前
Java AI工程化:PyTorch On Java+SpringBoot微服务部署(2025-2026最新实战)
java·人工智能
2601_955363152 小时前
技术赋能B端拓客:号码核验行业的迭代与价值升级
大数据·人工智能
Etherious_Young2 小时前
基于ResNet的石化图像及数据分类项目——从模型训练到GUI应用开发的完整实践
人工智能·机器学习·分类·卷积神经网络
有Li2 小时前
ACE-ProtoNet: 基于自适应协方差特征门和不确定性感知原型学习的冠状动脉分割/文献速递-多模态医学影像最新进展
人工智能·智能电视·文献·医学生