DataLoader的用法

答案来自 千问

这段代码的目的是加载 CIFAR-10 测试数据集,并利用 TensorBoard 将数据加载器(DataLoader)中的图片可视化。


1. 数据集定义

python 复制代码
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
  • 含义:定义了一个数据集对象。
  • 参数解析
    • ./dataset: 数据存放的文件夹。
    • train=False: 指定加载的是测试集(通常包含 10,000 张图片),而不是训练集。
    • transform=ToTensor(): 将图片转换为 PyTorch 张量(Tensor),并将像素值从 0-255 归一化到 0-1。

2. 数据加载器定义

python 复制代码
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
  • 含义:将上面的数据集包装成一个可迭代的加载器,用于批量读取数据。
  • 关键参数
    • batch_size=64: 每次取出 64 张图片打包在一起。
    • shuffle=True: 打乱数据顺序。虽然这是测试集,但这里强制打乱了(通常测试集不打乱以保证复现性,但这代码特意想看随机样本)。
    • drop_last=True: 这一点很重要 。如果测试集的总数(10,000)不能被 batch_size(64)整除,最后剩下的那一小批(不够 64 张的)将被丢弃 。这样可以保证你拿到的每一批数据形状都是固定的 [64, C, H, W]

3. 单独查看一张图片

python 复制代码
img, target = test_data[0]
  • 含义:直接通过索引访问数据集(不是 DataLoader)。
  • 作用 :取出数据集中的第一张图片和对应的标签。
  • 注意 :这里取出的是经过 ToTensor() 处理后的张量。test_data 是一个列表式的对象,支持下标访问。

4. 核心:For 循环与 Writer 部分

python 复制代码
for data in test_loader:
    imgs, targets = data
    writer.add_images("test_data", imgs, step)
    step = step + 1
这个循环会取出所有数据吗?

答案是:是的,它会遍历取出 test_data 中的所有数据(除了可能被丢弃的最后一小批)。

详细运行逻辑:
  1. 遍历机制
    test_loader 是一个迭代器。for 循环开始时,它会向 test_data 索要数据。
  2. 批次打包
    DataLoader 每次从 test_data 中取出 64 张图片(根据 batch_size),将它们堆叠成一个 4 维张量 imgs(形状通常是 [64, 3, 32, 32]),同时将 64 个标签堆叠成 targets
  3. 可视化
    writer.add_images(...) 将这 64 张图片记录到 TensorBoard 的日志中。
  4. 计数
    step 每次加 1。这意味着在 TensorBoard 中,第 1 个批次显示在 Step 0,第 2 个批次显示在 Step 1,以此类推。
  5. 结束
    循环会一直持续,直到 test_loadertest_data 里的所有图片(10,000 张)都取完并处理完毕。

总结

这段代码的整体功能是:

  1. 加载 CIFAR-10 测试集。
  2. 打乱数据顺序。
  3. 每次拿出 64 张图片。
  4. 将这些图片写入 TensorBoard 日志文件。
  5. 运行结束后,你可以打开 TensorBoard 查看到测试集中所有的图片样本(分批次展示)。

DataLoader 是迭代器?

简单直接的回答是: DataLoader 本身不是一个迭代器(Iterator) ,但它是一个可迭代对象(Iterable)

为了让你彻底理解,我们需要区分这两个概念,并看看它们在 for 循环中是如何工作的。

1. 核心概念区分

在 Python 中,这两个概念有严格的定义:

  • 可迭代对象 (Iterable):

    • 定义:只要可以用 for...in... 来循环的对象,都是可迭代对象。
    • 原理:它内部实现了 __iter__() 方法,该方法会返回一个迭代器。
    • 类比:歌单。歌单本身不是播放器,但它包含了歌曲列表。
  • 迭代器 (Iterator):

    • 定义:是具体执行读取操作的对象,它知道如何每次返回一个数据,并在数据耗尽时抛出 StopIteration 异常。
    • 原理:它内部实现了 __iter__()__next__() 方法。
    • 类比:MP3 播放器。它负责实际的"播放下一首"动作。

2. DataLoader 的工作机制

回到你的代码:

python 复制代码
for data in test_loader:
    ...

在这个循环中,发生了以下事情:

  1. test_loader (DataLoader 实例) 首先是一个可迭代对象
  2. for 循环开始时,它调用了 test_loader.__iter__() 方法。
  3. 这个方法会创建并返回一个迭代器(Iterator)。这个迭代器负责实际的数据读取工作(比如从硬盘读取图片、进行批处理)。
  4. 循环每次执行,都会调用该迭代器的 __next__() 方法,从而拿到下一个 batch 的数据。
  5. 当所有批次读完后,迭代器抛出 StopIteration,循环结束。

3. 总结与类比

概念 PyTorch 对应 类比 功能
可迭代对象 DataLoader 歌单 存储数据的列表,告诉播放器有哪些歌。
迭代器 DataLoader.__iter__() 返回的对象 播放器 实际执行播放(读取数据)的动作,每次播一首(一个 Batch)。

结论:

虽然我们在口语中常说"遍历 DataLoader",但从技术上讲,for 循环每次从 DataLoader 中取出的,是它内部生成的"迭代器"所提供的数据。

这也解释了为什么不能 多次直接遍历同一个 DataLoader 对象(除非重新创建),因为一旦迭代器耗尽,DataLoader 就会认为任务完成,除非再次调用 __iter__() 重新生成一个新的迭代器。

相关推荐
江瀚视野几秒前
多家银行向甲骨文断贷,巨头甲骨文这是怎么了?
大数据·人工智能
码界筑梦坊1 分钟前
325-基于Python的校园卡消费行为数据可视化分析系统
开发语言·python·信息可视化·django·毕业设计
ccLianLian3 分钟前
计算机基础·cs336·损失函数,优化器,调度器,数据处理和模型加载保存
人工智能·深度学习·计算机视觉·transformer
asheuojj4 分钟前
2026年GEO优化获客效果评估指南:如何精准衡量TOP5关
大数据·人工智能·python
多恩Stone4 分钟前
【RoPE】Flux 中的 Image Tokenization
开发语言·人工智能·python
callJJ6 分钟前
Spring AI ImageModel 完全指南:用 OpenAI DALL-E 生成图像
大数据·人工智能·spring·openai·springai·图像模型
铁蛋AI编程实战8 分钟前
2026 大模型推理框架测评:vLLM 0.5/TGI 2.0/TensorRT-LLM 1.8/DeepSpeed-MII 0.9 性能与成本防线对比
人工智能·机器学习·vllm
23遇见9 分钟前
CANN ops-nn 仓库高效开发指南:从入门到精通
人工智能
SAP工博科技9 分钟前
SAP 公有云 ERP 多工厂多生产线数据统一管理技术实现解析
大数据·运维·人工智能
芷栀夏12 分钟前
CANN ops-math:异构计算场景下基础数学算子的深度优化与硬件亲和设计解析
人工智能·cann