Python PyTorch 获取 MNIST 数据

Python PyTorch 获取 MNIST 数据

  • [1 PyTorch 获取 MNIST 数据](#1 PyTorch 获取 MNIST 数据)
  • [2 PyTorch 保存 MNIST 数据](#2 PyTorch 保存 MNIST 数据)
  • [3 PyTorch 显示 MNIST 数据](#3 PyTorch 显示 MNIST 数据)

1 PyTorch 获取 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_get():
    print(torch.__version__)
    # 定义数据转换
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据
    ])
    # 获取数据
    train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    # 训练数据
    train_image = train_data.data.numpy()
    train_label = train_data.targets.numpy()
    # 测试数据
    test_image = test_data.data.numpy()
    test_label = test_data.targets.numpy()

2 PyTorch 保存 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_save(mnist_path):
    print(torch.__version__)
    # 定义数据转换
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据
    ])
    # 获取数据
    train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    # 训练数据
    train_image = train_data.data.numpy()
    train_label = train_data.targets.numpy()
    # 测试数据
    test_image = test_data.data.numpy()
    test_label = test_data.targets.numpy()
    np.savez(mnist_path, train_data=train_image, train_label=train_label, test_data=test_image, test_label=test_label)

mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_save(mnist_path)

3 PyTorch 显示 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_show(mnist_path):
    data = np.load(mnist_path)
    image = data['train_data'][0:100]
    label = data['train_label'].reshape(-1, )
    plt.figure(figsize = (10, 10))
    for i in range(100):
        print('%f, %f' % (i, label[i]))
        plt.subplot(10, 10, i + 1)
        plt.imshow(image[i])
    plt.show()

mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_show(mnist_path)
相关推荐
秃头佛爷5 分钟前
Python学习大纲总结及注意事项
开发语言·python·学习
待磨的钝刨6 分钟前
【格式化查看JSON文件】coco的json文件内容都在一行如何按照json格式查看
开发语言·javascript·json
深度学习lover1 小时前
<项目代码>YOLOv8 苹果腐烂识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·苹果腐烂识别
XiaoLeisj2 小时前
【JavaEE初阶 — 多线程】单例模式 & 指令重排序问题
java·开发语言·java-ee
API快乐传递者2 小时前
淘宝反爬虫机制的主要手段有哪些?
爬虫·python
励志成为嵌入式工程师3 小时前
c语言简单编程练习9
c语言·开发语言·算法·vim
捕鲸叉4 小时前
创建线程时传递参数给线程
开发语言·c++·算法
A charmer4 小时前
【C++】vector 类深度解析:探索动态数组的奥秘
开发语言·c++·算法
Peter_chq4 小时前
【操作系统】基于环形队列的生产消费模型
linux·c语言·开发语言·c++·后端
阡之尘埃4 小时前
Python数据分析案例61——信贷风控评分卡模型(A卡)(scorecardpy 全面解析)
人工智能·python·机器学习·数据分析·智能风控·信贷风控