DAY38 Dataset和DataLoader

@浙大疏锦行

python 复制代码
import torch 
import torch.nn as nn
import torch.optim as optim 
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets,transforms
import matplotlib.pyplot as plt

torch.manual_seed(42)
python 复制代码
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])
python 复制代码
train_dataset=datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset=datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)
python 复制代码
import matplotlib.pyplot as plt
sample_idx=torch.randint(0,len(train_dataset),size=(1,)).item()
image,label=train_dataset[sample_idx]
python 复制代码
from torchvision import datasets, transforms
class MNIST(Dataset):
    def __init____init__(self,root,train=True,transform=None):
        self.data,self.targets=fetch_mnist_data(root,train)
        self.transform=transform

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        img,target=self.data[idx],self.targets[idx]

        if self.transform is not None:
            img=self.transform(img)
    
        return img,target
python 复制代码
def imshow(img):
    img=img*0.3081+0.1307
    nping=img.numpy()
    plt.imshow(nping[0],cmap='gray')
    plt.show()

print(f"Label:{label}")
imshow(image)
python 复制代码
train_loader=DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True
)

test_loader=DataLoader(
    test_dataset,
    batch_size=1000
)

下载cifar数据集并获取其中一张图片

python 复制代码
import torchvision
import numpy as np
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

train_dataset=torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

sample_idx=0
image,label=train_dataset[sample_idx]

# 类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"图片形状: {image.shape}")
print(f"标签: {label} - {classes[label]}")

def imshow(img):
    img=img*0.5+0.5
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.axis('off')

imshow(image)
plt.title(f'Label: {classes[label]} ({label})')
plt.show()
相关推荐
zzb15807 分钟前
Agent案例-智能文档问答助手
java·人工智能·笔记·python
HP-Patience14 分钟前
【Python爬虫常见错误】- AJAX动态加载数据爬取
爬虫·python·ajax
青瓷程序设计19 分钟前
【基于 YOLO的咖啡豆果实成熟度检测系统】+ Python+算法模型+目标检测+2026原创
python·算法·yolo
天才测试猿19 分钟前
Python接口自动化测试之Token详解及应用
自动化测试·软件测试·python·测试工具·职场和发展·测试用例·接口测试
童园管理札记28 分钟前
2026实测|GPT-4.5+Agent智能体:3小时搭建企业级客服系统,附完整源码与部署教程(二)
人工智能·python
:mnong34 分钟前
附图报价系统设计分析3
python·openvino
AmyLin_200135 分钟前
【pdf2md-2:关键核心】PDF 转 Markdown 技术拆解:两阶段流水线、四级标题检测与段落智能合并
windows·python·pdf·pip·pdf2md
薛不痒38 分钟前
Llamafactory的使用(1)
人工智能·python·llama
不喝水的鱼儿39 分钟前
KT Qwen3.5-35B-A3B 记录
java·前端·python
小陈工1 小时前
Python Web开发入门(三):配置文件管理与环境变量最佳实践
开发语言·jvm·数据库·python·oracle·性能优化·开源