1.数据集:
- 自定义数据集
- transforms中的类
如何将数据集和transforms结合在一起?
以CIFAR10为列
2.CIFAR10数据集的下载与导入
python
import torchvision
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
如果下载比较慢,可以把下载链接放到迅雷中进行下载。后创建dataset文件夹,将下载的数据集放入,即可运行。
导入的datasets和之前讲解的Dataset类是很相似的,实现了__getitem__()方法和__len__()方法。
3.将CIFAR10数据集的图片转换成tensor类型
python
import torchvision
dataset_transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
如果tensorboard的step不全可能是因为没加writer.close()
4.torchvision中的其他数据集
按住Ctrl键再点击可以查看源代码,找到url链接之后可以使用迅雷下载。