DAY38 Dataset和DataLoader

@浙大疏锦行

python 复制代码
import torch 
import torch.nn as nn
import torch.optim as optim 
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets,transforms
import matplotlib.pyplot as plt

torch.manual_seed(42)
python 复制代码
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])
python 复制代码
train_dataset=datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset=datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)
python 复制代码
import matplotlib.pyplot as plt
sample_idx=torch.randint(0,len(train_dataset),size=(1,)).item()
image,label=train_dataset[sample_idx]
python 复制代码
from torchvision import datasets, transforms
class MNIST(Dataset):
    def __init____init__(self,root,train=True,transform=None):
        self.data,self.targets=fetch_mnist_data(root,train)
        self.transform=transform

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        img,target=self.data[idx],self.targets[idx]

        if self.transform is not None:
            img=self.transform(img)
    
        return img,target
python 复制代码
def imshow(img):
    img=img*0.3081+0.1307
    nping=img.numpy()
    plt.imshow(nping[0],cmap='gray')
    plt.show()

print(f"Label:{label}")
imshow(image)
python 复制代码
train_loader=DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True
)

test_loader=DataLoader(
    test_dataset,
    batch_size=1000
)

下载cifar数据集并获取其中一张图片

python 复制代码
import torchvision
import numpy as np
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

train_dataset=torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

sample_idx=0
image,label=train_dataset[sample_idx]

# 类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"图片形状: {image.shape}")
print(f"标签: {label} - {classes[label]}")

def imshow(img):
    img=img*0.5+0.5
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.axis('off')

imshow(image)
plt.title(f'Label: {classes[label]} ({label})')
plt.show()
相关推荐
HDO清风1 分钟前
CASIA-HWDB2.x 数据集DGRL文件解析(python)
开发语言·人工智能·pytorch·python·目标检测·计算机视觉·restful
weixin_499771553 分钟前
Python上下文管理器(with语句)的原理与实践
jvm·数据库·python
weixin_452159556 分钟前
高级爬虫技巧:处理JavaScript渲染(Selenium)
jvm·数据库·python
多米Domi01112 分钟前
0x3f 第48天 面向实习的八股背诵第五天 + 堆一题 背了JUC的题,java.util.Concurrency
开发语言·数据结构·python·算法·leetcode·面试
深蓝海拓18 分钟前
PySide6从0开始学习的笔记(二十六) 重写Qt窗口对象的事件(QEvent)处理方法
笔记·python·qt·学习·pyqt
纠结哥_Shrek18 分钟前
外贸选品工程师的工作流程和方法论
python·机器学习
小汤圆不甜不要钱20 分钟前
「Datawhale」RAG技术全栈指南 Task 5
python·llm·rag
A懿轩A1 小时前
【Java 基础编程】Java 变量与八大基本数据类型详解:从声明到类型转换,零基础也能看懂
java·开发语言·python
Tansmjs1 小时前
使用Python自动收发邮件
jvm·数据库·python