数据加载之前,先学习两个Pytorch中的类:Dataset和DataLoader。在 PyTorch 中,Dataset 和 DataLoader 是两个非常重要的类,用于高效地加载和处理数据。它们通常一起使用,以便在训练深度学习模型时更好地管理数据。如果数据是图像数据,可以使用Image模块来完成图片操作。除此之外,可能还需要用到Python标准库中的os模块,如文件路径拼接,文件列表。
1、Dataset类和DataLoader类,Image模块、os模块
1. Dataset 类
Dataset 是一个抽象类,表示一个数据集。你可以通过继承 Dataset 类来创建自定义的数据集。Dataset 类主要定义了以下两个核心方法:
-
len(self): 返回数据集中样本的数量。
-
getitem(self, idx): 根据索引 idx 返回对应的样本。
通过这些方法,Dataset 类允许你以统一的方式访问数据集中的样本。
自定义 Dataset 示例
python
from torch.utils.data import Dataset
import torch
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
label = self.labels[idx]
return sample, label
# 示例数据
data = torch.randn(100, 3) # 100个样本,每个样本有3个特征
labels = torch.randint(0, 2, (100,)) # 100个标签,0或1
# 创建自定义数据集
dataset = CustomDataset(data, labels)
# 访问数据集中的样本
sample, label = dataset[0]
print(f"Sample: {sample}, Label: {label}")
更详细的介绍可以见:Pytorch中的torch.utils.data.Dataset 类
2. DataLoader 类
DataLoader 是一个迭代器,用于从 Dataset 中高效地加载数据。它提供了以下功能:
- 批量加载数据: 可以将数据分成多个小批量(mini-batches)进行加载。
- 多线程加载: 可以使用多个线程并行加载数据,减少 I/O 瓶颈。
- 数据打乱: 可以在每次迭代时打乱数据顺序,以避免模型过拟合。
- 自定义采样策略: 可以通过 Sampler 和 BatchSampler 自定义数据加载的顺序。
使用 DataLoader 示例
python
from torch.utils.data import DataLoader
# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)
# 遍历DataLoader
for batch_idx, (samples, labels) in enumerate(dataloader):
print(f"Batch {batch_idx}:")
print(f"Samples: {samples}, Labels: {labels}")
参数说明
- dataset: 需要加载的数据集,通常是 Dataset 类的实例。
- batch_size: 每个批次的样本数量。
- shuffle: 是否在每个 epoch 打乱数据顺序。
- num_workers: 用于数据加载的子进程数量。如果设置为 0,则数据加载在主进程中进行。
总结
- Dataset 类用于定义数据集的结构和如何访问数据。
- DataLoader 类用于高效地加载数据,支持批量加载、多线程加载和数据打乱等功能。
结合使用 Dataset 和 DataLoader 可以让你在训练深度学习模型时更加高效地处理数据。
3. Image模块
PIL(Python Imaging Library)是一个用于图像处理的Python库,而Image是其中的一个核心模块,提供了丰富的图像操作功能。PIL库已经停止更新,但其分支Pillow继续维护并扩展了功能,因此现在通常使用Pillow来代替PIL。
Image模块的主要功能
Image模块提供了以下主要功能:
打开和保存图像:
- Image.open(file):打开图像文件。
- Image.save(file, format):保存图像到文件。
图像属性:
- image.size:获取图像的尺寸(宽度, 高度)。
- image.mode:获取图像的模式(如 "RGB", "L", "CMYK" 等)。
- image.format:获取图像的格式(如 "JPEG", "PNG" 等)。
图像操作:
- image.resize(size):调整图像大小。
- image.rotate(angle):旋转图像。
- image.crop(box):裁剪图像。
- image.transpose(method):翻转或旋转图像(如 Image.FLIP_LEFT_RIGHT)。
图像转换:
- image.convert(mode):转换图像模式(如将彩色图像转换为灰度图像)。
- image.thumbnail(size):生成缩略图。
图像显示:
- image.show():显示图像(使用系统默认的图像查看器)。
图像处理:
- image.filter(filter):应用滤镜(如模糊、锐化等)。
- image.paste(im, box):将一张图像粘贴到另一张图像上。
示例代码
以下是一个简单的示例,展示如何使用Image模块:
python
from PIL import Image
# 打开图像
image = Image.open("example.jpg")
# 获取图像属性
print(f"Size: {image.size}")
print(f"Format: {image.format}")
print(f"Mode: {image.mode}")
# 调整图像大小
resized_image = image.resize((200, 200))
# 旋转图像
rotated_image = image.rotate(45)
# 保存图像
resized_image.save("resized_example.jpg")
rotated_image.save("rotated_example.jpg")
# 显示图像
image.show()
总结
Image模块是Pillow库中用于处理图像的核心模块,提供了从打开、操作到保存图像的完整功能。无论是简单的图像处理任务还是复杂的图像操作,Image模块都能胜任。
4. os模块
4.1 os.listdir函数
os.listdir是 Python 标准库 os 模块中的一个函数,用于列出指定目录中的所有文件和子目录的名称。它是一个非常常用的函数,用于遍历目录内容或检查文件系统中的文件。
函数定义
python
os.listdir(path='.')
参数:
path(可选):要列出内容的目录路径。默认为当前目录(.)。
返回值:
返回一个包含目录中所有文件和子目录名称的列表(list)。列表中的每个元素是一个字符串,表示文件或目录的名称。
主要特点
- 列出文件和目录:
- os.listdir 会返回指定目录下的所有文件和子目录的名称,包括隐藏文件(以 . 开头的文件)。
- 它不会递归列出子目录中的内容。
- 返回相对路径:
- 返回的文件名是相对于给定路径的相对路径,而不是绝对路径。
- 不区分文件和目录:
- 返回的列表中包含文件和目录的名称,但不会区分它们是文件还是目录。如果需要区分,可以结合 os.path.isfile 或os.path.isdir 使用。
- 不保证顺序:
- 返回的文件列表的顺序是不确定的,通常取决于文件系统的实现。
示例代码
- 列出当前目录的内容
python
import os
# 列出当前目录中的所有文件和子目录
contents = os.listdir()
print("当前目录内容:", contents)
- 列出指定目录的内容
python
import os
# 列出指定目录中的所有文件和子目录
path = "/path/to/your/directory"
contents = os.listdir(path)
print(f"目录 {path} 的内容:", contents)
- 区分文件和目录
python
import os
path = "/path/to/your/directory"
contents = os.listdir(path)
# 区分文件和目录
files = [f for f in contents if os.path.isfile(os.path.join(path, f))]
directories = [d for d in contents if os.path.isdir(os.path.join(path, d))]
print("文件:", files)
print("目录:", directories)
- 列出特定类型的文件
python
import os
path = "/path/to/your/directory"
contents = os.listdir(path)
# 列出所有 .txt 文件
txt_files = [f for f in contents if f.endswith(".txt")]
print("文本文件:", txt_files)
注意事项
- 路径问题:
- 如果路径不存在或没有访问权限,os.listdir 会抛出 FileNotFoundError 或 PermissionError 异常。
- 可以使用 os.path.exists 检查路径是否存在。
- 隐藏文件:
- os.listdir 会列出隐藏文件(如 .gitignore 或 .DS_Store),如果需要过滤隐藏文件,可以手动处理。
- 性能问题:
- 对于包含大量文件的目录,os.listdir 可能会比较慢。如果需要更高效的文件遍历,可以考虑使用 os.scandir。
与 os.scandir 的区别
os.listdir 返回的是文件名列表,而 os.scandir 返回的是 DirEntry 对象的迭代器,提供了更多的文件信息(如文件类型、文件属性等),性能也更好。例如:
python
import os
path = "/path/to/your/directory"
with os.scandir(path) as entries:
for entry in entries:
print(entry.name, entry.is_file())
总结
os.listdir 是一个简单而强大的工具,用于列出目录中的文件和子目录。它非常适合快速查看目录内容,但在需要更多文件信息或处理大量文件时,可以考虑使用 os.scandir 或其他更高级的文件遍历方法。
4.2 os.path.join函数
os.path.join 是 Python 标准库 os.path 模块中的一个函数,用于将多个路径组件连接成一个完整的路径。它能够根据操作系统的不同,自动使用正确的路径分隔符(例如,在 Windows 上使用 \,在 Linux 和 macOS 上使用 /),从而避免手动拼接路径时可能出现的错误。
函数定义
python
os.path.join(path, *paths)
参数:
- path:第一个路径组件(通常是目录路径)。
- *paths:可变参数,表示后续的路径组件(可以是文件名或子目录名)。
返回值:
返回一个字符串,表示拼接后的路径。
主要特点
- 跨平台兼容性:
- os.path.join 会根据操作系统的不同,自动使用正确的路径分隔符。
例如:
在 Windows 上:os.path.join("dir", "file.txt") 返回 "dir\file.txt"。
在 Linux/macOS 上:os.path.join("dir", "file.txt") 返回 "dir/file.txt"。
- 处理多余的分隔符:
- 如果路径组件中已经包含路径分隔符,os.path.join 会自动处理多余的分隔符,避免出现重复的分隔符。
- 支持绝对路径:
- 如果某个路径组件是绝对路径(例如以 / 或 C:\ 开头),os.path.join 会忽略之前的路径组件,从该绝对路径开始拼接。
- 灵活性:
- 可以拼接任意数量的路径组件。
示例代码
- 基本用法
python
import os
# 拼接路径
path = os.path.join("dir", "subdir", "file.txt")
print("拼接后的路径:", path)
- 在 Windows 上输出:
拼接后的路径: dir\subdir\file.txt - 在 Linux/macOS 上输出:
拼接后的路径: dir/subdir/file.txt
- 处理绝对路径
python
import os
# 如果某个组件是绝对路径,之前的组件会被忽略
path = os.path.join("dir", "/absolute/path", "file.txt")
print("拼接后的路径:", path)
输出:
拼接后的路径: /absolute/path/file.txt
- 处理多余的分隔符
python
import os
# 自动处理多余的分隔符
path = os.path.join("dir/", "/subdir/", "file.txt")
print("拼接后的路径:", path)
- 在 Windows 上输出:
拼接后的路径: dir\subdir\file.txt - 在 Linux/macOS 上输出:
拼接后的路径: dir/subdir/file.txt
- 拼接多个组件
python
import os
# 拼接多个路径组件
path = os.path.join("root", "dir1", "dir2", "file.txt")
print("拼接后的路径:", path)
- 在 Windows 上输出:
拼接后的路径: root\dir1\dir2\file.txt - 在 Linux/macOS 上输出:
拼接后的路径: root/dir1/dir2/file.txt
注意事项
- 路径规范化:
os.path.join 不会自动规范化路径(例如处理 . 或 ...)。如果需要规范化路径,可以使用 os.path.normpath。
python
import os
path = os.path.join("dir", "..", "file.txt")
normalized_path = os.path.normpath(path)
print("规范化后的路径:", normalized_path)
输出:
python
规范化后的路径: file.txt
- 空路径组件:
如果某个路径组件为空字符串,os.path.join 会忽略它。
python
import os
path = os.path.join("dir", "", "file.txt")
print("拼接后的路径:", path)
输出:
python
拼接后的路径: dir/file.txt
- 避免手动拼接路径:
手动拼接路径(例如使用 + 或字符串格式化)可能会导致跨平台兼容性问题,因此推荐使用 os.path.join。
总结
os.path.join 是一个非常有用的工具,用于安全、跨平台地拼接路径。它能够自动处理路径分隔符和多余的分隔符,确保生成的路径在不同操作系统上都能正常工作。在编写文件路径相关的代码时,强烈推荐使用 os.path.join 来避免潜在的错误。
2、数据加载示例
2.1 示例一
2.1.1 了解数据集内容及格式
首先,将数据集放到项目所在文件夹中。数据集示例如下:
整个数据集分为训练集(train)和验证集(val);训练集和验证集中都包含两种数据:蚂蚁(ants)和蜜蜂(bees),文件名就是数据的标签;ants中又包含很多张图片数据,bees同理:
2.1.2 编写程序加载数据集
- 导入所需的类或模块
python
from torch.utils.data import Dataset
from PIL import Image
import os
- 编写数据类
编写一个表示数据集的类,继承自Dataset类,并重写方法__init__()和方法__getitem__()和方法__len__():
python
class Mydata(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.image_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.image_path[idx]
img_path = os.path.join(self.path, img_name)
img = Image.open(img_path)
label = self.label_dir
return img,label
def __len__(self):
return len(self.image_path)
- 创建实例对象时,需要传入根目录(root_dir,即上述数据集的"train"或者"val"),以及标签目录(label_dir,即数据集中的"ants"或者"bees")。方法__init__()将路径拼接,并得到图片文件名的列表。
- 方法__getitem__()在对该示例索引时,通过索引找到图片文件名的列表中的名字(img_name = self.image_path[idx]),通过路径拼接找到图片文件路径(img_path = os.path.join(self.path, img_name))并读取图片文件(img = Image.open(img_path)),最后返回图片对象和标签。
- 方法__len__()返回图片文件名的列表的长度,即图片的数量。
- 读取训练数据集
创建Mydata实例并传入所需的文件路径("hymenoptera_data/train"和 "ants")。
python
root_dir = "hymenoptera_data/train"
ants_label_dir = "ants"
ants_dataset = Mydata(root_dir, ants_label_dir)
通过索引该实例查看一下返回值:返回的是一个图像对象和一个标签
python
>>> ants_dataset[0]
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=768x512 at 0x2BDD6D6A220>, 'ants')
用两个变量分别接收图片对象和其标签,并展示一下图片:
python
>>> img, label=ants_dataset[0]
>>> img.show()

训练集中还有一类bees数据。也用一样的方法加载,最后把两个数据集(ants和bees)合并拼接(+)起来:
python
root_dir = "hymenoptera_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = Mydata(root_dir, ants_label_dir)
bees_dataset = Mydata(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
合并之后的数据集长度等于两个数据集长度之和:
python
>>> len(ants_dataset)
124
>>> len(bees_dataset)
121
>>> len(train_dataset)
245
可以看到合并之后的数据集索引为123以及之前的标签是ants,索引为124及之后的标签是bees:
python
train_dataset[123][1]
'ants'
train_dataset[124][1]
'bees'
2.2 示例二
2.2.1 了解数据集内容及格式
首先,将数据集放到项目所在文件夹中。数据集示例如下:

同样包含训练集(train)和验证集(val)。

不过,不同的是,例如在train文件夹下,图片数据和图片标签分别存放在不同的文件夹中(如ants_image和ants_label分别存放蚂蚁的图片和对应图片的标签)。
ants_image中为.ipg的图片文件,ants_label中为.txt文本文件,他们的名字是一一对应的。

文本文件中是该图片的标签。
2.2.2 编写程序加载数据集
- 导入所需的类或模块
python
from torch.utils.data import Dataset
from PIL import Image
import os
- 编写数据类
编写一个表示数据集的类,继承自Dataset类,并重写方法__init__()和方法__getitem__()和方法__len__():
python
from torch.utils.data import Dataset
from PIL import Image
import os
class Mydata(Dataset):
def __init__(self,root_dir, image_dir, label_dir):
self.root_dir = root_dir
self.image_dir = image_dir
self.label_dir = label_dir
self.image_path = os.path.join(self.root_dir, self.image_dir)
self.label_path = os.path.join(self.root_dir, self.label_dir)
self.image_list = os.listdir(self.image_path)
self.label_list = os.listdir(self.label_path)
def __getitem__(self, idx):
img_name = self.image_list[idx]
img_item_path = os.path.join(self.image_path, img_name)
img = Image.open(img_item_path)
label_name = self.label_list[idx]
label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
with open(label_item_path, 'r') as f:
label = f.readline()
return img,label
def __len__(self):
return len(self.image_list)
- 创建实例对象时,需要传入根目录(root_dir,即上述数据集的"train"或者"val"),以及图像文件目录(image_dir,即数据集中的"ants_image"或者"bees_image"),标签目录(label_dir,即数据集中的"ants_label"或者"bees_label")。方法__init__()将路径拼接,并得到图片文件名的列表和图片标签文件名的列表。
- 方法__getitem__()在对该示例索引时,通过索引找到图片文件名列表中的名字(img_name = self.image_list[idx]),通过路径拼接找到图片文件路径(img_item_path = os.path.join(self.image_path, img_name))并读取图片文件(img = Image.open(img_item_path));通过索引找到对应的标签文件名(label_name = self.label_list[idx]),通过路径拼接找到标签文件路径(label_item_path = os.path.join(self.root_dir, self.label_dir, label_name))并读取.txt文件的内容。最后返回图片对象和标签。
- 方法__len__()返回图片文件名的列表的长度,即图片的数量。
- 读取训练数据集
创建Mydata实例并传入所需的文件路径("练手数据集/train", "ants_image"和"ants_label")。
python
root_dir = "练手数据集/train"
ants_image_dir = "ants_image"
ants_label_dir = "ants_label"
ants_dataset = Mydata(root_dir, ants_image_dir, ants_label_dir)
通过索引该实例查看一下返回值:返回的是一个图像对象和一个标签
python
>>> ants_dataset[1]
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x333 at 0x1DC370ACB80>, 'ants')
用两个变量分别接收图片对象和其标签,并展示一下图片:
python
>>> img, label=ants_dataset[1]
>>> img.show()

训练集中还有一类bees数据。也用一样的方法加载,最后把两个数据集(ants和bees)合并拼接(+)起来:
python
root_dir = "练手数据集/train"
ants_image_dir = "ants_image"
ants_label_dir = "ants_label"
bees_image_dir = "bees_image"
bees_label_dir = "bees_label"
ants_dataset = Mydata(root_dir, ants_image_dir, ants_label_dir)
bees_dataset = Mydata(root_dir, bees_image_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
合并之后的数据集长度等于两个数据集长度之和:
python
>>> len(ants_dataset)
124
>>> len(bees_dataset)
121
>>> len(train_dataset)
245
可以看到合并之后的数据集索引为123以及之前的标签是ants,索引为124及之后的标签是bees:
python
train_dataset[123][1]
'ants'
train_dataset[124][1]
'bees'