01-PyTorch加载数据初认识(dataset运用)

一、先看整体结构

这是一个标准的 PyTorch 自定义数据集模板,核心分为 3 个部分:

  1. 类定义 + __init__:初始化路径和数据列表
  2. __getitem__:按索引读取单张图片和标签
  3. __len__:返回数据集总长度

二、逐行代码讲解

1. 导入依赖

python运行

复制代码
from torch.utils.data import Dataset
from PIL import Image
import os
  • Dataset:PyTorch 提供的抽象基类,所有自定义数据集都要继承它,这样才能被 DataLoader 识别;
  • Image:来自 PIL 库,用来读取、处理图片;
  • os:用来拼接文件路径、读取目录下的文件名,处理本地文件系统。

2. 类定义与初始化方法 __init__

python运行

复制代码
class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)
  • class MyData(Dataset):定义一个新的类 MyData,继承自 Dataset
  • def __init__(self, root_dir, label_dir):类的构造函数,创建数据集对象时会自动执行,接收两个参数:
    • root_dir:数据集的根目录,比如 dataset/train
    • label_dir:类别目录,比如 ants(代表蚂蚁的图片文件夹);
  • self.root_dir = root_dir:把根目录保存到实例变量中,后续可以在类的其他方法里调用;
  • self.label_dir = label_dir:把类别目录保存到实例变量中;
  • self.path = os.path.join(self.root_dir, self.label_dir):拼接根目录和类别目录,得到完整的图片文件夹路径,比如 dataset/train/ants
  • self.img_path = os.listdir(self.path):读取 dataset/train/ants 目录下的所有文件名,存入 self.img_path 列表,后续可以按索引读取。

3. 核心方法 __getitem__

python

运行

复制代码
def __getitem__(self, idx):
    img_name = self.img_path[idx]
    img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
    img = Image.open(img_item_path)
    label = self.label_dir
    return img, label
  • def __getitem__(self, idx):PyTorch 规定的方法,按索引读取数据,idx 就是索引(从 0 开始);
  • img_name = self.img_path[idx]:根据索引 idx,从 self.img_path 列表中取出对应的图片文件名;
  • img_item_path = os.path.join(self.root_dir, self.label_dir, img_name):拼接根目录、类别目录和图片文件名,得到单张图片的完整路径,比如 dataset/train/ants/001.jpg
  • img = Image.open(img_item_path):用 PIL 读取图片,得到一个 Image 对象;
  • label = self.label_dir:把类别目录名(比如 ants)作为标签;
  • return img, label:返回图片和对应的标签,后续模型训练时会接收这两个值。

4. 长度方法 __len__

python

运行

复制代码
def __len__(self):
    return len(self.img_path)
  • def __len__(self):PyTorch 规定的方法,返回数据集的总样本数;
  • return len(self.img_path)self.img_path 是图片文件名列表,len(self.img_path) 就是图片总数,比如 dataset/train/ants 目录下有 124 张图片,就返回 124。

三、代码执行流程(结合你的控制台)

python

运行

复制代码
root_dir = "dataset/train"
ants_label_dir = "ants"
ants_dataset = MyData(root_dir, ants_label_dir)
  1. 创建 MyData 对象,传入根目录和类别目录;
  2. 自动执行 __init__:拼接路径、读取图片列表;
  3. 当你调用 len(ants_dataset) 时,会执行 __len__,返回图片总数;
  4. 当你调用 ants_dataset[0] 时,会执行 __getitem__(0),返回第 1 张图片和标签。

四、补充说明与小优化

  1. 标签处理 :这段代码里直接用 label = self.label_dir,后续训练时,模型需要的是数字标签,比如 ants=0bees=1,可以改成:

    python

    运行

    复制代码
    # 比如 ants 标签设为 0
    label = 0
  2. 路径拼接os.path.join 是跨平台的,Windows、Linux 都能正常拼接路径,避免手动写 /\ 出错;

  3. 遥感影像适配 :如果你后续要处理 .tif 格式的遥感影像,把 Image.open 换成 rasterio.open 即可,核心逻辑不变。

相关推荐
云烟成雨TD2 小时前
Agent Scope Java 2.x 系列【8】工具调用
java·人工智能·agent
abcy0712132 小时前
python fastapi celery hdfs 异步上传
python·hdfs·fastapi
TMT星球2 小时前
魔法原子上交会首秀VLA K02大模型,完成具身智能从“执行”到“理解”的能力跃迁
人工智能·算法·机器学习
Dxy12393102162 小时前
Python多线程如何操作全局变量:从踩坑到最佳实践
python
2301_764441332 小时前
番茄钟+AI:高效专注的秘密武器
人工智能·算法·数学建模·动态规划·交互
SilentSamsara2 小时前
RAG 系统入门:LangChain/LlamaIndex + Chroma 向量数据库的检索增强实战
数据库·人工智能·python·青少年编程·langchain
东方佑2 小时前
分形递归状态机 (FRSM) 实验报告-更新对比
人工智能·语言模型·自然语言处理·开源
YOLO视觉与编程2 小时前
jetson orin nano烧录jetpack7.2系统
人工智能·深度学习·yolo·目标检测·机器学习
昇腾CANN2 小时前
6月15号新课开讲|HCCL入门系列课,正式上线!
人工智能·开源·昇腾·cann