深度学习框架PyTorch笔记(四)数据转换 Data Transformation

深度学习框架PyTorch笔记(四)数据转换 Data Transformation

​ 在Pytorch中,数据转换(Data Transformation)是一种在加载的样本数据时对数据预处理的机制,将原始数据转换为合适模型训练的格式主要通过torchvision.transforms提供的工具完成。数据转换不仅可以实现基本的数据预处理(如归一化,大小调整等),还能帮助进行数据增强(如随机裁剪、翻转等),提高模型的泛化能力。其作用如下:

  • 数据预处理: 将原始图片(PIL Image、numpy array)转换为Tensor,并做归一化,以匹配模式输入要求。例如:输入的样本图像需要调整为固定大小,张量格式并归一化到0,1
  • 数据增强: 在训练时随机裁剪、翻转、改变颜色等,增加数据多样性,提升模型泛化能力。例如:通过随机旋转、裁剪和裁剪增加数据样本的变种,避免过拟合。
  • 统一处理: 把多种变换组合成一个pipeline,与DataLoader无缝衔接。可以动态地对数据进行处理,简化数据加载的复杂度。

环境准备:

bash 复制代码
pip install torch torchvision

导入常用模块:

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as transforms  # 推荐使用 v2
from PIL import Image

如果你习惯旧 API,可以通过 from torchvision import transforms 导入,但 v2 的功能更强大且兼容旧用法。

1.基础变换操作

变换函数名称 描述 实例
transforms.ToTensor() 将PIL图像或NumPy数组转换为PyTorch张量,并自动将像素值归一化到 0, 1 transform = transforms.ToTensor()
transforms.Normalize(mean, std) 对图像进行标准化,使数据符合零均值和单位方差。 transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transforms.Resize(size) 调整图像尺寸,确保输入到网络的图像大小一致。 transform = transforms.Resize((256, 256))
transforms.CenterCrop(size) 从图像中心裁剪指定大小的区域。 transform = transforms.CenterCrop(224)

1.1 ToTensor

将PIL图形或Numpy数组转换为PyTorch张量。同时将像素值从\(0,255\) 归一化为\(0,1\)。

python 复制代码
from torchvision import transforms
transform = transforms.ToTensor()

1.2 Normalize

​ 对数据进行标准化,使其符合特定的均值和标准差。通常用于图像数据,将其像素值归一化为零均值和单位方差。

python 复制代码
transform = transforms.Normalize(mean=[0.5], std=[0.5])  # 归一化到 [-1, 1]

1.3 Resize

调整样本图形的大小,确保输入到网络图像大小一致。

python 复制代码
transform=transforms.Resize((128, 128)) # 将图像调整为 128x128

1.4 CenterCrop

从图像中心裁剪指定大小的区域。

python 复制代码
transform = transforms.CenterCrop(128)  # 裁剪 128x128 的区域

2. 数据增强操作(Data Augmentation)

变换函数名称 描述 实例
transforms.RandomHorizontalFlip(p) 随机水平翻转图像。 transform = transforms.RandomHorizontalFlip(p=0.5)
transforms.RandomRotation(degrees) 随机旋转图像。 transform = transforms.RandomRotation(degrees=45)
transforms.ColorJitter(brightness, contrast, saturation, hue) 调整图像的亮度、对比度、饱和度和色调。 transform = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
transforms.RandomCrop(size) 随机裁剪指定大小的区域。 transform = transforms.RandomCrop(224)
transforms.RandomResizedCrop(size) 随机裁剪图像并调整到指定大小。 transform = transforms.RandomResizedCrop(224)

2.1 RandomCrop

从图形中随机裁剪指定大小。

pyhton 复制代码
transform = transforms.RandomCrop(128)

2.2 RandomHorizontalFlip

以一定概率水平翻转图形。

python 复制代码
transform = transforms.RandomHorizontalFlip(p=0.5)  # 50% 概率翻转

2.3 RandomRotation

随机将图像旋转一定角度。

python 复制代码
transform = transforms.RandomRotation(degrees=30)  # 随机旋转 -30 到 +30 度

2.4 ColorJitter

随机改变图像的亮度、对比度、饱和度或色调。

python 复制代码
transform = transforms.ColorJitter(brightness=0.5, contrast=0.5)

3.组合变换

​ 这些转换可以通过Compose组合在一起,以便对图像进行一系列的转换。Compose类允许你创建一个包含多个转换操作的列表,这些操作将按照定义的顺序应用于输入数据中。

变换函数名称 描述 实例
transforms.Compose() 将多个变换组合在一起,按照顺序依次应用。 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize((256, 256))])

通过transforms.Compose将多个变换组合起来。

python 复制代码
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

4.与Dataset和DataLoader配合

​ PyTorch的Data set 类负责加载数据,transform作为参数参入,在__getitem__中被调用。

python 复制代码
class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

创建DataLoader:

python 复制代码
train_dataset = ImageDataset(train_paths, train_labels, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

5.完整案例

python 复制代码
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import datasets, transforms


# 原始和增强后的图像可视化
transform_augment = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor()
])

# 加载数据集
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_augment)

# 显示图像
def show_images(dataset):
    fig, axs = plt.subplots(1, 5, figsize=(15, 5))
    for i in range(5):
        image, label = dataset[i]
        axs[i].imshow(image.squeeze(0), cmap='gray')  # 将 (1, H, W) 转为 (H, W)
        axs[i].set_title(f"Label: {label}")
        axs[i].axis('off')
    plt.show()

show_images(dataset)

显示效果如下: