PyTorch Dataloader工作原理 之 default collate_fn操作

场景设定

假设我们有一个 Dataset,它的 __getitem__ 方法返回一个包含两项的元组:一个 (3, 32, 32) 的图像张量和一个整数标签。

python 复制代码
import torch
from torch.utils.data import Dataset

class MyImageDataset(Dataset):
    def __len__(self):
        return 10 # 数据集里有10个样本

    def __getitem__(self, idx):
        # 模拟返回一个图像张量和一个标签
        # 图像形状: (通道, 高, 宽)
        image_tensor = torch.randn(3, 32, 32)
        label = idx % 2 # 标签是 0 或 1
        return image_tensor, label

dataset = MyImageDataset()

现在,我们创建一个 DataLoader,设置 batch_size=4,并且 指定 collate_fn,这样它就会使用默认的 default_collate

python 复制代码
from torch.utils.data import DataLoader

data_loader = DataLoader(dataset, batch_size=4, shuffle=False)

当我们在 for 循环中迭代 data_loader 时,default_collate 会在幕后执行以下步骤:


default_collate 的工作流程

第 1 步:获取样本列表 (List of Samples)

DataLoader 首先从 dataset 中获取一个批次大小(这里是 4)的样本。它会得到一个列表,列表的每个元素都是 __getitem__ 的返回值。

这个列表(我们称之为 batch_list)看起来像这样:

复制代码
batch_list = [
    ( <Tensor_0 shape=(3,32,32)>, 0 ),  # 第一个样本 (image_0, label_0)
    ( <Tensor_1 shape=(3,32,32)>, 1 ),  # 第二个样本 (image_1, label_1)
    ( <Tensor_2 shape=(3,32,32)>, 0 ),  # 第三个样本 (image_2, label_2)
    ( <Tensor_3 shape=(3,32,32)>, 1 )   # 第四个样本 (image_3, label_3)
]

这是一个长度为 4 的列表,每个元素都是一个元组。

第 2 步:转置列表 (Transpose the List)

这是最关键的一步,也就是"堆叠对应部分 "的第一阶段。default_collate 会"转置"这个列表。它将所有元组的第 0 个元素收集在一起,形成一个新的列表;然后将所有元组的第 1 个元素收集在一起,形成另一个新的列表。

  • 原始结构 : List[Tuple[Image, Label]]
  • 转置后结构 : Tuple[List[Image], List[Label]]

转置后的结果如下:

python 复制代码
# 这是一个概念上的表示,不是实际代码
transposed_batch = (
    # 所有样本的第0个元素 (图像)
    [ <Tensor_0>, <Tensor_1>, <Tensor_2>, <Tensor_3> ],

    # 所有样本的第1个元素 (标签)
    [ 0, 1, 0, 1 ]
)

现在我们得到了一个元组,元组里有两个列表:一个是包含 4 个图像张量的列表,另一个是包含 4 个整数标签的列表。

第 3 步:打包/堆叠每一部分 (Collate/Stack Each Part)

现在,default_collate 会遍历转置后元组中的每一个列表,并尝试将它们"打包"成一个批次张量。

对于第一个列表(图像张量列表):

  • 输入是 [ <Tensor_0 shape=(3,32,32)>, <Tensor_1 shape=(3,32,32)>, ... ]
  • default_collate 发现这些都是 PyTorch 张量,于是它会使用 torch.stack(list_of_tensors, dim=0)
  • torch.stack 会创建一个新的维度(批次维度)在第 0 维,然后将这些张量沿着这个新维度拼接起来。
  • 结果 : 一个单一的张量,形状为 (4, 3, 32, 32)。这里的 4 就是 batch_size

对于第二个列表(标签列表):

  • 输入是 [0, 1, 0, 1]
  • default_collate 发现这些是 Python 的数字。它会先将它们转换成一个张量列表 [tensor(0), tensor(1), tensor(0), tensor(1)]
  • 然后,它同样对这个张量列表使用 torch.stack
  • 结果 : 一个单一的张量,形状为 (4,)
第 4 步:返回最终批次

最后,default_collate 将打包好的各个部分重新组合成一个元组(保持原始的结构),并返回。

所以,你在 for 循环中得到的 batch 变量实际上是:

python 复制代码
# batch[0] 是图像批次, batch[1] 是标签批次
batch = (
    <Tensor shape=(4, 3, 32, 32)>,
    <Tensor shape=(4,)>
)

这就是为什么你可以这样解包:
for images_batch, labels_batch in data_loader:


如果样本是字典呢?

default_collate 同样智能地处理字典。如果你的 __getitem__ 返回一个字典:

python 复制代码
def __getitem__(self, idx):
    return {
        'image': torch.randn(3, 32, 32),
        'label': idx % 2
    }

那么 default_collate 的工作流程会是:

  1. 获取样本列表:

    复制代码
    batch_list = [
        {'image': <Tensor_0>, 'label': 0},
        {'image': <Tensor_1>, 'label': 1},
        {'image': <Tensor_2>, 'label': 0},
        {'image': <Tensor_3>, 'label': 1}
    ]
  2. 转置/重组 : 它会将所有字典的 'image' 值收集到一个列表,将所有 'label' 值收集到另一个列表。

  3. 打包/堆叠:

    • [<Tensor_0>, <Tensor_1>, ...] 被堆叠成一个 (4, 3, 32, 32) 的张量。
    • [0, 1, 0, 1] 被堆叠成一个 (4,) 的张量。
  4. 返回最终批次 : 它会返回一个字典,字典的键和样本字典的键相同,但值是堆叠后的批次张量。

    python 复制代码
    batch = {
        'image': <Tensor shape=(4, 3, 32, 32)>,
        'label': <Tensor shape=(4,)>
    }

总结

default_collate 的"堆叠对应部分"是一个递归的过程:

  1. 它检查一批样本的数据结构(是元组、字典还是其他)。
  2. 它"转置"这个结构,将所有样本的"第一部分"放在一起,所有样本的"第二部分"放在一起,以此类推。
  3. 对于每个集合起来的部分,它会根据其数据类型应用 torch.stack(如果是张量、数字等)或递归地调用 collate 过程(如果是更复杂的嵌套结构)。

这个过程的前提是,所有待堆叠的张量必须具有完全相同的形状 。如果形状不一(例如,变长的文本序列),torch.stack 就会失败,这就是为什么在这种情况下你需要提供一个自定义的 collate_fn 来进行 padding 等操作。

相关推荐
熬夜敲代码的小N18 小时前
仓颉ArrayList动态数组源码分析:从底层实现到性能优化
数据结构·python·算法·ai·性能优化
yumgpkpm18 小时前
Hadoop大数据平台在中国AI时代的后续发展趋势研究CMP(类Cloudera CDP 7.3 404版华为鲲鹏Kunpeng)
大数据·hive·hadoop·python·zookeeper·oracle·cloudera
进击的炸酱面18 小时前
第五章 神经网络
人工智能·深度学习·神经网络
沉默媛19 小时前
如何下载安装以及使用labelme,一个可以打标签的工具,实现数据集处理,详细教程
图像处理·人工智能·python·yolo·计算机视觉
HMS Core19 小时前
【FAQ】HarmonyOS SDK 闭源开放能力 — Push Kit
linux·python·华为·harmonyos
CODE_RabbitV19 小时前
【1min 速通 -- PyTorch 张量数据类型】张量类型的获取、转化与判别
人工智能·pytorch·python
程序猿202319 小时前
Python每日一练---第九天:H指数
开发语言·python
武陵悭臾20 小时前
Python应用开发学习:Pygame中实现切换开关及鼠标拖动连续填充功能
python·学习·程序人生·个人开发·pygame
JELEE.20 小时前
Django中的clean()方法和full_clean()方法
后端·python·django
2401_8414956420 小时前
【LeetCode刷题】移动零
数据结构·python·算法·leetcode·数组·双指针法·移动零