Pytorch个人学习记录总结 04

目录

torchvision

DataLoader


torchvision

transforams是对单张图片进行处理,而制作数据集的时候,是需要对图像进行批量处理的。因此本节是将torchvision中的datasetstransforms联合使用对数据集进行预处理操作。

  1. (torchvision官方文档地址:torchvision --- Torchvision 0.15 documentation

  2. torchvision.datasets中提供了内置数据集和自定义数据集所需的函数(DatasetFolder、ImageFolder、VisionDataset)(torchvision.datasets官方文档地址:Datasets --- Torchvision 0.15 documentation

  3. torchvision.models中包含了已经训练好的图像分类、图像分割、目标检测的神经网络模型。(torchvision.models的官方文档地址:Models and pre-trained weights --- Torchvision 0.15 documentation

  4. torchvision.transforms对图像进行转换和增强(torchvision.transforms的官方文档地址:Transforming and augmenting images --- Torchvision 0.15 documentation

  5. torchvision.utils包含各种实用工具,主要用于可视化(tensorboard是在torch.utils.tensorboard中)(torchvision.utils的官方文档地址:Utils --- Torchvision 0.15 documentation

    python 复制代码
    import torchvision
    from torch.utils.tensorboard import SummaryWriter
    from torchvision.transforms import transforms
    
    # 1. 用transforms设置图片转换方式
    data_transform = transforms.Compose([  # 用Compose将所有转换操作集合起来
        transforms.ToTensor()  # 因为CIFAR10数据集的每张图像size=(32,32)比较小,所以只进行ToTensor的操作
    ])
    
    # 2. 加载内置数据集CIFAR10,并设置transforms(download最好一直设置成True)
    #   1. root:(若要下载的话)表示数据集存放的根目录
    #   2. train=True 或者 False,分别表示是构造训练集train_set还是测试集test_set
    #   3. transform = data_transform,用自定义的data_transform对数据集中的每张图像进行预处理
    #   4. download=True 或者 False,分别表示是否从网上下载数据集到root中(如果root下已有数据集,尽管设置成True也不会再下载了,所以download最好一直设置成True)
    train_set = torchvision.datasets.CIFAR10('./dataset', train=True, transform=data_transform, download=True)
    test_set = torchvision.datasets.CIFAR10('./dataset', train=False, transform=data_transform, download=True)
    
    # 3. 写进tensorboard查看
    writer = SummaryWriter('CIFAR10')
    for i in range(10):
        img, label = test_set[i]    # test_set[i]返回的依次是图像(PIL.Image)和类别(int)
        writer.add_image('test_set', img, i)
    
    writer.close()

    DataLoader

官方文档地址:torch.utils.data.DataLoader

python 复制代码
CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, 
	sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, 
	pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, 
	multiprocessing_context=None, generator=None, *, prefetch_factor=2, 
	persistent_workers=False)

除了dataset(指明数据集的位置)之外的参数都设置了默认值。

torch.utils.data.DataLoader重点关注的参数有:

  • dataset (Dataset) :指明从哪个数据集加载数据(如上节中自定义的train_set
  • batch_size (int):每个批次(batch)加载多少样本。
  • shuffle (bool) :每轮(epoch)是否打乱样本的顺序。(最好设置成True)
  • num_workers (int) :有多少个子流程用于数据加载。0表示主进程加载。
  • (在Windows下只能设置成0,不然会出错!虽然default=0,但是最好还是手动再设置一下num_workers=0)
  • drop_last (bool) :如果数据集大小不能被batch_size整除,则最后一个批次将会不完整(即样本数<batch_size)。设置为True则删掉最后一个batch,False则保留(默认为False,即会保存最后那个不完整的批次)。
相关推荐
HyperAI超神经1 分钟前
软银/英伟达/红杉资本/贝佐斯等参投,机器人初创公司Skild AI融资14亿美元,打造通用基础模型
人工智能·深度学习·机器学习·机器人·ai编程
数说星榆1815 分钟前
边缘计算革命:终端设备的本地化智能
人工智能·边缘计算
菜菜小狗的学习笔记9 分钟前
黑马程序员java web学习笔记--后端进阶(一)AOP
java·笔记·学习
墨染天姬10 分钟前
【AI】KIMI2.5---开源榜第一
人工智能·开源
智驱力人工智能12 分钟前
实线变道检测 高架道路安全治理的工程化实践 隧道压实线监测方案 城市快速路压实线实时预警 压实线与车牌识别联动方案
人工智能·opencv·算法·安全·yolo·边缘计算
萤丰信息15 分钟前
智慧园区:以技术赋能,构筑安全便捷的现代化生态空间
大数据·人工智能·科技·安全·智慧城市·智慧园区
码农三叔18 分钟前
(7-3-01)电机与执行器系统:驱动器开发与控制接口(1)电机驱动电路+编码器与反馈
人工智能·单片机·嵌入式硬件·架构·机器人·人形机器人
光羽隹衡21 分钟前
计算机视觉--Opencv(模板匹配)
人工智能·opencv·计算机视觉
互联科技报22 分钟前
2026Ai短视频工具市场报告:行业规模、占有率及内容特工队AI排名
人工智能
小马爱打代码24 分钟前
Spring AI 进阶:RAG 技术原理拆解与本地知识库检索落地
人工智能·深度学习·spring