PyTorch 分布式训练(Distributed Data Parallel, DDP)简介

PyTorch 分布式训练(Distributed Data Parallel, DDP)

一、DDP 核心概念

torch.nn.parallel.DistributedDataParallel

1. DDP 是什么?

Distributed Data Parallel (DDP) 是 PyTorch 提供的分布式训练接口,DistributedDataParallel相比 DataParallel 具有以下优势:

  • 多进程而非多线程:避免 Python GIL 限制
  • 更高的效率:每个 GPU 有独立的进程,减少通信开销
  • 更好的扩展性:支持多机多卡训练
  • 更均衡的负载:无主 GPU 瓶颈问题

2. 核心组件

  • 进程组 (Process Group):管理进程间通信
  • NCCL 后端:NVIDIA 优化的 GPU 通信库
  • Ring-AllReduce:高效的梯度同步算法

二、完整 DDP 训练 Demo

1. 基础训练脚本 (ddp_demo.py)

python 复制代码
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from torch.cuda.amp import GradScaler

def setup(rank, world_size):
    """初始化分布式环境"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    """清理分布式环境"""
    dist.destroy_process_group()

class SimpleModel(nn.Module):
    """简单的CNN模型"""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc = nn.Linear(9216, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        return self.fc(x)

def prepare_dataloader(rank, world_size, batch_size=32):
    """准备分布式数据加载器"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    return loader

def train(rank, world_size, epochs=2):
    """训练函数"""
    setup(rank, world_size)
    
    # 设置当前设备
    torch.cuda.set_device(rank)
    
    # 初始化模型、优化器等
    model = SimpleModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    optimizer = optim.Adam(ddp_model.parameters())
    scaler = GradScaler()  # 混合精度训练
    criterion = nn.CrossEntropyLoss()
    train_loader = prepare_dataloader(rank, world_size)
    
    for epoch in range(epochs):
        ddp_model.train()
        train_loader.sampler.set_epoch(epoch)  # 确保每个epoch有不同的shuffle
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(rank), target.to(rank)
            
            optimizer.zero_grad()
            
            # 混合精度训练
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                output = ddp_model(data)
                loss = criterion(output, target)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            if batch_idx % 100 == 0:
                print(f"Rank {rank}, Epoch {epoch}, Batch {batch_idx}, Loss {loss.item():.4f}")
    
    cleanup()

if __name__ == "__main__":
    # 单机多卡启动时,torchrun会自动设置这些环境变量
    rank = int(os.environ['LOCAL_RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    train(rank, world_size)

2. 启动训练

使用 torchrun 启动分布式训练(推荐 PyTorch 1.9+):

bash 复制代码
# 单机4卡训练
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12355 ddp_demo.py

3. 关键组件解析

3.1 分布式数据采样 (DistributedSampler)
python 复制代码
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
  • 确保每个 GPU 处理不同的数据子集
  • 自动处理数据分片和 epoch 间的 shuffle
3.2 模型包装 (DDP)
python 复制代码
ddp_model = DDP(model, device_ids=[rank])
  • 自动处理梯度同步
  • 透明地包装模型,使用方式与普通模型一致
3.3 混合精度训练 (AMP)
python 复制代码
scaler = GradScaler()
with torch.autocast(device_type='cuda', dtype=torch.float16):
    # 前向计算
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
  • 减少显存占用,加速训练
  • 自动管理 float16/float32 转换

三、DDP 最佳实践

  1. 数据加载

    • 必须使用 DistributedSampler
    • 每个 epoch 前调用 sampler.set_epoch(epoch) 保证 shuffle 正确性
  2. 模型保存

    python 复制代码
    if rank == 0:  # 只在主进程保存
        torch.save(model.state_dict(), "model.pth")
  3. 多机训练

    bash 复制代码
    # 机器1 (主节点)
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=IP1 --master_port=12355 ddp_demo.py
    
    # 机器2
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr=IP1 --master_port=12355 ddp_demo.py
  4. 性能调优

    • 调整 batch_size 使各 GPU 负载均衡
    • 使用 pin_memory=True 加速数据加载
    • 考虑梯度累积减少通信频率

四、常见问题解决

  1. CUDA 内存不足

    • 减少 batch_size
    • 使用梯度累积
    python 复制代码
    for i, (data, target) in enumerate(train_loader):
        if i % 2 == 0:
            optimizer.zero_grad()
        # 前向和反向...
        if i % 2 == 1:
            optimizer.step()
  2. 进程同步失败

    • 检查所有节点的 MASTER_ADDRMASTER_PORT 一致
    • 确保防火墙开放相应端口
  3. 精度问题

    • 混合精度训练时出现 NaN:调整 GradScaler 参数
    python 复制代码
    scaler = GradScaler(init_scale=1024, growth_factor=2.0)
相关推荐
兮希yxx25 分钟前
conda配置pytorch虚拟环境
人工智能·pytorch·conda
程序员miki13 小时前
Pytorch的CUDA版本安装使用教程
人工智能·pytorch·python
Ronin-Lotus14 小时前
深度学习篇---pytorch数据集
人工智能·pytorch·深度学习
爱补鱼的猫猫14 小时前
pytorch可视化工具(训练评估:Tensorboard、swanlab)
人工智能·pytorch·python
九年义务漏网鲨鱼16 小时前
PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回
人工智能·pytorch·batch
星期天要睡觉18 小时前
深度学习——ResNet 卷积神经网络
pytorch·深度学习·神经网络
二向箔reverse19 小时前
从传统CNN到残差网络:用PyTorch实现更强大的图像分类模型
网络·pytorch·cnn
AI绘画哇哒哒1 天前
【值得收藏】手把手教你用PyTorch构建Transformer英汉翻译系统,从训练到推理
人工智能·pytorch·ai·语言模型·程序员·大模型·transformer
CH3_CH2_CHO1 天前
DAY02:【DL 第一弹】pytorch
人工智能·pytorch·python·深度学习·回归
过往入尘土1 天前
从 0 到 1 实现 PyTorch 食物图像分类:核心知识点与完整实
人工智能·pytorch·分类