PyTorch 分布式训练(Distributed Data Parallel, DDP)简介

PyTorch 分布式训练(Distributed Data Parallel, DDP)

一、DDP 核心概念

torch.nn.parallel.DistributedDataParallel

1. DDP 是什么?

Distributed Data Parallel (DDP) 是 PyTorch 提供的分布式训练接口,DistributedDataParallel相比 DataParallel 具有以下优势:

  • 多进程而非多线程:避免 Python GIL 限制
  • 更高的效率:每个 GPU 有独立的进程,减少通信开销
  • 更好的扩展性:支持多机多卡训练
  • 更均衡的负载:无主 GPU 瓶颈问题

2. 核心组件

  • 进程组 (Process Group):管理进程间通信
  • NCCL 后端:NVIDIA 优化的 GPU 通信库
  • Ring-AllReduce:高效的梯度同步算法

二、完整 DDP 训练 Demo

1. 基础训练脚本 (ddp_demo.py)

python 复制代码
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from torch.cuda.amp import GradScaler

def setup(rank, world_size):
    """初始化分布式环境"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    """清理分布式环境"""
    dist.destroy_process_group()

class SimpleModel(nn.Module):
    """简单的CNN模型"""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc = nn.Linear(9216, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        return self.fc(x)

def prepare_dataloader(rank, world_size, batch_size=32):
    """准备分布式数据加载器"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    return loader

def train(rank, world_size, epochs=2):
    """训练函数"""
    setup(rank, world_size)
    
    # 设置当前设备
    torch.cuda.set_device(rank)
    
    # 初始化模型、优化器等
    model = SimpleModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    optimizer = optim.Adam(ddp_model.parameters())
    scaler = GradScaler()  # 混合精度训练
    criterion = nn.CrossEntropyLoss()
    train_loader = prepare_dataloader(rank, world_size)
    
    for epoch in range(epochs):
        ddp_model.train()
        train_loader.sampler.set_epoch(epoch)  # 确保每个epoch有不同的shuffle
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(rank), target.to(rank)
            
            optimizer.zero_grad()
            
            # 混合精度训练
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                output = ddp_model(data)
                loss = criterion(output, target)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            if batch_idx % 100 == 0:
                print(f"Rank {rank}, Epoch {epoch}, Batch {batch_idx}, Loss {loss.item():.4f}")
    
    cleanup()

if __name__ == "__main__":
    # 单机多卡启动时,torchrun会自动设置这些环境变量
    rank = int(os.environ['LOCAL_RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    train(rank, world_size)

2. 启动训练

使用 torchrun 启动分布式训练(推荐 PyTorch 1.9+):

bash 复制代码
# 单机4卡训练
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12355 ddp_demo.py

3. 关键组件解析

3.1 分布式数据采样 (DistributedSampler)
python 复制代码
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
  • 确保每个 GPU 处理不同的数据子集
  • 自动处理数据分片和 epoch 间的 shuffle
3.2 模型包装 (DDP)
python 复制代码
ddp_model = DDP(model, device_ids=[rank])
  • 自动处理梯度同步
  • 透明地包装模型,使用方式与普通模型一致
3.3 混合精度训练 (AMP)
python 复制代码
scaler = GradScaler()
with torch.autocast(device_type='cuda', dtype=torch.float16):
    # 前向计算
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
  • 减少显存占用,加速训练
  • 自动管理 float16/float32 转换

三、DDP 最佳实践

  1. 数据加载

    • 必须使用 DistributedSampler
    • 每个 epoch 前调用 sampler.set_epoch(epoch) 保证 shuffle 正确性
  2. 模型保存

    python 复制代码
    if rank == 0:  # 只在主进程保存
        torch.save(model.state_dict(), "model.pth")
  3. 多机训练

    bash 复制代码
    # 机器1 (主节点)
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=IP1 --master_port=12355 ddp_demo.py
    
    # 机器2
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr=IP1 --master_port=12355 ddp_demo.py
  4. 性能调优

    • 调整 batch_size 使各 GPU 负载均衡
    • 使用 pin_memory=True 加速数据加载
    • 考虑梯度累积减少通信频率

四、常见问题解决

  1. CUDA 内存不足

    • 减少 batch_size
    • 使用梯度累积
    python 复制代码
    for i, (data, target) in enumerate(train_loader):
        if i % 2 == 0:
            optimizer.zero_grad()
        # 前向和反向...
        if i % 2 == 1:
            optimizer.step()
  2. 进程同步失败

    • 检查所有节点的 MASTER_ADDRMASTER_PORT 一致
    • 确保防火墙开放相应端口
  3. 精度问题

    • 混合精度训练时出现 NaN:调整 GradScaler 参数
    python 复制代码
    scaler = GradScaler(init_scale=1024, growth_factor=2.0)
相关推荐
九章云极AladdinEdu6 小时前
GitHub新手生存指南:AI项目版本控制与协作实战
人工智能·pytorch·opencv·机器学习·github·gpu算力
z are11 小时前
PyTorch 模型开发全栈指南:从定义、修改到保存的完整闭环
人工智能·pytorch·python
点云SLAM14 小时前
Pytorch中cuda相关操作详见和代码示例
人工智能·pytorch·python·深度学习·3d·cuda·多gpu训练
cwn_1 天前
Sequential 损失函数 反向传播 优化器 模型的使用修改保存加载
人工智能·pytorch·python·深度学习·机器学习
老鱼说AI1 天前
Transformer Masked loss原理精讲及其PyTorch逐行实现
人工智能·pytorch·python·深度学习·transformer
空中湖1 天前
PyTorch武侠演义 第一卷:初入江湖 第5章:玉如意的秘密
人工智能·pytorch·neo4j
许愿与你永世安宁2 天前
PyTorch中神经网络的模型构建
人工智能·pytorch·python
F_D_Z2 天前
模型的存储、加载和部署
人工智能·pytorch·python·深度学习
梦星辰.2 天前
PyTorch数据选取与索引详解:从入门到高效实践
人工智能·pytorch·python
fouen2 天前
DenseNet详解,附模型代码(pytorch)
人工智能·pytorch·python·深度学习·计算机视觉