Pytorch下载Mnist手写数据识别训练数据集的代码详解

python 复制代码
datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

1. datasets.MNIST

这是torchvision.datasets模块中的一个类,专门用于加载MNIST数据集。MNIST是一个著名的手写数字识别数据集,包含60,000个训练样本和10,000个测试样本,每个样本是28x28的灰度图像。

2. 参数解释

root='./data'
  • 作用:指定数据集下载和存储的根目录。

  • 解释 :这里设置为当前目录下的data文件夹。如果该文件夹不存在,PyTorch会自动创建它。

  • 默认值:通常没有默认值,必须指定。

train=False
  • 作用:指定加载的是训练集还是测试集。

  • 解释

    • train=True:加载训练集(60,000个样本)。

    • train=False:加载测试集(10,000个样本)。

  • 默认值 :通常是True(加载训练集),但明确指定是好的实践。

download=True
  • 作用:控制是否下载数据集。

  • 解释

    • download=True:如果数据集在root目录下不存在,则自动下载。

    • download=False:不下载,仅尝试从root目录加载。

  • 默认值 :通常是False,但这里显式设置为True以确保下载。

transform=transforms.ToTensor()
  • 作用:指定对加载的数据进行何种预处理或转换。

  • 解释

    • transforms.ToTensor()是PyTorch的一个转换函数,它将PIL图像或NumPy数组转换为PyTorch张量(torch.Tensor),并自动进行以下操作:

      1. 将图像的像素值从[0, 255]缩放到[0.0, 1.0](除以255)。

      2. 将图像的形状从(H, W, C)(高度、宽度、通道)转换为(C, H, W)(通道、高度、宽度)。对于MNIST,因为是灰度图像,所以通道数为1,形状从(28, 28)变为(1, 28, 28)

    • 如果不指定transform,返回的是PIL图像格式。

  • 默认值:如果不指定,返回原始数据(通常是PIL图像)。

3. 返回值

这行代码的返回值是一个torchvision.datasets.MNIST对象,可以像数据集一样使用:

  • 可以通过索引(如dataset[0])访问单个样本。

  • 可以通过len(dataset)获取数据集大小。

  • 通常用于DataLoader中批量加载数据。

4. 完整示例

python

复制代码
from torchvision import datasets, transforms

# 下载并加载MNIST测试集
test_dataset = datasets.MNIST(
    root='./data',      # 数据存储目录
    train=False,        # 加载测试集
    download=True,      # 如果数据不存在则下载
    transform=transforms.ToTensor()  # 转换为张量并归一化到[0,1]
)

# 打印数据集大小
print(len(test_dataset))  # 输出: 10000

# 获取第一个样本
image, label = test_dataset[0]
print(image.shape)  # 输出: torch.Size([1, 28, 28])
print(label)        # 输出: 7(标签)

5. 其他常见参数(虽然不是这里使用的)

  • target_transform:对标签(target)进行转换的函数(类似transform对图像的作用)。

  • 某些数据集可能有额外参数(如MNIST通常没有,但其他数据集可能有split等)。

总结:这行代码的作用是从PyTorch自动下载MNIST测试集到./data目录,并将图像转换为PyTorch张量格式,方便后续用于深度学习模型的测试。

相关推荐
糖葫芦君2 分钟前
25-GRPO IS SECRETLY A PROCESS REWARD MODEL
人工智能·大模型
俊男无期10 分钟前
【AI入门】通俗易懂讲AI(初稿)
人工智能
工业互联网专业25 分钟前
基于协同过滤算法的小说推荐系统_django+spider
python·django·毕业设计·源码·课程设计·spider·协同过滤算法
星星的月亮叫太阳31 分钟前
large-scale-DRL-exploration 代码阅读 总结
python·算法
喜欢吃豆1 小时前
GraphRAG 技术教程:从核心概念到高级架构
人工智能·架构·大模型
王哈哈^_^1 小时前
YOLOv11视觉检测实战:安全距离测算全解析
人工智能·数码相机·算法·yolo·计算机视觉·目标跟踪·视觉检测
Q_Q19632884751 小时前
python+django/flask基于Echarts+Python的图书零售监测系统设计与实现(带大屏)
spring boot·python·django·flask·node.js·php
AI浩1 小时前
FeatEnHancer:在低光视觉下增强目标检测及其他任务的分层特征
人工智能·目标检测·目标跟踪