PyTorch transforms的用法

答案来自 千问

需要注意每个方法的输入输出是什么类型的数据,方法需要的参数等。

torchvision.transforms 是 PyTorch 中用于图像预处理和数据增强的核心模块。它的主要作用是将原始的 PIL 图像或 NumPy 数组转换为模型可接受的张量(Tensor)格式,并进行标准化等操作。

简单来说,它就是深度学习中的"图像处理工具箱"。

🧰 核心用法:组合变换 (Compose)

transforms 最常用的模式是使用 transforms.Compose。你可以把多个处理步骤像流水线一样组合在一起,图像会按顺序依次经过这些处理步骤。

python 复制代码
import torchvision.transforms as transforms

# 定义一个处理流水线
transform = transforms.Compose([
    transforms.Resize((224, 224)),      # 1. 调整图像大小
    transforms.ToTensor(),              # 2. 转换为张量
    transforms.Normalize(mean=[0.5], std=[0.5])  # 3. 标准化
])

🛠️ 常用工具 (Transforms) 介绍

以下是一些最常用的图像处理工具:

1. 转换格式:ToTensor

这是几乎所有流程的第一步或中间步骤。它的作用是:

  • 将 PIL 图像或 NumPy 数组转换为 PyTorch 张量(Tensor)。
  • 自动将像素值从 [0, 255] 缩放到 [0.0, 1.0] 范围内。
2. 调整大小:Resize

将图像调整为指定的尺寸,以适应模型的输入要求。

  • transforms.Resize(256): 将图像的短边缩放到256,保持长宽比。
  • transforms.Resize((224, 224)): 将图像直接缩放到指定的宽和高。
3. 裁剪:CenterCrop 与 RandomCrop
  • CenterCrop: 从图像中心裁剪出指定大小的区域。常用于测试阶段,以获得稳定的输入。
  • RandomCrop: 从图像中随机位置裁剪。常用于训练阶段,作为一种数据增强手段,增加数据的多样性。
4. 翻转与旋转
  • RandomHorizontalFlip: 以指定的概率(默认0.5)随机水平翻转图像。这是最常用的数据增强方法之一。
  • RandomRotation : 随机旋转图像指定的角度,例如 transforms.RandomRotation(degrees=30) 会随机旋转 -30 到 +30 度之间的角度。
5. 色彩调整:ColorJitter

随机改变图像的亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue),使模型对色彩变化更具鲁棒性。

6. 标准化:Normalize

使用给定的均值(mean)和标准差(std)对图像张量进行标准化,使数据分布更有利于模型训练。这一步必须ToTensor 之后执行,因为它是针对张量数值进行的运算。

  • 常用参数(基于 ImageNet 数据集统计值):
    • mean=[0.485, 0.456, 0.406]
    • std=[0.229, 0.224, 0.225]

📝 实际应用示例

下面展示如何为训练集和测试集构建不同的预处理流程。

python 复制代码
from torchvision import transforms

# 训练集的数据增强和预处理
# 目标:增加数据多样性,防止模型过拟合
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),        # 随机裁剪并缩放到224x224
    transforms.RandomHorizontalFlip(),       # 随机水平翻转
    transforms.ColorJitter(brightness=0.2,    # 随机调整亮度、对比度等
                           contrast=0.2, 
                           saturation=0.2),
    transforms.ToTensor(),                   # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 标准化
                         std=[0.229, 0.224, 0.225])
])

# 测试集/验证集的预处理
# 目标:保持数据原始分布,进行一致的评估
test_transform = transforms.Compose([
    transforms.Resize(256),                  # 先将短边缩放到256
    transforms.CenterCrop(224),              # 再从中心裁剪出224x224区域
    transforms.ToTensor(),                   # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 标准化
                         std=[0.229, 0.224, 0.225])
])

关键点总结:

  1. 顺序很重要ToTensor 必须在 Normalize 之前,因为 Normalize 需要输入是张量。
  2. 场景不同,策略不同 :训练时使用 Random 开头的变换来增强数据;测试时使用确定性的变换(如 CenterCrop)以保证结果可复现。
  3. 结合数据集使用 :定义好 transform 后,可以在创建数据集(如 torchvision.datasets.CIFAR10)时,将其作为参数传入,数据加载时就会自动应用这些变换。
相关推荐
tzc_fly11 分钟前
深度范式转移:漂移模型(Drifting Models)解析
人工智能
小雨中_20 分钟前
3.5 ReMax:用 Greedy 作为基线的 REINFORCE + RLOO
人工智能·python·深度学习·机器学习·自然语言处理
TImCheng060932 分钟前
方法论:将AI深度嵌入工作流的“场景-工具-SOP”三步法
大数据·人工智能
geneculture35 分钟前
四维矩阵分析:人机互助超级个体与超级OPC关系研究——基于HI×AI、个体×团队、个体×OPC与波士顿矩阵的整合框架
人工智能·百度
智算菩萨43 分钟前
2026年春节后,AI大模型格局彻底变了——Claude 4.6、GPT-5.2与六大国产模型全面横评
人工智能·gpt·ai编程
overmind1 小时前
oeasy Python 116 用列表乱序shuffle来洗牌抓拍玩升级拖拉机
服务器·windows·python
A懿轩A1 小时前
【Java 基础编程】Java 枚举与注解从零到一:Enum 用法 + 常用注解 + 自定义注解实战
java·开发语言·python
狮子座明仔1 小时前
Agent World Model:给智能体造一个“矩阵世界“——无限合成环境驱动的强化学习
人工智能·线性代数·语言模型·矩阵
OpenMiniServer1 小时前
AI 大模型的本质:基于大数据的拟合,而非创造
大数据·人工智能
SmartBrain1 小时前
FastAPI实战(第二部分):用户注册接口开发详解
数据库·人工智能·python·fastapi