DAY38 Dataset和DataLoader

@浙大疏锦行

python 复制代码
import torch 
import torch.nn as nn
import torch.optim as optim 
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets,transforms
import matplotlib.pyplot as plt

torch.manual_seed(42)
python 复制代码
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])
python 复制代码
train_dataset=datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset=datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)
python 复制代码
import matplotlib.pyplot as plt
sample_idx=torch.randint(0,len(train_dataset),size=(1,)).item()
image,label=train_dataset[sample_idx]
python 复制代码
from torchvision import datasets, transforms
class MNIST(Dataset):
    def __init____init__(self,root,train=True,transform=None):
        self.data,self.targets=fetch_mnist_data(root,train)
        self.transform=transform

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        img,target=self.data[idx],self.targets[idx]

        if self.transform is not None:
            img=self.transform(img)
    
        return img,target
python 复制代码
def imshow(img):
    img=img*0.3081+0.1307
    nping=img.numpy()
    plt.imshow(nping[0],cmap='gray')
    plt.show()

print(f"Label:{label}")
imshow(image)
python 复制代码
train_loader=DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True
)

test_loader=DataLoader(
    test_dataset,
    batch_size=1000
)

下载cifar数据集并获取其中一张图片

python 复制代码
import torchvision
import numpy as np
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

train_dataset=torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

sample_idx=0
image,label=train_dataset[sample_idx]

# 类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"图片形状: {image.shape}")
print(f"标签: {label} - {classes[label]}")

def imshow(img):
    img=img*0.5+0.5
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.axis('off')

imshow(image)
plt.title(f'Label: {classes[label]} ({label})')
plt.show()
相关推荐
Michelle80232 小时前
24大数据 16-1 函数复习
python
dagouaofei2 小时前
AI自动生成PPT工具对比分析,效率差距明显
人工智能·python·powerpoint
ku_code_ku2 小时前
python bert_score使用本地模型的方法
开发语言·python·bert
祁思妙想3 小时前
linux常用命令
开发语言·python
流水落花春去也3 小时前
用yolov8 训练,最后形成训练好的文件。 并且能在后续项目使用
python
Serendipity_Carl3 小时前
数据可视化实战之链家
python·数据可视化·数据清洗
小裴(碎碎念版)3 小时前
文件读写常用操作
开发语言·爬虫·python
TextIn智能文档云平台3 小时前
图片转文字后怎么输入大模型处理
前端·人工智能·python
ujainu4 小时前
Python学习第一天:保留字和标识符
python·学习·标识符·保留字