DAY38 Dataset和DataLoader

@浙大疏锦行

python 复制代码
import torch 
import torch.nn as nn
import torch.optim as optim 
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets,transforms
import matplotlib.pyplot as plt

torch.manual_seed(42)
python 复制代码
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])
python 复制代码
train_dataset=datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset=datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)
python 复制代码
import matplotlib.pyplot as plt
sample_idx=torch.randint(0,len(train_dataset),size=(1,)).item()
image,label=train_dataset[sample_idx]
python 复制代码
from torchvision import datasets, transforms
class MNIST(Dataset):
    def __init____init__(self,root,train=True,transform=None):
        self.data,self.targets=fetch_mnist_data(root,train)
        self.transform=transform

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        img,target=self.data[idx],self.targets[idx]

        if self.transform is not None:
            img=self.transform(img)
    
        return img,target
python 复制代码
def imshow(img):
    img=img*0.3081+0.1307
    nping=img.numpy()
    plt.imshow(nping[0],cmap='gray')
    plt.show()

print(f"Label:{label}")
imshow(image)
python 复制代码
train_loader=DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True
)

test_loader=DataLoader(
    test_dataset,
    batch_size=1000
)

下载cifar数据集并获取其中一张图片

python 复制代码
import torchvision
import numpy as np
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

train_dataset=torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

sample_idx=0
image,label=train_dataset[sample_idx]

# 类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"图片形状: {image.shape}")
print(f"标签: {label} - {classes[label]}")

def imshow(img):
    img=img*0.5+0.5
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.axis('off')

imshow(image)
plt.title(f'Label: {classes[label]} ({label})')
plt.show()
相关推荐
寻星探路2 分钟前
【Python 全栈测开之路】Python 进阶:库的使用与第三方生态(标准库+Pip+实战)
java·开发语言·c++·python·ai·c#·pip
子夜江寒2 小时前
基于 OpenCV 的图像形态学与边缘检测
python·opencv·计算机视觉
少林码僧8 小时前
2.31 机器学习神器项目实战:如何在真实项目中应用XGBoost等算法
人工智能·python·算法·机器学习·ai·数据挖掘
智航GIS9 小时前
10.4 Selenium:Web 自动化测试框架
前端·python·selenium·测试工具
jarreyer9 小时前
摄像头相关记录
python
宝贝儿好9 小时前
【强化学习】第六章:无模型控制:在轨MC控制、在轨时序差分学习(Sarsa)、离轨学习(Q-learning)
人工智能·python·深度学习·学习·机器学习·机器人
大、男人9 小时前
python之asynccontextmanager学习
开发语言·python·学习
默默前行的虫虫10 小时前
nicegui文件上传归纳
python
一个没有本领的人10 小时前
UIU-Net运行记录
python
国强_dev10 小时前
Python 的“非直接原因”报错
开发语言·python