Transformers

PyTorch Transforms

一、什么是 transforms

transforms 是 PyTorch 中用来处理和预处理图像数据的一组工具。它帮助我们在训练机器学习模型之前,对数据进行规范化和增强处理。能将数据从原始格式转换为更适合模型训练的格式。常见的图像数据变换操作包括:

  1. ToTensor:

    • 将图像从普通格式(例如像素值在 0-255)转换为 PyTorch 能处理的 Tensor 格式,并归一化到 0 到 1 之间。
    • 例如,彩色图像被转换为一个 [3, H, W] 的张量,其中 H 和 W 是图像的高和宽,3 表示 RGB 三通道。
  2. Normalize:

    • 调整张量图像中的像素值,使其均值为零,标准差为一,更易于模型处理。
    • 通常需要提供均值和标准差参数,对每个图像通道分别应用。
  3. RandomCrop:

    • 随机裁剪图像到指定大小,帮助模型学习不同视图,防止过拟合。
  4. RandomHorizontalFlip:

    • 随机水平翻转图像,提高模型对图像左右变换后的识别能力。
  5. Resize:

    • 将图像调整为指定尺寸,确保所有输入图像具有统一大小。
  6. Compose:

    • 用于组合多个变换操作,可以连贯地对每张图像进行一系列处理。

举例说明

假设有一个猫和狗的图片数据集,准备训练一个图像识别模型。原始图片大小不一,就可以使用 transforms 来规范化这些图片:

python 复制代码
import torchvision.transforms as transforms

# 定义一个变换序列
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # 将图像调整到 128x128 像素
    transforms.ToTensor(),         # 转换成 Tensor 格式并归一化
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 对像素值进行归一化
])

# 应用于你的数据集载入器

在这个例子里,transform 先调整图片大小,接着转换成 Tensor 格式,最后进行标准化处理。

总结

• transforms 是处理和增强图像数据的强大工具,尤其在模型训练阶段。

• 通过数据增强(如裁剪、翻转),可以提高模型的泛化能力和鲁棒性。

• 组合变换简化数据预处理过程,确保模型接收到统一标准化的数据。

二、简单的使用

ini 复制代码
import torchvision.transforms as transforms
from PIL import Image

#使用ToTensor()将PIL图像转换为张量
#为什么要使用张量类型?
img_path = "hymenoptera_data/train/ants/0013035.jpg"
img = Image.open(img_path)
print(img)
tensor_tran = transforms.ToTensor()
tentor_img = tensor_tran(img)

截图:

解释 transforms.ToTensor() 的实例及 __call__ 方法

tensor_tran是transforms.ToTensor()类的一个实例 在定义 tensor_tran = transforms.ToTensor() 后,它具有(继承) call 方法,使其可以像函数一样调用

更详细的解释

    • transforms.ToTensor 是一个类,负责将图像转换为张量。
    • 类通常定义了一组属性和方法,包括特殊的 __call__ 方法。
  1. 实例

    • tensor_tran = transforms.ToTensor() 创建了 transforms.ToTensor 类的一个实例。
    • 实例 tensor_tran 继承了类所定义的所有功能,包括 __call__ 方法。
  2. 调用行为

    • tensor_tran(img) 涉及调用 tensor_tran__call__ 方法。
    • 这表示 tensor_tran 可以像函数一样使用,用来处理输入数据 img

Ps:在 Python 中,call 方法是一个特殊的方法,它允许对象像函数一样被调用。在定义了一个类并为这个类实现 call 方法之后,创建的对象实例就可以直接使用括号 () 来调用,就像调用函数一样。 这是 Python 中一种让对象更具灵活性和可操作性的技术。

ruby 复制代码
class MyCallable:    
    def __call__(self, x):        
        return x * x
obj = MyCallable()
result = obj(5)  # instance is called like a function
print(result)  # outputs: 25

• 在这个例子中,MyCallable 类定义了一个 call 方法。通过实例化 MyCallable 后,obj(5) 实际上调用的是 call 方法。

读取Tensor图片

import 复制代码
from PIL import Image
from torch.utils.tensorboard import SummaryWriter  
img_path = "hymenoptera_data/train/ants/0013035.jpg"
# 打开图像
img = Image.open(img_path)
# 创建一个 SummaryWriter 实例
writer = SummaryWriter("logs")
# 创建 ToTensor 变换
tensor_tran = transforms.ToTensor()
# 将图像转换为张量
tensor_img = tensor_tran(img)
# 将转换后的张量写入 TensorBoard
writer.add_image("Tensor_img", tensor_img)  
# 关闭写入器
writer.close()
相关推荐
golitter.3 小时前
pytorch的 Size[3] 和 Size[3,1] 区别
人工智能·pytorch·python
盼小辉丶3 小时前
PyTorch生成式人工智能(24)——使用PyTorch构建Transformer模型
pytorch·深度学习·transformer
旧时光巷8 小时前
【机器学习③】 | CNN篇
人工智能·pytorch·python·机器学习·cnn·卷积神经网络·lenet-5
wow_DG16 小时前
【Pytorch✨】LSTM01 入门
人工智能·pytorch·lstm
旧时光巷20 小时前
【深度学习②】| DNN篇
人工智能·pytorch·深度学习·dnn·模型训练·手写数字识别·深度神经网络
一碗白开水一20 小时前
【模型细节】FPN经典网络模型 (Feature Pyramid Networks)详解及其变形优化
网络·人工智能·pytorch·深度学习·计算机视觉
瘦的可以下饭了1 天前
Tensorboard
pytorch
爱分享的飘哥1 天前
第三十七章:文生图的炼金术:Stable Diffusion完整工作流深度解析
人工智能·pytorch·stable diffusion·文生图·ai绘画·代码实战·cfg
wow_DG2 天前
【PyTorch✨】01 初识PyTorch
人工智能·pytorch·python