python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
torch.manual_seed(42)
python
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))
])
python
train_dataset=datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataset=datasets.MNIST(
root='./data',
train=False,
transform=transform
)
python
import matplotlib.pyplot as plt
sample_idx=torch.randint(0,len(train_dataset),size=(1,)).item()
image,label=train_dataset[sample_idx]
python
from torchvision import datasets, transforms
class MNIST(Dataset):
def __init____init__(self,root,train=True,transform=None):
self.data,self.targets=fetch_mnist_data(root,train)
self.transform=transform
def __len__(self):
return len(self.data)
def __getitem__(self,idx):
img,target=self.data[idx],self.targets[idx]
if self.transform is not None:
img=self.transform(img)
return img,target
python
def imshow(img):
img=img*0.3081+0.1307
nping=img.numpy()
plt.imshow(nping[0],cmap='gray')
plt.show()
print(f"Label:{label}")
imshow(image)

python
train_loader=DataLoader(
train_dataset,
batch_size=64,
shuffle=True
)
test_loader=DataLoader(
test_dataset,
batch_size=1000
)
下载cifar数据集并获取其中一张图片
python
import torchvision
import numpy as np
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
train_dataset=torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
sample_idx=0
image,label=train_dataset[sample_idx]
# 类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
print(f"图片形状: {image.shape}")
print(f"标签: {label} - {classes[label]}")
def imshow(img):
img=img*0.5+0.5
npimg=img.numpy()
plt.imshow(np.transpose(npimg,(1,2,0)))
plt.axis('off')
imshow(image)
plt.title(f'Label: {classes[label]} ({label})')
plt.show()