Pytorch下载Mnist手写数据识别训练数据集的代码详解

python 复制代码
datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

1. datasets.MNIST

这是torchvision.datasets模块中的一个类,专门用于加载MNIST数据集。MNIST是一个著名的手写数字识别数据集,包含60,000个训练样本和10,000个测试样本,每个样本是28x28的灰度图像。

2. 参数解释

root='./data'
  • 作用:指定数据集下载和存储的根目录。

  • 解释 :这里设置为当前目录下的data文件夹。如果该文件夹不存在,PyTorch会自动创建它。

  • 默认值:通常没有默认值,必须指定。

train=False
  • 作用:指定加载的是训练集还是测试集。

  • 解释

    • train=True:加载训练集(60,000个样本)。

    • train=False:加载测试集(10,000个样本)。

  • 默认值 :通常是True(加载训练集),但明确指定是好的实践。

download=True
  • 作用:控制是否下载数据集。

  • 解释

    • download=True:如果数据集在root目录下不存在,则自动下载。

    • download=False:不下载,仅尝试从root目录加载。

  • 默认值 :通常是False,但这里显式设置为True以确保下载。

transform=transforms.ToTensor()
  • 作用:指定对加载的数据进行何种预处理或转换。

  • 解释

    • transforms.ToTensor()是PyTorch的一个转换函数,它将PIL图像或NumPy数组转换为PyTorch张量(torch.Tensor),并自动进行以下操作:

      1. 将图像的像素值从[0, 255]缩放到[0.0, 1.0](除以255)。

      2. 将图像的形状从(H, W, C)(高度、宽度、通道)转换为(C, H, W)(通道、高度、宽度)。对于MNIST,因为是灰度图像,所以通道数为1,形状从(28, 28)变为(1, 28, 28)

    • 如果不指定transform,返回的是PIL图像格式。

  • 默认值:如果不指定,返回原始数据(通常是PIL图像)。

3. 返回值

这行代码的返回值是一个torchvision.datasets.MNIST对象,可以像数据集一样使用:

  • 可以通过索引(如dataset[0])访问单个样本。

  • 可以通过len(dataset)获取数据集大小。

  • 通常用于DataLoader中批量加载数据。

4. 完整示例

python

复制代码
from torchvision import datasets, transforms

# 下载并加载MNIST测试集
test_dataset = datasets.MNIST(
    root='./data',      # 数据存储目录
    train=False,        # 加载测试集
    download=True,      # 如果数据不存在则下载
    transform=transforms.ToTensor()  # 转换为张量并归一化到[0,1]
)

# 打印数据集大小
print(len(test_dataset))  # 输出: 10000

# 获取第一个样本
image, label = test_dataset[0]
print(image.shape)  # 输出: torch.Size([1, 28, 28])
print(label)        # 输出: 7(标签)

5. 其他常见参数(虽然不是这里使用的)

  • target_transform:对标签(target)进行转换的函数(类似transform对图像的作用)。

  • 某些数据集可能有额外参数(如MNIST通常没有,但其他数据集可能有split等)。

总结:这行代码的作用是从PyTorch自动下载MNIST测试集到./data目录,并将图像转换为PyTorch张量格式,方便后续用于深度学习模型的测试。

相关推荐
非著名架构师7 小时前
超级工程的“数字风洞”:高精度AI气象如何在数字孪生中预演台风、暴雪,确保重大基础设施全生命周期安全?
人工智能·智慧农业·灾害预警·galeweather.cn·ai气象模型·高精度农业气象
qq_356196957 小时前
Day 45 简单CNN@浙大疏锦行
python
superman超哥7 小时前
仓颉语言中字典的增删改查:深度剖析与工程实践
c语言·开发语言·c++·python·仓颉
延凡科技7 小时前
延凡智慧水库系统:数字孪生+AI驱动水库安全与智能调度
人工智能·安全
magic_ll7 小时前
【yolo系列】yolov10的结构解析、一致性双重分配
人工智能
Christo37 小时前
2024《Three-way clustering: Foundations, survey and challenges》
人工智能·算法·机器学习·数据挖掘
carver w7 小时前
智能医学工程选题分享
python
醒过来摸鱼7 小时前
Java Compiler API使用
java·开发语言·python
言之。7 小时前
Claude Code IDE 集成工作原理详解
ide·人工智能
肥猪猪爸8 小时前
计算机视觉中的Mask是干啥的
图像处理·人工智能·深度学习·神经网络·目标检测·计算机视觉·视觉检测