Python PyTorch 获取 MNIST 数据

Python PyTorch 获取 MNIST 数据

  • [1 PyTorch 获取 MNIST 数据](#1 PyTorch 获取 MNIST 数据)
  • [2 PyTorch 保存 MNIST 数据](#2 PyTorch 保存 MNIST 数据)
  • [3 PyTorch 显示 MNIST 数据](#3 PyTorch 显示 MNIST 数据)

1 PyTorch 获取 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_get():
    print(torch.__version__)
    # 定义数据转换
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据
    ])
    # 获取数据
    train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    # 训练数据
    train_image = train_data.data.numpy()
    train_label = train_data.targets.numpy()
    # 测试数据
    test_image = test_data.data.numpy()
    test_label = test_data.targets.numpy()

2 PyTorch 保存 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_save(mnist_path):
    print(torch.__version__)
    # 定义数据转换
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据
    ])
    # 获取数据
    train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    # 训练数据
    train_image = train_data.data.numpy()
    train_label = train_data.targets.numpy()
    # 测试数据
    test_image = test_data.data.numpy()
    test_label = test_data.targets.numpy()
    np.savez(mnist_path, train_data=train_image, train_label=train_label, test_data=test_image, test_label=test_label)

mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_save(mnist_path)

3 PyTorch 显示 MNIST 数据

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_show(mnist_path):
    data = np.load(mnist_path)
    image = data['train_data'][0:100]
    label = data['train_label'].reshape(-1, )
    plt.figure(figsize = (10, 10))
    for i in range(100):
        print('%f, %f' % (i, label[i]))
        plt.subplot(10, 10, i + 1)
        plt.imshow(image[i])
    plt.show()

mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_show(mnist_path)
相关推荐
liwulin05063 分钟前
【JAVA】JVM 堆内存“缓冲空间”的压缩机制及调整方法
java·开发语言·jvm
Mysticbinary7 分钟前
Python 迭代器和生成器概念
python·迭代器·生成器
kaka.liulin -study9 分钟前
Multi Agents Collaboration OS:数据与知识协同构建数据工作流自动化
人工智能·python·深度学习·数据分析
Simon—欧阳14 分钟前
C#异步方法返回Task<T>的同步调用
开发语言·前端·javascript
红队it24 分钟前
【机器学习算法】基于python商品销量数据分析大屏可视化预测系统(完整系统源码+数据库+开发笔记+详细启动教程)✅
python·机器学习·数据分析
michaelzhouh25 分钟前
php调用大模型应用接口实现流式输出以及数据过滤
开发语言·php·php调用大模型api流式输出
小郝 小郝25 分钟前
【C语言】浮点数在内存的储存
c语言·开发语言
韩zj28 分钟前
springboot调用python文件,python文件使用其他dat文件,适配windows和linux,以及docker环境的方案
windows·spring boot·python
佩奇的技术笔记36 分钟前
Java学习手册:JVM、JRE和JDK的关系
java·开发语言·jvm
拖拉机1 小时前
Python(五)字典
后端·python