python/pytorch读取数据集

MNIST数据集

MNIST数据集包含了6万张手写数字([1,28,28]尺寸),以特殊格式存储。本文首先将MNIST数据集另存为png格式,然后再读取png格式图片,开展后续训练

另存为png格式

python 复制代码
import torch
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import models, transforms
from torchvision.utils import save_image
from PIL import Image

#将MNIST数据集转换为图片
tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1
datasetMNIST = MNIST("./data", train=True, download=True, transform=tf)
pbar = tqdm(datasetMNIST)
for index, (img,cl) in enumerate(pbar):
   save_image(img, f"./data/MNIST_PNG/x/{index}.png")
   # 以写入模式打开文件
   with open(f"./data/MNIST_PNG/c/{index}.txt", "w", encoding="utf-8") as file:
        # 将字符串写入文件
        file.write(f"{cl}")

注意:MNIST源数据存放在./data文件下,如果没有数据也没关系,代码会自动从网上下载。另存为png的数据放在了./data/MNIST_PNG/文件下。子文件夹x存放6万张图片,子文件夹c存放6万个文本文件,每个文本文件内有一行字符串,说明该对应的手写数字是几(标签)。

读取png格式数据集

python 复制代码
class MyMNISTDataset(Dataset):
   def __init__(self, data):
       self.data = data

   def __len__(self):
       return len(self.data)

   def __getitem__(self, idx):
       x = self.data[idx][0] #图像
       y = self.data[idx][1] #标签
       return x, y
   
def load_data(dataNum=60000):
    data = []
    pbar = tqdm(range(dataNum))
    for i in pbar:
        # 指定图片路径
        image_path = f'./data/MNIST_PNG/x/{i}.png'
        cond_path=f'./data/MNIST_PNG/c/{i}.txt'
        # 定义图像预处理
        preprocess = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),  # 将图像转换为灰度图像(单通道)
        transforms.ToTensor()
        ])
        # 使用预处理加载图像
        image_tensor = preprocess(Image.open(image_path))
        # 加载条件文档(tag)
        with open(cond_path, 'r') as file:
            line = file.readline()
            number = int(line)  # 将字符串转换为整数,图像的类别
            data.append((image_tensor, number))
    return data
   

data=load_data(60000)
# 创建数据集实例
dataset = MyMNISTDataset(data)

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
pbar = tqdm(dataloader)

for index, (img,cond) in enumerate(pbar):
    #这里对每一批进行训练...
    print(f"Batch {index}: img = {img.shape}, cond = {cond}")

load_data函数用于读取数据文件,返回一个data张量。data张量又被用于构造MyMNISTDataset类的对象datasetdataset对象又被DataLoader函数转换为dataloader

dataloader事实上按照batch将数据集进行了分割,4张图片一组进行训练。上述代码的输出如下:

bash 复制代码
......
Batch 7847: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 1, 5, 2])
Batch 7848: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 2, 6, 0])
Batch 7849: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 3, 0, 9])
Batch 7850: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 2, 9, 5])
Batch 7851: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 2, 4, 4])
Batch 7852: img = torch.Size([4, 1, 28, 28]), cond = tensor([1, 4, 2, 6])
Batch 7853: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 5, 3, 5])
Batch 7854: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 1, 0, 1])
Batch 7855: img = torch.Size([4, 1, 28, 28]), cond = tensor([9, 8, 9, 7])
Batch 7856: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 6, 6, 7])
Batch 7857: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 4, 1, 6])
Batch 7858: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 4, 6, 5])
Batch 7859: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 3, 1, 9])
Batch 7860: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 5, 8, 6])
Batch 7861: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 4, 8, 9])
Batch 7862: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 3, 5, 8])
Batch 7863: img = torch.Size([4, 1, 28, 28]), cond = tensor([8, 0, 0, 6])
......
相关推荐
湫ccc25 分钟前
《Python基础》之字符串格式化输出
开发语言·python
mqiqe1 小时前
Python MySQL通过Binlog 获取变更记录 恢复数据
开发语言·python·mysql
AttackingLin1 小时前
2024强网杯--babyheap house of apple2解法
linux·开发语言·python
哭泣的眼泪4081 小时前
解析粗糙度仪在工业制造及材料科学和建筑工程领域的重要性
python·算法·django·virtualenv·pygame
湫ccc2 小时前
《Python基础》之基本数据类型
开发语言·python
余炜yw3 小时前
【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感
人工智能·rnn·深度学习
drebander3 小时前
使用 Java Stream 优雅实现List 转化为Map<key,Map<key,value>>
java·python·list
莫叫石榴姐3 小时前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
96773 小时前
对抗样本存在的原因
深度学习
威威猫的栗子3 小时前
Python Turtle召唤童年:喜羊羊与灰太狼之懒羊羊绘画
开发语言·python