PyTorch DataLoader 高级用法

好的,这是一个关于 PyTorch DataLoadersamplercollate_fn 等参数的非常好的问题。这些参数是 PyTorch 数据加载管道的核心,理解它们能让你高度自定义数据处理流程,以适应各种复杂的任务需求。

我将为你详细解释指定这些参数的方式以及它们之间的区别,内容将涵盖:

  1. DataLoader 的核心工作流程
  2. samplerbatch_sampler:控制数据采样的顺序和方式
    • 默认行为 (shuffle=True/False)
    • 指定 sampler:精细控制样本顺序
    • 指定 batch_sampler:精细控制批次构成
    • 三者之间的关系和互斥性
  3. collate_fn:自定义样本到批次的转换
    • 默认行为 (default_collate)
    • 指定自定义 collate_fn
  4. 总结与最佳实践

1. DataLoader 的核心工作流程

在深入细节之前,我们先理解 DataLoader 在一个 epoch 中是如何工作的:

  1. 启动迭代 : 当你写 for batch in data_loader: 时,DataLoader 的迭代器被创建。
  2. 获取索引 : DataLoader 首先向 sampler (或 batch_sampler ) 请求样本的索引 (indices)。
    • 如果使用 sampler,它会一个一个地返回样本索引。DataLoader 内部会根据 batch_sizedrop_last 将这些索引组成一个批次的索引列表。
    • 如果使用 batch_sampler,它会一次性返回一个完整的、已经组织好的批次索引列表。
  3. 获取数据 : DataLoader 使用上一步得到的批次索引列表,通过 dataset[index] 的方式从你的 Dataset 对象中获取一批数据样本。这时你会得到一个列表,列表中的每个元素都是 Dataset__getitem__ 方法返回的结果。例如 [sample1, sample2, sample3, ...].
  4. 整理批次 : DataLoader 将这个样本列表传递给 collate_fn 函数。
  5. 返回批次 : collate_fn 函数将样本列表处理(例如,堆叠成 Tensor)并返回一个最终的批次(batch)。这个批次就是你在 for 循环中接收到的 batch 变量。

一个简化的工作流程图

现在,我们来详细看 samplercollate_fn


2. samplerbatch_sampler:控制数据采样

sampler 的核心职责是生成一系列索引,决定了从数据集中抽取样本的顺序。

2.1 默认行为 (shuffle)

这是最简单、最常见的方式。你在创建 DataLoader 时,通过 shuffle 参数来控制。

python 复制代码
import torch
from torch.utils.data import TensorDataset, DataLoader

# 创建一个简单的数据集
data = torch.randn(10, 3) # 10个样本,每个样本3个特征
labels = torch.arange(10) # 标签为 0 到 9
dataset = TensorDataset(data, labels)

# 方式一:顺序采样
# shuffle=False (默认)
loader_seq = DataLoader(dataset, batch_size=4, shuffle=False)
print("顺序采样 (shuffle=False):")
for _, batch_labels in loader_seq:
    print(batch_labels.tolist())
# 输出:
# [0, 1, 2, 3]
# [4, 5, 6, 7]
# [8, 9]

# 方式二:随机采样
# shuffle=True
loader_rand = DataLoader(dataset, batch_size=4, shuffle=True)
print("\n随机采样 (shuffle=True):")
for _, batch_labels in loader_rand:
    print(batch_labels.tolist())
# 输出 (每次可能不同):
# [3, 8, 1, 5]
# [0, 9, 2, 6]
# [7, 4]

工作原理:

  • shuffle=False: DataLoader 内部会使用 SequentialSampler,它按照 0, 1, 2, ... 的顺序生成索引。
  • shuffle=True: DataLoader 内部会使用 RandomSampler,它会在每个 epoch 开始时,将所有索引(0len(dataset)-1)随机打乱,然后按打乱后的顺序生成索引。

区别:

  • 这是最上层的抽象,简单直接。
  • 你无法进行更复杂的采样控制,比如类别均衡采样。
2.2 指定 sampler

当你需要比简单的"顺序"或"随机"更复杂的采样策略时,就需要手动创建一个 sampler 对象并传递给 DataLoader

重要 : 当你手动指定 sampler 时,必须将 shuffle 设置为 False (或不设置,默认为 False) 。因为 sampler 已经定义了索引的生成顺序,shuffle=True 会与之冲突。

PyTorch 内置了一些有用的 sampler

  • SequentialSampler: 按顺序采样,等同于 shuffle=False
  • RandomSampler: 随机采样,等同于 shuffle=True
  • SubsetRandomSampler: 在一个给定的索引子集内进行随机采样。常用于交叉验证。
  • WeightedRandomSampler: 根据每个样本的权重进行采样。常用于处理类别不均衡问题。

示例:使用 WeightedRandomSampler 进行类别均衡采样

假设我们有一个不均衡的数据集,类别 '0' 的样本远多于类别 '1'。我们希望在训练时,每个批次中类别 '0' 和 '1' 的样本数量大致相等。

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

class ImbalancedDataset(Dataset):
    def __init__(self):
        # 90个类别0的样本, 10个类别1的样本
        self.data = torch.randn(100, 5)
        self.labels = torch.cat([torch.zeros(90, dtype=torch.long), torch.ones(10, dtype=torch.long)])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

dataset = ImbalancedDataset()

# 计算每个样本的权重
# 类别0的权重: 1 / 90
# 类别1的权重: 1 / 10
class_counts = [90.0, 10.0]
num_samples = sum(class_counts)
class_weights = [num_samples / class_count for class_count in class_counts]

# 为数据集中的每个样本分配权重
sample_weights = [class_weights[label] for label in dataset.labels]
sample_weights = torch.DoubleTensor(sample_weights)

# 创建 WeightedRandomSampler
# num_samples: 每个epoch采样的总数
# replacement=True: 允许重复采样(对于不均衡问题通常需要)
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

# 创建 DataLoader,注意 shuffle=False
# 因为 sampler 已经决定了采样顺序
loader_balanced = DataLoader(dataset, batch_size=10, sampler=sampler)

print("使用 WeightedRandomSampler 进行均衡采样:")
for epoch in range(2):
    print(f"Epoch {epoch+1}")
    for _, batch_labels in loader_balanced:
        print(f"  Batch labels: {batch_labels.tolist()}")
        print(f"  Class counts: 0={torch.sum(batch_labels == 0).item()}, 1={torch.sum(batch_labels == 1).item()}")

# 输出 (每次可能不同, 但类别1的比例会显著提高):
# Epoch 1
#   Batch labels: [1, 1, 1, 0, 1, 0, 1, 0, 1, 0]
#   Class counts: 0=4, 1=6
#   ... (其他批次)
# Epoch 2
#   ...

自定义 sampler : 你还可以通过继承 torch.utils.data.Sampler 并实现 __iter____len__ 方法来创建自己的采样器。

python 复制代码
from torch.utils.data import Sampler
import numpy as np

class EvenOddSampler(Sampler):
    """一个自定义的采样器,先采样所有偶数索引,再采样所有奇数索引"""
    def __init__(self, data_source):
        self.data_source = data_source
        self.even_indices = [i for i in range(len(data_source)) if i % 2 == 0]
        self.odd_indices = [i for i in range(len(data_source)) if i % 2 != 0]

    def __iter__(self):
        # 返回一个索引的迭代器
        return iter(self.even_indices + self.odd_indices)

    def __len__(self):
        return len(self.data_source)

# 使用自定义sampler
dataset_simple = TensorDataset(torch.arange(10))
custom_sampler = EvenOddSampler(dataset_simple)
loader_custom = DataLoader(dataset_simple, batch_size=4, sampler=custom_sampler)

print("\n使用自定义 EvenOddSampler:")
for batch in loader_custom:
    print(batch[0].tolist())
# 输出:
# [0, 2, 4, 6]
# [8, 1, 3, 5]
# [7, 9]
2.3 指定 batch_sampler

batch_sampler 是一个更底层的工具。它不像 sampler 那样一次返回一个索引,而是一次返回一个批次的索引列表

重要 : 当你手动指定 batch_sampler 时,以下参数将被忽略且必须不设置:batch_size, shuffle, sampler, drop_last 。因为 batch_sampler 已经完全接管了批次的形成方式。

这在你需要对批次内的样本构成有特殊要求时非常有用。例如,在 NLP 中,为了减少 padding,你可能希望将长度相近的句子放在同一个批次里。

示例:使用 BatchSampler

BatchSampler 是一个包装器,它接收一个 samplerbatch_sizedrop_last 参数,然后生成批次索引。这实际上是 DataLoader 内部的默认行为。

python 复制代码
from torch.utils.data import BatchSampler, SequentialSampler

dataset_simple = TensorDataset(torch.arange(20))

# 创建一个顺序采样器
seq_sampler = SequentialSampler(dataset_simple)

# 使用 BatchSampler 包装它
batch_sampler = BatchSampler(seq_sampler, batch_size=5, drop_last=False)

# 创建 DataLoader,注意其他参数都不能设置
loader_batch_sampler = DataLoader(dataset_simple, batch_sampler=batch_sampler)

print("\n使用 BatchSampler:")
for batch in loader_batch_sampler:
    print(batch[0].tolist())
# 输出:
# [0, 1, 2, 3, 4]
# [5, 6, 7, 8, 9]
# [10, 11, 12, 13, 14]
# [15, 16, 17, 18, 19]

自定义 batch_sampler :这才是 batch_sampler 真正强大的地方。你可以继承 torch.utils.data.Sampler(注意,BatchSampler 也继承自 Sampler)并实现 __iter__,让它 yield 一个个批次的索引列表。

python 复制代码
class GroupLengthBatchSampler(Sampler):
    """
    一个自定义的BatchSampler,尝试将长度相近的样本分到同一个批次。
    这是一个简化的实现。
    """
    def __init__(self, data_source, batch_size):
        self.data_source = data_source
        self.batch_size = batch_size
        # 假设 data_source 有一个 'lengths' 属性
        # 按长度排序索引
        self.sorted_indices = np.argsort([len(x) for x in self.data_source.texts])

    def __iter__(self):
        # 将排序后的索引分块
        for i in range(0, len(self.sorted_indices), self.batch_size):
            yield self.sorted_indices[i : i + self.batch_size]

    def __len__(self):
        return (len(self.sorted_indices) + self.batch_size - 1) // self.batch_size

# 假设有一个带文本的数据集
class TextDataset(Dataset):
    def __init__(self):
        self.texts = [
            "short", "a bit longer", "very very long sentence", "medium one",
            "tiny", "another medium one", "this is also a very long sentence"
        ]
    def __len__(self): return len(self.texts)
    def __getitem__(self, idx): return self.texts[idx]

text_dataset = TextDataset()
my_batch_sampler = GroupLengthBatchSampler(text_dataset, batch_size=2)

# collate_fn 在这里只是为了打印,后面会详细讲
loader_grouped = DataLoader(text_dataset, batch_sampler=my_batch_sampler, collate_fn=lambda x: x)

print("\n使用自定义 GroupLengthBatchSampler:")
for batch in loader_grouped:
    print(f"Batch: {batch}, Lengths: {[len(s) for s in batch]}")

# 输出 (按长度排序后的批次):
# Batch: ['short', 'tiny'], Lengths: [5, 4]
# Batch: ['medium one', 'a bit longer'], Lengths: [10, 12]
# Batch: ['another medium one', 'very very long sentence'], Lengths: [18, 25]
# Batch: ['this is also a very long sentence'], Lengths: [33]
2.4 三者关系总结
方式 作用 如何工作 互斥参数 适用场景
shuffle=True/False 控制是随机还是顺序采样 内部使用 RandomSamplerSequentialSampler 最简单、最常见的场景。
sampler=... 定义单个样本的抽取顺序 提供一个生成索引序列的迭代器 shuffle 需要复杂采样逻辑,如类别均衡、子集采样等。
batch_sampler=... 定义批次索引列表的生成方式 提供一个生成索引列表的迭代器 batch_size, shuffle, sampler, drop_last 需要控制批次内部的构成,如按长度分组以减少padding。

3. collate_fn:自定义样本到批次的转换

collate_fn 的职责是在 DataLoaderDataset 获取到一个样本列表后,将这个列表整理(collate)成一个批次

3.1 默认行为 (default_collate)

如果你不指定 collate_fnDataLoader 会使用 torch.utils.data.default_collate。它的行为是:

  • 它会尝试将输入样本列表中的每个元素(通常是元组,如 (data, label))的对应部分堆叠(stack)起来。
  • 它能处理 PyTorch Tensors, NumPy arrays, Python numbers 和 strings。
  • 对于 Tensors,它会使用 torch.stack 在第0维(批次维)上进行堆叠。
  • 这要求一个批次内的所有样本都有相同的形状。

默认行为失败的场景 :

当一个批次内的样本形状不同时,default_collate 会失败。最常见的例子是 NLP 中的变长序列或计算机视觉中的不同尺寸图像。

python 复制代码
# 失败的例子
dataset_variable_len = [torch.tensor([1,2,3]), torch.tensor([4,5])]
try:
    # 默认的 collate_fn 无法处理不同长度的 tensor
    loader_fail = DataLoader(dataset_variable_len, batch_size=2)
    for batch in loader_fail:
        pass
except RuntimeError as e:
    print(f"默认 collate_fn 失败: {e}")
# 输出: 默认 collate_fn 失败: stack expects each tensor to be equal size, but got [3] at entry 0 and [2] at entry 1
3.2 指定自定义 collate_fn

为了解决上述问题,你可以提供一个自定义的 collate_fn 函数。这个函数接收一个列表,列表中的每个元素都是 Dataset__getitem__ 的返回值。你需要在这个函数里实现将这个列表转换成一个批次的逻辑。

示例:为变长序列实现 padding

这是 collate_fn 最经典的应用。

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

class VariableLengthDataset(Dataset):
    def __init__(self):
        self.data = [
            (torch.tensor([1, 2, 3]), 0),
            (torch.tensor([4, 5]), 1),
            (torch.tensor([6, 7, 8, 9]), 0),
            (torch.tensor([10]), 1)
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx] # 返回 (sequence, label)

def custom_collate_fn(batch):
    """
    自定义的 collate_fn 函数,用于处理变长序列。
    :param batch: 一个列表,其中每个元素是 Dataset 的 __getitem__ 返回值。
                  例如: [ (tensor([1,2,3]), 0), (tensor([4,5]), 1) ]
    """
    # 1. 将数据和标签分离
    sequences = [item[0] for item in batch]
    labels = [item[1] for item in batch]

    # 2. 对序列进行 padding
    # pad_sequence 会自动将序列填充到该批次中最长序列的长度
    # batch_first=True 表示返回的 tensor 形状为 (batch_size, seq_len)
    padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0)

    # 3. 将标签转换为 Tensor
    labels = torch.LongTensor(labels)

    # 4. 返回处理好的批次
    return padded_sequences, labels

# 创建 DataLoader 并指定自定义的 collate_fn
dataset = VariableLengthDataset()
loader_padded = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=custom_collate_fn)

print("\n使用自定义 collate_fn 进行 padding:")
for seq_batch, label_batch in loader_padded:
    print("Padded Sequences Batch:")
    print(seq_batch)
    print("Labels Batch:")
    print(label_batch)
    print("-" * 20)

# 可能的输出:
# 使用自定义 collate_fn 进行 padding:
# Padded Sequences Batch:
# tensor([[ 6,  7,  8,  9],
#         [10,  0,  0,  0]])
# Labels Batch:
# tensor([0, 1])
# --------------------
# Padded Sequences Batch:
# tensor([[1, 2, 3],
#         [4, 5, 0]])
# Labels Batch:
# tensor([0, 1])
# --------------------

collate_fn 的区别:

  • 它不关心样本的抽取顺序 ,只关心拿到一批样本后如何组合
  • 它的功能与 sampler 是正交的、互补的。你可以同时使用自定义的 sampler 和自定义的 collate_fn

4. 总结与最佳实践

  1. 简单场景 : 如果你只需要顺序或随机打乱数据,并且所有数据样本形状一致,那么直接使用 DataLoadershuffle=True/Falsebatch_size 参数就足够了。

  2. 类别不均衡/特定顺序 : 如果你需要解决类别不均衡问题(使用 WeightedRandomSampler)或按特定规则(如先训练简单样本,后训练难样本)抽取数据,那么你需要自定义 sampler ,并记得设置 shuffle=False

  3. 优化批次内构成 : 如果你想通过将相似长度/大小的样本组合在一起以优化计算效率(例如,减少 NLP 中的 padding 或 CV 中可变尺寸图像的处理开销),那么你需要自定义 batch_sampler 。这是最底层的控制,它会覆盖 batch_size, shuffle, sampler 等参数。

  4. 处理可变数据 : 如果你的数据样本形状不一(如变长文本、不同尺寸的图片),导致默认的堆叠操作失败,那么你需要自定义 collate_fn。在这个函数里,你可以实现 padding、图像缩放等预处理,将一批异构的样本转换成一个规整的 Tensor 批次。

黄金组合 : 在复杂的任务中,你经常会同时使用这些工具。例如,在 NLP 任务中,一个高效的 DataLoader 可能会:

  • 使用一个自定义的 batch_sampler,它首先根据句子长度对所有样本进行粗略分组,然后在每个组内进行随机采样,最后形成批次。这能保证批次内长度相近,同时保留一定的随机性。
  • 使用一个自定义的 collate_fn,它接收 batch_sampler 给出的索引所对应的一批样本,然后对它们进行精确的 padding,并同时处理标签和其它元数据。

通过灵活组合 sampler, batch_samplercollate_fn,你可以为几乎任何数据类型和训练策略构建出高效、定制化的数据加载管道。

相关推荐
每月一号准时摆烂3 小时前
PS基本教学(三)——像素与分辨率的关系以及图片的格式
人工智能·计算机视觉
Lynnxiaowen3 小时前
今天我们开始学习python语句和模块
linux·运维·开发语言·python·学习
song150265372983 小时前
全自动视觉检测设备
人工智能·计算机视觉·视觉检测
2501_906519673 小时前
大语言模型的幻觉问题:机理、评估与抑制路径探析
人工智能
ThreeAu.3 小时前
pytest 实战:用例管理、插件技巧、断言详解
python·单元测试·pytest·测试开发工程师
ZKNOW甄知科技3 小时前
客户案例 | 派克新材x甄知科技,构建全场景智能IT运维体系
大数据·运维·人工智能·科技·低代码·微服务·制造
视觉语言导航3 小时前
CoRL-2025 | SocialNav-SUB:用于社交机器人导航场景理解的视觉语言模型基准测试
人工智能·机器人·具身智能
资源补给站4 小时前
服务器高效操作指南:Python 环境退出与 Linux 终端快捷键全解析
linux·服务器·python
一苓二肆4 小时前
代码加密技术
linux·windows·python·spring·eclipse