PyTorch的Dataloader模块解析

一、DataLoader核心概念

什么是DataLoader?

Dataloader是PyTorch中用于批量加载数据的工具,它将Dataset包装成一个可迭代对象,支持:

  • 批量处理(batching)
  • 数据打乱(shuffling)
  • 多进程加载(multiprocessing)

二、基本使用参数

python 复制代码
from torch.utils.data import DataLoader, Dataset

dataloader = DataLoader(
    dataset=dataset,      # 数据集对象
    batch_size=32,        # 批量大小
    shuffle=True,         # 是否打乱
    num_workers=4,        # 子进程数,0:主进程加载,调试用,>0:多进程并行加载
    pin_memory=True       # 锁页内存,加速CPU到GPU的数据传输
    drop_last=False       # 是否丢弃最后一个不完整的batch
)

for batch in dataloader:
    inputs, labels = batch
    # 训练代码...

三、工作流程

1、数据加载流程:

python 复制代码
Dataset → Sampler → BatchSampler → DataLoader → 模型
  1. 初始化阶段:创建worker进程;
  2. 索引生成:Sampler生成样本索引;
  3. 批量组织:BatchSampler将索引分组;
  4. 数据加载:Worker进程读取数据;
  5. 数据合并:collate_fn整理batch数据;
  6. 返回结果:将batch数据返回给主进程

2、Sampler和BatchSampler

(1)Sampler-控制样本顺序

python 复制代码
from torch.utils.data import RandomSampler, SequentialSampler

# 随机采样(等价于shuffle=True)
sampler = RandomSampler(dataset)
# 顺序采样(等价于shuffle=False)  
sampler = SequentialSampler(dataset)

dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

(2) 自定义Sampler

python 复制代码
class ImbalancedSampler(Sampler):
    def __init__(self, dataset, minority_class_indices):
        self.dataset = dataset
        self.minority_indices = minority_class_indices
        self.majority_indices = [i for i in range(len(dataset)) 
                               if i not in minority_class_indices]
    
    def __iter__(self):
        # 过采样少数类,保持类别平衡
        indices = []
        # 每个少数类样本采样2次
        for idx in self.minority_indices:
            indices.extend([idx] * 2)
        indices.extend(self.majority_indices)
        random.shuffle(indices)
        return iter(indices)

(3)BatchSampler-控制batch组织方式

python 复制代码
from torch.utils.data import BatchSampler

sampler = RandomSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size=32, drop_last=False)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)

3、collate_fn函数

作用:将多个样本整理成一个batch

python 复制代码
def custom_collate_fn(batch):
    """
    batch: list of (image, label, metadata) tuples
    返回: 整理后的batch数据
    """
    images, labels, metadatas = zip(*batch)
    
    # 处理图像
    images = torch.stack(images, dim=0)
    
    # 处理标签
    labels = torch.tensor(labels)
    
    # 处理变长数据
    metadatas = list(metadatas)  # 保持列表形式
    
    return images, labels, metadatas

dataloader = DataLoader(dataset, collate_fn=custom_collate_fn, batch_size=32)

处理变长序列的collate_fn

python 复制代码
def pad_collate_fn(batch):
    # batch: [(sequence, label), ...]
    sequences, labels = zip(*batch)
    
    # 计算最大长度
    max_len = max(len(seq) for seq in sequences)
    
    # 填充序列
    padded_sequences = []
    for seq in sequences:
        pad_size = max_len - len(seq)
        padded_seq = torch.cat([seq, torch.zeros(pad_size)])
        padded_sequences.append(padded_seq)
    
    return torch.stack(padded_sequences), torch.tensor(labels)

4、多进程数据加载

(1)Worker进程工程流程

python 复制代码
# 主进程
dataloader = DataLoader(dataset, num_workers=4)

# 内部工作:
# 1. 主进程创建4个worker进程
# 2. 每个worker有自己的数据集副本
# 3. Worker预加载多个batch到队列
# 4. 主进程从队列获取batch

(2)常见多进程问题及解决

问题1:内存泄漏

python 复制代码
# 错误:每个epoch都重新创建DataLoader
for epoch in range(epochs):
    dataloader = DataLoader(dataset, num_workers=4)  # 内存泄漏!
    
# 正确:只创建一次
dataloader = DataLoader(dataset, num_workers=4)
for epoch in range(epochs):
    for batch in dataloader:
        # ...

问题2:CUDA在多进程中的使用

python 复制代码
# 错误:在worker进程中初始化CUDA
def bad_worker_init(worker_id):
    torch.cuda.set_device(0)  # 错误!

# 正确:只在主进程使用CUDA
# Worker进程只负责数据加载,不涉及GPU操作

5、性能优化技巧

(1)选择合适的num_workers

python 复制代码
# 测试最佳worker数量
import time

def find_optimal_workers(dataset, max_workers=8):
    for n in range(max_workers + 1):
        loader = DataLoader(dataset, num_workers=n, batch_size=32)
        start = time.time()
        for batch in loader:
            pass
        duration = time.time() - start
        print(f"Workers: {n}, Time: {duration:.2f}s")

(2)使用pin_memory加速

python 复制代码
dataloader = DataLoader(
    dataset, 
    num_workers=4, 
    pin_memory=True,  # GPU训练时开启
    batch_size=32
)

(3)预加载策略

python 复制代码
# 使用prefetch_factor(PyTorch 1.7+)
dataloader = DataLoader(
    dataset,
    num_workers=4,
    prefetch_factor=2,  # 每个worker预加载2个batch
    batch_size=32
)

6、其他问题

(1)DataLoader中的num_workers设置多少合适?

通常设置为CPU核心数的2-4倍,但需要实际测试,设置过多会增加进程切换开销。

(2)shuffle=True是如何工作的?

在每个epoch开始时,RandomSampler会生成新的随机索引序列,实现数据打乱。

(3)多进程数据加载时,如何避免重复数据?

每个worker进程通过sampler获取不同的索引范围,使用不同的随机种子确保数据不重复。

(4)pin_memory为什么能加速?

锁页内存允许DMA直接传输到GPU,避免了额外的内存复制。

(5)如何处理变长数据?

使用自定义的collate_fn进行填充(padding)或打包(packing)

7、实际代码示例

python 复制代码
class AdvancedDataLoader:
    def __init__(self, dataset, config):
        self.dataset = dataset
        self.config = config
        
    def create_dataloader(self, is_training=True):
        # 选择sampler
        if is_training:
            sampler = RandomSampler(self.dataset)
        else:
            sampler = SequentialSampler(self.dataset)
            
        # 创建DataLoader
        dataloader = DataLoader(
            dataset=self.dataset,
            batch_size=self.config.batch_size,
            sampler=sampler,
            num_workers=self.config.num_workers,
            pin_memory=self.config.use_gpu,
            drop_last=is_training,  # 训练时丢弃不完整batch
            collate_fn=self.pad_collate_fn if self.config.pad_sequences else None,
            persistent_workers=self.config.persistent_workers
        )
        return dataloader
相关推荐
是一个Bug1 小时前
Spring Boot 的全局异常处理器
spring boot·后端·python
dTTb1 小时前
python元组和字典
python
秋邱1 小时前
技术深耕:教育 AGI 的能力跃迁与安全加固
大数据·人工智能
一水鉴天1 小时前
整体设计 定稿 之16 三层智能合约体系实现设计和开发的实时融合
前端·人工智能·架构·智能合约
Peter_Monster1 小时前
LangChain到底是什么?
人工智能·langchain·大模型
HAPPY酷1 小时前
技术沟通的底层逻辑:用结构化方法提升协作效率
大数据·人工智能
java_logo1 小时前
Prometheus Docker 容器化部署指南
运维·人工智能·docker·容器·prometheus·ai编程
非著名架构师1 小时前
【光伏功率预测】EMD 分解 + PCA 降维 + LSTM 的联合建模与 Matlab 实现
人工智能·matlab·lstm·高精度光伏功率预测模型
Aspect of twilight1 小时前
KNN分类器与K-means无监督聚类详解
人工智能·机器学习·kmeans·knn