文章目录
- [🚀 PyTorch 学习笔记:手把手带你解锁 Dataset 类!](#🚀 PyTorch 学习笔记:手把手带你解锁 Dataset 类!)
- [💡 核心概念:Dataset vs DataLoader](#💡 核心概念:Dataset vs DataLoader)
-
- [📸 基础补给站:如何用 Python "看见"图片?>>PIL](#📸 基础补给站:如何用 Python “看见”图片?>>PIL)
-
- [🔍 深度拆解:新手必知的 3 个细节](#🔍 深度拆解:新手必知的 3 个细节)
-
- [① 路径的"陷阱" 🪤](#① 路径的“陷阱” 🪤)
- [② img 到底是个啥? 🧬](#② img 到底是个啥? 🧬)
- [📂 进阶补给站:数据集的"档案管理员">>os](#📂 进阶补给站:数据集的“档案管理员”>>os)
-
- [🔍 深度拆解:为什么这一步至关重要?](#🔍 深度拆解:为什么这一步至关重要?)
-
- [① 自动化索引 🤖](#① 自动化索引 🤖)
- [② 路径拼接的"智慧" 🧠](#② 路径拼接的“智慧” 🧠)
- [③ listdir 的小细节 📌](#③ listdir 的小细节 📌)
- [🛠️ 实战演练:从零实现一个 Dataset(以"蚂蚁🐜 与 蜜蜂🐝"为例)](#🛠️ 实战演练:从零实现一个 Dataset(以“蚂蚁🐜 与 蜜蜂🐝”为例))
-
- [1 准备工作:导入工具包](#1 准备工作:导入工具包)
- [2 核心类:myDataset 的结构](#2 核心类:myDataset 的结构)
- [3🏃 运行测试:见证奇迹的时刻](#3🏃 运行测试:见证奇迹的时刻)
- [🌟 关于练手数据集下载](#🌟 关于练手数据集下载)
🚀 PyTorch 学习笔记:手把手带你解锁 Dataset 类!
Hello 大家好!这是我跟着"小土堆"深度学习打卡的第3条笔记。今天要攻克的是 PyTorch 数据处理的核心 ------ Dataset 类。不管是做图像分类还是目标检测,第一步永远是把数据"喂"给模型。而 Dataset 就是那个最辛苦的"工具人"。
💡 核心概念:Dataset vs DataLoader
在正式写代码前,我们先用一个生动的垃圾处理比喻来搞清楚这两个家伙:
- Dataset(数据集): 负责"提取垃圾"并"给垃圾编号"。它提供了一种方式,让我们能通过编号找到对应的垃圾(数据)和它的种类(标签)。
- DataLoader(数据加载器): 负责"打包垃圾"并"装上垃圾车"。它把 Dataset 整理好的数据按批次组合,喂给神经网络。
📸 基础补给站:如何用 Python "看见"图片?>>PIL
在定义 Dataset 之前,我们必须先学会如何手动读取一张图片。这里用到的是最常用的库:PIL
python
from PIL import Image
# 第一步:指定图片所在的"书架位置"(第xx排第xx列的第xx本书)
# 建议加上 r 防止转义字符报错
img_path = r"E:\P-Pytorch学习\dataset\train\ants\0013035.jpg"
# 第二步:把图片加载进内存,变成一个对象
# 就像把第xx排第xx列的第xx本书(img_path)从书架上拿下来,还没开始读,但已经拿在手里了
img = Image.open(img_path)
# 第三步:调用系统默认查看器打开图片
# 运行这一行,你的电脑会自动弹出一个窗口显示这张图,你就可以开始看书了!
img.show()
🔍 深度拆解:新手必知的 3 个细节
① 路径的"陷阱" 🪤
在 Windows 系统下,文件夹路径常用反斜杠 \。但在 Python 字符串里,\ 有特殊含义(比如 \n 代表换行)。
- 解决方法: 在路径字符串前加一个 r(代表 raw string,原始字符串)。
- 错误示范:"C:\new_folder\test.jpg"(这里的 \n 会被误认)
- 正确示范: r"C:\new_folder\test.jpg"
② img 到底是个啥? 🧬
当你执行 img = Image.open(img_path) 后,img 不仅仅是一个名字,它包含了这张图片的所有属性。你可以试着打印一下:
- print(img.format):查看格式(如 JPEG, PNG)
- print(img.size):查看尺寸(宽, 高)
- print(img.mode):查看色彩模式(如 RGB, L)
📂 进阶补给站:数据集的"档案管理员">>os
在 Dataset 类的 init 初始化阶段,我们需要准确地定位到存放图片的文件夹。
python
#假设你的文件夹结构u是这样的
dataset/
└── train/
├── ants/ (100张图)
└── bees/ (100张图)
python
import os
# 第一招:获取设定根目录和子目录
root_dir = r"dataset/train" # 总的训练集目录
label_dir = "ants" # 具体的分类目录(也是标签名)
# 第二招:【路径拼接】 os.path.join
# 为什么要用它?因为不同系统(Windows/Linux)的斜杠方向不一样
# 它会自动帮你处理成:dataset/train/ants (Linux) 或 dataset/train\ants (Win)
path = os.path.join(root_dir, label_dir)
# 第三招:【获取文件列表】 os.listdir
# 就像打开文件夹扫视一眼,把里面所有文件的名字存成一个列表
img_path_list = os.listdir(path)
# 看看我们拿到了什么?
print(img_path_list)
# 输出示例:['0013035.jpg', '5650394.jpg', ...]
🔍 深度拆解:为什么这一步至关重要?
① 自动化索引 🤖
有了 img_path_list 这个列表,我们在 __getitem__(self, idx) 里就可以通过 索引 idx 轻松拿到图片名字,这样,不论数据集有 100 张还是 100 万张,代码逻辑都是通用的。
python
img_name = self.img_path_list[idx]
② 路径拼接的"智慧" 🧠
很多新手喜欢用字符串相加:path = root_dir + "/" + label_dir。
注意: 这样做非常容易出错(比如漏写斜杠)。使用 os.path.join 不仅更专业,还能保证你的代码在服务器(Linux)和本地(Windows)都能完美运行。
③ listdir 的小细节 📌
os.listdir 只会给出文件名(如 1.jpg),而不是完整路径。所以当你真正要去 Image.open 的时候,记得要用os.path.join 把路径补全哦!
🛠️ 实战演练:从零实现一个 Dataset(以"蚂蚁🐜 与 蜜蜂🐝"为例)
Dataset 类实战所做的就是要实现获取文件夹下每一张图片的地址!实现下面两个功能。
- 如何获取每一个数据的地址及其label
- 告诉我们总共有多少的数据
1 准备工作:导入工具包
python
import os#用来处理系统文件的工具包
from torch.utils.data import Dataset#用来装数据集的工具包
from PIL import Image#用来读取图片的工具包
2 核心类:myDataset 的结构
我们要继承 PyTorch 的 Dataset 类,并重写三个关键方法:
__init__:初始化。相当于告诉程序:去哪找数据?__getitem__:获取数据。给它一个索引(idx),它还你一张图片和标签。__len__:统计长度。告诉程序:总共有多少张图?
🚩小白疑问:self 是什么?
简单来说,self 就像是一个"公共储物柜"。在 init 里存进去的东西,通过 self.xxx 的形式,在 getitem 或其他函数里也能随时取出来用。
python
class myDataset(Dataset):
def __init__(self, root_dir, label_dir):# 这里我要做的是:获取文件夹下所有图片名称列表
# 初始化:通过拼接路径,获取所有图片的名称列表
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir) # 完整文件夹地址
self.img_path = os.listdir(self.path) # 文件夹下所有图片名
def __getitem__(self, idx):# 这里我要做的是:获取图片名称列表下的每一张图片
# 1. 根据索引获取单张图片的名称
img_name = self.img_path[idx]
# 2. 拼接出这张图片的完整路径
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
# 3. 读取图片内容
img = Image.open(img_item_path)
# 4. 获取标签(这里直接用文件夹名作为标签)
label = self.label_dir
return img, label# 这里返回两个变量,第一个是图片,第二个为图片的标签
def __len__(self):
# 返回数据集的总张数
return len(self.img_path)#查看数据集的长度len(train_dataset)
3🏃 运行测试:见证奇迹的时刻
代码写好了,让我们看看怎么实例化并使用它:
python
# 设置路径(记得改成你自己的路径哦!)
root_dir = r'E:\dataset\train'
ants_label_dir = 'ants'
bees_label_dir = 'bees'
# 1. 创建蚂蚁数据集实例
ants_dataset = myDataset(root_dir, ants_label_dir)
# 2. 创建蜜蜂数据集实例
bees_dataset = myDataset(root_dir, bees_label_dir)
# 3. 关键操作:数据集相加!这个时候如果我有个蜜蜂的,而我想要的实整个数据集(蚂蚁加蜜蜂)!
# 这样就把蚂蚁和蜜蜂的数据合二为一了
train_dataset = ants_dataset + bees_dataset
# 查看总长度
print(f"训练集总长度: {len(train_dataset)}")
# 获取第一张图,因为我设置的是返回两个变量,所以这里也是两个
img, label = train_dataset[0]
img.show()
print(f"这张图的标签是: {label}")
🌟 关于练手数据集下载
我是从bilibili小土堆的视频下的链接下载的,大家可以直接去下载,或者我放到了下面也可以下载。
代码:https://github.com/xiaotudui/pytorch-tutorial
国内仓库:https://gitcode.com/xiaotudui1/pytorch-tutorial/
蚂蚁蜜蜂/练手数据集:https://pan.quark.cn/s/e4c425fc4c0d