PyTorch中 torch.utils.data.DataLoader 的详细解析和读取点云数据示例

一、DataLoader 是什么?

torch.utils.data.DataLoader 是 PyTorch 中用于加载数据的核心接口,它支持:

  • 批量读取(batch)
  • 数据打乱(shuffle)
  • 多线程并行加载(num_workers)
  • 自动将数据打包成 batch
  • 数据预处理和增强(搭配 Dataset 使用)

二、常见参数详解

参数 含义
dataset 传入的 Dataset 对象(如自定义或 torchvision.datasets
batch_size 每个 batch 的样本数量
shuffle 是否打乱数据(通常训练集为 True)
num_workers 并行加载数据的线程数(越大越快,但依机器决定)
drop_last 是否丢弃最后一个不足 batch_size 的 batch
pin_memory 若为 True,会将数据复制到 CUDA 的 page-locked 内存中(加速 GPU 训练)
collate_fn 自定义打包 batch 的函数(可用于变长序列、图神经网络等)
sampler 控制数据采样策略,不能与 shuffle 同时使用
persistent_workers 若为 True,worker 在 epoch 间保持运行状态(提高效率,PyTorch 1.7+)

三、基本使用示例

搭配 Dataset 使用

python 复制代码
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self):
        self.data = [i for i in range(100)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

dataset = MyDataset()
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=2)

for batch in loader:
    print(batch)

四、自定义 collate_fn 示例

适用于:变长数据(如文本、点云)或特殊处理需求

python 复制代码
from torch.nn.utils.rnn import pad_sequence

def my_collate_fn(batch):
    # 假设每个样本是 list 或 tensor(变长)
    batch = [torch.tensor(item) for item in batch]
    padded = pad_sequence(batch, batch_first=True, padding_value=0)
    return padded

loader = DataLoader(dataset, batch_size=4, collate_fn=my_collate_fn)

五、使用注意事项

  1. Windows 平台注意:

    • 设置 num_workers > 0 时,必须使用:

      python 复制代码
      if __name__ == '__main__':
          DataLoader(...)
  2. 过多线程数可能导致瓶颈:

    • 通常 num_workers = cpu_count() // 2 较稳定
  3. GPU 加速:

    • 训练时推荐设置 pin_memory=True 可提高 GPU 训练数据传输效率。
  4. 不要同时设置 shuffle=Truesampler

    • 否则会报错,二者功能冲突。

六、训练中的典型使用方式

python 复制代码
for epoch in range(num_epochs):
    for i, batch in enumerate(train_loader):
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

七、调试技巧与加速建议

场景 建议
数据加载慢 增加 num_workers
GPU 等数据 设置 pin_memory=True
Dataset 中有耗时操作 考虑预处理或使用缓存
debug 模式 设置 num_workers=0,禁用多进程

八、与 TensorDataset、ImageFolder 配合

python 复制代码
from torchvision.datasets import ImageFolder
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = ImageFolder(root='your/image/folder', transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

九、点云数据处理场景应用实例

点云数据处理 场景中,使用 torch.utils.data.DataLoader 时,常遇到如下需求:

  • 每帧点云大小不同(变长 Tensor)
  • 点云数据 + 标签(如语义、实例)
  • 使用 .bin.pcd.npy 等格式加载
  • 数据增强(如旋转、裁剪、噪声)
  • GPU 加速 + 批量训练

1. 点云数据 Dataset 示例(以 .npy 文件为例)

python 复制代码
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

class PointCloudDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.files = sorted([f for f in os.listdir(root_dir) if f.endswith('.npy')])
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        point_cloud = np.load(os.path.join(self.root_dir, self.files[idx]))  # shape: [N, 3] or [N, 6]
        point_cloud = torch.tensor(point_cloud, dtype=torch.float32)

        if self.transform:
            point_cloud = self.transform(point_cloud)

        return point_cloud

2. 自定义 collate_fn(处理变长点云)

python 复制代码
def collate_pointcloud_fn(batch):
    """
    输入: List of [N_i x 3] tensors
    输出: 
        - 合并后的 [B x N_max x 3] tensor
        - 每个样本的真实点数 list
    """
    max_points = max(pc.shape[0] for pc in batch)
    padded = torch.zeros((len(batch), max_points, batch[0].shape[1]))
    lengths = []

    for i, pc in enumerate(batch):
        lengths.append(pc.shape[0])
        padded[i, :pc.shape[0], :] = pc

    return padded, torch.tensor(lengths)

3. 加载器构建示例

python 复制代码
dataset = PointCloudDataset("/path/to/your/pointclouds")

loader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_pointcloud_fn
)

for batch_points, batch_lengths in loader:
    # batch_points: [B, N_max, 3]
    # batch_lengths: [B]
    print(batch_points.shape)

4. 可选扩展功能

功能 实现方法
点云旋转/缩放 自定义 transform(例如随机旋转矩阵乘点云)
加载 .pcd 使用 open3d, pypcd, 或 pclpy
同时加载标签 在 Dataset 中返回 (point_cloud, label),修改 collate_fn
voxel downsampling 使用 open3d.geometry.VoxelDownSample
GPU 加速 point_cloud = point_cloud.cuda(non_blocking=True)

5. 训练循环中使用

python 复制代码
for epoch in range(num_epochs):
    for batch_pc, batch_len in loader:
        batch_pc = batch_pc.to(device)
        # 可用 batch_len 做 mask 或 attention mask
        out = model(batch_pc)
        ...

相关推荐
Dylan的码园8 小时前
稀疏 MoE 与原生多模态双驱:2025 大模型技术演进全景
人工智能·机器学习·ai作画·数据挖掘·boosting·oneflow
_-CHEN-_8 小时前
Prompt Manager: 让你的 AI 提示词管理更专业
人工智能·prompt
weixin_397578028 小时前
Transformer 架构 “Attention Is All You Need“
人工智能
檀越剑指大厂8 小时前
AI 当主程还能远程开发?TRAE SOLO 的实用体验与cpolar内网突破
人工智能
河码匠8 小时前
Django rest framework 自定义url
后端·python·django
哥只是传说中的小白8 小时前
无需验证手机Sora2也能用!视频生成,创建角色APi接入教程,开发小白也能轻松接入
数据库·人工智能
cnxy1888 小时前
Python Web开发新时代:FastAPI vs Django性能对比
前端·python·fastapi
lkbhua莱克瓦248 小时前
参数如何影响着大语言模型
人工智能·llm·大语言模型
neardi临滴科技8 小时前
从算法逻辑到芯端落地:YOLO 目标检测的进化与瑞芯微实践
算法·yolo·目标检测
小雨下雨的雨8 小时前
Flutter跨平台开发实战:鸿蒙系列-循环交互艺术系列——瀑布流:不规则网格的循环排布算法
算法·flutter·华为·交互·harmonyos·鸿蒙系统