答案来自 千问
需要注意每个方法的输入输出是什么类型的数据,方法需要的参数等。
torchvision.transforms 是 PyTorch 中用于图像预处理和数据增强的核心模块。它的主要作用是将原始的 PIL 图像或 NumPy 数组转换为模型可接受的张量(Tensor)格式,并进行标准化等操作。
简单来说,它就是深度学习中的"图像处理工具箱"。
🧰 核心用法:组合变换 (Compose)
transforms 最常用的模式是使用 transforms.Compose。你可以把多个处理步骤像流水线一样组合在一起,图像会按顺序依次经过这些处理步骤。
python
import torchvision.transforms as transforms
# 定义一个处理流水线
transform = transforms.Compose([
transforms.Resize((224, 224)), # 1. 调整图像大小
transforms.ToTensor(), # 2. 转换为张量
transforms.Normalize(mean=[0.5], std=[0.5]) # 3. 标准化
])
🛠️ 常用工具 (Transforms) 介绍
以下是一些最常用的图像处理工具:
1. 转换格式:ToTensor
这是几乎所有流程的第一步或中间步骤。它的作用是:
- 将 PIL 图像或 NumPy 数组转换为 PyTorch 张量(Tensor)。
- 自动将像素值从
[0, 255]缩放到[0.0, 1.0]范围内。
2. 调整大小:Resize
将图像调整为指定的尺寸,以适应模型的输入要求。
transforms.Resize(256): 将图像的短边缩放到256,保持长宽比。transforms.Resize((224, 224)): 将图像直接缩放到指定的宽和高。
3. 裁剪:CenterCrop 与 RandomCrop
- CenterCrop: 从图像中心裁剪出指定大小的区域。常用于测试阶段,以获得稳定的输入。
- RandomCrop: 从图像中随机位置裁剪。常用于训练阶段,作为一种数据增强手段,增加数据的多样性。
4. 翻转与旋转
- RandomHorizontalFlip: 以指定的概率(默认0.5)随机水平翻转图像。这是最常用的数据增强方法之一。
- RandomRotation : 随机旋转图像指定的角度,例如
transforms.RandomRotation(degrees=30)会随机旋转 -30 到 +30 度之间的角度。
5. 色彩调整:ColorJitter
随机改变图像的亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue),使模型对色彩变化更具鲁棒性。
6. 标准化:Normalize
使用给定的均值(mean)和标准差(std)对图像张量进行标准化,使数据分布更有利于模型训练。这一步必须 在 ToTensor 之后执行,因为它是针对张量数值进行的运算。
- 常用参数(基于 ImageNet 数据集统计值):
mean=[0.485, 0.456, 0.406]std=[0.229, 0.224, 0.225]
📝 实际应用示例
下面展示如何为训练集和测试集构建不同的预处理流程。
python
from torchvision import transforms
# 训练集的数据增强和预处理
# 目标:增加数据多样性,防止模型过拟合
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224), # 随机裁剪并缩放到224x224
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ColorJitter(brightness=0.2, # 随机调整亮度、对比度等
contrast=0.2,
saturation=0.2),
transforms.ToTensor(), # 转换为张量
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 标准化
std=[0.229, 0.224, 0.225])
])
# 测试集/验证集的预处理
# 目标:保持数据原始分布,进行一致的评估
test_transform = transforms.Compose([
transforms.Resize(256), # 先将短边缩放到256
transforms.CenterCrop(224), # 再从中心裁剪出224x224区域
transforms.ToTensor(), # 转换为张量
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 标准化
std=[0.229, 0.224, 0.225])
])
关键点总结:
- 顺序很重要 :
ToTensor必须在Normalize之前,因为Normalize需要输入是张量。 - 场景不同,策略不同 :训练时使用
Random开头的变换来增强数据;测试时使用确定性的变换(如CenterCrop)以保证结果可复现。 - 结合数据集使用 :定义好
transform后,可以在创建数据集(如torchvision.datasets.CIFAR10)时,将其作为参数传入,数据加载时就会自动应用这些变换。