小土堆- P5-笔记

Dataset = 数据仓库,DataLoader = 搬运工


比喻:

复制代码
Dataset    = 一整箱苹果(所有数据)
DataLoader = 每次拿几个苹果出来(分批取数据)

代码理解:

复制代码
from torch.utils.data import Dataset, DataLoader

# Dataset:定义数据怎么存、怎么取
class MyDataset(Dataset):
    def __init__(self):
        self.data = [1, 2, 3, 4, 5, 6, 7, 8]   # 所有数据
    
    def __len__(self):
        return len(self.data)        # 一共多少个
    
    def __getitem__(self, idx):
        return self.data[idx]        # 取第 idx 个

# DataLoader:分批次取
dataset = MyDataset()                         # 8个数据
loader = DataLoader(dataset, batch_size=2)    # 每次取2个

for batch in loader:
    print(batch)   # 输出: [1,2], [3,4], [5,6], [7,8]

为什么要分开?

角色 职责
Dataset 告诉我"数据在哪、怎么读"
DataLoader 告诉我"每次取几个、要不要打乱"

一句话:Dataset 存数据,DataLoader 分批喂给模型。

用已经有的数据集

mydataset.py 里写这个:

复制代码
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class MyDataset(Dataset):
    def __init__(self, root, split='train'):
        """
        root: 数据集路径,如 '/path/to/SECOND'
        split: 'train', 'val', 或 'test'
        """
        self.dir_t1 = os.path.join(root, split, 't1')
        self.dir_t2 = os.path.join(root, split, 't2')
        self.dir_label = os.path.join(root, split, 'change')
        
        # 获取所有图片名
        self.images = os.listdir(self.dir_t1)
        
        # 预处理
        self.transform = transforms.ToTensor()
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        name = self.images[idx]
        
        # 读图
        img_t1 = Image.open(os.path.join(self.dir_t1, name))
        img_t2 = Image.open(os.path.join(self.dir_t2, name))
        label = Image.open(os.path.join(self.dir_label, name))
        
        # 转 Tensor
        img_t1 = self.transform(img_t1)
        img_t2 = self.transform(img_t2)
        label = self.transform(label)
        
        return img_t1, img_t2, label


# 测试代码
if __name__ == '__main__':
    root = '/path/to/SECOND'  # ← 改成你的路径
    
    dataset = MyDataset(root, split='test')
    print("数据量:", len(dataset))
    
    loader = DataLoader(dataset, batch_size=2, shuffle=True)
    
    for t1, t2, label in loader:
        print("t1 形状:", t1.shape)
        print("t2 形状:", t2.shape)
        print("label 形状:", label.shape)
        break

然后运行:

复制代码
python mydataset.py

预期输出:

复制代码
数据量: 1000(或其他数字)
t1 形状: torch.Size([2, 3, H, W])
t2 形状: torch.Size([2, 3, H, W])
label 形状: torch.Size([2, 1, H, W])

记得改 root 路径,改成本地/服务器上 SECOND 数据集的实际位置。

读数据用 CPU,训练模型才用 GPU。

相关推荐
小瑞瑞acd17 小时前
【小瑞瑞精讲】卷积神经网络(CNN):从入门到精通,计算机如何“看”懂世界?
人工智能·python·深度学习·神经网络·机器学习
驭渊的小故事17 小时前
简单模板笔记
数据结构·笔记·算法
芷栀夏18 小时前
CANN ops-math:揭秘异构计算架构下数学算子的低延迟高吞吐优化逻辑
人工智能·深度学习·神经网络·cann
孤狼warrior18 小时前
YOLO目标检测 一千字解析yolo最初的摸样 模型下载,数据集构建及模型训练代码
人工智能·python·深度学习·算法·yolo·目标检测·目标跟踪
Rorsion18 小时前
PyTorch实现线性回归
人工智能·pytorch·线性回归
机器学习之心18 小时前
TCN-Transformer-BiGRU组合模型回归+SHAP分析+新数据预测+多输出!深度学习可解释分析
深度学习·回归·transformer·shap分析
LLWZAI18 小时前
让朱雀AI检测无法判断的AI公众号文章,当创作者开始与算法「躲猫猫」
大数据·人工智能·深度学习
智者知已应修善业19 小时前
【洛谷P9975奶牛被病毒传染最少数量推导,导出多样例】2025-2-26
c语言·c++·经验分享·笔记·算法·推荐算法
Junlan2719 小时前
Cursor使用入门及连接服务器方法(更新中)
服务器·人工智能·笔记
霖大侠19 小时前
【无标题】
人工智能·深度学习·机器学习