6.6 day38

知识点回顾:

  1. Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
  2. Dataloader类
  3. minist手写数据集的了解

了解下cifar数据集,尝试获取其中一张图片

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt
 
# 设置随机种子,确保结果可复现
torch.manual_seed(42)
# 先归一化,再标准化
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量并归一化到[0,1]
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])
 
# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)
 
test_dataset = datasets.CIFAR10(
    root='./data',
    train=False,
    transform=transform
)
 
import matplotlib.pyplot as plt
 
# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签
 
 
# minist数据集的简化版本
class MNIST(Dataset):
    def __init__(self, root, train=True, transform=None):
        # 初始化:加载图片路径和标签
        self.data, self.targets = fetch_mnist_data(root, train) # 这里假设 fetch_mnist_data 是一个函数,用于加载 MNIST 数据集的图片路径和标签
        self.transform = transform # 预处理操作
        
    def __len__(self): 
        return len(self.data)  # 返回样本总数
    
    def __getitem__(self, idx): # 获取指定索引的样本
        # 获取指定索引的图像和标签
        img, target = self.data[idx], self.targets[idx]
        
        # 应用图像预处理(如ToTensor、Normalize)
        if self.transform is not None: # 如果有预处理操作
            img = self.transform(img) # 转换图像格式
        # 这里假设 img 是一个 PIL 图像对象,transform 会将其转换为张量并进行归一化
            
        return img, target  # 返回处理后的图像和标签
 
# 可视化原始图像(需要反归一化)
def imshow(img):
    img = img * 0.3081 + 0.1307  # 反标准化
    npimg = img.numpy()
    plt.imshow(npimg[0], cmap='gray') # 显示灰度图像
    plt.show()
 
print(f"Label: {label}")
imshow(image)
相关推荐
上去我就QWER13 分钟前
Python下常用开源库
python·1024程序员节
程序员杰哥1 小时前
Pytest之收集用例规则与运行指定用例
自动化测试·软件测试·python·测试工具·职场和发展·测试用例·pytest
Jyywww1212 小时前
Python基于实战练习的知识点回顾
开发语言·python
朝朝辞暮i2 小时前
从0开始学python(day2)
python
程序员黄同学3 小时前
Python中的列表推导式、字典推导式和集合推导式的性能和应用场景?
开发语言·python
AI小云3 小时前
【Python高级编程】类和实例化
开发语言·人工智能·python
道之极万物灭3 小时前
Python uv虚拟环境管理工具详解
开发语言·python·uv
高洁013 小时前
【无标题】大模型-模型压缩:量化、剪枝、蒸馏、二值化 (2
人工智能·python·深度学习·神经网络·知识图谱
一晌小贪欢3 小时前
Python爬虫第10课:分布式爬虫架构与Scrapy-Redis
分布式·爬虫·python·网络爬虫·python爬虫·python3
代码AI弗森4 小时前
Python × NumPy」 vs 「JavaScript × TensorFlow.js」生态全景图
javascript·python·numpy