Python打卡训练营学习记录Day38

知识点回顾:

  1. Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
  2. Dataloader类
  3. minist手写数据集的了解

**作业:**了解下cifar数据集,尝试获取其中一张图片

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt
import numpy as np
 
# 设置随机种子,确保结果可复现
torch.manual_seed(42)
 
 
# 1. 数据预处理,该写法非常类似于管道pipeline
# transforms 模块提供了一系列常用的图像预处理操作
 
# 先归一化,再标准化
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量并归一化到[0,1]
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # CIFAR-10数据集的均值和标准差(R/G/B通道)
])
 
 
# 2. 加载CIFAR-10数据集,如果没有会自动下载(CIFAR-10是32x32彩色图像,共10类)
train_dataset = datasets.CIFAR10(  # 替换为CIFAR10数据集类
    root='./data',       # 数据存储目录(与原MNIST路径一致,会自动新建CIFAR10子目录)
    train=True,          # True加载训练集(50000张),False加载测试集(10000张)
    download=True,       # 本地无数据时自动下载(约163MB,首次运行需等待)
    transform=transform  # 沿用原预处理管道(注意:CIFAR-10是3通道,建议后续调整Normalize的均值和标准差)
)
 
test_dataset = datasets.CIFAR10(  # 替换为CIFAR10数据集类
    root='./data',       # 与训练集共用存储目录
    train=False,         # 加载测试集用于模型评估
    transform=transform  # 保持与训练集相同的预处理
    # download=True      # 若训练集已下载,测试集可省略(或保留以确保完整性)
)
 
 
import matplotlib.pyplot as plt
 
# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签
 
 
 
# CIFAR-10数据集的简化版本(32x32彩色图像,10类)
class CIFAR10(Dataset):
    def __init__(self, root, train=True, transform=None):
        # 初始化:加载图片路径和标签
        self.data, self.targets = fetch_cifar10_data(root, train) # 假设fetch_cifar10_data用于加载CIFAR-10数据
        self.transform = transform # 预处理操作
        
    def __len__(self): 
        return len(self.data)  # 返回样本总数
    
    def __getitem__(self, idx): # 获取指定索引的样本
        # 获取指定索引的图像和标签
        img, target = self.data[idx], self.targets[idx]
        
        # 应用图像预处理(如ToTensor、Normalize)
        if self.transform is not None: # 如果有预处理操作
            img = self.transform(img) # 转换图像格式
        # 这里假设 img 是一个 PIL 图像对象,transform 会将其转换为张量并进行归一化
            
        return img, target  # 返回处理后的图像和标签
 
 
# 可视化原始图像(需要反归一化,适配CIFAR-10的3通道彩色图像)
def imshow(img):
    # 使用CIFAR-10的标准差和均值进行反标准化(顺序对应R/G/B通道)
    img = img * torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1) + torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
    npimg = img.numpy()
    #调整维度顺序(PyTorch张量是[C,H,W],plt需要[H,W,C])并移除灰度映射
    plt.imshow(np.transpose(npimg, (1, 2, 0))) 
    plt.show()
 
print(f"Label: {label}")
imshow(image)

@浙大疏锦行

相关推荐
西岸行者5 天前
学习笔记:SKILLS 能帮助更好的vibe coding
笔记·学习
悠哉悠哉愿意5 天前
【单片机学习笔记】串口、超声波、NE555的同时使用
笔记·单片机·学习
别催小唐敲代码5 天前
嵌入式学习路线
学习
毛小茛5 天前
计算机系统概论——校验码
学习
babe小鑫5 天前
大专经济信息管理专业学习数据分析的必要性
学习·数据挖掘·数据分析
winfreedoms5 天前
ROS2知识大白话
笔记·学习·ros2
在这habit之下5 天前
Linux Virtual Server(LVS)学习总结
linux·学习·lvs
我想我不够好。5 天前
2026.2.25监控学习
学习
im_AMBER5 天前
Leetcode 127 删除有序数组中的重复项 | 删除有序数组中的重复项 II
数据结构·学习·算法·leetcode
CodeJourney_J5 天前
从“Hello World“ 开始 C++
c语言·c++·学习