一、DataLoader 是什么?
torch.utils.data.DataLoader
是 PyTorch 中用于加载数据的核心接口,它支持:
- 批量读取(batch)
- 数据打乱(shuffle)
- 多线程并行加载(num_workers)
- 自动将数据打包成 batch
- 数据预处理和增强(搭配 Dataset 使用)
二、常见参数详解
参数 | 含义 |
---|---|
dataset |
传入的 Dataset 对象(如自定义或 torchvision.datasets ) |
batch_size |
每个 batch 的样本数量 |
shuffle |
是否打乱数据(通常训练集为 True) |
num_workers |
并行加载数据的线程数(越大越快,但依机器决定) |
drop_last |
是否丢弃最后一个不足 batch_size 的 batch |
pin_memory |
若为 True,会将数据复制到 CUDA 的 page-locked 内存中(加速 GPU 训练) |
collate_fn |
自定义打包 batch 的函数(可用于变长序列、图神经网络等) |
sampler |
控制数据采样策略,不能与 shuffle 同时使用 |
persistent_workers |
若为 True,worker 在 epoch 间保持运行状态(提高效率,PyTorch 1.7+) |
三、基本使用示例
搭配 Dataset 使用
python
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self):
self.data = [i for i in range(100)]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
dataset = MyDataset()
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=2)
for batch in loader:
print(batch)
四、自定义 collate_fn 示例
适用于:变长数据(如文本、点云)或特殊处理需求
python
from torch.nn.utils.rnn import pad_sequence
def my_collate_fn(batch):
# 假设每个样本是 list 或 tensor(变长)
batch = [torch.tensor(item) for item in batch]
padded = pad_sequence(batch, batch_first=True, padding_value=0)
return padded
loader = DataLoader(dataset, batch_size=4, collate_fn=my_collate_fn)
五、使用注意事项
-
Windows 平台注意:
-
设置
num_workers > 0
时,必须使用:pythonif __name__ == '__main__': DataLoader(...)
-
-
过多线程数可能导致瓶颈:
- 通常
num_workers = cpu_count() // 2
较稳定
- 通常
-
GPU 加速:
- 训练时推荐设置
pin_memory=True
可提高 GPU 训练数据传输效率。
- 训练时推荐设置
-
不要同时设置
shuffle=True
和sampler
:- 否则会报错,二者功能冲突。
六、训练中的典型使用方式
python
for epoch in range(num_epochs):
for i, batch in enumerate(train_loader):
inputs, labels = batch
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
七、调试技巧与加速建议
场景 | 建议 |
---|---|
数据加载慢 | 增加 num_workers |
GPU 等数据 | 设置 pin_memory=True |
Dataset 中有耗时操作 | 考虑预处理或使用缓存 |
debug 模式 | 设置 num_workers=0 ,禁用多进程 |
八、与 TensorDataset、ImageFolder 配合
python
from torchvision.datasets import ImageFolder
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
dataset = ImageFolder(root='your/image/folder', transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
九、点云数据处理场景应用实例
在 点云数据处理 场景中,使用 torch.utils.data.DataLoader
时,常遇到如下需求:
- 每帧点云大小不同(变长 Tensor)
- 点云数据 + 标签(如语义、实例)
- 使用
.bin
、.pcd
或.npy
等格式加载 - 数据增强(如旋转、裁剪、噪声)
- GPU 加速 + 批量训练
1. 点云数据 Dataset 示例(以 .npy
文件为例)
python
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class PointCloudDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.files = sorted([f for f in os.listdir(root_dir) if f.endswith('.npy')])
self.transform = transform
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
point_cloud = np.load(os.path.join(self.root_dir, self.files[idx])) # shape: [N, 3] or [N, 6]
point_cloud = torch.tensor(point_cloud, dtype=torch.float32)
if self.transform:
point_cloud = self.transform(point_cloud)
return point_cloud
2. 自定义 collate_fn
(处理变长点云)
python
def collate_pointcloud_fn(batch):
"""
输入: List of [N_i x 3] tensors
输出:
- 合并后的 [B x N_max x 3] tensor
- 每个样本的真实点数 list
"""
max_points = max(pc.shape[0] for pc in batch)
padded = torch.zeros((len(batch), max_points, batch[0].shape[1]))
lengths = []
for i, pc in enumerate(batch):
lengths.append(pc.shape[0])
padded[i, :pc.shape[0], :] = pc
return padded, torch.tensor(lengths)
3. 加载器构建示例
python
dataset = PointCloudDataset("/path/to/your/pointclouds")
loader = DataLoader(
dataset,
batch_size=8,
shuffle=True,
num_workers=4,
pin_memory=True,
collate_fn=collate_pointcloud_fn
)
for batch_points, batch_lengths in loader:
# batch_points: [B, N_max, 3]
# batch_lengths: [B]
print(batch_points.shape)
4. 可选扩展功能
功能 | 实现方法 |
---|---|
点云旋转/缩放 | 自定义 transform (例如随机旋转矩阵乘点云) |
加载 .pcd |
使用 open3d , pypcd , 或 pclpy |
同时加载标签 | 在 Dataset 中返回 (point_cloud, label) ,修改 collate_fn |
voxel downsampling | 使用 open3d.geometry.VoxelDownSample |
GPU 加速 | point_cloud = point_cloud.cuda(non_blocking=True) |
5. 训练循环中使用
python
for epoch in range(num_epochs):
for batch_pc, batch_len in loader:
batch_pc = batch_pc.to(device)
# 可用 batch_len 做 mask 或 attention mask
out = model(batch_pc)
...