小土堆- P5-笔记

Dataset = 数据仓库,DataLoader = 搬运工


比喻:

复制代码
Dataset    = 一整箱苹果(所有数据)
DataLoader = 每次拿几个苹果出来(分批取数据)

代码理解:

复制代码
from torch.utils.data import Dataset, DataLoader

# Dataset:定义数据怎么存、怎么取
class MyDataset(Dataset):
    def __init__(self):
        self.data = [1, 2, 3, 4, 5, 6, 7, 8]   # 所有数据
    
    def __len__(self):
        return len(self.data)        # 一共多少个
    
    def __getitem__(self, idx):
        return self.data[idx]        # 取第 idx 个

# DataLoader:分批次取
dataset = MyDataset()                         # 8个数据
loader = DataLoader(dataset, batch_size=2)    # 每次取2个

for batch in loader:
    print(batch)   # 输出: [1,2], [3,4], [5,6], [7,8]

为什么要分开?

角色 职责
Dataset 告诉我"数据在哪、怎么读"
DataLoader 告诉我"每次取几个、要不要打乱"

一句话:Dataset 存数据,DataLoader 分批喂给模型。

用已经有的数据集

mydataset.py 里写这个:

复制代码
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class MyDataset(Dataset):
    def __init__(self, root, split='train'):
        """
        root: 数据集路径,如 '/path/to/SECOND'
        split: 'train', 'val', 或 'test'
        """
        self.dir_t1 = os.path.join(root, split, 't1')
        self.dir_t2 = os.path.join(root, split, 't2')
        self.dir_label = os.path.join(root, split, 'change')
        
        # 获取所有图片名
        self.images = os.listdir(self.dir_t1)
        
        # 预处理
        self.transform = transforms.ToTensor()
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        name = self.images[idx]
        
        # 读图
        img_t1 = Image.open(os.path.join(self.dir_t1, name))
        img_t2 = Image.open(os.path.join(self.dir_t2, name))
        label = Image.open(os.path.join(self.dir_label, name))
        
        # 转 Tensor
        img_t1 = self.transform(img_t1)
        img_t2 = self.transform(img_t2)
        label = self.transform(label)
        
        return img_t1, img_t2, label


# 测试代码
if __name__ == '__main__':
    root = '/path/to/SECOND'  # ← 改成你的路径
    
    dataset = MyDataset(root, split='test')
    print("数据量:", len(dataset))
    
    loader = DataLoader(dataset, batch_size=2, shuffle=True)
    
    for t1, t2, label in loader:
        print("t1 形状:", t1.shape)
        print("t2 形状:", t2.shape)
        print("label 形状:", label.shape)
        break

然后运行:

复制代码
python mydataset.py

预期输出:

复制代码
数据量: 1000(或其他数字)
t1 形状: torch.Size([2, 3, H, W])
t2 形状: torch.Size([2, 3, H, W])
label 形状: torch.Size([2, 1, H, W])

记得改 root 路径,改成本地/服务器上 SECOND 数据集的实际位置。

读数据用 CPU,训练模型才用 GPU。

相关推荐
码农三叔1 天前
(10-3)大模型时代的人形机器人感知:多模态Transformer
深度学习·机器人·大模型·transformer·人形机器人
左左右右左右摇晃1 天前
Java异常处理笔记
笔记
今儿敲了吗1 天前
44| 汉诺塔问题
数据结构·c++·笔记·学习·算法·深度优先
bryant_meng1 天前
【AI】《Explainable Machine Learning》
人工智能·深度学习·机器学习·计算机视觉·可解释性
就叫你天选之人啦1 天前
GBDT系列八股(XGBoost、LightGBM)
人工智能·深度学习·学习·机器学习
CoderIsArt1 天前
StarCoder-3B微调和RAG的技术原理
人工智能·深度学习·机器学习
黄嚯嚯1 天前
从字段堆砌到类型建模:一个 PricingDetails 的重构实践
java·笔记
智算菩萨1 天前
通往AGI之路:基于性能与通用性的等级划分框架深度解析
论文阅读·人工智能·深度学习·ai·agi
困死,根本不会1 天前
蓝桥杯 Python 备考全攻略:从入门到进阶的学习路线
笔记·python·学习·算法·蓝桥杯
郝学胜-神的一滴1 天前
深度学习入门基石:PyTorch张量核心技术全解析
人工智能·pytorch·python·深度学习·算法·机器学习