DataLoader简介

DataLoader 是 PyTorch 中一个非常核心的数据加载工具,它的主要作用是将数据集(Dataset)包装成一个可迭代的对象,为模型训练提供批量、打乱、多进程等数据服务。

简单来说,DataLoader 就像一个"数据分餐机",自动帮你把大量数据分成一个个小批次(batch),送到模型面前。

🎯 核心作用

功能 说明 重要性
批量化 将数据分成固定大小的小批次 ⭐⭐⭐ 必须
打乱数据 每个epoch随机打乱数据顺序 ⭐⭐⭐ 重要
多进程加载 使用多个子进程并行加载数据 ⭐⭐ 提速用
自动索引 自动从Dataset中按索引取数据 ⭐⭐⭐ 自动化

💻 基本用法

1. 最简单的例子

python 复制代码
import torch
from torch.utils.data import DataLoader, TensorDataset

# 准备数据
X = torch.randn(100, 5)   # 100个样本,每个5个特征
y = torch.randn(100, 1)   # 100个标签

# 创建 Dataset 和 DataLoader
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# 使用 DataLoader
for batch_X, batch_y in dataloader:
    print(f"批次特征形状: {batch_X.shape}")  # torch.Size([16, 5])
    print(f"批次标签形状: {batch_y.shape}")  # torch.Size([16, 1])
    break  # 只打印第一个批次

2. 完整训练循环示例

python 复制代码
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# 1. 准备数据
X = torch.randn(1000, 10)   # 1000个样本,10个特征
y = torch.randn(1000, 1)    # 1000个标签
dataset = TensorDataset(X, y)

# 2. 创建 DataLoader
dataloader = DataLoader(
    dataset, 
    batch_size=32,      # 每批32个样本
    shuffle=True,       # 打乱顺序
    num_workers=2,      # 使用2个进程加载数据(Windows下建议设为0)
    drop_last=True      # 丢弃最后不足一批的数据
)

# 3. 简单模型
model = nn.Linear(10, 1)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 4. 训练循环
epochs = 5
for epoch in range(epochs):
    total_loss = 0
    for batch_X, batch_y in dataloader:
        # 前向传播
        pred = model(batch_X)
        loss = criterion(pred, batch_y)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}, 平均损失: {avg_loss:.4f}")

🔧 重要参数详解

参数 含义 常用值 示例
batch_size 每批样本数量 16, 32, 64, 128 batch_size=32
shuffle 是否打乱数据 True(训练集),False(测试集) shuffle=True
num_workers 并行加载进程数 0(Windows),2-8(Linux/Mac) num_workers=4
drop_last 丢弃最后不足一批的数据 True(避免批次大小不整),False(保留所有) drop_last=True
pin_memory 锁页内存,加速GPU传输 True(使用GPU时) pin_memory=True

📊 实际使用示例

训练集 vs 测试集

python 复制代码
# 划分训练集和测试集
n_samples = 1000
X = torch.randn(n_samples, 10)
y = torch.randn(n_samples, 1)

train_size = int(0.8 * n_samples)
test_size = n_samples - train_size

train_X, test_X = X[:train_size], X[train_size:]
train_y, test_y = y[:train_size], y[train_size:]

# 创建两个 DataLoader
train_dataset = TensorDataset(train_X, train_y)
test_dataset = TensorDataset(test_X, test_y)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)  # 测试集不打乱

print(f"训练批次数: {len(train_loader)}")  # ceil(800/32) = 25
print(f"测试批次数: {len(test_loader)}")   # ceil(200/64) = 4

查看数据分布

python 复制代码
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 查看一个epoch的批次数
num_batches = len(dataloader)
print(f"总样本数: {len(dataset)}")
print(f"Batch大小: {dataloader.batch_size}")
print(f"一个epoch的批次数: {num_batches}")

# 遍历查看每个批次的实际大小
for i, (batch_X, batch_y) in enumerate(dataloader):
    print(f"批次 {i}: X形状 {batch_X.shape}, y形状 {batch_y.shape}")

⚙️ num_workers 注意事项

python 复制代码
# Windows 上常见问题
dataloader = DataLoader(dataset, batch_size=32, num_workers=2)  
# 可能报错:BrokenPipeError

# Windows 解决方案:
if __name__ == '__main__':
    dataloader = DataLoader(dataset, batch_size=32, num_workers=0)  # 设为0
    # 或者使用 spwan 启动方式
    # 推荐:Windows 上保持 num_workers=0

🎨 高级用法:自定义采样器

python 复制代码
from torch.utils.data import SubsetRandomSampler

# 自定义采样策略
indices = list(range(1000))
train_indices = indices[:800]
val_indices = indices[800:]

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

train_loader = DataLoader(dataset, batch_size=32, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=32, sampler=val_sampler)

📋 常用配置模板

python 复制代码
# 配置模板(可直接复制使用)
def create_dataloaders(X_train, y_train, X_test, y_test, batch_size=32):
    """创建训练和测试 DataLoader"""
    train_dataset = TensorDataset(X_train, y_train)
    test_dataset = TensorDataset(X_test, y_test)
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,      # Windows 设 0,Linux/Mac 可设 2-4
        pin_memory=False    # 使用 GPU 时可设为 True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,      # 测试集不打乱
        num_workers=0,
        pin_memory=False
    )
    
    return train_loader, test_loader

# 使用
train_loader, test_loader = create_dataloaders(X_train, y_train, X_test, y_test, batch_size=64)

💡 总结

场景 推荐配置
训练集 batch_size=32-128, shuffle=True, drop_last=True
验证/测试集 batch_size=64-256, shuffle=False, drop_last=False
Windows 系统 num_workers=0
Linux/Mac + 大内存 num_workers=4
使用 GPU pin_memory=True 可加速

一句话总结DataLoader 是连接数据集和训练循环的桥梁,它自动处理批次划分、数据打乱、并行加载,让数据喂给模型变得优雅高效。

相关推荐
qq_411262422 小时前
四博AI智能音响方案(基于四博小助手AITOYO2)
人工智能·macos·xcode
AI木马人2 小时前
7.计算机视觉:让AI拥有一双“火眼金睛”
人工智能·计算机视觉
亿电连接器替代品网2 小时前
工业防水连接器选型:Amphenol LTW替代方案详解
大数据·网络·人工智能·硬件工程·材料工程
多年小白2 小时前
谷歌第八代 TPU 来了:性能提升 124%
网络·人工智能·科技·深度学习·ai
带娃的IT创业者2 小时前
Claude Code Routines 深度解析:重新定义 AI 辅助编程的工作流自动化
运维·人工智能·自动化·ai编程·工作流·anthropic·claude code
冬至喵喵2 小时前
本体论在数仓 Data Agent 中的应用
人工智能
Jmayday2 小时前
Pytorch:张量的操作
人工智能·pytorch·python
guslegend2 小时前
AI生图第3节:gpt-image-2的提示词反解析与Json结构化生图
人工智能·gpt·json
我是发哥哈2 小时前
主流AI视频生成方案商用化能力横向评测
大数据·人工智能·学习·机器学习·chatgpt·音视频