python/pytorch读取数据集

MNIST数据集

MNIST数据集包含了6万张手写数字([1,28,28]尺寸),以特殊格式存储。本文首先将MNIST数据集另存为png格式,然后再读取png格式图片,开展后续训练

另存为png格式

python 复制代码
import torch
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import models, transforms
from torchvision.utils import save_image
from PIL import Image

#将MNIST数据集转换为图片
tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1
datasetMNIST = MNIST("./data", train=True, download=True, transform=tf)
pbar = tqdm(datasetMNIST)
for index, (img,cl) in enumerate(pbar):
   save_image(img, f"./data/MNIST_PNG/x/{index}.png")
   # 以写入模式打开文件
   with open(f"./data/MNIST_PNG/c/{index}.txt", "w", encoding="utf-8") as file:
        # 将字符串写入文件
        file.write(f"{cl}")

注意:MNIST源数据存放在./data文件下,如果没有数据也没关系,代码会自动从网上下载。另存为png的数据放在了./data/MNIST_PNG/文件下。子文件夹x存放6万张图片,子文件夹c存放6万个文本文件,每个文本文件内有一行字符串,说明该对应的手写数字是几(标签)。

读取png格式数据集

python 复制代码
class MyMNISTDataset(Dataset):
   def __init__(self, data):
       self.data = data

   def __len__(self):
       return len(self.data)

   def __getitem__(self, idx):
       x = self.data[idx][0] #图像
       y = self.data[idx][1] #标签
       return x, y
   
def load_data(dataNum=60000):
    data = []
    pbar = tqdm(range(dataNum))
    for i in pbar:
        # 指定图片路径
        image_path = f'./data/MNIST_PNG/x/{i}.png'
        cond_path=f'./data/MNIST_PNG/c/{i}.txt'
        # 定义图像预处理
        preprocess = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),  # 将图像转换为灰度图像(单通道)
        transforms.ToTensor()
        ])
        # 使用预处理加载图像
        image_tensor = preprocess(Image.open(image_path))
        # 加载条件文档(tag)
        with open(cond_path, 'r') as file:
            line = file.readline()
            number = int(line)  # 将字符串转换为整数,图像的类别
            data.append((image_tensor, number))
    return data
   

data=load_data(60000)
# 创建数据集实例
dataset = MyMNISTDataset(data)

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
pbar = tqdm(dataloader)

for index, (img,cond) in enumerate(pbar):
    #这里对每一批进行训练...
    print(f"Batch {index}: img = {img.shape}, cond = {cond}")

load_data函数用于读取数据文件,返回一个data张量。data张量又被用于构造MyMNISTDataset类的对象datasetdataset对象又被DataLoader函数转换为dataloader

dataloader事实上按照batch将数据集进行了分割,4张图片一组进行训练。上述代码的输出如下:

bash 复制代码
......
Batch 7847: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 1, 5, 2])
Batch 7848: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 2, 6, 0])
Batch 7849: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 3, 0, 9])
Batch 7850: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 2, 9, 5])
Batch 7851: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 2, 4, 4])
Batch 7852: img = torch.Size([4, 1, 28, 28]), cond = tensor([1, 4, 2, 6])
Batch 7853: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 5, 3, 5])
Batch 7854: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 1, 0, 1])
Batch 7855: img = torch.Size([4, 1, 28, 28]), cond = tensor([9, 8, 9, 7])
Batch 7856: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 6, 6, 7])
Batch 7857: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 4, 1, 6])
Batch 7858: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 4, 6, 5])
Batch 7859: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 3, 1, 9])
Batch 7860: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 5, 8, 6])
Batch 7861: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 4, 8, 9])
Batch 7862: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 3, 5, 8])
Batch 7863: img = torch.Size([4, 1, 28, 28]), cond = tensor([8, 0, 0, 6])
......
相关推荐
TF男孩8 小时前
ARQ:一款低成本的消息队列,实现每秒万级吞吐
后端·python·消息队列
该用户已不存在13 小时前
Mojo vs Python vs Rust: 2025年搞AI,该学哪个?
后端·python·rust
站大爷IP15 小时前
Java调用Python的5种实用方案:从简单到进阶的全场景解析
python
用户83562907805121 小时前
从手动编辑到代码生成:Python 助你高效创建 Word 文档
后端·python
c8i21 小时前
python中类的基本结构、特殊属性于MRO理解
python
隐语SecretFlow21 小时前
国人自研开源隐私计算框架SecretFlow,深度拆解框架及使用【开发者必看】
深度学习
liwulin050621 小时前
【ESP32-CAM】HELLO WORLD
python
Doris_20231 天前
Python条件判断语句 if、elif 、else
前端·后端·python
Doris_20231 天前
Python 模式匹配match case
前端·后端·python
Billy_Zuo1 天前
人工智能深度学习——卷积神经网络(CNN)
人工智能·深度学习·cnn