关于load_data_fashion_mnist函数运行原理以及运行速度慢解决方案

需要了解具体解决方案的可以直接跳至第三块内容

前言

在 PyTorch 深度学习入门实践中,Fashion-MNIST 时尚服饰数据集是新手必练的经典数据集,而load_data_fashion_mnist函数几乎是加载该数据集的标配工具。

但很多小伙伴在使用时都会遇到两个痛点:

  1. 函数运行超慢,等待时间长
  2. 不清楚函数底层原理,只会照搬调用

今天这篇文章就从函数简介、运行原理、卡顿原因 + 极致优化方案三个维度,把这个函数彻底讲透,新手也能轻松掌握!


一、load_data_fashion_mnist 函数核心简介

1. 函数定位

load_data_fashion_mnist 不是 Python 原生函数 ,也不是 PyTorch 官方内置函数,而是《动手学深度学习》等经典教材、深度学习项目中高频使用的自定义封装函数

它的核心价值:一行代码完成 Fashion-MNIST 数据集的下载、预处理、迭代器生成,极大简化新手数据加载流程。

2. 标准完整代码(可直接复制使用)

python 复制代码
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def load_data_fashion_mnist(batch_size, resize=None):
    """
    加载Fashion-MNIST数据集,返回训练集和测试集的DataLoader
    :param batch_size: 每个批次的样本数量
    :param resize: 可选,将图片resize到指定尺寸(如resize=64)
    :return: train_iter训练集迭代器, test_iter测试集迭代器
    """
    # 1. 定义数据预处理操作:转张量 + 标准化
    trans = [transforms.ToTensor()]  # 转成PyTorch张量
    if resize:
        trans.insert(0, transforms.Resize(resize))  # 可选resize
    trans = transforms.Compose(trans)  # 组合预处理
    
    # 2. 加载数据集(自动下载,root为存储路径)
    mnist_train = datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True
    )
    mnist_test = datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True
    )
    
    # 3. 生成数据迭代器(训练集打乱,测试集不打乱)
    train_iter = DataLoader(
        mnist_train, batch_size=batch_size, shuffle=True, num_workers=0
    )
    test_iter = DataLoader(
        mnist_test, batch_size=batch_size, shuffle=False, num_workers=0
    )
    return train_iter, test_iter

二、load_data_fashion_mnist 运行原理解析

这个函数本质是三步流水线作业,清晰易懂:

1. 数据集自动下载

函数会检测指定路径(../data)下是否存在 Fashion-MNIST 数据集,不存在则自动从官方服务器下载,包含 60000 张训练图 + 10000 张测试图。

2. 数据标准化与预处理

将原始图片转换为 PyTorch 模型可识别的张量格式,支持自定义图片缩放,完成数据归一化预处理。

3. 生成训练 / 测试迭代器

最终返回两个DataLoader迭代器:

  • 训练集迭代器:打乱数据,防止模型学习顺序特征
  • 测试集迭代器:不打乱数据 ,保证评估结果稳定迭代器会按batch_size分批输出数据,直接对接模型训练。

三、关于load_data_fashion_mnist为什么会慢和如何提高效率

这是大家最关心的核心问题。函数运行慢主要分两种场景,对应不同解决方案:

场景 1:首次运行极慢 → 数据集在线下载导致

问题原因

函数默认开启download=True,首次运行时会从国外服务器下载数据集,网速受限导致等待时间极长。

解决方案:本地手动下载数据集
  1. 提前下载 Fashion-MNIST 数据集压缩包(4 个文件)
  2. 在项目根目录创建data文件夹,再新建FashionMNIST\raw子文件夹
  3. 将下载的 4 个文件直接放入raw文件夹中
  4. 再次运行函数,会直接读取本地文件

场景 2:本地已有数据集,依旧读取缓慢 → CPU 单核读取导致

问题原因

原生函数默认num_workers=0,当数据集在本地,但依旧读取速度慢,还有一种原因,那就是load_data_fashion_mnist函数本身默认的读取是采用cpu单核读取的,而大部分代码常常只有一个batch_size参数,而没有num_workers参数,而这个参数影响有多少个cpu核心参与本次读取工作,因此需要将num_workers设置一下,推荐是自己cpu核心数量的1/3~2/3。 *window操作系统下不支持修改num_workers 最后,Linux和Mac系统下,num_workers设置如果出现报错,那么可能是你使用的是d2l下的load_data_fashion_mnist,早期版本的d2l是不支持设置load_data_fashion_mnist的num_workers,而最新的部分版本是支持的。更新指令为:

python 复制代码
pip install --upgrade d2l
相关推荐
云烟成雨TD几秒前
Spring AI 1.x 系列【52】可观测集成 SkyWalking
人工智能·spring·skywalking
云烟成雨TD几秒前
Spring AI 1.x 系列【57】动态工具发现:Tool Search Tool
java·人工智能·spring
AndrewHZ1 分钟前
【LLM技术全景】规模定律与模型演进:为什么模型越大越强?
人工智能·gpt·深度学习·语言模型·llm·openai·规模定律
galaxylove1 分钟前
Gartner发布创新洞察:AI SOC智能体加速通信运营商安全运营转型
大数据·人工智能·安全
甩手网软件11 分钟前
Shopee2026新规:费率重构与履约收紧下,卖家如何破局?
大数据·人工智能
数据库小学妹13 分钟前
AI时代数据库怎么选?多模融合、数据统一存储与选型实战指南
数据库·人工智能·经验分享·ai
lizhihai_9920 分钟前
股市学习心得-AI 产业链核心标的梳理清单
大数据·服务器·人工智能·科技·学习
天佑木枫22 分钟前
15天Python入门系列 · 序
开发语言·python
happylifetree23 分钟前
Python017-第二章15.数据容器-dict常用操作
python
暮雪倾风24 分钟前
【AI】国内使用Claude Code,配置Claude Code,使用DeepSeek为例
人工智能