DataLoader的用法

答案来自 千问

这段代码的目的是加载 CIFAR-10 测试数据集,并利用 TensorBoard 将数据加载器(DataLoader)中的图片可视化。


1. 数据集定义

python 复制代码
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
  • 含义:定义了一个数据集对象。
  • 参数解析
    • ./dataset: 数据存放的文件夹。
    • train=False: 指定加载的是测试集(通常包含 10,000 张图片),而不是训练集。
    • transform=ToTensor(): 将图片转换为 PyTorch 张量(Tensor),并将像素值从 0-255 归一化到 0-1。

2. 数据加载器定义

python 复制代码
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
  • 含义:将上面的数据集包装成一个可迭代的加载器,用于批量读取数据。
  • 关键参数
    • batch_size=64: 每次取出 64 张图片打包在一起。
    • shuffle=True: 打乱数据顺序。虽然这是测试集,但这里强制打乱了(通常测试集不打乱以保证复现性,但这代码特意想看随机样本)。
    • drop_last=True: 这一点很重要 。如果测试集的总数(10,000)不能被 batch_size(64)整除,最后剩下的那一小批(不够 64 张的)将被丢弃 。这样可以保证你拿到的每一批数据形状都是固定的 [64, C, H, W]

3. 单独查看一张图片

python 复制代码
img, target = test_data[0]
  • 含义:直接通过索引访问数据集(不是 DataLoader)。
  • 作用 :取出数据集中的第一张图片和对应的标签。
  • 注意 :这里取出的是经过 ToTensor() 处理后的张量。test_data 是一个列表式的对象,支持下标访问。

4. 核心:For 循环与 Writer 部分

python 复制代码
for data in test_loader:
    imgs, targets = data
    writer.add_images("test_data", imgs, step)
    step = step + 1
这个循环会取出所有数据吗?

答案是:是的,它会遍历取出 test_data 中的所有数据(除了可能被丢弃的最后一小批)。

详细运行逻辑:
  1. 遍历机制
    test_loader 是一个迭代器。for 循环开始时,它会向 test_data 索要数据。
  2. 批次打包
    DataLoader 每次从 test_data 中取出 64 张图片(根据 batch_size),将它们堆叠成一个 4 维张量 imgs(形状通常是 [64, 3, 32, 32]),同时将 64 个标签堆叠成 targets
  3. 可视化
    writer.add_images(...) 将这 64 张图片记录到 TensorBoard 的日志中。
  4. 计数
    step 每次加 1。这意味着在 TensorBoard 中,第 1 个批次显示在 Step 0,第 2 个批次显示在 Step 1,以此类推。
  5. 结束
    循环会一直持续,直到 test_loadertest_data 里的所有图片(10,000 张)都取完并处理完毕。

总结

这段代码的整体功能是:

  1. 加载 CIFAR-10 测试集。
  2. 打乱数据顺序。
  3. 每次拿出 64 张图片。
  4. 将这些图片写入 TensorBoard 日志文件。
  5. 运行结束后,你可以打开 TensorBoard 查看到测试集中所有的图片样本(分批次展示)。

DataLoader 是迭代器?

简单直接的回答是: DataLoader 本身不是一个迭代器(Iterator) ,但它是一个可迭代对象(Iterable)

为了让你彻底理解,我们需要区分这两个概念,并看看它们在 for 循环中是如何工作的。

1. 核心概念区分

在 Python 中,这两个概念有严格的定义:

  • 可迭代对象 (Iterable):

    • 定义:只要可以用 for...in... 来循环的对象,都是可迭代对象。
    • 原理:它内部实现了 __iter__() 方法,该方法会返回一个迭代器。
    • 类比:歌单。歌单本身不是播放器,但它包含了歌曲列表。
  • 迭代器 (Iterator):

    • 定义:是具体执行读取操作的对象,它知道如何每次返回一个数据,并在数据耗尽时抛出 StopIteration 异常。
    • 原理:它内部实现了 __iter__()__next__() 方法。
    • 类比:MP3 播放器。它负责实际的"播放下一首"动作。

2. DataLoader 的工作机制

回到你的代码:

python 复制代码
for data in test_loader:
    ...

在这个循环中,发生了以下事情:

  1. test_loader (DataLoader 实例) 首先是一个可迭代对象
  2. for 循环开始时,它调用了 test_loader.__iter__() 方法。
  3. 这个方法会创建并返回一个迭代器(Iterator)。这个迭代器负责实际的数据读取工作(比如从硬盘读取图片、进行批处理)。
  4. 循环每次执行,都会调用该迭代器的 __next__() 方法,从而拿到下一个 batch 的数据。
  5. 当所有批次读完后,迭代器抛出 StopIteration,循环结束。

3. 总结与类比

概念 PyTorch 对应 类比 功能
可迭代对象 DataLoader 歌单 存储数据的列表,告诉播放器有哪些歌。
迭代器 DataLoader.__iter__() 返回的对象 播放器 实际执行播放(读取数据)的动作,每次播一首(一个 Batch)。

结论:

虽然我们在口语中常说"遍历 DataLoader",但从技术上讲,for 循环每次从 DataLoader 中取出的,是它内部生成的"迭代器"所提供的数据。

这也解释了为什么不能 多次直接遍历同一个 DataLoader 对象(除非重新创建),因为一旦迭代器耗尽,DataLoader 就会认为任务完成,除非再次调用 __iter__() 重新生成一个新的迭代器。

相关推荐
科技社几秒前
咪咕互娱亮相数字中国峰会:“精品游戏+轻量终端”组合,打开数字娱乐新想象
人工智能
m0_4954964129 分钟前
mysql处理复杂SQL性能_InnoDB优化器与MyISAM差异
jvm·数据库·python
数智化精益手记局1 小时前
拆解物料管理erp系统的核心功能,看物料管理erp系统如何解决库存积压与缺料难题
大数据·网络·人工智能·安全·信息可视化·精益工程
Flying pigs~~1 小时前
RAG 完整面试指南:原理、优化、幻觉解决方案
人工智能·prompt·rag·智能体·检索增强生成·rag优化
博.闻广见1 小时前
AI_概率统计-2.常见分布
人工智能·机器学习
企业架构师老王1 小时前
2026制造业安全生产隐患识别AI方案:从主流产品对比看企业级AI Agent的非侵入式落地路径
人工智能·安全·ai
forEverPlume1 小时前
PHP怎么使用Eloquent Attribute Composition属性组合_Laravel通过组合构建复杂属性【方法】
jvm·数据库·python
Aleeeeex1 小时前
RAG 那点事:从 8 份企业文档到能用的问答系统,全过程拆给你看
人工智能·python·ai编程
冬奇Lab1 小时前
一天一个开源项目(第87篇):Tank-OS —— Red Hat 工程师用一个周末,把 AI Agent 塞进了一个可启动的 Linux 镜像
人工智能·开源·资讯
小糖学代码1 小时前
LLM系列:2.pytorch入门:8.神经网络的损失函数(criterion)
人工智能·深度学习·神经网络