Pytorch下载Mnist手写数据识别训练数据集的代码详解

python 复制代码
datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

1. datasets.MNIST

这是torchvision.datasets模块中的一个类,专门用于加载MNIST数据集。MNIST是一个著名的手写数字识别数据集,包含60,000个训练样本和10,000个测试样本,每个样本是28x28的灰度图像。

2. 参数解释

root='./data'
  • 作用:指定数据集下载和存储的根目录。

  • 解释 :这里设置为当前目录下的data文件夹。如果该文件夹不存在,PyTorch会自动创建它。

  • 默认值:通常没有默认值,必须指定。

train=False
  • 作用:指定加载的是训练集还是测试集。

  • 解释

    • train=True:加载训练集(60,000个样本)。

    • train=False:加载测试集(10,000个样本)。

  • 默认值 :通常是True(加载训练集),但明确指定是好的实践。

download=True
  • 作用:控制是否下载数据集。

  • 解释

    • download=True:如果数据集在root目录下不存在,则自动下载。

    • download=False:不下载,仅尝试从root目录加载。

  • 默认值 :通常是False,但这里显式设置为True以确保下载。

transform=transforms.ToTensor()
  • 作用:指定对加载的数据进行何种预处理或转换。

  • 解释

    • transforms.ToTensor()是PyTorch的一个转换函数,它将PIL图像或NumPy数组转换为PyTorch张量(torch.Tensor),并自动进行以下操作:

      1. 将图像的像素值从[0, 255]缩放到[0.0, 1.0](除以255)。

      2. 将图像的形状从(H, W, C)(高度、宽度、通道)转换为(C, H, W)(通道、高度、宽度)。对于MNIST,因为是灰度图像,所以通道数为1,形状从(28, 28)变为(1, 28, 28)

    • 如果不指定transform,返回的是PIL图像格式。

  • 默认值:如果不指定,返回原始数据(通常是PIL图像)。

3. 返回值

这行代码的返回值是一个torchvision.datasets.MNIST对象,可以像数据集一样使用:

  • 可以通过索引(如dataset[0])访问单个样本。

  • 可以通过len(dataset)获取数据集大小。

  • 通常用于DataLoader中批量加载数据。

4. 完整示例

python

复制代码
from torchvision import datasets, transforms

# 下载并加载MNIST测试集
test_dataset = datasets.MNIST(
    root='./data',      # 数据存储目录
    train=False,        # 加载测试集
    download=True,      # 如果数据不存在则下载
    transform=transforms.ToTensor()  # 转换为张量并归一化到[0,1]
)

# 打印数据集大小
print(len(test_dataset))  # 输出: 10000

# 获取第一个样本
image, label = test_dataset[0]
print(image.shape)  # 输出: torch.Size([1, 28, 28])
print(label)        # 输出: 7(标签)

5. 其他常见参数(虽然不是这里使用的)

  • target_transform:对标签(target)进行转换的函数(类似transform对图像的作用)。

  • 某些数据集可能有额外参数(如MNIST通常没有,但其他数据集可能有split等)。

总结:这行代码的作用是从PyTorch自动下载MNIST测试集到./data目录,并将图像转换为PyTorch张量格式,方便后续用于深度学习模型的测试。

相关推荐
java1234_小锋19 分钟前
【NLP舆情分析】基于python微博舆情分析可视化系统(flask+pandas+echarts) 视频教程 - 微博类别信息爬取
开发语言·python·flask
盼小辉丶1 小时前
图机器学习(11)——链接预测
人工智能·机器学习·图机器学习
CareyWYR1 小时前
每周AI论文速递(250714-250718)
人工智能
想要成为计算机高手2 小时前
9. isaacsim4.2教程-ROS加相机/CLOCK
人工智能·机器人·ros·仿真·具身智能·isaacsim
Elastic 中国社区官方博客2 小时前
AI 驱动的仪表板:从愿景到 Kibana
大数据·数据库·人工智能·elasticsearch·搜索引擎·全文检索·kibana
西柚小萌新2 小时前
【大模型:知识图谱】--6.Neo4j DeskTop安装+使用
人工智能·知识图谱
BTU_YC2 小时前
Neo4j Python 驱动库完整教程(带输入输出示例)
开发语言·python·neo4j
lishaoan772 小时前
用Python实现神经网络(四)
python·神经网络·多层神经网络
曾几何时`2 小时前
分别使用Cypher与python构建neo4j图谱
开发语言·python·机器学习
杨小扩2 小时前
开发者进化论:驾驭AI,开启软件工程新纪元
人工智能·软件工程