一文详解PyTorch DDP

PyTorch DDP 指的是 DistributedDataParallel ,是 PyTorch 官方提供的、用于多 GPU / 多机器分布式训练的核心并行方案

一句话概括:

DDP = 多进程 + 参数同步 + 高性能数据并行训练


一、DDP 是用来解决什么问题的?

当你训练的模型:

  • 单卡 显存不够
  • 单卡 训练太慢
  • 充分利用多张 GPU / 多台机器

就需要 并行训练

DDP 就是 PyTorch 推荐、也是工业界事实标准的并行方式。


二、DDP 的核心思想(非常重要)

1️⃣ 数据并行(Data Parallel)

  • 每个 GPU 一个进程
  • 每个进程一份完整模型
  • 每个进程只处理一部分数据

示意:

复制代码
GPU0: Model + Data shard 0
GPU1: Model + Data shard 1
GPU2: Model + Data shard 2
GPU3: Model + Data shard 3

2️⃣ 反向传播时自动同步梯度(All-Reduce)

loss.backward() 时:

  • 每个进程算出自己的梯度
  • 使用 NCCL / Gloo 等通信后端
  • 自动做 All-Reduce
  • 得到 所有 GPU 的平均梯度
  • 再各自 optimizer.step()

✔️ 模型参数始终保持一致


三、DDP 和 DataParallel(DP)的区别

这是高频面试 & 实战必考点 👇

对比项 DataParallel (DP) DistributedDataParallel (DDP)
并行方式 单进程多线程 多进程
性能 ❌ 慢 快(官方推荐)
GPU 利用率
通信方式 主卡聚合 All-Reduce
可扩展性 支持多机多卡
是否推荐 已不推荐 ⭐⭐⭐⭐⭐

结论:只要是多 GPU,一律用 DDP


四、DDP 的基本使用流程(核心代码结构)

1️⃣ 初始化进程组

python 复制代码
import torch.distributed as dist

dist.init_process_group(
    backend="nccl",   # GPU 用 nccl
    init_method="env://"
)

2️⃣ 设置当前进程使用的 GPU

python 复制代码
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

3️⃣ 包裹模型

python 复制代码
model = MyModel().cuda()
model = torch.nn.parallel.DistributedDataParallel(
    model,
    device_ids=[local_rank]
)

4️⃣ 使用 DistributedSampler(非常关键)

python 复制代码
from torch.utils.data import DistributedSampler

sampler = DistributedSampler(dataset)
dataloader = DataLoader(
    dataset,
    sampler=sampler,
    batch_size=32
)

每个进程只拿到 自己那一份数据,不重不漏。


5️⃣ 正常训练(几乎不用改)

python 复制代码
for epoch in range(epochs):
    sampler.set_epoch(epoch)  # 保证 shuffle 一致
    for batch in dataloader:
        loss = model(batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

五、DDP 的关键概念速记

概念 含义
world_size 总进程数(通常 = GPU 数)
rank 全局进程编号
local_rank 当前机器上的 GPU 编号
backend 通信后端(nccl / gloo
All-Reduce 梯度同步算法

六、为什么 DDP 性能这么好?

  1. 多进程,避免 GIL
  2. 通信和反向传播重叠
  3. NCCL 针对 GPU 高度优化
  4. 没有主卡瓶颈

这也是:

  • LLaMA
  • Qwen
  • Stable Diffusion
  • 各类工业训练框架

全部使用 DDP 的原因

相关推荐
冬奇Lab5 小时前
Workflow 系列(04):Multi-Agent 协调——编排器边界、并发控制与上下文隔离
人工智能·工作流引擎
冬奇Lab5 小时前
每日一个开源项目(第147篇):HyperGraphRAG - 用超图表示 N 元关系,RAG 的第三代范式
人工智能·开源·graphql
甲维斯5 小时前
Github + 阿里云oss实现类似codex的自动更新!
人工智能
阿里云大数据AI技术7 小时前
光轮智能 × 阿里云:共建 Physical AI 云上数据、评测与持续学习基础设施
人工智能·机器学习
机器之心7 小时前
实锤了:Claude Code偷查用户,时区、中国AI实验室全是关键词
人工智能·openai
网易云信7 小时前
Cursor点燃个人开发者,企业级AI为何频频受挫?Agent工厂从提效工具到AI员工的跃迁
人工智能·开源
网易云信7 小时前
解锁触手可及的温暖:网易智企 x Wander Puffs AI 云游泡芙
人工智能
转转技术团队8 小时前
从 PRD 到可验证代码:AI 需求开发闭环实践
人工智能