1 数据加载Dataset
PyTorch的数据读取机制主要依赖于
Dataset
和DataLoader
这两个核心组件。它们用于加载和处理数据,以便在训练模型时进行高效的数据流动和处理。Dataset
Dataset
是一个抽象类,用户可以继承这个类并重载以下两个方法来创建自定义的数据集:
__init__
方法:
csv_file
:指向包含图像路径和标签的CSV文件路径。root_dir
:包含所有图像的根目录路径。transform
:一个可选的变换,用于在返回样本之前处理数据。在初始化过程中,读取CSV文件并存储在
self.data_frame
中,还设置了图像的根目录和可选的变换。
__len__
方法:
- 返回数据集中样本的数量,即CSV文件中记录的行数。
__getitem__
方法:
- 接收一个索引
idx
,从CSV文件中获取对应的图像路径和标签。- 使用PIL库打开图像文件,并将其转换为RGB格式。
- 如果定义了变换,则将其应用到图像。
- 返回处理后的图像和对应的标签。
自定义Dataset示例
python
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
label = self.labels[index]
return sample, label
# 示例数据
data = torch.randn(100, 3) # 100个样本,每个样本3个特征
labels = torch.randint(0, 2, (100,)) # 100个标签
# 创建自定义数据集
dataset = CustomDataset(data, labels)
2 可迭代的数据装载器DataLoader
DataLoader
是 PyTorch 中一个非常重要的类,用于构建可迭代的数据装载器。它能够有效地加载数据并在训练模型时提供数据批次。下面我们详细介绍DataLoader
的各个参数和使用方法。DataLoader 的功能
DataLoader
主要用于在训练过程中,每个for
循环中从数据集中获取一个指定大小(batch_size
)的数据批次。参数解释
dataset
:
- 类型:
Dataset
类实例- 功能:决定数据从哪里读取以及如何读取。
Dataset
类定义了数据集的具体内容及访问方式。
batch_size
:
- 类型:整数
- 功能:每个数据批次的大小。例如,
batch_size=32
表示每次从数据集中获取32个样本。
num_workers
:
- 类型:整数
- 功能:决定使用多少个子进程来加载数据。更多的进程数可以加快数据加载速度,但过多的进程数可能会导致系统资源不足,建议设置为 4、8、16 等。
shuffle
:
- 类型:布尔值
- 功能:决定每个 epoch 开始时是否打乱数据顺序。打乱数据可以增加训练过程的随机性,通常设置为
True
。
drop_last
:
- 类型:布尔值
- 功能:如果数据集中的样本数不能被
batch_size
整除,决定是否舍弃最后一个不完整的数据批次。设置为True
表示舍弃。重要概念
Epoch:
- 定义:所有训练样本都已输入到模型中,称为一个 epoch。
Iteration:
- 定义:一个批次的样本输入到模型中,称为一次 iteration。
Batch Size:
- 定义:批大小,决定一个 epoch 中有多少次 iteration。
python
# 创建 DataLoader 实例
dataloader = DataLoader(
dataset=dataset, # 自定义数据集
batch_size=32, # 每批次32个样本
shuffle=True, # 每个epoch开始时打乱数据
num_workers=4, # 使用4个子进程加载数据
drop_last=True # 当样本数不能被batch_size整除时,舍弃最后一批数据
)
# 训练循环示例
for epoch in range(num_epochs):
for batch_idx, (data, labels) in enumerate(dataloader):
# 模型训练代码
pass
3 图像预处理transforms
在PyTorch中,
transforms
是一个用于图像预处理的模块。transforms
提供了一组常用的图像变换方法,可以对图像进行数据增强、归一化、裁剪、缩放等操作。transforms
主要用于将图像数据转换成适合模型输入的格式。常用的Transforms
以下是一些常用的
transforms
操作:
transforms.Compose
:将多个变换组合起来。transforms.Resize
:调整图像大小。transforms.CenterCrop
:从图像中心裁剪。transforms.RandomCrop
:随机裁剪图像。transforms.RandomHorizontalFlip
:随机水平翻转图像。transforms.ToTensor
:将PIL图像或Numpy数组转换为张量,并将像素值归一化到[0, 1]。transforms.Normalize
:用均值和标准差归一化张量。transforms.ColorJitter
:随机改变图像的亮度、对比度和饱和度。transforms.RandomRotation
:随机旋转图像。
python
from torchvision import transforms
from PIL import Image
# 定义图像预处理变换
transform = transforms.Compose([
transforms.Resize((128, 128)), # 调整图像大小
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(10), # 随机旋转10度
transforms.ColorJitter(brightness=0.5), # 随机改变亮度
transforms.ToTensor(), # 转换为张量并归一化到[0, 1]
transforms.Normalize((0.5,), (0.5,)) # 用均值0.5和标准差0.5归一化
])
# 加载图像
image = Image.open("path_to_image.jpg").convert("RGB")
# 应用预处理变换
transformed_image = transform(image)
# 检查变换后的图像
print(transformed_image.size())
如果现有的
transforms
无法满足需求,可以自定义变换。只需实现__call__
方法即可
python
import torch
class CustomTransform:
def __call__(self, sample):
# 自定义变换逻辑,例如将图像转换为灰度图
return transforms.functional.rgb_to_grayscale(sample)
# 使用自定义变换
transform = transforms.Compose([
transforms.Resize((128, 128)),
CustomTransform(),
transforms.ToTensor()
])
image = Image.open("path_to_image.jpg").convert("RGB")
transformed_image = transform(image)
print(transformed_image.size())
4 综合数据读取和数据预处理
以下是一个综合示例,展示如何定义数据集并使用各种
transforms
进行图像预处理和数据增强。
python
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import pandas as pd
class CustomCSVImageDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
self.data_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.data_frame)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.root_dir, self.data_frame.iloc[idx, 0])
image = Image.open(img_name).convert('RGB')
label = self.data_frame.iloc[idx, 1]
if self.transform:
image = self.transform(image)
return image, label
# 定义图像预处理和数据增强
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 示例数据
csv_file = './data/labels.csv'
root_dir = './data/images'
# 创建数据集
dataset = CustomCSVImageDataset(csv_file=csv_file, root_dir=root_dir, transform=transform)
# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
# 迭代DataLoader
for batch_idx, (data, labels) in enumerate(dataloader):
print(f"Batch {batch_idx}:")
print("数据大小:", data.size())
print("标签大小:", labels.size())