python/pytorch读取数据集

MNIST数据集

MNIST数据集包含了6万张手写数字([1,28,28]尺寸),以特殊格式存储。本文首先将MNIST数据集另存为png格式,然后再读取png格式图片,开展后续训练

另存为png格式

python 复制代码
import torch
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import models, transforms
from torchvision.utils import save_image
from PIL import Image

#将MNIST数据集转换为图片
tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1
datasetMNIST = MNIST("./data", train=True, download=True, transform=tf)
pbar = tqdm(datasetMNIST)
for index, (img,cl) in enumerate(pbar):
   save_image(img, f"./data/MNIST_PNG/x/{index}.png")
   # 以写入模式打开文件
   with open(f"./data/MNIST_PNG/c/{index}.txt", "w", encoding="utf-8") as file:
        # 将字符串写入文件
        file.write(f"{cl}")

注意:MNIST源数据存放在./data文件下,如果没有数据也没关系,代码会自动从网上下载。另存为png的数据放在了./data/MNIST_PNG/文件下。子文件夹x存放6万张图片,子文件夹c存放6万个文本文件,每个文本文件内有一行字符串,说明该对应的手写数字是几(标签)。

读取png格式数据集

python 复制代码
class MyMNISTDataset(Dataset):
   def __init__(self, data):
       self.data = data

   def __len__(self):
       return len(self.data)

   def __getitem__(self, idx):
       x = self.data[idx][0] #图像
       y = self.data[idx][1] #标签
       return x, y
   
def load_data(dataNum=60000):
    data = []
    pbar = tqdm(range(dataNum))
    for i in pbar:
        # 指定图片路径
        image_path = f'./data/MNIST_PNG/x/{i}.png'
        cond_path=f'./data/MNIST_PNG/c/{i}.txt'
        # 定义图像预处理
        preprocess = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),  # 将图像转换为灰度图像(单通道)
        transforms.ToTensor()
        ])
        # 使用预处理加载图像
        image_tensor = preprocess(Image.open(image_path))
        # 加载条件文档(tag)
        with open(cond_path, 'r') as file:
            line = file.readline()
            number = int(line)  # 将字符串转换为整数,图像的类别
            data.append((image_tensor, number))
    return data
   

data=load_data(60000)
# 创建数据集实例
dataset = MyMNISTDataset(data)

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
pbar = tqdm(dataloader)

for index, (img,cond) in enumerate(pbar):
    #这里对每一批进行训练...
    print(f"Batch {index}: img = {img.shape}, cond = {cond}")

load_data函数用于读取数据文件,返回一个data张量。data张量又被用于构造MyMNISTDataset类的对象datasetdataset对象又被DataLoader函数转换为dataloader

dataloader事实上按照batch将数据集进行了分割,4张图片一组进行训练。上述代码的输出如下:

bash 复制代码
......
Batch 7847: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 1, 5, 2])
Batch 7848: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 2, 6, 0])
Batch 7849: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 3, 0, 9])
Batch 7850: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 2, 9, 5])
Batch 7851: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 2, 4, 4])
Batch 7852: img = torch.Size([4, 1, 28, 28]), cond = tensor([1, 4, 2, 6])
Batch 7853: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 5, 3, 5])
Batch 7854: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 1, 0, 1])
Batch 7855: img = torch.Size([4, 1, 28, 28]), cond = tensor([9, 8, 9, 7])
Batch 7856: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 6, 6, 7])
Batch 7857: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 4, 1, 6])
Batch 7858: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 4, 6, 5])
Batch 7859: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 3, 1, 9])
Batch 7860: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 5, 8, 6])
Batch 7861: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 4, 8, 9])
Batch 7862: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 3, 5, 8])
Batch 7863: img = torch.Size([4, 1, 28, 28]), cond = tensor([8, 0, 0, 6])
......
相关推荐
坐吃山猪1 小时前
Python进度条
linux·服务器·python
l1t1 小时前
四种python工具包用SQL查询csv和parquet文件的方法比较
大数据·python·sql
清水白石0081 小时前
Python 并发三剑客:多线程、多进程与协程的实战抉择
java·服务器·python
2301_793804691 小时前
更优雅的测试:Pytest框架入门
jvm·数据库·python
zhangfeng11331 小时前
国家超算中心 命令行是否会消耗算力卡,找不到显卡,是否需要退出
人工智能·深度学习
dinl_vin2 小时前
python:常用的基础工具包
开发语言·python
renhongxia12 小时前
PostTrainBench:LLM 代理能否自动化 LLM 后培训?
运维·人工智能·深度学习·机器学习·架构·自动化·transformer
wefly20172 小时前
无需安装、开箱即用!m3u8live.cn 在线 HLS 播放器,调试直播流效率翻倍
前端·后端·python·前端开发工具·后端开发工具
2301_815482932 小时前
用Python实现自动化的Web测试(Selenium)
jvm·数据库·python
将心ONE2 小时前
melo tts安装使用
python