1、dataset加载数据集
python
dataset_tranform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
train_set = torchvision.datasets.CIFAR10(root="./train_dataset",train=True,transform=dataset_tranform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./train_dataset",train=False,transform=dataset_tranform,download=True)
print(test_set[0])
writer = SummaryWriter('p10')
for i in range(10):
img,target = test_set[i]
writer.add_image("test_set",img,i)
writer.close()
下载这个CIFAR10这个数据集,通过tensorboard查看一下
2.dataloader从数据集中加载数据
python
test_data = torchvision.datasets.CIFAR10(root="./train_dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)
writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
imgs ,targets = data
writer .add_images("test_data",imgs,step)
step = step + 1
writer.close()
我们从CIFAR10这个数据集中,每次加载64张图片