小土堆- P5-笔记

Dataset = 数据仓库,DataLoader = 搬运工


比喻:

复制代码
Dataset    = 一整箱苹果(所有数据)
DataLoader = 每次拿几个苹果出来(分批取数据)

代码理解:

复制代码
from torch.utils.data import Dataset, DataLoader

# Dataset:定义数据怎么存、怎么取
class MyDataset(Dataset):
    def __init__(self):
        self.data = [1, 2, 3, 4, 5, 6, 7, 8]   # 所有数据
    
    def __len__(self):
        return len(self.data)        # 一共多少个
    
    def __getitem__(self, idx):
        return self.data[idx]        # 取第 idx 个

# DataLoader:分批次取
dataset = MyDataset()                         # 8个数据
loader = DataLoader(dataset, batch_size=2)    # 每次取2个

for batch in loader:
    print(batch)   # 输出: [1,2], [3,4], [5,6], [7,8]

为什么要分开?

角色 职责
Dataset 告诉我"数据在哪、怎么读"
DataLoader 告诉我"每次取几个、要不要打乱"

一句话:Dataset 存数据,DataLoader 分批喂给模型。

用已经有的数据集

mydataset.py 里写这个:

复制代码
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class MyDataset(Dataset):
    def __init__(self, root, split='train'):
        """
        root: 数据集路径,如 '/path/to/SECOND'
        split: 'train', 'val', 或 'test'
        """
        self.dir_t1 = os.path.join(root, split, 't1')
        self.dir_t2 = os.path.join(root, split, 't2')
        self.dir_label = os.path.join(root, split, 'change')
        
        # 获取所有图片名
        self.images = os.listdir(self.dir_t1)
        
        # 预处理
        self.transform = transforms.ToTensor()
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        name = self.images[idx]
        
        # 读图
        img_t1 = Image.open(os.path.join(self.dir_t1, name))
        img_t2 = Image.open(os.path.join(self.dir_t2, name))
        label = Image.open(os.path.join(self.dir_label, name))
        
        # 转 Tensor
        img_t1 = self.transform(img_t1)
        img_t2 = self.transform(img_t2)
        label = self.transform(label)
        
        return img_t1, img_t2, label


# 测试代码
if __name__ == '__main__':
    root = '/path/to/SECOND'  # ← 改成你的路径
    
    dataset = MyDataset(root, split='test')
    print("数据量:", len(dataset))
    
    loader = DataLoader(dataset, batch_size=2, shuffle=True)
    
    for t1, t2, label in loader:
        print("t1 形状:", t1.shape)
        print("t2 形状:", t2.shape)
        print("label 形状:", label.shape)
        break

然后运行:

复制代码
python mydataset.py

预期输出:

复制代码
数据量: 1000(或其他数字)
t1 形状: torch.Size([2, 3, H, W])
t2 形状: torch.Size([2, 3, H, W])
label 形状: torch.Size([2, 1, H, W])

记得改 root 路径,改成本地/服务器上 SECOND 数据集的实际位置。

读数据用 CPU,训练模型才用 GPU。

相关推荐
子午9 小时前
【2026计算机毕设~AI项目】鸟类识别系统~Python+深度学习+人工智能+图像识别+算法模型
图像处理·人工智能·python·深度学习
历程里程碑9 小时前
Linxu14 进程一
linux·c语言·开发语言·数据结构·c++·笔记·算法
传说故事9 小时前
【论文自动阅读】Goal Force: 教视频模型实现Physics-Conditioned Goals
人工智能·深度学习·视频生成
FPGA小c鸡9 小时前
【FPGA深度学习加速】RNN与LSTM硬件加速完全指南:从算法原理到硬件实现
rnn·深度学习·fpga开发
三水不滴10 小时前
Redis 持久化机制
数据库·经验分享·redis·笔记·缓存·性能优化
不断进步的咕咕怪10 小时前
meme分析
笔记
中屹指纹浏览器10 小时前
进程级沙箱隔离与WebGL指纹抗识别:指纹浏览器核心技术难点与工程落地
经验分享·笔记
sayang_shao10 小时前
Rust多线程编程学习笔记
笔记·学习·rust
进阶的猪10 小时前
Qt学习笔记
笔记·学习
mango_mangojuice10 小时前
Linux学习笔记 1.19
linux·服务器·数据库·笔记·学习