一文详解PyTorch DDP

PyTorch DDP 指的是 DistributedDataParallel ,是 PyTorch 官方提供的、用于多 GPU / 多机器分布式训练的核心并行方案

一句话概括:

DDP = 多进程 + 参数同步 + 高性能数据并行训练


一、DDP 是用来解决什么问题的?

当你训练的模型:

  • 单卡 显存不够
  • 单卡 训练太慢
  • 充分利用多张 GPU / 多台机器

就需要 并行训练

DDP 就是 PyTorch 推荐、也是工业界事实标准的并行方式。


二、DDP 的核心思想(非常重要)

1️⃣ 数据并行(Data Parallel)

  • 每个 GPU 一个进程
  • 每个进程一份完整模型
  • 每个进程只处理一部分数据

示意:

复制代码
GPU0: Model + Data shard 0
GPU1: Model + Data shard 1
GPU2: Model + Data shard 2
GPU3: Model + Data shard 3

2️⃣ 反向传播时自动同步梯度(All-Reduce)

loss.backward() 时:

  • 每个进程算出自己的梯度
  • 使用 NCCL / Gloo 等通信后端
  • 自动做 All-Reduce
  • 得到 所有 GPU 的平均梯度
  • 再各自 optimizer.step()

✔️ 模型参数始终保持一致


三、DDP 和 DataParallel(DP)的区别

这是高频面试 & 实战必考点 👇

对比项 DataParallel (DP) DistributedDataParallel (DDP)
并行方式 单进程多线程 多进程
性能 ❌ 慢 快(官方推荐)
GPU 利用率
通信方式 主卡聚合 All-Reduce
可扩展性 支持多机多卡
是否推荐 已不推荐 ⭐⭐⭐⭐⭐

结论:只要是多 GPU,一律用 DDP


四、DDP 的基本使用流程(核心代码结构)

1️⃣ 初始化进程组

python 复制代码
import torch.distributed as dist

dist.init_process_group(
    backend="nccl",   # GPU 用 nccl
    init_method="env://"
)

2️⃣ 设置当前进程使用的 GPU

python 复制代码
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

3️⃣ 包裹模型

python 复制代码
model = MyModel().cuda()
model = torch.nn.parallel.DistributedDataParallel(
    model,
    device_ids=[local_rank]
)

4️⃣ 使用 DistributedSampler(非常关键)

python 复制代码
from torch.utils.data import DistributedSampler

sampler = DistributedSampler(dataset)
dataloader = DataLoader(
    dataset,
    sampler=sampler,
    batch_size=32
)

每个进程只拿到 自己那一份数据,不重不漏。


5️⃣ 正常训练(几乎不用改)

python 复制代码
for epoch in range(epochs):
    sampler.set_epoch(epoch)  # 保证 shuffle 一致
    for batch in dataloader:
        loss = model(batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

五、DDP 的关键概念速记

概念 含义
world_size 总进程数(通常 = GPU 数)
rank 全局进程编号
local_rank 当前机器上的 GPU 编号
backend 通信后端(nccl / gloo
All-Reduce 梯度同步算法

六、为什么 DDP 性能这么好?

  1. 多进程,避免 GIL
  2. 通信和反向传播重叠
  3. NCCL 针对 GPU 高度优化
  4. 没有主卡瓶颈

这也是:

  • LLaMA
  • Qwen
  • Stable Diffusion
  • 各类工业训练框架

全部使用 DDP 的原因

相关推荐
程序员cxuan4 小时前
为每个任务配一套 harness:Claude Code 里的动态工作流
人工智能
程序员cxuan4 小时前
Claude Fable 5 来了
人工智能·后端·程序员
云边云科技_云网融合4 小时前
云边云科技亮相 2026 WOD 制造业数智化博览会 云网融合赋能制造焕新
人工智能·科技·安全·制造
biter down4 小时前
从 0 到 1 搭建 Python 接口自动化测试框架(博客系统实战)
开发语言·python
Σίσυφος19004 小时前
激光三角 光平面标定-多高度误差分析
人工智能·计算机视觉·平面
JS菌4 小时前
手写一个 AI Agent 全栈项目:从沙箱执行到子智能体的完整实现
前端·人工智能·后端
lqqjuly4 小时前
前沿算法深度解析(二)
人工智能·算法·机器学习
Bode_20025 小时前
基于大数据分析的全生命周期质量追溯质量评估体系落地方案
大数据·人工智能
分布式存储与RustFS5 小时前
RustFS S3 Table 开源后,我重新梳理了一下 Iceberg 数据湖的选型思路
人工智能·开源·minio·dpu·rustfs·ai存储·s3 table
DevOpenClub6 小时前
用 Agent 搭建网页内容采集与结构化处理流水线
人工智能