python/pytorch读取数据集

MNIST数据集

MNIST数据集包含了6万张手写数字([1,28,28]尺寸),以特殊格式存储。本文首先将MNIST数据集另存为png格式,然后再读取png格式图片,开展后续训练

另存为png格式

python 复制代码
import torch
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import models, transforms
from torchvision.utils import save_image
from PIL import Image

#将MNIST数据集转换为图片
tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1
datasetMNIST = MNIST("./data", train=True, download=True, transform=tf)
pbar = tqdm(datasetMNIST)
for index, (img,cl) in enumerate(pbar):
   save_image(img, f"./data/MNIST_PNG/x/{index}.png")
   # 以写入模式打开文件
   with open(f"./data/MNIST_PNG/c/{index}.txt", "w", encoding="utf-8") as file:
        # 将字符串写入文件
        file.write(f"{cl}")

注意:MNIST源数据存放在./data文件下,如果没有数据也没关系,代码会自动从网上下载。另存为png的数据放在了./data/MNIST_PNG/文件下。子文件夹x存放6万张图片,子文件夹c存放6万个文本文件,每个文本文件内有一行字符串,说明该对应的手写数字是几(标签)。

读取png格式数据集

python 复制代码
class MyMNISTDataset(Dataset):
   def __init__(self, data):
       self.data = data

   def __len__(self):
       return len(self.data)

   def __getitem__(self, idx):
       x = self.data[idx][0] #图像
       y = self.data[idx][1] #标签
       return x, y
   
def load_data(dataNum=60000):
    data = []
    pbar = tqdm(range(dataNum))
    for i in pbar:
        # 指定图片路径
        image_path = f'./data/MNIST_PNG/x/{i}.png'
        cond_path=f'./data/MNIST_PNG/c/{i}.txt'
        # 定义图像预处理
        preprocess = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),  # 将图像转换为灰度图像(单通道)
        transforms.ToTensor()
        ])
        # 使用预处理加载图像
        image_tensor = preprocess(Image.open(image_path))
        # 加载条件文档(tag)
        with open(cond_path, 'r') as file:
            line = file.readline()
            number = int(line)  # 将字符串转换为整数,图像的类别
            data.append((image_tensor, number))
    return data
   

data=load_data(60000)
# 创建数据集实例
dataset = MyMNISTDataset(data)

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
pbar = tqdm(dataloader)

for index, (img,cond) in enumerate(pbar):
    #这里对每一批进行训练...
    print(f"Batch {index}: img = {img.shape}, cond = {cond}")

load_data函数用于读取数据文件,返回一个data张量。data张量又被用于构造MyMNISTDataset类的对象datasetdataset对象又被DataLoader函数转换为dataloader

dataloader事实上按照batch将数据集进行了分割,4张图片一组进行训练。上述代码的输出如下:

bash 复制代码
......
Batch 7847: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 1, 5, 2])
Batch 7848: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 2, 6, 0])
Batch 7849: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 3, 0, 9])
Batch 7850: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 2, 9, 5])
Batch 7851: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 2, 4, 4])
Batch 7852: img = torch.Size([4, 1, 28, 28]), cond = tensor([1, 4, 2, 6])
Batch 7853: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 5, 3, 5])
Batch 7854: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 1, 0, 1])
Batch 7855: img = torch.Size([4, 1, 28, 28]), cond = tensor([9, 8, 9, 7])
Batch 7856: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 6, 6, 7])
Batch 7857: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 4, 1, 6])
Batch 7858: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 4, 6, 5])
Batch 7859: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 3, 1, 9])
Batch 7860: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 5, 8, 6])
Batch 7861: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 4, 8, 9])
Batch 7862: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 3, 5, 8])
Batch 7863: img = torch.Size([4, 1, 28, 28]), cond = tensor([8, 0, 0, 6])
......
相关推荐
阿尔的代码屋4 小时前
[大模型实战 07] 基于 LlamaIndex ReAct 框架手搓全自动博客监控 Agent
人工智能·python
AI探索者1 天前
LangGraph StateGraph 实战:状态机聊天机器人构建指南
python
AI探索者1 天前
LangGraph 入门:构建带记忆功能的天气查询 Agent
python
FishCoderh1 天前
Python自动化办公实战:批量重命名文件,告别手动操作
python
躺平大鹅1 天前
Python函数入门详解(定义+调用+参数)
python
曲幽1 天前
我用FastAPI接ollama大模型,差点被asyncio整崩溃(附对话窗口实战)
python·fastapi·web·async·httpx·asyncio·ollama
两万五千个小时1 天前
落地实现 Anthropic Multi-Agent Research System
人工智能·python·架构
CoovallyAIHub1 天前
仿生学突破:SILD模型如何让无人机在电力线迷宫中发现“隐形威胁”
深度学习·算法·计算机视觉
CoovallyAIHub1 天前
从春晚机器人到零样本革命:YOLO26-Pose姿态估计实战指南
深度学习·算法·计算机视觉
CoovallyAIHub1 天前
Le-DETR:省80%预训练数据,这个实时检测Transformer刷新SOTA|Georgia Tech & 北交大
深度学习·算法·计算机视觉