PyTorch深度学习(小土堆)笔记3:小土堆 Dataset 类实战笔记,99% 的新手都踩坑!看完秒懂数据加载底层逻辑!

文章目录

  • [🚀 PyTorch 学习笔记:手把手带你解锁 Dataset 类!](#🚀 PyTorch 学习笔记:手把手带你解锁 Dataset 类!)
  • [💡 核心概念:Dataset vs DataLoader](#💡 核心概念:Dataset vs DataLoader)
    • [📸 基础补给站:如何用 Python "看见"图片?>>PIL](#📸 基础补给站:如何用 Python “看见”图片?>>PIL)
      • [🔍 深度拆解:新手必知的 3 个细节](#🔍 深度拆解:新手必知的 3 个细节)
        • [① 路径的"陷阱" 🪤](#① 路径的“陷阱” 🪤)
        • [② img 到底是个啥? 🧬](#② img 到底是个啥? 🧬)
    • [📂 进阶补给站:数据集的"档案管理员">>os](#📂 进阶补给站:数据集的“档案管理员”>>os)
      • [🔍 深度拆解:为什么这一步至关重要?](#🔍 深度拆解:为什么这一步至关重要?)
        • [① 自动化索引 🤖](#① 自动化索引 🤖)
        • [② 路径拼接的"智慧" 🧠](#② 路径拼接的“智慧” 🧠)
        • [③ listdir 的小细节 📌](#③ listdir 的小细节 📌)
  • [🛠️ 实战演练:从零实现一个 Dataset(以"蚂蚁🐜 与 蜜蜂🐝"为例)](#🛠️ 实战演练:从零实现一个 Dataset(以“蚂蚁🐜 与 蜜蜂🐝”为例))
    • [1 准备工作:导入工具包](#1 准备工作:导入工具包)
    • [2 核心类:myDataset 的结构](#2 核心类:myDataset 的结构)
    • [3🏃 运行测试:见证奇迹的时刻](#3🏃 运行测试:见证奇迹的时刻)
  • [🌟 关于练手数据集下载](#🌟 关于练手数据集下载)

🚀 PyTorch 学习笔记:手把手带你解锁 Dataset 类!

Hello 大家好!这是我跟着"小土堆"深度学习打卡的第3条笔记。今天要攻克的是 PyTorch 数据处理的核心 ------ Dataset 类。不管是做图像分类还是目标检测,第一步永远是把数据"喂"给模型。而 Dataset 就是那个最辛苦的"工具人"。


💡 核心概念:Dataset vs DataLoader

在正式写代码前,我们先用一个生动的垃圾处理比喻来搞清楚这两个家伙:

  • Dataset(数据集): 负责"提取垃圾"并"给垃圾编号"。它提供了一种方式,让我们能通过编号找到对应的垃圾(数据)和它的种类(标签)。
  • DataLoader(数据加载器): 负责"打包垃圾"并"装上垃圾车"。它把 Dataset 整理好的数据按批次组合,喂给神经网络。

📸 基础补给站:如何用 Python "看见"图片?>>PIL

在定义 Dataset 之前,我们必须先学会如何手动读取一张图片。这里用到的是最常用的库:PIL

python 复制代码
from PIL import Image

# 第一步:指定图片所在的"书架位置"(第xx排第xx列的第xx本书)
# 建议加上 r 防止转义字符报错
img_path = r"E:\P-Pytorch学习\dataset\train\ants\0013035.jpg"

# 第二步:把图片加载进内存,变成一个对象
# 就像把第xx排第xx列的第xx本书(img_path)从书架上拿下来,还没开始读,但已经拿在手里了
img = Image.open(img_path)

# 第三步:调用系统默认查看器打开图片
# 运行这一行,你的电脑会自动弹出一个窗口显示这张图,你就可以开始看书了!
img.show()

🔍 深度拆解:新手必知的 3 个细节

① 路径的"陷阱" 🪤

在 Windows 系统下,文件夹路径常用反斜杠 \。但在 Python 字符串里,\ 有特殊含义(比如 \n 代表换行)。

  • 解决方法: 在路径字符串前加一个 r(代表 raw string,原始字符串)。
  • 错误示范:"C:\new_folder\test.jpg"(这里的 \n 会被误认)
  • 正确示范: r"C:\new_folder\test.jpg"
② img 到底是个啥? 🧬

当你执行 img = Image.open(img_path) 后,img 不仅仅是一个名字,它包含了这张图片的所有属性。你可以试着打印一下:

  • print(img.format):查看格式(如 JPEG, PNG)
  • print(img.size):查看尺寸(宽, 高)
  • print(img.mode):查看色彩模式(如 RGB, L)

📂 进阶补给站:数据集的"档案管理员">>os

在 Dataset 类的 init 初始化阶段,我们需要准确地定位到存放图片的文件夹。

python 复制代码
#假设你的文件夹结构u是这样的
dataset/
└── train/
    ├── ants/ (100张图)
    └── bees/ (100张图)
python 复制代码
import os

# 第一招:获取设定根目录和子目录
root_dir = r"dataset/train"   # 总的训练集目录
label_dir = "ants"            # 具体的分类目录(也是标签名)

# 第二招:【路径拼接】 os.path.join
# 为什么要用它?因为不同系统(Windows/Linux)的斜杠方向不一样
# 它会自动帮你处理成:dataset/train/ants (Linux) 或 dataset/train\ants (Win)
path = os.path.join(root_dir, label_dir)

# 第三招:【获取文件列表】 os.listdir
# 就像打开文件夹扫视一眼,把里面所有文件的名字存成一个列表
img_path_list = os.listdir(path)

# 看看我们拿到了什么?
print(img_path_list) 
# 输出示例:['0013035.jpg', '5650394.jpg', ...]

🔍 深度拆解:为什么这一步至关重要?

① 自动化索引 🤖

有了 img_path_list 这个列表,我们在 __getitem__(self, idx) 里就可以通过 索引 idx 轻松拿到图片名字,这样,不论数据集有 100 张还是 100 万张,代码逻辑都是通用的。

python 复制代码
img_name = self.img_path_list[idx]
② 路径拼接的"智慧" 🧠

很多新手喜欢用字符串相加:path = root_dir + "/" + label_dir。

注意: 这样做非常容易出错(比如漏写斜杠)。使用 os.path.join 不仅更专业,还能保证你的代码在服务器(Linux)和本地(Windows)都能完美运行。

③ listdir 的小细节 📌

os.listdir 只会给出文件名(如 1.jpg),而不是完整路径。所以当你真正要去 Image.open 的时候,记得要用os.path.join 把路径补全哦!


🛠️ 实战演练:从零实现一个 Dataset(以"蚂蚁🐜 与 蜜蜂🐝"为例)

Dataset 类实战所做的就是要实现获取文件夹下每一张图片的地址!实现下面两个功能。

  1. 如何获取每一个数据的地址及其label
  2. 告诉我们总共有多少的数据

1 准备工作:导入工具包

python 复制代码
import os#用来处理系统文件的工具包
from torch.utils.data import Dataset#用来装数据集的工具包
from PIL import Image#用来读取图片的工具包

2 核心类:myDataset 的结构

我们要继承 PyTorch 的 Dataset 类,并重写三个关键方法:

  • __init__:初始化。相当于告诉程序:去哪找数据?
  • __getitem__:获取数据。给它一个索引(idx),它还你一张图片和标签。
  • __len__:统计长度。告诉程序:总共有多少张图?

🚩小白疑问:self 是什么?

简单来说,self 就像是一个"公共储物柜"。在 init 里存进去的东西,通过 self.xxx 的形式,在 getitem 或其他函数里也能随时取出来用。

python 复制代码
class myDataset(Dataset):
    def __init__(self, root_dir, label_dir):# 这里我要做的是:获取文件夹下所有图片名称列表
        # 初始化:通过拼接路径,获取所有图片的名称列表
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir) # 完整文件夹地址
        self.img_path = os.listdir(self.path)                 # 文件夹下所有图片名

    def __getitem__(self, idx):# 这里我要做的是:获取图片名称列表下的每一张图片
        # 1. 根据索引获取单张图片的名称
        img_name = self.img_path[idx]
        # 2. 拼接出这张图片的完整路径
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        # 3. 读取图片内容
        img = Image.open(img_item_path)
        # 4. 获取标签(这里直接用文件夹名作为标签)
        label = self.label_dir
        return img, label# 这里返回两个变量,第一个是图片,第二个为图片的标签

    def __len__(self):
        # 返回数据集的总张数
        return len(self.img_path)#查看数据集的长度len(train_dataset)

3🏃 运行测试:见证奇迹的时刻

代码写好了,让我们看看怎么实例化并使用它:

python 复制代码
# 设置路径(记得改成你自己的路径哦!)
root_dir = r'E:\dataset\train'
ants_label_dir = 'ants'
bees_label_dir = 'bees'

# 1. 创建蚂蚁数据集实例
ants_dataset = myDataset(root_dir, ants_label_dir)

# 2. 创建蜜蜂数据集实例
bees_dataset = myDataset(root_dir, bees_label_dir)

# 3. 关键操作:数据集相加!这个时候如果我有个蜜蜂的,而我想要的实整个数据集(蚂蚁加蜜蜂)!
# 这样就把蚂蚁和蜜蜂的数据合二为一了
train_dataset = ants_dataset + bees_dataset

# 查看总长度
print(f"训练集总长度: {len(train_dataset)}")

# 获取第一张图,因为我设置的是返回两个变量,所以这里也是两个
img, label = train_dataset[0]
img.show()
print(f"这张图的标签是: {label}")

🌟 关于练手数据集下载

我是从bilibili小土堆的视频下的链接下载的,大家可以直接去下载,或者我放到了下面也可以下载。

代码:https://github.com/xiaotudui/pytorch-tutorial

国内仓库:https://gitcode.com/xiaotudui1/pytorch-tutorial/

蚂蚁蜜蜂/练手数据集:https://pan.quark.cn/s/e4c425fc4c0d

课程资源:https://pan.quark.cn/s/c59a198b005d

相关推荐
会飞的不留神1 小时前
【图形学笔记】概率密度函数的通俗理解和应用
笔记
悠哉悠哉愿意1 小时前
【强化学习学习笔记】马尔科夫决策过程
笔记·学习·交互·强化学习
山岚的运维笔记2 小时前
SQL Server笔记 -- 第52章 拆分字符串函数
数据库·笔记·sql·mysql·microsoft·sqlserver
陈天伟教授2 小时前
人工智能应用- 搜索引擎:02. 搜索引擎发展史
人工智能·深度学习·神经网络·游戏·搜索引擎·机器翻译
小lo想吃棒棒糖2 小时前
思路启发:超越Transformer的无限上下文:SSM-Attention混合架构的理论分析
人工智能·pytorch·python
pop_xiaoli2 小时前
effective-Objective-C 第三章阅读笔记
笔记·ios·objective-c
勾股导航2 小时前
灰狼优化算法GWO
人工智能·深度学习·机器学习
盼小辉丶2 小时前
Transformer实战——Transformer跨语言零样本学习
深度学习·transformer·零样本学习
sheyuDemo2 小时前
关于深度学习的d2l库的安装
人工智能·python·深度学习·机器学习·numpy