PyTorch transforms的用法

答案来自 千问

需要注意每个方法的输入输出是什么类型的数据,方法需要的参数等。

torchvision.transforms 是 PyTorch 中用于图像预处理和数据增强的核心模块。它的主要作用是将原始的 PIL 图像或 NumPy 数组转换为模型可接受的张量(Tensor)格式,并进行标准化等操作。

简单来说,它就是深度学习中的"图像处理工具箱"。

🧰 核心用法:组合变换 (Compose)

transforms 最常用的模式是使用 transforms.Compose。你可以把多个处理步骤像流水线一样组合在一起,图像会按顺序依次经过这些处理步骤。

python 复制代码
import torchvision.transforms as transforms

# 定义一个处理流水线
transform = transforms.Compose([
    transforms.Resize((224, 224)),      # 1. 调整图像大小
    transforms.ToTensor(),              # 2. 转换为张量
    transforms.Normalize(mean=[0.5], std=[0.5])  # 3. 标准化
])

🛠️ 常用工具 (Transforms) 介绍

以下是一些最常用的图像处理工具:

1. 转换格式:ToTensor

这是几乎所有流程的第一步或中间步骤。它的作用是:

  • 将 PIL 图像或 NumPy 数组转换为 PyTorch 张量(Tensor)。
  • 自动将像素值从 [0, 255] 缩放到 [0.0, 1.0] 范围内。
2. 调整大小:Resize

将图像调整为指定的尺寸,以适应模型的输入要求。

  • transforms.Resize(256): 将图像的短边缩放到256,保持长宽比。
  • transforms.Resize((224, 224)): 将图像直接缩放到指定的宽和高。
3. 裁剪:CenterCrop 与 RandomCrop
  • CenterCrop: 从图像中心裁剪出指定大小的区域。常用于测试阶段,以获得稳定的输入。
  • RandomCrop: 从图像中随机位置裁剪。常用于训练阶段,作为一种数据增强手段,增加数据的多样性。
4. 翻转与旋转
  • RandomHorizontalFlip: 以指定的概率(默认0.5)随机水平翻转图像。这是最常用的数据增强方法之一。
  • RandomRotation : 随机旋转图像指定的角度,例如 transforms.RandomRotation(degrees=30) 会随机旋转 -30 到 +30 度之间的角度。
5. 色彩调整:ColorJitter

随机改变图像的亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue),使模型对色彩变化更具鲁棒性。

6. 标准化:Normalize

使用给定的均值(mean)和标准差(std)对图像张量进行标准化,使数据分布更有利于模型训练。这一步必须ToTensor 之后执行,因为它是针对张量数值进行的运算。

  • 常用参数(基于 ImageNet 数据集统计值):
    • mean=[0.485, 0.456, 0.406]
    • std=[0.229, 0.224, 0.225]

📝 实际应用示例

下面展示如何为训练集和测试集构建不同的预处理流程。

python 复制代码
from torchvision import transforms

# 训练集的数据增强和预处理
# 目标:增加数据多样性,防止模型过拟合
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),        # 随机裁剪并缩放到224x224
    transforms.RandomHorizontalFlip(),       # 随机水平翻转
    transforms.ColorJitter(brightness=0.2,    # 随机调整亮度、对比度等
                           contrast=0.2, 
                           saturation=0.2),
    transforms.ToTensor(),                   # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 标准化
                         std=[0.229, 0.224, 0.225])
])

# 测试集/验证集的预处理
# 目标:保持数据原始分布,进行一致的评估
test_transform = transforms.Compose([
    transforms.Resize(256),                  # 先将短边缩放到256
    transforms.CenterCrop(224),              # 再从中心裁剪出224x224区域
    transforms.ToTensor(),                   # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 标准化
                         std=[0.229, 0.224, 0.225])
])

关键点总结:

  1. 顺序很重要ToTensor 必须在 Normalize 之前,因为 Normalize 需要输入是张量。
  2. 场景不同,策略不同 :训练时使用 Random 开头的变换来增强数据;测试时使用确定性的变换(如 CenterCrop)以保证结果可复现。
  3. 结合数据集使用 :定义好 transform 后,可以在创建数据集(如 torchvision.datasets.CIFAR10)时,将其作为参数传入,数据加载时就会自动应用这些变换。
相关推荐
码界筑梦坊几秒前
130-基于Python的体育用品销售数据可视化分析系统
开发语言·python·信息可视化·flask·毕业设计
2的n次方_几秒前
健身 Agent:不止视频,更有 AI 人物实时跟练交互
人工智能·音视频·交互·魔珐星云
前端不太难几秒前
CPU+GPU:开启AI推理新时代
人工智能·状态模式
chian-ocean1 分钟前
创业者实操:10 分钟搭建可商业化的交互型 AI 家电导购产品
人工智能
盼小辉丶1 分钟前
PyTorch强化学习实战——Atari游戏包装器
pytorch·深度学习·强化学习
海上彼尚2 分钟前
Nodejs也能写Agent - 6.基础篇 - Agent
前端·人工智能·后端·node.js
码界筑梦坊2 分钟前
131-基于Flask的美国新泽西州自动售货机销售数据可视化分析系统
开发语言·python·信息可视化·数据分析·flask·毕业设计
viperrrrrrrrrr72 分钟前
强化学习入门笔记
人工智能·强化学习
轻刀快马2 分钟前
AI 架构的文艺复兴:用操作系统“内存管理”重构 LLM 状态机 —— 深度解密 Claude Code
人工智能·架构
随身数智备忘录3 分钟前
拆解安全生产法三大核心功能,安全生产法如何解决责任不清与事故追责难
大数据·人工智能·安全