PyTorch transforms的用法

答案来自 千问

需要注意每个方法的输入输出是什么类型的数据,方法需要的参数等。

torchvision.transforms 是 PyTorch 中用于图像预处理和数据增强的核心模块。它的主要作用是将原始的 PIL 图像或 NumPy 数组转换为模型可接受的张量(Tensor)格式,并进行标准化等操作。

简单来说,它就是深度学习中的"图像处理工具箱"。

🧰 核心用法:组合变换 (Compose)

transforms 最常用的模式是使用 transforms.Compose。你可以把多个处理步骤像流水线一样组合在一起,图像会按顺序依次经过这些处理步骤。

python 复制代码
import torchvision.transforms as transforms

# 定义一个处理流水线
transform = transforms.Compose([
    transforms.Resize((224, 224)),      # 1. 调整图像大小
    transforms.ToTensor(),              # 2. 转换为张量
    transforms.Normalize(mean=[0.5], std=[0.5])  # 3. 标准化
])

🛠️ 常用工具 (Transforms) 介绍

以下是一些最常用的图像处理工具:

1. 转换格式:ToTensor

这是几乎所有流程的第一步或中间步骤。它的作用是:

  • 将 PIL 图像或 NumPy 数组转换为 PyTorch 张量(Tensor)。
  • 自动将像素值从 [0, 255] 缩放到 [0.0, 1.0] 范围内。
2. 调整大小:Resize

将图像调整为指定的尺寸,以适应模型的输入要求。

  • transforms.Resize(256): 将图像的短边缩放到256,保持长宽比。
  • transforms.Resize((224, 224)): 将图像直接缩放到指定的宽和高。
3. 裁剪:CenterCrop 与 RandomCrop
  • CenterCrop: 从图像中心裁剪出指定大小的区域。常用于测试阶段,以获得稳定的输入。
  • RandomCrop: 从图像中随机位置裁剪。常用于训练阶段,作为一种数据增强手段,增加数据的多样性。
4. 翻转与旋转
  • RandomHorizontalFlip: 以指定的概率(默认0.5)随机水平翻转图像。这是最常用的数据增强方法之一。
  • RandomRotation : 随机旋转图像指定的角度,例如 transforms.RandomRotation(degrees=30) 会随机旋转 -30 到 +30 度之间的角度。
5. 色彩调整:ColorJitter

随机改变图像的亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue),使模型对色彩变化更具鲁棒性。

6. 标准化:Normalize

使用给定的均值(mean)和标准差(std)对图像张量进行标准化,使数据分布更有利于模型训练。这一步必须ToTensor 之后执行,因为它是针对张量数值进行的运算。

  • 常用参数(基于 ImageNet 数据集统计值):
    • mean=[0.485, 0.456, 0.406]
    • std=[0.229, 0.224, 0.225]

📝 实际应用示例

下面展示如何为训练集和测试集构建不同的预处理流程。

python 复制代码
from torchvision import transforms

# 训练集的数据增强和预处理
# 目标:增加数据多样性,防止模型过拟合
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),        # 随机裁剪并缩放到224x224
    transforms.RandomHorizontalFlip(),       # 随机水平翻转
    transforms.ColorJitter(brightness=0.2,    # 随机调整亮度、对比度等
                           contrast=0.2, 
                           saturation=0.2),
    transforms.ToTensor(),                   # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 标准化
                         std=[0.229, 0.224, 0.225])
])

# 测试集/验证集的预处理
# 目标:保持数据原始分布,进行一致的评估
test_transform = transforms.Compose([
    transforms.Resize(256),                  # 先将短边缩放到256
    transforms.CenterCrop(224),              # 再从中心裁剪出224x224区域
    transforms.ToTensor(),                   # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 标准化
                         std=[0.229, 0.224, 0.225])
])

关键点总结:

  1. 顺序很重要ToTensor 必须在 Normalize 之前,因为 Normalize 需要输入是张量。
  2. 场景不同,策略不同 :训练时使用 Random 开头的变换来增强数据;测试时使用确定性的变换(如 CenterCrop)以保证结果可复现。
  3. 结合数据集使用 :定义好 transform 后,可以在创建数据集(如 torchvision.datasets.CIFAR10)时,将其作为参数传入,数据加载时就会自动应用这些变换。
相关推荐
kisshuan1239628 分钟前
ERM增强残差融合模块改进YOLOv26多尺度特征融合精度与边缘检测能力
人工智能·深度学习·yolo
猫头虎2 小时前
OpenClaw下载安装配置|Windows安装流程|macOS 安装流程|Telegram 集成使用|飞书集成使用|常见坑和注意事项保姆级教程
人工智能·windows·macos·开源·aigc·飞书·ai编程
TEC_INO2 小时前
Linux38:AT函数
人工智能·opencv·计算机视觉
做cv的小昊3 小时前
大语言模型系统:【CMU 11-868】课程学习笔记02——GPU编程基础1(GPU Programming Basics 1)
人工智能·笔记·学习·语言模型·llm·transformer·agent
一方热衷.6 小时前
YOLO26-Seg ONNXruntime C++/python推理
开发语言·c++·python
YMWM_8 小时前
如何将包路径添加到conda环境lerobot的python路径中呢?
人工智能·python·conda
星辰_mya8 小时前
关于ai——纯笔记
人工智能
智算菩萨8 小时前
GPT-5.4原生操控电脑揭秘:从Playwright脚本到屏幕截图识别,手把手搭建你的第一个自动化智能体
人工智能·gpt·ai·chatgpt·自动化
田里的水稻8 小时前
ubuntu22.04_openclaw_ROS2
人工智能·python·机器人
行走__Wz8 小时前
【刘二大人】《PyTorch深度学习实践》——PyTorch实现线性回归代码(自用)
pytorch·深度学习·线性回归