DataLoader的使用

示例代码:

python 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())

# batch_size=4 取test_data[0]到test_data[3] 返回 打包好的img0-3, 打包好的target0-3(shuffle=True随机抓取)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)

writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
    imgs, targets = data
    # print(imgs.shape)
    # print(targets)
    writer.add_images("test_data_drop_last", imgs, step)
    step = step+1

writer.close()
python 复制代码
# batch_size=4 取test_data[0]到test_data[3] 返回 打包好的img0-3, 打包好的target0-3(随机抓取)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

设置drop_last=False 结果,最后一步不足64张任然进行了保留

设置drop_last=True后

最后一步不足64张进行了舍去,所以只有155步

相关推荐
逻辑驱动的ken25 分钟前
Java高频面试考点场景题30
java·开发语言·深度学习·面试·职场和发展
AI人工智能+40 分钟前
营业执照识别技术通过计算机视觉与人工智能技术,实现企业证照信息的自动化采集
人工智能·深度学习·ocr·营业执照识别
七夜zippoe1 小时前
OpenClaw 上下文管理:Token 优化策略
大数据·人工智能·深度学习·token·openclaw
web守墓人2 小时前
【深度学习】Pytorch gpu加速原理探究
人工智能·pytorch·深度学习
沪漂阿龙2 小时前
面试题:循环神经网络(RNN)是什么?词嵌入、时序建模、梯度消失、LSTM/GRU 一文讲透
人工智能·rnn·深度学习·gru·lstm
坐望云起2 小时前
机器学习笔记 - 基于C++的深度学习 四、实现梯度下降
笔记·深度学习·机器学习
源码之家2 小时前
计算机毕业设计:Python基于知识图谱的医疗问答系统 Neo4j 机器学习 BERT 深度学习 ECharts(建议收藏)✅
python·深度学习·机器学习·信息可视化·数据分析·知识图谱·课程设计
沪漂阿龙2 小时前
面试题:传统序列模型详解——RNN、LSTM、GRU 原理、区别、优缺点一文讲透
人工智能·rnn·深度学习·gru·lstm
栈溢出了2 小时前
GAT(Graph Attention Network)学习笔记
人工智能·深度学习·算法·机器学习
:mnong3 小时前
论文研读:基于深度学习的制造成本估算特征可视化研究
人工智能·深度学习·制造