01-PyTorch加载数据初认识(dataset运用)

一、先看整体结构

这是一个标准的 PyTorch 自定义数据集模板,核心分为 3 个部分:

  1. 类定义 + __init__:初始化路径和数据列表
  2. __getitem__:按索引读取单张图片和标签
  3. __len__:返回数据集总长度

二、逐行代码讲解

1. 导入依赖

python运行

复制代码
from torch.utils.data import Dataset
from PIL import Image
import os
  • Dataset:PyTorch 提供的抽象基类,所有自定义数据集都要继承它,这样才能被 DataLoader 识别;
  • Image:来自 PIL 库,用来读取、处理图片;
  • os:用来拼接文件路径、读取目录下的文件名,处理本地文件系统。

2. 类定义与初始化方法 __init__

python运行

复制代码
class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)
  • class MyData(Dataset):定义一个新的类 MyData,继承自 Dataset
  • def __init__(self, root_dir, label_dir):类的构造函数,创建数据集对象时会自动执行,接收两个参数:
    • root_dir:数据集的根目录,比如 dataset/train
    • label_dir:类别目录,比如 ants(代表蚂蚁的图片文件夹);
  • self.root_dir = root_dir:把根目录保存到实例变量中,后续可以在类的其他方法里调用;
  • self.label_dir = label_dir:把类别目录保存到实例变量中;
  • self.path = os.path.join(self.root_dir, self.label_dir):拼接根目录和类别目录,得到完整的图片文件夹路径,比如 dataset/train/ants
  • self.img_path = os.listdir(self.path):读取 dataset/train/ants 目录下的所有文件名,存入 self.img_path 列表,后续可以按索引读取。

3. 核心方法 __getitem__

python

运行

复制代码
def __getitem__(self, idx):
    img_name = self.img_path[idx]
    img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
    img = Image.open(img_item_path)
    label = self.label_dir
    return img, label
  • def __getitem__(self, idx):PyTorch 规定的方法,按索引读取数据,idx 就是索引(从 0 开始);
  • img_name = self.img_path[idx]:根据索引 idx,从 self.img_path 列表中取出对应的图片文件名;
  • img_item_path = os.path.join(self.root_dir, self.label_dir, img_name):拼接根目录、类别目录和图片文件名,得到单张图片的完整路径,比如 dataset/train/ants/001.jpg
  • img = Image.open(img_item_path):用 PIL 读取图片,得到一个 Image 对象;
  • label = self.label_dir:把类别目录名(比如 ants)作为标签;
  • return img, label:返回图片和对应的标签,后续模型训练时会接收这两个值。

4. 长度方法 __len__

python

运行

复制代码
def __len__(self):
    return len(self.img_path)
  • def __len__(self):PyTorch 规定的方法,返回数据集的总样本数;
  • return len(self.img_path)self.img_path 是图片文件名列表,len(self.img_path) 就是图片总数,比如 dataset/train/ants 目录下有 124 张图片,就返回 124。

三、代码执行流程(结合你的控制台)

python

运行

复制代码
root_dir = "dataset/train"
ants_label_dir = "ants"
ants_dataset = MyData(root_dir, ants_label_dir)
  1. 创建 MyData 对象,传入根目录和类别目录;
  2. 自动执行 __init__:拼接路径、读取图片列表;
  3. 当你调用 len(ants_dataset) 时,会执行 __len__,返回图片总数;
  4. 当你调用 ants_dataset[0] 时,会执行 __getitem__(0),返回第 1 张图片和标签。

四、补充说明与小优化

  1. 标签处理 :这段代码里直接用 label = self.label_dir,后续训练时,模型需要的是数字标签,比如 ants=0bees=1,可以改成:

    python

    运行

    复制代码
    # 比如 ants 标签设为 0
    label = 0
  2. 路径拼接os.path.join 是跨平台的,Windows、Linux 都能正常拼接路径,避免手动写 /\ 出错;

  3. 遥感影像适配 :如果你后续要处理 .tif 格式的遥感影像,把 Image.open 换成 rasterio.open 即可,核心逻辑不变。

相关推荐
火山引擎开发者社区19 小时前
没有长期记忆,Agent 谈何持续进化?一图看懂火山 Mem0:解锁 Agent 持续学习与进化之路
人工智能
冬奇Lab1 天前
Workflow 系列(06):安全——跨步骤注入传播与四层防御
人工智能·工作流引擎
冬奇Lab1 天前
每日一个开源项目(第149篇):RAG-Anything - 把图片、表格、公式当成一等公民的多模态 RAG 框架
人工智能·开源
米小虾1 天前
AI Agent 安全实战指南:当智能体开始"不听话",开发者该如何应对?
人工智能·安全·agent
IT_陈寒1 天前
Vite的热更新突然不香了,排查三小时差点砸键盘
前端·人工智能·后端
用户8356290780511 天前
Python 实现 PDF 文件加密与解密方法
后端·python
用户8356290780511 天前
使用 Python 冻结与拆分 Excel 窗格教程
后端·python
阿里云大数据AI技术1 天前
构建高转化海外电商搜索:阿里云OpenSearch行业算法版的全链路智能优化策略实战
人工智能·搜索引擎
Awu12271 天前
⚡从零开发 Agent CLI(五)实现一个可治理、可扩展的工具系统
前端·人工智能·claude