Pytorch学习--DataLoader的使用

一、DataLoader简介

DataLoader官网

重要参数:画红框的参数

  • dataset:
    • 作用:表示要加载的数据集。DataLoader通过该参数从数据集中读取数据。
    • 类型:Dataset,即PyTorch定义的Dataset类,用于封装数据并提供数据索引的功能。
  • batch_size:
    • 作用:指定每次加载的数据样本数量(即每个批次的数据量)。默认值为1。
    • 类型:int(可选),默认为1。设置为大于1的值时,可以加速训练,因为数据将被批量处理。
  • shuffle:
    • 作用:是否在每个epoch结束后打乱数据顺序。如果设置为True,数据会在每个epoch重新随机排列。默认值是False,即数据不打乱。
    • 类型:bool(可选),是否打乱数据。
  • sampler:
    • 作用:定义从数据集中提取数据的策略。可以传入一个Sampler类的实例,自定义数据抽样的方式。注意,如果指定了sampler,则不能再使用shuffle。
    • 类型:Sampler或Iterable(可选),用于控制数据抽样。
  • batch_sampler:
    • 作用:与sampler类似,但batch_sampler返回的是一批次的索引,而不是单个样本索引。此参数与batch_size、shuffle和drop_last互斥,不能同时使用。
    • 类型:Sampler或Iterable(可选),专门用于批次索引的抽样。
  • num_workers:
    • 作用:指定用于数据加载的子进程数量。0表示在主进程中进行数据加载。较大的值可以加速数据加载,但需要在进程间共享数据。
    • 类型:int(可选),默认为0。
  • drop_last:
    • 作用:是否丢弃最后一个未满批次的数据。当数据集的大小不能整除batch_size时,最后一个批次的大小可能会小于batch_size。如果将drop_last设为True,则丢弃这个不完整的批次。
    • 类型:bool(可选),默认为False。

二、代码初识

python 复制代码
import torchvision.datasets
from torch.utils.data import DataLoader

train_data=torchvision.datasets.CIFAR10(root="datasets",train=False,transform=torchvision.transforms.ToTensor(),download=True)
train_loader=DataLoader(dataset=train_data,batch_size=4,shuffle=True)

img,target=train_data[0]
print(img.shape)
print(target)

for data in train_loader:
    imgs,targets=data
    print(imgs.shape)
    print(targets)

因为这里采取的是随机抽样

三、使用tensorboard可视化

python 复制代码
import torchvision.datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

train_data=torchvision.datasets.CIFAR10(root="datasets",train=False,transform=torchvision.transforms.ToTensor(),download=True)
#shuffle会在epoch中表现出来
train_loader=DataLoader(dataset=train_data,batch_size=4,shuffle=True)


img,target=train_data[0]
writer=SummaryWriter("logs")

step=0

for epoch in range(2):
    for data in train_loader:
        imgs,targets=data
        #注意:这里是add_images,不是add_image
        writer.add_images("epoch{}".format(epoch),imgs,step)
        step+=1
writer.close()
相关推荐
searchforAI1 分钟前
B站视频怎么转文字稿?AI自动总结要点+生成思维导图教程
人工智能·笔记·学习·ai·语音识别·知识管理·视频总结
“码”力全开3 分钟前
【架构深探】基于Docker与GB28181/RTSP的边缘计算AI视频管理平台:异构算力调度与源码交付实践
人工智能·docker·架构
伶俜663 分钟前
鸿蒙实战(二) ArkUI AI 相机:从零实现实时滤镜与人脸贴纸
人工智能·数码相机
FBI HackerHarry浩4 分钟前
修改Pycharm2023.2.5连接数据库创建的SQL文件保存的默认位置
python·pycharm
老徐聊GEO4 分钟前
AI搜索获客:亲测有效的实践案例分享
大数据·人工智能·python
只做人间不老仙5 分钟前
C++ grpc 拦截器示例学习
开发语言·c++·学习
用户337922545685 分钟前
从字节跳动 DeerFlow 源码看 Agent 平台设计(一):什么是 Agent?一个成熟 Agent 平台的 8 个核心组件
人工智能
踏着七彩祥云的小丑6 分钟前
Go学习第7天:Map集合 + 递归函数 + 类型转换
开发语言·学习·golang·go
fan65404146 分钟前
本地服务企业GEO优化中的跨平台信息一致性校验工具设计
人工智能
一切皆是因缘际会7 分钟前
LLM温度Temperature底层采样机理
人工智能·机器学习·ai·架构