Transformers

PyTorch Transforms

一、什么是 transforms

transforms 是 PyTorch 中用来处理和预处理图像数据的一组工具。它帮助我们在训练机器学习模型之前,对数据进行规范化和增强处理。能将数据从原始格式转换为更适合模型训练的格式。常见的图像数据变换操作包括:

  1. ToTensor:

    • 将图像从普通格式(例如像素值在 0-255)转换为 PyTorch 能处理的 Tensor 格式,并归一化到 0 到 1 之间。
    • 例如,彩色图像被转换为一个 [3, H, W] 的张量,其中 H 和 W 是图像的高和宽,3 表示 RGB 三通道。
  2. Normalize:

    • 调整张量图像中的像素值,使其均值为零,标准差为一,更易于模型处理。
    • 通常需要提供均值和标准差参数,对每个图像通道分别应用。
  3. RandomCrop:

    • 随机裁剪图像到指定大小,帮助模型学习不同视图,防止过拟合。
  4. RandomHorizontalFlip:

    • 随机水平翻转图像,提高模型对图像左右变换后的识别能力。
  5. Resize:

    • 将图像调整为指定尺寸,确保所有输入图像具有统一大小。
  6. Compose:

    • 用于组合多个变换操作,可以连贯地对每张图像进行一系列处理。

举例说明

假设有一个猫和狗的图片数据集,准备训练一个图像识别模型。原始图片大小不一,就可以使用 transforms 来规范化这些图片:

python 复制代码
import torchvision.transforms as transforms

# 定义一个变换序列
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # 将图像调整到 128x128 像素
    transforms.ToTensor(),         # 转换成 Tensor 格式并归一化
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 对像素值进行归一化
])

# 应用于你的数据集载入器

在这个例子里,transform 先调整图片大小,接着转换成 Tensor 格式,最后进行标准化处理。

总结

• transforms 是处理和增强图像数据的强大工具,尤其在模型训练阶段。

• 通过数据增强(如裁剪、翻转),可以提高模型的泛化能力和鲁棒性。

• 组合变换简化数据预处理过程,确保模型接收到统一标准化的数据。

二、简单的使用

ini 复制代码
import torchvision.transforms as transforms
from PIL import Image

#使用ToTensor()将PIL图像转换为张量
#为什么要使用张量类型?
img_path = "hymenoptera_data/train/ants/0013035.jpg"
img = Image.open(img_path)
print(img)
tensor_tran = transforms.ToTensor()
tentor_img = tensor_tran(img)

截图:

解释 transforms.ToTensor() 的实例及 __call__ 方法

tensor_tran是transforms.ToTensor()类的一个实例 在定义 tensor_tran = transforms.ToTensor() 后,它具有(继承) call 方法,使其可以像函数一样调用

更详细的解释

    • transforms.ToTensor 是一个类,负责将图像转换为张量。
    • 类通常定义了一组属性和方法,包括特殊的 __call__ 方法。
  1. 实例

    • tensor_tran = transforms.ToTensor() 创建了 transforms.ToTensor 类的一个实例。
    • 实例 tensor_tran 继承了类所定义的所有功能,包括 __call__ 方法。
  2. 调用行为

    • tensor_tran(img) 涉及调用 tensor_tran__call__ 方法。
    • 这表示 tensor_tran 可以像函数一样使用,用来处理输入数据 img

Ps:在 Python 中,call 方法是一个特殊的方法,它允许对象像函数一样被调用。在定义了一个类并为这个类实现 call 方法之后,创建的对象实例就可以直接使用括号 () 来调用,就像调用函数一样。 这是 Python 中一种让对象更具灵活性和可操作性的技术。

ruby 复制代码
class MyCallable:    
    def __call__(self, x):        
        return x * x
obj = MyCallable()
result = obj(5)  # instance is called like a function
print(result)  # outputs: 25

• 在这个例子中,MyCallable 类定义了一个 call 方法。通过实例化 MyCallable 后,obj(5) 实际上调用的是 call 方法。

读取Tensor图片

import 复制代码
from PIL import Image
from torch.utils.tensorboard import SummaryWriter  
img_path = "hymenoptera_data/train/ants/0013035.jpg"
# 打开图像
img = Image.open(img_path)
# 创建一个 SummaryWriter 实例
writer = SummaryWriter("logs")
# 创建 ToTensor 变换
tensor_tran = transforms.ToTensor()
# 将图像转换为张量
tensor_img = tensor_tran(img)
# 将转换后的张量写入 TensorBoard
writer.add_image("Tensor_img", tensor_img)  
# 关闭写入器
writer.close()
相关推荐
Francek Chen4 小时前
【深度学习计算机视觉】09:语义分割和数据集
人工智能·pytorch·深度学习·计算机视觉·数据集·语义分割
虚行7 小时前
PyCharm中搭建PyTorch和YOLOv10开发环境
pytorch·yolo·pycharm
StarPrayers.19 小时前
基于PyTorch的CIFAR10加载与TensorBoard可视化实践
人工智能·pytorch·python·深度学习·机器学习
西柚小萌新2 天前
【深入浅出PyTorch】--3.1.PyTorch组成模块1
人工智能·pytorch·python
西猫雷婶2 天前
random.shuffle()函数随机打乱数据
开发语言·pytorch·python·学习·算法·线性回归·numpy
深栈2 天前
机器学习:线性回归
人工智能·pytorch·python·机器学习·线性回归·sklearn
蒋星熠2 天前
TensorFlow与PyTorch深度对比分析:从基础原理到实战选择的完整指南
人工智能·pytorch·python·深度学习·ai·tensorflow·neo4j
、、、、南山小雨、、、、2 天前
Pytorch强化学习demo
pytorch·深度学习·机器学习·强化学习
炘东5922 天前
vscode连接算力平台
pytorch·vscode·深度学习·gpu算力
西猫雷婶2 天前
pytorch基本运算-torch.normal()函数输出多维数据时,如何绘制正态分布函数图
人工智能·pytorch·python·深度学习·神经网络·机器学习·线性回归