DataLoader的用法

答案来自 千问

这段代码的目的是加载 CIFAR-10 测试数据集,并利用 TensorBoard 将数据加载器(DataLoader)中的图片可视化。


1. 数据集定义

python 复制代码
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
  • 含义:定义了一个数据集对象。
  • 参数解析
    • ./dataset: 数据存放的文件夹。
    • train=False: 指定加载的是测试集(通常包含 10,000 张图片),而不是训练集。
    • transform=ToTensor(): 将图片转换为 PyTorch 张量(Tensor),并将像素值从 0-255 归一化到 0-1。

2. 数据加载器定义

python 复制代码
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
  • 含义:将上面的数据集包装成一个可迭代的加载器,用于批量读取数据。
  • 关键参数
    • batch_size=64: 每次取出 64 张图片打包在一起。
    • shuffle=True: 打乱数据顺序。虽然这是测试集,但这里强制打乱了(通常测试集不打乱以保证复现性,但这代码特意想看随机样本)。
    • drop_last=True: 这一点很重要 。如果测试集的总数(10,000)不能被 batch_size(64)整除,最后剩下的那一小批(不够 64 张的)将被丢弃 。这样可以保证你拿到的每一批数据形状都是固定的 [64, C, H, W]

3. 单独查看一张图片

python 复制代码
img, target = test_data[0]
  • 含义:直接通过索引访问数据集(不是 DataLoader)。
  • 作用 :取出数据集中的第一张图片和对应的标签。
  • 注意 :这里取出的是经过 ToTensor() 处理后的张量。test_data 是一个列表式的对象,支持下标访问。

4. 核心:For 循环与 Writer 部分

python 复制代码
for data in test_loader:
    imgs, targets = data
    writer.add_images("test_data", imgs, step)
    step = step + 1
这个循环会取出所有数据吗?

答案是:是的,它会遍历取出 test_data 中的所有数据(除了可能被丢弃的最后一小批)。

详细运行逻辑:
  1. 遍历机制
    test_loader 是一个迭代器。for 循环开始时,它会向 test_data 索要数据。
  2. 批次打包
    DataLoader 每次从 test_data 中取出 64 张图片(根据 batch_size),将它们堆叠成一个 4 维张量 imgs(形状通常是 [64, 3, 32, 32]),同时将 64 个标签堆叠成 targets
  3. 可视化
    writer.add_images(...) 将这 64 张图片记录到 TensorBoard 的日志中。
  4. 计数
    step 每次加 1。这意味着在 TensorBoard 中,第 1 个批次显示在 Step 0,第 2 个批次显示在 Step 1,以此类推。
  5. 结束
    循环会一直持续,直到 test_loadertest_data 里的所有图片(10,000 张)都取完并处理完毕。

总结

这段代码的整体功能是:

  1. 加载 CIFAR-10 测试集。
  2. 打乱数据顺序。
  3. 每次拿出 64 张图片。
  4. 将这些图片写入 TensorBoard 日志文件。
  5. 运行结束后,你可以打开 TensorBoard 查看到测试集中所有的图片样本(分批次展示)。

DataLoader 是迭代器?

简单直接的回答是: DataLoader 本身不是一个迭代器(Iterator) ,但它是一个可迭代对象(Iterable)

为了让你彻底理解,我们需要区分这两个概念,并看看它们在 for 循环中是如何工作的。

1. 核心概念区分

在 Python 中,这两个概念有严格的定义:

  • 可迭代对象 (Iterable):

    • 定义:只要可以用 for...in... 来循环的对象,都是可迭代对象。
    • 原理:它内部实现了 __iter__() 方法,该方法会返回一个迭代器。
    • 类比:歌单。歌单本身不是播放器,但它包含了歌曲列表。
  • 迭代器 (Iterator):

    • 定义:是具体执行读取操作的对象,它知道如何每次返回一个数据,并在数据耗尽时抛出 StopIteration 异常。
    • 原理:它内部实现了 __iter__()__next__() 方法。
    • 类比:MP3 播放器。它负责实际的"播放下一首"动作。

2. DataLoader 的工作机制

回到你的代码:

python 复制代码
for data in test_loader:
    ...

在这个循环中,发生了以下事情:

  1. test_loader (DataLoader 实例) 首先是一个可迭代对象
  2. for 循环开始时,它调用了 test_loader.__iter__() 方法。
  3. 这个方法会创建并返回一个迭代器(Iterator)。这个迭代器负责实际的数据读取工作(比如从硬盘读取图片、进行批处理)。
  4. 循环每次执行,都会调用该迭代器的 __next__() 方法,从而拿到下一个 batch 的数据。
  5. 当所有批次读完后,迭代器抛出 StopIteration,循环结束。

3. 总结与类比

概念 PyTorch 对应 类比 功能
可迭代对象 DataLoader 歌单 存储数据的列表,告诉播放器有哪些歌。
迭代器 DataLoader.__iter__() 返回的对象 播放器 实际执行播放(读取数据)的动作,每次播一首(一个 Batch)。

结论:

虽然我们在口语中常说"遍历 DataLoader",但从技术上讲,for 循环每次从 DataLoader 中取出的,是它内部生成的"迭代器"所提供的数据。

这也解释了为什么不能 多次直接遍历同一个 DataLoader 对象(除非重新创建),因为一旦迭代器耗尽,DataLoader 就会认为任务完成,除非再次调用 __iter__() 重新生成一个新的迭代器。

相关推荐
若风的雨2 小时前
AI优化控制相关的核心API分类总结
人工智能
工程师老罗2 小时前
PyTorch与TensorBoard兼容性问题解决方案
人工智能·pytorch·python
K姐研究社2 小时前
免费Nano Banana 制作PPT,SpeedAI 智能体一句话生成
人工智能·aigc·powerpoint
爱吃鱼的两包盐2 小时前
轻量化网络简介
人工智能·深度学习
曲幽2 小时前
FastAPI日志实战:从踩坑到优雅配置,让你的应用会“说话”
python·logging·fastapi·web·error·log·info
麦兜*2 小时前
深入解析云原生AI应用全栈架构:从Kubernetes智能调度与Istio服务网格到Knative事件驱动与Prometheus可观测性实战指南
人工智能·云原生·架构
小舞O_o2 小时前
CondaError: Run ‘conda init‘ before ‘conda activate‘
linux·python·conda
张较瘦_2 小时前
[论文阅读] AI | TOFU-D与COD:两款Dialogflow聊天机器人数据集,为质量与安全研究赋能
论文阅读·人工智能·机器人
HansenPole8252 小时前
深度学习基础知识
人工智能·深度学习