python/pytorch读取数据集

MNIST数据集

MNIST数据集包含了6万张手写数字([1,28,28]尺寸),以特殊格式存储。本文首先将MNIST数据集另存为png格式,然后再读取png格式图片,开展后续训练

另存为png格式

python 复制代码
import torch
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import models, transforms
from torchvision.utils import save_image
from PIL import Image

#将MNIST数据集转换为图片
tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1
datasetMNIST = MNIST("./data", train=True, download=True, transform=tf)
pbar = tqdm(datasetMNIST)
for index, (img,cl) in enumerate(pbar):
   save_image(img, f"./data/MNIST_PNG/x/{index}.png")
   # 以写入模式打开文件
   with open(f"./data/MNIST_PNG/c/{index}.txt", "w", encoding="utf-8") as file:
        # 将字符串写入文件
        file.write(f"{cl}")

注意:MNIST源数据存放在./data文件下,如果没有数据也没关系,代码会自动从网上下载。另存为png的数据放在了./data/MNIST_PNG/文件下。子文件夹x存放6万张图片,子文件夹c存放6万个文本文件,每个文本文件内有一行字符串,说明该对应的手写数字是几(标签)。

读取png格式数据集

python 复制代码
class MyMNISTDataset(Dataset):
   def __init__(self, data):
       self.data = data

   def __len__(self):
       return len(self.data)

   def __getitem__(self, idx):
       x = self.data[idx][0] #图像
       y = self.data[idx][1] #标签
       return x, y
   
def load_data(dataNum=60000):
    data = []
    pbar = tqdm(range(dataNum))
    for i in pbar:
        # 指定图片路径
        image_path = f'./data/MNIST_PNG/x/{i}.png'
        cond_path=f'./data/MNIST_PNG/c/{i}.txt'
        # 定义图像预处理
        preprocess = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),  # 将图像转换为灰度图像(单通道)
        transforms.ToTensor()
        ])
        # 使用预处理加载图像
        image_tensor = preprocess(Image.open(image_path))
        # 加载条件文档(tag)
        with open(cond_path, 'r') as file:
            line = file.readline()
            number = int(line)  # 将字符串转换为整数,图像的类别
            data.append((image_tensor, number))
    return data
   

data=load_data(60000)
# 创建数据集实例
dataset = MyMNISTDataset(data)

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
pbar = tqdm(dataloader)

for index, (img,cond) in enumerate(pbar):
    #这里对每一批进行训练...
    print(f"Batch {index}: img = {img.shape}, cond = {cond}")

load_data函数用于读取数据文件,返回一个data张量。data张量又被用于构造MyMNISTDataset类的对象datasetdataset对象又被DataLoader函数转换为dataloader

dataloader事实上按照batch将数据集进行了分割,4张图片一组进行训练。上述代码的输出如下:

bash 复制代码
......
Batch 7847: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 1, 5, 2])
Batch 7848: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 2, 6, 0])
Batch 7849: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 3, 0, 9])
Batch 7850: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 2, 9, 5])
Batch 7851: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 2, 4, 4])
Batch 7852: img = torch.Size([4, 1, 28, 28]), cond = tensor([1, 4, 2, 6])
Batch 7853: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 5, 3, 5])
Batch 7854: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 1, 0, 1])
Batch 7855: img = torch.Size([4, 1, 28, 28]), cond = tensor([9, 8, 9, 7])
Batch 7856: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 6, 6, 7])
Batch 7857: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 4, 1, 6])
Batch 7858: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 4, 6, 5])
Batch 7859: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 3, 1, 9])
Batch 7860: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 5, 8, 6])
Batch 7861: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 4, 8, 9])
Batch 7862: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 3, 5, 8])
Batch 7863: img = torch.Size([4, 1, 28, 28]), cond = tensor([8, 0, 0, 6])
......
相关推荐
CryptoRzz34 分钟前
股票数据源对接技术指南:印度尼西亚、印度、韩国
数据库·python·金融·数据分析·区块链
KangkangLoveNLP42 分钟前
Llama:开源的急先锋
人工智能·深度学习·神经网络·算法·机器学习·自然语言处理·llama
胖哥真不错1 小时前
Python实现NOA星雀优化算法优化卷积神经网络CNN回归模型项目实战
python·cnn·卷积神经网络·项目实战·cnn回归模型·noa星雀优化算法
久邦科技1 小时前
《Deepseek从入门到精通》清华大学中文pdf完整版
人工智能·深度学习·机器学习
来自于狂人1 小时前
给大模型“贴膏药”:LoRA微调原理说明书
人工智能·深度学习·transformer
love530love2 小时前
【笔记】记一次PyCharm的问题反馈
ide·人工智能·windows·笔记·python·pycharm
梦醒沉醉2 小时前
MCP(一)——QuickStart
python·mcp
照物华2 小时前
httpx[http2] 和 httpx 的核心区别及使用场景如下
python·httpx
山海不说话2 小时前
PyGame游戏开发(入门知识+组件拆分+历史存档/回放+人机策略)
开发语言·python·pygame
tyatyatya2 小时前
MATLAB中进行深度学习网络训练的模型评估步骤
网络·深度学习·matlab