python打卡day38

Dataset和Dataloader类

  1. Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
  2. Dataloader类
  3. minist手写数据集的了解

**作业:**了解下cifar数据集,尝试获取其中一张图片

python 复制代码
import torch
import torchvision
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 设置随机种子以确保结果可重复
torch.manual_seed(42)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))  # CIFAR的标准化参数
])
train_dataset = torchvision.datasets.CIFAR10(
    root='./dataCIFAR',  # 数据存放的路径
    train=True,     # 使用训练集
    download=True,  # 如果没有数据,就下载
    transform=transform
)

# 定义类别
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# 随机选择一张图片
idx = torch.randint(0, len(train_dataset), size=(1,))
img, label = train_dataset[idx]

# 反标准化函数
def denormalize(x):
    mean = torch.tensor([0.4914, 0.4822, 0.4465])
    std = torch.tensor([0.2470, 0.2435, 0.2616])
    # CIFAR-10是彩色图像,需要对所有通道进行反标准化
    return x * std[:, None, None] + mean[:, None, None]

# 显示图片
plt.figure()
plt.imshow(denormalize(img).permute(1, 2, 0))  # 调整通道顺序以正确显示彩色图像
plt.title(f'Label: {classes[label]}')
plt.axis('off')
plt.show()


# 3. 创建数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
    shuffle=True # 随机打乱数据
)

@浙大疏锦行

相关推荐
databook12 小时前
Manim实现闪光轨迹特效
后端·python·动效
Juchecar13 小时前
解惑:NumPy 中 ndarray.ndim 到底是什么?
python
用户83562907805114 小时前
Python 删除 Excel 工作表中的空白行列
后端·python
Json_14 小时前
使用python-fastApi框架开发一个学校宿舍管理系统-前后端分离项目
后端·python·fastapi
数据智能老司机20 小时前
精通 Python 设计模式——分布式系统模式
python·设计模式·架构
数据智能老司机21 小时前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言
数据智能老司机21 小时前
精通 Python 设计模式——测试模式
python·设计模式·架构
数据智能老司机21 小时前
精通 Python 设计模式——性能模式
python·设计模式·架构
c8i1 天前
drf初步梳理
python·django
每日AI新事件1 天前
python的异步函数
python