#c 目的 需要的目的
专门处理数据的代码可能会变得「混乱且难以维护」,理想情况下是将「数据集代码」与「模型训练代码」「解耦(decoupled)」,以提高可读性和模块性。
torch.utils.data.DataLoader
和torch.utils.data.Dataset
,可以使用「预加载数据集」以及「自定义数据」。
Dataset
存储样本及其相应的标签。
DataLoader
则围绕Dataset
包装了一个可迭代对象,以便于轻松访问样本。
1 下载数据集
#e MNIST数据集
以下载TorchVision
下的Fashion-MNIST
为例。Fashion-MNIST是一个由Zalando的文章图片组成的数据集,包含60,000个训练样本和10,000个测试样本。每个样本包括一个28×28的灰度图像以及来自10个类别之一的相关标签。
python
# 下载训练数据集
train_data = datasets.FashionMNIST(
root="data", # 数据存储的路径
train=True, # 指定下载的是训练数据集
download=True, # 如果数据不存在,则通过网络下载
transform=ToTensor() # 将图片转换为Tensor
)
# 下载测试数据集
test_data = datasets.FashionMNIST(
root="data", # 数据存储的路径
train=False, # 指定下载的是测试数据集
download=True, # 如果数据不存在,则通过网络下载
transform=ToTensor() # 将图片转换为Tensor
)
2 迭代和可视化数据
#e 迭代和可视化
python
lables_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))#创建一个matplotlib图形对象,设置图形的大小为8x8英寸。
cols, rows = 3, 3#设置列数和行数
for i in range(1, cols * rows +1):#循环9次
sample_idx = torch.randint(len(training_data),size=(1,)).item()
'''
使用torch.randint随机生成一个介于0和训练数据集长度之间的整数,作为随机选取的图像的索引。
size=(1,)指定生成一个数,item()将其转换为Python的标准整数。
'''
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)#添加子图,设置行数、列数和子图的索引,位置由i决定
plt.title(lables_map[label])#设置标题
plt.axis("off")#关闭坐标轴
plt.imshow(img.squeeze(), cmap="gray")#灰度显示
# plt.imshow(img.squeeze())#彩色显示,无需指定cmap
'''
img.squeeze()将图像张量的维度为1的轴删除,因为imshow函数预期的是一个二维图像。
cmap="gray"指定了灰度图像。
'''
plt.show()
3 自定义数据集
#c 要素 自定义数据集要素
自定义的Dataset类必须实现以下三个函数:
__init__
:初始化函数,用于设置数据集的属性,如加载数据、预处理步骤等。
__len__
:返回数据集中样本的数量。这个函数使得Dataset对象可以被len()函数调用,通常返回数据集中样本的总数。
__getitem__
:根据索引获取单个样本。这个函数允许通过索引访问数据集中的每个样本。索引从0开始,对应于数据集中的第一个样本。
#e 三要素 自定义数据集要素
python
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transfrom=None, target_tansform=None):
self.img_labels = pd.read_csv(annotations_file)#读取CSV文件
self.img_dir = img_dir#图像目录
self.transfrom = transfrom#图像转换
self.target_tansform = target_tansform#目标转换
def __len__(self):
return len(self.img_labels)#返回数据集的长度
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
imgage = read_image(img_path)#读取图像,转换成张量
label = self.img_labels.iloc[idx, 1]#检索对应的标签
if self.transfrom:#转换图像
imgage = self.transfrom(imgage)
if self.target_tansform:
label = self.target_tansform(label)
return imgage, label #以元组的形式返回图像和标签
4 使用DataLoader准备训练数据
#c 思路 数据准备思路
在训练模型的过程中,通常希望以"小批量"的形式传递样本,每个周期重新打乱数据以减少模型的过拟合,并使用Python的multiprocessing
多进程来加速数据检索。数据集(Dataset)负责逐个样本地获取数据集的特征和标签。DataLoader是一个可迭代对象,它抽象了这些复杂性,提供了一个简单的API。
#e 准备代码
python
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
#在这里,DataLoader将训练数据集传递给train_dataloader,每个小批量包含64个特征和标签对,shuffle=True表示在每个周期重新打乱数据。
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
5 通过DataLoader迭代
#c 特点 迭代特点
将该数据集加载到DataLoader中,可以根据需要迭代遍历数据集。每次迭代都会返回一批train_features
和train_labels
(分别包含batch_size=64
个特征和标签)。若指定了shuffle=True
,在遍历完所有的批次之后,数据会被重新打乱。这意味着每个周期(epoch)开始时,数据的顺序都会随机化,有助于模型学习到「更加泛化「的特征,从而减少「过拟合」的风险。
#e 迭代代码
python
train_features, train_labels = next(iter(train_dataloader))
#iter(train_dataloader)返回一个迭代器对象,next()函数返回迭代器的下一批数据
print(f"Feature batch shape: {train_features.size()}")#size()返回张量的形状(批量大小、通道数、高度、宽度)
print(f"Labels batch shape: {train_labels.size()}")#size()返回张量的形状(批量大小)
img = train_features[0].squeeze()#删除维度为1的轴,特别是当图像以(1,高度,宽度)或(1,通道数,高度,宽度)的形式存在时。
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
'''
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 6
'''