PyTorch 分布式训练(Distributed Data Parallel, DDP)简介

PyTorch 分布式训练(Distributed Data Parallel, DDP)

一、DDP 核心概念

torch.nn.parallel.DistributedDataParallel

1. DDP 是什么?

Distributed Data Parallel (DDP) 是 PyTorch 提供的分布式训练接口,DistributedDataParallel相比 DataParallel 具有以下优势:

  • 多进程而非多线程:避免 Python GIL 限制
  • 更高的效率:每个 GPU 有独立的进程,减少通信开销
  • 更好的扩展性:支持多机多卡训练
  • 更均衡的负载:无主 GPU 瓶颈问题

2. 核心组件

  • 进程组 (Process Group):管理进程间通信
  • NCCL 后端:NVIDIA 优化的 GPU 通信库
  • Ring-AllReduce:高效的梯度同步算法

二、完整 DDP 训练 Demo

1. 基础训练脚本 (ddp_demo.py)

python 复制代码
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from torch.cuda.amp import GradScaler

def setup(rank, world_size):
    """初始化分布式环境"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    """清理分布式环境"""
    dist.destroy_process_group()

class SimpleModel(nn.Module):
    """简单的CNN模型"""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc = nn.Linear(9216, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        return self.fc(x)

def prepare_dataloader(rank, world_size, batch_size=32):
    """准备分布式数据加载器"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    return loader

def train(rank, world_size, epochs=2):
    """训练函数"""
    setup(rank, world_size)
    
    # 设置当前设备
    torch.cuda.set_device(rank)
    
    # 初始化模型、优化器等
    model = SimpleModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    optimizer = optim.Adam(ddp_model.parameters())
    scaler = GradScaler()  # 混合精度训练
    criterion = nn.CrossEntropyLoss()
    train_loader = prepare_dataloader(rank, world_size)
    
    for epoch in range(epochs):
        ddp_model.train()
        train_loader.sampler.set_epoch(epoch)  # 确保每个epoch有不同的shuffle
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(rank), target.to(rank)
            
            optimizer.zero_grad()
            
            # 混合精度训练
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                output = ddp_model(data)
                loss = criterion(output, target)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            if batch_idx % 100 == 0:
                print(f"Rank {rank}, Epoch {epoch}, Batch {batch_idx}, Loss {loss.item():.4f}")
    
    cleanup()

if __name__ == "__main__":
    # 单机多卡启动时,torchrun会自动设置这些环境变量
    rank = int(os.environ['LOCAL_RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    train(rank, world_size)

2. 启动训练

使用 torchrun 启动分布式训练(推荐 PyTorch 1.9+):

bash 复制代码
# 单机4卡训练
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12355 ddp_demo.py

3. 关键组件解析

3.1 分布式数据采样 (DistributedSampler)
python 复制代码
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
  • 确保每个 GPU 处理不同的数据子集
  • 自动处理数据分片和 epoch 间的 shuffle
3.2 模型包装 (DDP)
python 复制代码
ddp_model = DDP(model, device_ids=[rank])
  • 自动处理梯度同步
  • 透明地包装模型,使用方式与普通模型一致
3.3 混合精度训练 (AMP)
python 复制代码
scaler = GradScaler()
with torch.autocast(device_type='cuda', dtype=torch.float16):
    # 前向计算
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
  • 减少显存占用,加速训练
  • 自动管理 float16/float32 转换

三、DDP 最佳实践

  1. 数据加载

    • 必须使用 DistributedSampler
    • 每个 epoch 前调用 sampler.set_epoch(epoch) 保证 shuffle 正确性
  2. 模型保存

    python 复制代码
    if rank == 0:  # 只在主进程保存
        torch.save(model.state_dict(), "model.pth")
  3. 多机训练

    bash 复制代码
    # 机器1 (主节点)
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=IP1 --master_port=12355 ddp_demo.py
    
    # 机器2
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr=IP1 --master_port=12355 ddp_demo.py
  4. 性能调优

    • 调整 batch_size 使各 GPU 负载均衡
    • 使用 pin_memory=True 加速数据加载
    • 考虑梯度累积减少通信频率

四、常见问题解决

  1. CUDA 内存不足

    • 减少 batch_size
    • 使用梯度累积
    python 复制代码
    for i, (data, target) in enumerate(train_loader):
        if i % 2 == 0:
            optimizer.zero_grad()
        # 前向和反向...
        if i % 2 == 1:
            optimizer.step()
  2. 进程同步失败

    • 检查所有节点的 MASTER_ADDRMASTER_PORT 一致
    • 确保防火墙开放相应端口
  3. 精度问题

    • 混合精度训练时出现 NaN:调整 GradScaler 参数
    python 复制代码
    scaler = GradScaler(init_scale=1024, growth_factor=2.0)
相关推荐
没有梦想的咸鱼185-1037-16632 小时前
【遥感技术】从CNN到Transformer:基于PyTorch的遥感影像、无人机影像的地物分类、目标检测、语义分割和点云分类
pytorch·python·深度学习·机器学习·数据分析·cnn·transformer
Teacher.chenchong4 小时前
基于PyTorch深度学习无人机遥感影像目标检测、地物分类及语义分割实践技术应用
pytorch·深度学习·无人机
weixin_457340215 小时前
RTX5060 Ti显卡安装cuda版本PyTorch踩坑记录
人工智能·pytorch·python
zhurui_xiaozhuzaizai9 小时前
大模型里使用的pytorch dataset 和dataloader详细解析和介绍
人工智能·pytorch·python
鲸鱼24019 小时前
Pytorch工具箱2
人工智能·pytorch·python
西猫雷婶9 小时前
python学智能算法(三十九)|使用PyTorch模块的normal()函数绘制正态分布函数图
开发语言·人工智能·pytorch·python·深度学习·神经网络·学习
麻雀无能为力11 小时前
Pytorch框架笔记
人工智能·pytorch·笔记
起个名字费劲死了1 天前
Pytorch Yolov11 OBB 旋转框检测+window部署+推理封装 留贴记录
c++·人工智能·pytorch·python·深度学习·yolo·机器人
辞--忧1 天前
PyTorch 神经网络工具箱完全指南
pytorch·神经网络
辞--忧1 天前
PyTorch 数据处理与可视化全攻略
人工智能·pytorch