pytorch学习(十六):终——解释一些bug

1.torch.tensortorchvision.transforms.ToTensor()的区别

torch.tensortorchvision.transforms.ToTensor() 在 PyTorch 中虽然都用于创建或转换数据为张量(Tensor),但它们的使用场景和目的有所不同。

torch.tensor

torch.tensor 是 PyTorch 中用于创建张量的基本函数。这个函数可以接受各种类型的数据(如列表、NumPy数组等)作为输入,并返回一个具有指定数据类型(dtype)和是否需要在GPU上(device)的张量。这个函数主要用于直接创建或初始化张量,而不是在数据预处理或数据加载的流程中。

torchvision.transforms.ToTensor()

torchvision.transforms.ToTensor() 是 PyTorch 视觉库(torchvision)中的一个转换操作,主要用于将 PIL 图像或 NumPy ndarray(通常是图像数据)转换为 PyTorch 张量。这个转换不仅仅是简单的数据类型转换,它还会执行以下操作:

  1. 将图像数据从 [0, 255] 归一化到 [0.0, 1.0]:这是通过除以 255 来实现的,因为图像通常以 8 位无符号整数(uint8)的形式存储,其值域为 [0, 255]。

  2. 将图像数据从 H x W x C 转换为 C x H x W:这里 H 是高度,W 是宽度,C 是颜色通道数(例如,RGB 图像有 3 个通道)。这种转换是因为 PyTorch 期望的输入张量维度顺序是通道数在前(C x H x W),而 PIL 图像和 NumPy ndarray 的默认维度顺序是高度在前(H x W x C)。

使用场景的区别

  • torch.tensor:当你需要直接创建或初始化一个张量,比如定义一个模型参数、操作或计算的中间结果时,你会使用这个函数。

  • torchvision.transforms.ToTensor():当你处理图像数据时,特别是在数据加载和预处理阶段,你需要将图像数据(无论是 PIL 图像还是 NumPy ndarray)转换为 PyTorch 张量,并准备它们以供模型训练或推理使用。此时,你会使用这个函数。

2.torchvision.transforms.Resize((32,32))报错

该函数只能接受PIL图像。

不可以opencv直接读取然后传给transform,需要用PIL读取。

python 复制代码
# -*- coding: utf-8 -*-  
# File created on 2024/8/12 
# 作者:酷尔
# 公众号:酷尔计算机
import cv2
import torchvision
import torch
from PIL import Image
img=Image.open('dog.png')

transform=torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((32,32)),
        torchvision.transforms.ToTensor()
    ]
)
img=transform(img)
print(img.shape)

3.用GPU训练的模型,测试模型也要把照片转为GPU格式

img=img.cuda()

python 复制代码
# -*- coding: utf-8 -*-  
# File created on 2024/8/12 
# 作者:酷尔
# 公众号:酷尔计算机
import cv2
import torchvision
import torch
from PIL import Image
from torch import nn
img=Image.open('dog.png')

transform=torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((32,32)),
        torchvision.transforms.ToTensor()
    ]
)
img=transform(img)
# print(img.shape)
class Wang(nn.Module):
    def __init__(self):
        super(Wang, self).__init__()
        self.model1=nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10),
        )

    def forward(self,x):
        x=self.model1(x)
        return x
model=torch.load("./models/wang_0.pth")
# print(model)
img=torch.reshape(img,(1,3,32,32))
img=img.cuda()
model.eval()
with torch.no_grad():
    output=model(img)
print(output)

print(output.argmax(1))
相关推荐
灵智工坊LingzhiAI14 分钟前
人体坐姿检测系统项目教程(YOLO11+PyTorch+可视化)
人工智能·pytorch·python
DKPT5 小时前
Java桥接模式实现方式与测试方法
java·笔记·学习·设计模式·桥接模式
zzc9217 小时前
Adobe Illustrator设置的颜色和显示的颜色不对应问题
adobe·bug·illustrator·错误·配色·透明度·底色
好好研究7 小时前
学习栈和队列的插入和删除操作
数据结构·学习
新中地GIS开发老师8 小时前
新发布:26考研院校和专业大纲
学习·考研·arcgis·大学生·遥感·gis开发·地理信息科学
SH11HF9 小时前
小菜狗的云计算之旅,学习了解rsync+sersync实现数据实时同步(详细操作步骤)
学习·云计算
Frank学习路上9 小时前
【IOS】XCode创建firstapp并运行(成为IOS开发者)
开发语言·学习·ios·cocoa·xcode
Chef_Chen10 小时前
从0开始学习计算机视觉--Day07--神经网络
神经网络·学习·计算机视觉
X_StarX12 小时前
【Unity笔记02】订阅事件-自动开门
笔记·学习·unity·游戏引擎·游戏开发·大学生
MingYue_SSS12 小时前
开关电源抄板学习
经验分享·笔记·嵌入式硬件·学习