Python PyTorch 获取 MNIST 数据

Python PyTorch 获取 MNIST 数据

  • [1 PyTorch 获取 MNIST 数据](#1 PyTorch 获取 MNIST 数据)
  • [2 PyTorch 保存 MNIST 数据](#2 PyTorch 保存 MNIST 数据)
  • [3 PyTorch 显示 MNIST 数据](#3 PyTorch 显示 MNIST 数据)

1 PyTorch 获取 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_get():
    print(torch.__version__)
    # 定义数据转换
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据
    ])
    # 获取数据
    train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    # 训练数据
    train_image = train_data.data.numpy()
    train_label = train_data.targets.numpy()
    # 测试数据
    test_image = test_data.data.numpy()
    test_label = test_data.targets.numpy()

2 PyTorch 保存 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_save(mnist_path):
    print(torch.__version__)
    # 定义数据转换
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据
    ])
    # 获取数据
    train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    # 训练数据
    train_image = train_data.data.numpy()
    train_label = train_data.targets.numpy()
    # 测试数据
    test_image = test_data.data.numpy()
    test_label = test_data.targets.numpy()
    np.savez(mnist_path, train_data=train_image, train_label=train_label, test_data=test_image, test_label=test_label)

mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_save(mnist_path)

3 PyTorch 显示 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_show(mnist_path):
    data = np.load(mnist_path)
    image = data['train_data'][0:100]
    label = data['train_label'].reshape(-1, )
    plt.figure(figsize = (10, 10))
    for i in range(100):
        print('%f, %f' % (i, label[i]))
        plt.subplot(10, 10, i + 1)
        plt.imshow(image[i])
    plt.show()

mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_show(mnist_path)
相关推荐
DaphneOdera178 分钟前
Git Bash 配置 zsh
开发语言·git·bash
Code侠客行15 分钟前
Scala语言的编程范式
开发语言·后端·golang
lozhyf34 分钟前
Go语言-学习一
开发语言·学习·golang
dujunqiu44 分钟前
bash: ./xxx: No such file or directory
开发语言·bash
爱偷懒的程序源1 小时前
解决go.mod文件中replace不生效的问题
开发语言·golang
日月星宿~1 小时前
【JVM】调优
java·开发语言·jvm
加德霍克1 小时前
【机器学习】使用scikit-learn中的KNN包实现对鸢尾花数据集或者自定义数据集的的预测
人工智能·python·学习·机器学习·作业
2401_843785231 小时前
C语言 指针_野指针 指针运算
c语言·开发语言
matlabgoodboy1 小时前
代码编写java代做matlab程序代编Python接单c++代写web系统设计
java·python·matlab
l1x1n01 小时前
No.37 笔记 | Python面向对象编程学习笔记:探索代码世界的奇妙之旅
笔记·python·学习