Python PyTorch 获取 MNIST 数据

Python PyTorch 获取 MNIST 数据

  • [1 PyTorch 获取 MNIST 数据](#1 PyTorch 获取 MNIST 数据)
  • [2 PyTorch 保存 MNIST 数据](#2 PyTorch 保存 MNIST 数据)
  • [3 PyTorch 显示 MNIST 数据](#3 PyTorch 显示 MNIST 数据)

1 PyTorch 获取 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_get():
    print(torch.__version__)
    # 定义数据转换
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据
    ])
    # 获取数据
    train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    # 训练数据
    train_image = train_data.data.numpy()
    train_label = train_data.targets.numpy()
    # 测试数据
    test_image = test_data.data.numpy()
    test_label = test_data.targets.numpy()

2 PyTorch 保存 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_save(mnist_path):
    print(torch.__version__)
    # 定义数据转换
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据
    ])
    # 获取数据
    train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    # 训练数据
    train_image = train_data.data.numpy()
    train_label = train_data.targets.numpy()
    # 测试数据
    test_image = test_data.data.numpy()
    test_label = test_data.targets.numpy()
    np.savez(mnist_path, train_data=train_image, train_label=train_label, test_data=test_image, test_label=test_label)

mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_save(mnist_path)

3 PyTorch 显示 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_show(mnist_path):
    data = np.load(mnist_path)
    image = data['train_data'][0:100]
    label = data['train_label'].reshape(-1, )
    plt.figure(figsize = (10, 10))
    for i in range(100):
        print('%f, %f' % (i, label[i]))
        plt.subplot(10, 10, i + 1)
        plt.imshow(image[i])
    plt.show()

mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_show(mnist_path)
相关推荐
F_D_Z9 小时前
数据集相关类代码回顾理解 | sns.distplot\%matplotlib inline\sns.scatterplot
python·深度学习·matplotlib
9***P3349 小时前
PHP代码覆盖率
开发语言·php·代码覆盖率
daidaidaiyu10 小时前
一文入门 LangGraph 开发
python·ai
CoderYanger10 小时前
优选算法-栈:67.基本计算器Ⅱ
java·开发语言·算法·leetcode·职场和发展·1024程序员节
jllllyuz10 小时前
Matlab实现基于Matrix Pencil算法实现声源信号角度和时间估计
开发语言·算法·matlab
多多*10 小时前
Java复习 操作系统原理 计算机网络相关 2025年11月23日
java·开发语言·网络·算法·spring·microsoft·maven
p***434810 小时前
Rust网络编程模型
开发语言·网络·rust
ᐇ95910 小时前
Java集合框架深度实战:构建智能教育管理与娱乐系统
java·开发语言·娱乐
不知更鸟11 小时前
前端报错:快速解决Django接口404问题
前端·python·django
4***721311 小时前
【玩转全栈】----Django模板语法、请求与响应
数据库·python·django