一文详解PyTorch DDP

PyTorch DDP 指的是 DistributedDataParallel ,是 PyTorch 官方提供的、用于多 GPU / 多机器分布式训练的核心并行方案

一句话概括:

DDP = 多进程 + 参数同步 + 高性能数据并行训练


一、DDP 是用来解决什么问题的?

当你训练的模型:

  • 单卡 显存不够
  • 单卡 训练太慢
  • 充分利用多张 GPU / 多台机器

就需要 并行训练

DDP 就是 PyTorch 推荐、也是工业界事实标准的并行方式。


二、DDP 的核心思想(非常重要)

1️⃣ 数据并行(Data Parallel)

  • 每个 GPU 一个进程
  • 每个进程一份完整模型
  • 每个进程只处理一部分数据

示意:

复制代码
GPU0: Model + Data shard 0
GPU1: Model + Data shard 1
GPU2: Model + Data shard 2
GPU3: Model + Data shard 3

2️⃣ 反向传播时自动同步梯度(All-Reduce)

loss.backward() 时:

  • 每个进程算出自己的梯度
  • 使用 NCCL / Gloo 等通信后端
  • 自动做 All-Reduce
  • 得到 所有 GPU 的平均梯度
  • 再各自 optimizer.step()

✔️ 模型参数始终保持一致


三、DDP 和 DataParallel(DP)的区别

这是高频面试 & 实战必考点 👇

对比项 DataParallel (DP) DistributedDataParallel (DDP)
并行方式 单进程多线程 多进程
性能 ❌ 慢 快(官方推荐)
GPU 利用率
通信方式 主卡聚合 All-Reduce
可扩展性 支持多机多卡
是否推荐 已不推荐 ⭐⭐⭐⭐⭐

结论:只要是多 GPU,一律用 DDP


四、DDP 的基本使用流程(核心代码结构)

1️⃣ 初始化进程组

python 复制代码
import torch.distributed as dist

dist.init_process_group(
    backend="nccl",   # GPU 用 nccl
    init_method="env://"
)

2️⃣ 设置当前进程使用的 GPU

python 复制代码
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

3️⃣ 包裹模型

python 复制代码
model = MyModel().cuda()
model = torch.nn.parallel.DistributedDataParallel(
    model,
    device_ids=[local_rank]
)

4️⃣ 使用 DistributedSampler(非常关键)

python 复制代码
from torch.utils.data import DistributedSampler

sampler = DistributedSampler(dataset)
dataloader = DataLoader(
    dataset,
    sampler=sampler,
    batch_size=32
)

每个进程只拿到 自己那一份数据,不重不漏。


5️⃣ 正常训练(几乎不用改)

python 复制代码
for epoch in range(epochs):
    sampler.set_epoch(epoch)  # 保证 shuffle 一致
    for batch in dataloader:
        loss = model(batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

五、DDP 的关键概念速记

概念 含义
world_size 总进程数(通常 = GPU 数)
rank 全局进程编号
local_rank 当前机器上的 GPU 编号
backend 通信后端(nccl / gloo
All-Reduce 梯度同步算法

六、为什么 DDP 性能这么好?

  1. 多进程,避免 GIL
  2. 通信和反向传播重叠
  3. NCCL 针对 GPU 高度优化
  4. 没有主卡瓶颈

这也是:

  • LLaMA
  • Qwen
  • Stable Diffusion
  • 各类工业训练框架

全部使用 DDP 的原因

相关推荐
袁气满满~_~16 小时前
Python数据分析学习
开发语言·笔记·python·学习
deephub16 小时前
构建自己的AI编程助手:基于RAG的上下文感知实现方案
人工智能·机器学习·ai编程·rag·ai编程助手
AI营销干货站16 小时前
工业B2B获客难?原圈科技解析2026五大AI营销增长引擎
人工智能
程序员老刘·17 小时前
重拾Eval能力:D4rt为Flutter注入AI进化基因
人工智能·flutter·跨平台开发·客户端开发
kebijuelun17 小时前
FlashInfer-Bench:把 AI 生成的 GPU Kernel 放进真实 LLM 系统的“闭环引擎”
人工智能·gpt·深度学习·机器学习·语言模型
Deepoch17 小时前
Deepoc具身模型开发板:让炒菜机器人成为您的智能厨师
人工智能·机器人·开发板·具身模型·deepoc·炒菜机器人·厨房机器人
axinawang17 小时前
二、信息系统与安全--考点--浙江省高中信息技术学考(Python)
python·浙江省高中信息技术
Elastic 中国社区官方博客17 小时前
Elastic:DevRel 通讯 — 2026 年 1 月
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
寻星探路17 小时前
【算法专题】滑动窗口:从“无重复字符”到“字母异位词”的深度剖析
java·开发语言·c++·人工智能·python·算法·ai
Dxy123931021617 小时前
python连接minio报错:‘SSL routines‘, ‘ssl3_get_record‘, ‘wrong version number‘
开发语言·python·ssl