文章目录
引言
- 这里主要是看一下如何加载数据集,并且生成批次训练的数据。
- 最大的收获是,知道了如何在训练阶段提高模型训练的性能
- 增加batch_size
- 增加num_worker
- 数据预加载
正文
图像分类数据集
主要包介绍
-
这个模块主要是将如何加载数据集,并且生成一个迭代器,每一次访问都会俺批次生成数据。
-
具体应用到以下几个功能:
-
torchvision.datasets:获取数据集
-
这个包拥有很多用于计算机视觉处理的功能 ,这个包主要有一些公开常用的计算机的视觉数据集,比如说mnist还有fashion-mnist等。
-
这个包中的数据集可以直接被dataloader调用,会方便很多
-
dataset这个类还可以被继承实现,制作自己的dataset类
-
-
transforms:
- 图像预处理还有数据增强功能专用包,可以单独使用,也可以多个功能按照顺序进行组合compose,作为一个预处理函数。
-
utils.data.DataLoader
- 自动批量加载或训练数据的功能
-
主要流程
- 在加载数据集时,需要按照如下流程进行处理:
-
制定数据预处理的环节,并组合为完整的流程
- 使用transform实现图片的剪裁还有重置大小等基本预处理操作
- 将所有操作进行组合
-
获取数据集,并转为dataset类
- 继承或者直接使用torchvision.dataset类
-
生成批量获取数据集dataloader加载生活器
- 生成DataLoader实例
-
逐批次验证数据集
-
具体代码
python
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))
# 逐批次遍历数据
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
练习
问题一
-
在加载训练参数的过程中,影响模型的性能的参数有哪些?
-
batch_size :表示加载到内存中的数据量,越大,所需要的内存越多,反之亦然。
-
DataLoader(num_workers = ?) :表示用于加载数据的线程数,线程越多,加载的越快 ,同样的需要的内存越多。
-
问题二
- pytorch中的数据迭代器的性能非常重要,有哪些方式可以改进它?
- DataLoader 的 persistent_workers 参数 :
- 控制在每一个训练epoch后不需要关闭或者重启数据加载工作的进程
- persistent_worker = True
- 使用数据预取Prefetching
- GPU在执行任务的同时,CPU可以预先加载下一批数据
- num_wokrer
- 提高加载数据的进程数量,提高运算效率
- pin_memory加速数据传输
- pin_memory = True
- 加速数据从CPU到GPU的过程
- DataLoader 的 persistent_workers 参数 :
pytorch提供的其他的数据集
图像分类数据集
CIFAR-10/CIFAR-100: 包含 10 类(CIFAR-10)或 100 类(CIFAR-100)的小图像。
MNIST: 手写数字数据集。
Fashion-MNIST: 与 MNIST 类似,但用于衣物分类。
ImageNet: 一个大规模的图像分类数据集。
SVHN (Street View House Numbers): 用于数字识别的街景房号数据集。
目标检测和分割数据集
COCO (Common Objects in Context): 用于多种视觉任务,包括目标检测、图像分割和标注。
VOC (Pascal Visual Object Classes): 包括图像分类、目标检测和图像分割任务。
Cityscapes: 用于城市场景理解,包括语义分割和实例分割。
其他
CelebA: 用于面部属性识别的大规模人脸属性数据集。
STL-10: 用于自我监督学习和图像分类的数据集。
Omniglot: 包含多种语言的字符,用于一次学习和其他语言任务。
EMNIST: 扩展的 MNIST 数据集,包括字母和数字。
总结
- 很多的东西,还是要自己系统地了解一下,不然很多东西都不了解,现在知道了。继续弄吧,这都是欠下的技术债。