DAY38 Dataset和DataLoader

@浙大疏锦行

python 复制代码
import torch 
import torch.nn as nn
import torch.optim as optim 
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets,transforms
import matplotlib.pyplot as plt

torch.manual_seed(42)
python 复制代码
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])
python 复制代码
train_dataset=datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset=datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)
python 复制代码
import matplotlib.pyplot as plt
sample_idx=torch.randint(0,len(train_dataset),size=(1,)).item()
image,label=train_dataset[sample_idx]
python 复制代码
from torchvision import datasets, transforms
class MNIST(Dataset):
    def __init____init__(self,root,train=True,transform=None):
        self.data,self.targets=fetch_mnist_data(root,train)
        self.transform=transform

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        img,target=self.data[idx],self.targets[idx]

        if self.transform is not None:
            img=self.transform(img)
    
        return img,target
python 复制代码
def imshow(img):
    img=img*0.3081+0.1307
    nping=img.numpy()
    plt.imshow(nping[0],cmap='gray')
    plt.show()

print(f"Label:{label}")
imshow(image)
python 复制代码
train_loader=DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True
)

test_loader=DataLoader(
    test_dataset,
    batch_size=1000
)

下载cifar数据集并获取其中一张图片

python 复制代码
import torchvision
import numpy as np
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

train_dataset=torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

sample_idx=0
image,label=train_dataset[sample_idx]

# 类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"图片形状: {image.shape}")
print(f"标签: {label} - {classes[label]}")

def imshow(img):
    img=img*0.5+0.5
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.axis('off')

imshow(image)
plt.title(f'Label: {classes[label]} ({label})')
plt.show()
相关推荐
JaydenAI42 分钟前
[LangChain之链]LangChain的Chain——由Runnable构建的管道
python·langchain
kali-Myon43 分钟前
2025春秋杯网络安全联赛冬季赛-day3
python·安全·web安全·ai·php·web·ctf
AbsoluteLogic1 小时前
Python——彻底明白Super() 该如何使用
python
小猪咪piggy1 小时前
【Python】(4) 列表和元组
开发语言·python
墨理学AI1 小时前
一文学会一点python数据分析-小白原地进阶(mysql 安装 - mysql - python 数据分析 - 学习阶段梳理)
python·mysql·数据分析
数研小生1 小时前
亚马逊商品列表API详解
前端·数据库·python·pandas
独好紫罗兰1 小时前
对python的再认识-基于数据结构进行-a005-元组-CRUD
开发语言·数据结构·python
jianghua0011 小时前
Python中的简单爬虫
爬虫·python·信息可视化
喵手2 小时前
Python爬虫实战:针对Python官网,精准提取出每一个历史版本的版本号、发布日期以及对应的文档/详情页链接等信息,并最终清洗为标准化的CSV文件!
爬虫·python·爬虫实战·零基础python爬虫教学·python官方数据采集·采集历史版本版本号等信息·导出csv文件
databook2 小时前
像搭积木一样思考:数据科学中的“自下而上”之道
python·数据挖掘·数据分析