torchvision中的数据使用

1、下载数据集

在pytorch官网中找到docs选择Domains,在该页面中有各种数据类型的数据集

在左边菜单栏中选择datasets

python 复制代码
import torchvision
train_set=torchvision.datasets.CIFAR10(root='/data',train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root='./data',train=False,download=True)

2、Dataloader的使用

Dataloader参数介绍

  • dataset :加载的数据集,必须是 torch.utils.data.Dataset 的子类实例。

  • batch_size:每个批次的数据样本数,默认值为1。

  • shuffle:是否在每个周期开始时打乱数据,默认为 False。

  • sampler:定义从数据集中抽取样本的策略,如果指定,则忽略 shuffle 参数。

  • num_workers:用于数据加载的子进程数量,默认为0,表示数据将在主进程中加载。

  • collate_fn:如何将多个数据样本整合成一个批次,通常不需要指定。

  • pin_memory:如果为 True,会将数据放置到 GPU 上去,默认为 False。

  • drop_last:如果数据集大小不能被批次大小整除,是否丢弃最后一个不完整的批次,默认为 False。

python 复制代码
test_loader=DataLoader(dataset=test_set,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
#获取一张图片的信息
img,target=test_set[0]
print(img.shape)
print(target)

writer=SummaryWriter("dataloader")
#taet_loader是一个迭代对象,用for循环进行迭代
step=0
for data in test_loader:
    imgs,targets=data
    # print(imgs.shape)
    # print(targets)
    writer.add_image("test_data",imgs,step,dataformats='NCHW')
    step+=1

writer.close()

添加轮次

python 复制代码
for epoch in range(2):
    step=0
    for data in test_loader:
        imgs,targets=data
        # print(imgs.shape)
        # print(targets)
        writer.add_image("Epoch:{}".format(epoch),imgs,step,dataformats='NCHW')
        step+=1
相关推荐
Dfreedom.21 小时前
图像滤波:非线性滤波与边缘保留技术
图像处理·人工智能·opencv·计算机视觉·非线性滤波·图像滤波
小白跃升坊1 天前
基于1Panel的AI运维
linux·运维·人工智能·ai大模型·教学·ai agent
kicikng1 天前
走在智能体前沿:智能体来了(西南总部)的AI Agent指挥官与AI调度官实践
人工智能·系统架构·智能体协作·ai agent指挥官·ai调度官·应用层ai
测试者家园1 天前
测试用例智能生成:是效率革命,还是“垃圾进,垃圾出”的新挑战?
人工智能·职场和发展·测试用例·测试策略·质量效能·智能化测试·用例设计
GIS瞧葩菜1 天前
Cesium 轴拖拽 + 旋转圈拖拽 核心数学知识
人工智能·算法·机器学习
njsgcs1 天前
dqn和cnn有什么区别 dqn怎么保存训练经验到本地
人工智能·神经网络·cnn
AndrewHZ1 天前
【AI黑话日日新】什么是AI智能体?
人工智能·算法·语言模型·大模型·llm·ai智能体
cd_949217211 天前
九昆仑低碳科技:所罗门群岛全国森林碳汇项目开发合作白皮书
大数据·人工智能·科技
工程师老罗1 天前
目标检测数据标注的工具与使用方法
人工智能·目标检测·计算机视觉
yuankoudaodaokou1 天前
高校科研新利器:思看科技三维扫描仪助力精密研究
人工智能·python·科技