Python PyTorch 获取 MNIST 数据

Python PyTorch 获取 MNIST 数据

  • [1 PyTorch 获取 MNIST 数据](#1 PyTorch 获取 MNIST 数据)
  • [2 PyTorch 保存 MNIST 数据](#2 PyTorch 保存 MNIST 数据)
  • [3 PyTorch 显示 MNIST 数据](#3 PyTorch 显示 MNIST 数据)

1 PyTorch 获取 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_get():
    print(torch.__version__)
    # 定义数据转换
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据
    ])
    # 获取数据
    train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    # 训练数据
    train_image = train_data.data.numpy()
    train_label = train_data.targets.numpy()
    # 测试数据
    test_image = test_data.data.numpy()
    test_label = test_data.targets.numpy()

2 PyTorch 保存 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_save(mnist_path):
    print(torch.__version__)
    # 定义数据转换
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据
    ])
    # 获取数据
    train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    # 训练数据
    train_image = train_data.data.numpy()
    train_label = train_data.targets.numpy()
    # 测试数据
    test_image = test_data.data.numpy()
    test_label = test_data.targets.numpy()
    np.savez(mnist_path, train_data=train_image, train_label=train_label, test_data=test_image, test_label=test_label)

mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_save(mnist_path)

3 PyTorch 显示 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_show(mnist_path):
    data = np.load(mnist_path)
    image = data['train_data'][0:100]
    label = data['train_label'].reshape(-1, )
    plt.figure(figsize = (10, 10))
    for i in range(100):
        print('%f, %f' % (i, label[i]))
        plt.subplot(10, 10, i + 1)
        plt.imshow(image[i])
    plt.show()

mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_show(mnist_path)
相关推荐
炼丹师小米2 分钟前
Ubuntu24.04.1系统下VideoMamba环境配置
python·环境配置·videomamba
GFCGUO8 分钟前
ubuntu18.04运行OpenPCDet出现的问题
linux·python·学习·ubuntu·conda·pip
快乐就好ya33 分钟前
Java多线程
java·开发语言
CS_GaoMing1 小时前
Centos7 JDK 多版本管理与 Maven 构建问题和注意!
java·开发语言·maven·centos7·java多版本
985小水博一枚呀2 小时前
【深度学习基础模型】神经图灵机(Neural Turing Machines, NTM)详细理解并附实现代码。
人工智能·python·rnn·深度学习·lstm·ntm
2401_858120532 小时前
Spring Boot框架下的大学生就业招聘平台
java·开发语言
转调2 小时前
每日一练:地下城游戏
开发语言·c++·算法·leetcode
Java探秘者2 小时前
Maven下载、安装与环境配置详解:从零开始搭建高效Java开发环境
java·开发语言·数据库·spring boot·spring cloud·maven·idea
2303_812044463 小时前
Bean,看到P188没看了与maven
java·开发语言
秋夫人3 小时前
idea 同一个项目不同模块如何设置不同的jdk版本
java·开发语言·intellij-idea