PyTorch Distributed 源码导读笔记(基于 torch/distributed/distributed_c10d.py)
说明: 重点围绕:
distributed_c10d.py、后端注册、init_process_group、collective 调用链、ProcessGroup抽象类、NCCL 后端。
目录
- [第 0 章:基础概念总览](#第 0 章:基础概念总览)
- [第 1 章:Backend 注册系统](#第 1 章:Backend 注册系统)
- [第 2 章:init_process_group ------ 分布式启动总入口](#第 2 章:init_process_group —— 分布式启动总入口)
- [第 3 章:Collective 包装函数 ------ Python 层的"包装 + 路由"](#第 3 章:Collective 包装函数 —— Python 层的“包装 + 路由”)
- [第 4 章:new_group ------ 多进程组管理的基础](#第 4 章:new_group —— 多进程组管理的基础)
- [第 5 章:_World ------ PyTorch 分布式的全局"大脑"](#第 5 章:_World —— PyTorch 分布式的全局“大脑”)
- [第 6 章:"包装 + 路由"中的 Routing 究竟是什么](#第 6 章:“包装 + 路由”中的 Routing 究竟是什么)
- [第 7 章:Python → C++ → Backend 的完整调用链](#第 7 章:Python → C++ → Backend 的完整调用链)
- [第 8 章:ProcessGroup 抽象类与 Work 接口(自定义后端必须实现什么)](#第 8 章:ProcessGroup 抽象类与 Work 接口(自定义后端必须实现什么))
第 0 章:基础概念总览
在正式看源码前,先把几个核心名词统一一下(后面会频繁出现):
0.1 rank
- rank = 进程编号。
- 在一个分布式训练中,每个参与的进程都有一个整型编号:0, 1, ..., world_size-1。
- rank 在外部由启动器(如
torchrun,mpirun, Slurm + 启动脚本)或环境变量指定。
0.2 world_size
- world_size = 总进程数。
- 比如 8 张 GPU,每张 1 个进程,则 world_size = 8。
0.3 group / ProcessGroup
- group(进程组) = "一组一起做通信的 rank 集合"。
- ProcessGroup (C++ 类)是 PyTorch 对"进程组"的实现抽象:
- 一个 ProcessGroup 管一组 rank;
- 提供 allreduce / broadcast / allgather / barrier 等接口;
- 可以自己提供继承ProcessGroup的实现并注册。
0.4 backend(后端)
- backend = 具体通信实现方式。
- 常见后端:
- NCCL:NVIDIA GPU 集合通信库;
- Gloo:CPU/网络集合通信库;
- MPI:传统 MPI;
- 其他(UCC/XCCL,自定义后端)。
- backend 决定了 ProcessGroup 的具体类型并继承实现:
ProcessGroupNCCLProcessGroupGlooProcessGroupMPI
0.5 Store
- Store = 所有进程共享的"小型键值对字典"(控制面)。
- 用途:
- 初始化阶段:用于 rendezvous(让进程互相发现);
- 存 rank / world_size 信息;
- 实现 barrier 等简单同步;
- 实现:
TCPStore:通过 TCP server 共享;FileStore:通过共享文件系统;PrefixStore:在一个 Store 基础上加前缀,实现命名空间隔离;- 测试/本地时还可以有内存 HashStore/FakeStore。
重要:Store 不传大 tensor,主要传"控制信息"。
0.6 rendezvous(会合 / 集合点)
- 多个分布式进程启动时,需要一个"集合点":
- 谁来参加(rank)?
- 一共有几个人(world_size)?
- 通信地址是啥(MASTER_ADDR / PORT)?
- rendezvous 就是这一套"让所有进程在 Store 处会合,然后交换身份信息"的过程。
init_process_group里如果没传 store,会调用rendezvous(init_method, ...)来完成这些工作。
0.7 collective(集合通信)
不同于点对点的 send/recv,collective 是"整个 group 一起干一件事如":
all_reduce:每个 rank 有一个 tensor,先聚合(如求和),然后每个 rank 都拿到结果;broadcast:一个 src rank 的 tensor 发给所有人;all_gather:每个 rank 提供一个 tensor,最终大家都有所有 rank 的 tensor 列表;reduce_scatter:先规约,再把结果按块分散给各个 rank;barrier:所有人都到这行才能往下执行。
0.8 Work(异步句柄)
- 每次 collective 一般是异步的,返回一个
Work对象:work.is_completed()work.wait()
- 对于 NCCL,Work 里会包含 CUDA events、streams 等,用于检测完成与同步。
0.9 Routing(路由)
- 在"包装 + 路由"这句话里:
- 包装 :Python 层提供友好的 API(如
dist.all_reduce); - 路由 :Python 层决定:
- 使用哪个 group(default_pg / 子 group);
- 使用哪个 backend(NCCL/Gloo/MPI/UCCL);
- 调用哪个 C++ 的 ProcessGroup 方法;
- 包装 :Python 层提供友好的 API(如
- Python 不做实际通信,只做 路由 + 状态管理,这里的路由指的不是网络层面的路由,而是上层消息到底层真实硬件发射经过的调用路径。
第 1 章:Backend 注册系统
distributed_c10d.py 的首要任务之一,是维护一个"后端注册表"。
1.1 _backend 字典:后端注册表
python
_backend: dict[str, Callable[..., ProcessGroup]] = {}
含义:
- key:后端名称(通常是字符串
"nccl"、"gloo"等,或者Backend.NCCL这样的枚举值); - value:构造这个后端 ProcessGroup 的函数(构造器),例如:
python
# 举例:构造 NCCL 的 ProcessGroup
def _default_pg_nccl_constructor(store, rank, world_size, **kwargs) -> ProcessGroup:
...
return process_group_nccl
1.2 _register_backend(name, func)
注册后端:
python
def _register_backend(name: str, func: Callable[..., ProcessGroup]) -> None:
if name in _backend:
raise RuntimeError(f"backend '{name}' already registered")
_backend[name] = func
典型使用:
python
_register_backend("nccl", _default_pg_nccl_constructor)
_register_backend("gloo", _default_pg_gloo_constructor)
_register_backend("mpi", _default_pg_mpi_constructor)
这一步就是:
把不同 backend 的"构造函数"挂到全局字典 _backend 里。
1.3 _get_backend(name)
python
def _get_backend(name: str) -> Callable[..., ProcessGroup]:
return _backend[name]
在 init_process_group 中,会用这个函数按名称拿到构造器:
python
pg_constructor = _get_backend(backend)
pg = pg_constructor(store, rank, world_size, timeout=timeout, group_name=group_name)
1.4 Backend 枚举类
源码中有一个 Backend 枚举,类似:
python
class Backend(str, Enum):
GLOO = "gloo"
NCCL = "nccl"
MPI = "mpi"
UCC = "ucc"
XCCL = "xccl"
作用:
- 统一 backend 名称(避免大小写、拼写差异);
- 方便和 C++ 层的 backend type 映射;
- 用户可以写
backend="nccl"或backend=Backend.NCCL。
1.5 Backend 系统总结
Backend 注册系统完成的工作:
- 定义各种后端名字(枚举)。
- 为不同名字注册对应的 ProcessGroup 构造器。
- 在
init_process_group时,通过名字找到构造器并创建 ProcessGroup。
Python 层无需知道 NCCL/Gloo/MPI 的细节,只需通过名字把调用路由到对应 C++ backend。
第 2 章:init_process_group ------ 分布式启动总入口
几乎所有使用 torch.distributed 的代码都会先调用:
python
dist.init_process_group(backend, init_method, ...)
这是整个分布式系统的"开机按钮"。
2.1 函数签名与参数
简化后的定义:
python
def init_process_group(
backend,
init_method=None,
timeout=default_pg_timeout,
world_size=None,
rank=None,
store=None,
group_name="",
):
...
关键参数:
| 参数 | 含义 |
|---|---|
| backend | 通信后端("nccl" / "gloo" / "mpi" / ...) |
| init_method | rendezvous 方式("env://" / "tcp://ip:port" / "file://path") |
| world_size | 总进程数(可从环境变量/rdzv 获得) |
| rank | 当前进程编号(可从环境变量/rdzv 获得) |
| store | 用户自建 Store 时可跳过 rendezvous |
| group_name | 默认进程组名字(一般不用关心) |
| timeout | rendezvous 和后端创建的超时时间 |
2.2 整体流程概览
init_process_group 的主要步骤(简化)是:
- 检查是否已经初始化(不允许重复初始化默认 PG)。
- 标准化 backend 参数(转为
Backend枚举)。 - 如果用户没传
store:- 调用
rendezvous(init_method, rank, world_size, timeout); - 得到
(store, rank, world_size)。
- 调用
- 用
_get_backend(backend)找到对应的后端构造器。 - 调用构造器创建
ProcessGroup:pg = pg_constructor(store, rank, world_size, ...)
- 调
_world.set_default_pg(pg, store, backend):- 注册 global default ProcessGroup;
- 注册 default Store;
- 记录 backend 类型。
- 返回 None,初始化完成。
2.3 rendezvous 的职责(再强化一次)
当 store is None 时:
python
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
rendezvous 函数负责:
- 根据
init_method构造一个 Store(TCPStore / FileStore 等); - 所有进程在这个 Store 处"集合"(会合);
- 每个进程最终得到:
- 自己的
rank; - 全局
world_size; - 一个可用的 Store。
- 自己的
2.4 创建 ProcessGroup
python
backend = Backend(backend)
pg_constructor = _get_backend(backend)
pg = pg_constructor(store, rank, world_size, timeout=timeout, group_name=group_name)
这一步是真正"选择后端 + 创建通信上下文"的地方。
- NCCL:创建
ProcessGroupNCCL,内部初始化 communicator、streams; - Gloo:创建
ProcessGroupGloo,建立 TCP 连接; - MPI:使用 MPI 通信。
2.5 注册 default_pg
python
_world.set_default_pg(pg, store, backend)
其中 _World 是一个全局状态管理器(下一章详细讲),这一步完成:
- 设置
default_pg; - 设置
default_store; - 记录当前 backend 类型。
此后:
python
dist.all_reduce(tensor)
如果不传 group,就会使用这个 default_pg。
第 3 章:Collective 包装函数 ------ Python 层的"包装 + 路由"
distributed_c10d.py 提供了很多你熟悉的函数:
all_reducebroadcastall_gatherreduce_scatterbarrier- ...
这些函数本身 不做通信,而是"包装 + 路由"------通过 ProcessGroup 把调用交给 C++ 层。
3.1 典型例子:all_reduce
源码结构大致类似:
python
def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
if group is None:
group = _get_default_group()
work = group.allreduce([tensor], op=op)
if async_op:
return work
else:
work.wait()
return work
解析:
- 选择 group
group=None时,使用_get_default_group()→_world.default_pg。
- 调用 C++ ProcessGroup
group.allreduce([tensor], op=op)是 C++ 层的虚函数调用。
- 处理异步 / 同步
- 若
async_op=False:立即work.wait(); - 若
async_op=True:返回Work对象,让用户自己wait()。
- 若
3.2 其它常见 collective 的结构
基本套路都一致:
broadcast
python
def broadcast(tensor, src, group=None, async_op=False):
group = group or _get_default_group()
work = group.broadcast([tensor], src=src)
return work if async_op else (work.wait() or work)
- src rank 的 tensor 作为数据源,覆盖所有其他 rank 的 tensor。
all_gather
python
def all_gather(tensor_list, tensor, group=None, async_op=False):
group = group or _get_default_group()
work = group.allgather([tensor_list], [tensor])
return work if async_op else (work.wait() or work)
tensor_list长度 = world_size,存放所有 rank 的 tensor。
reduce_scatter
python
def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
group = group or _get_default_group()
work = group.reduce_scatter(output, input_list, op=op)
return work if async_op else (work.wait() or work)
barrier
python
def barrier(group=None, async_op=False):
group = group or _get_default_group()
if group.has_barrier():
work = group.barrier()
else:
work = _store_based_barrier(...)
return work if async_op else (work.wait() or work)
- 如果 backend 自己实现了 barrier,就用后端的;
- 否则可以用 Store 模拟 barrier(store-based barrier)。
3.3 小结
- Python 核心三件事:
- 选 group(默认 group 或子 group);
- 调用
ProcessGroup的对应 C++ 方法; - 包装 Work,处理 async/sync。
- 真正的通信细节全部在 C++ backend。
第 4 章:new_group ------ 多进程组管理的基础
实际训练中经常需要多个"通信域":
- world group:所有 rank;
- data parallel group:所有 GPU;
- tensor parallel group:每几个 GPU 为一组;
- pipeline parallel group;
- MoE 专家组等。
new_group() 就是用来在默认 world group 的基础上创建子组。
4.1 new_group 的作用
例如:
python
pg = dist.new_group(ranks=[0,1,2,3])
dist.all_reduce(tensor, group=pg)
只有 rank 0/1/2/3 会参与这次 all_reduce。
4.2 new_group 的典型实现流程(简化)
python
def new_group(ranks=None, backend=None, timeout=default_pg_timeout, pg_options=None):
# 1. 确定 ranks
if ranks is None:
ranks = list(range(get_world_size()))
default_pg = _get_default_group()
# 2. 如果 ranks 就是整个 world,直接返回 default_pg(优化)
if set(ranks) == set(range(default_pg.size())):
return default_pg
# 3. 使用 PrefixStore 构造该 group 的 store
pg_store = PrefixStore(str(_world.next_group_id), _world.default_store)
# 4. 选 backend(没指定就继承默认 backend)
backend = backend or _world.backend
pg_constructor = _get_backend(backend)
# 5. 创建 ProcessGroup(注意 group 内的 rank 会重新编号 0..group_world_size-1)
pg = pg_constructor(pg_store, rank_in_group, group_world_size, timeout=timeout, pg_options=pg_options)
# 6. 在 _World 里登记
_world.pg_map[pg] = pg_store
_world.pg_group_ranks[pg] = ranks
_world.next_group_id += 1
return pg
4.3 PrefixStore 的意义
- 多个 group 共享一个底层 Store;
- 用前缀区分不同 group 的 key 空间:
text
group0 的 key: "0/allreduce/..."
group1 的 key: "1/allreduce/..."
这就保证了:
各个 group 在控制面上互不干扰。
4.4 new_group 对复杂并行方式的支持
例如 world_size=8 时:
text
world group: [0,1,2,3,4,5,6,7]
TP group:
g0: [0,1]
g1: [2,3]
g2: [4,5]
g3: [6,7]
Expert group:
e0: [0,4]
e1: [1,5]
e2: [2,6]
e3: [3,7]
每一个都是一个独立的 ProcessGroup,通过 new_group() 与 PrefixStore 创建和隔离。
第 5 章:_World ------ PyTorch 分布式的全局"大脑"
_World 是 distributed_c10d.py 中的一个内部类,用于管理整个进程中的所有分布式状态。
5.1 _World 内部成员(典型)
python
class _World:
def __init__(self):
self.default_pg = None
self.default_store = None
self.backend = None
self.pg_map = {}
self.pg_group_ranks = {}
self.next_group_id = 0
...
5.2 default_pg(默认 ProcessGroup)
default_pg是全局默认 group:- 包含所有 rank:[0...world_size-1];
- 在
init_process_group中由第一个 ProcessGroup 设置;
- 大多数 API(如
dist.all_reduce不传 group)都会使用它。
5.3 default_store(默认 Store)
default_store是默认的 Store(TCPStore/FileStore),在 init 时设置;- 所有子 group 的 Store 都是
PrefixStore(prefix, default_store)的形式; - 这样既共享底层存储,又通过前缀隔离 group。
5.4 backend(默认后端类型)
backend记录默认 ProcessGroup 的后端类型(如 NCCL/Gloo);new_group未传 backend 时,会自动继承该 backend。
5.5 pg_map 和 pg_group_ranks
pg_map[pg] = pg_store:- 记录每个 ProcessGroup 对应的 Store(通常是 PrefixStore)。
pg_group_ranks[pg] = ranks_list:- 记录这个 group 里包含哪些 rank。
这些信息用于:
- 后续 debug;
- group 销毁或清理资源;
- 上层框架需要查询某个 group 的成员。
5.6 next_group_id
- 每次
new_group时递增:- 用于给 PrefixStore 生成唯一前缀;
- 确保不同 group 的 key 不会冲突。
5.7 生命周期管理:destroy_process_group
典型逻辑(简化):
- 销毁 default_pg 或某个 pg:
- 调用 C++ 层
ProcessGroup的析构,释放 NCCL/Gloo 资源; - 从
_world.pg_map/_world.pg_group_ranks中删除; - 若是 default_pg,则清空
default_pg和default_store。
- 调用 C++ 层
如果忘了销毁,PyTorch 会打印类似:
text
Warning: ProcessGroupNCCL was not destroyed properly
第 6 章:"包装 + 路由"中的 Routing 究竟是什么
之前说过一句话:
distributed_c10d.py只是"包装 + 路由"。
这里的 路由(routing) 是指:
决定"这次 collective 调用应该交给哪个 ProcessGroup / 哪个 backend 的 C++ 实现去处理"。
6.1 路由的三个维度
-
选择 group
- 默认:
_get_default_group(); - 指定:使用
new_group()返回的子组;
- 默认:
-
通过 group 找到 backend
- 每个 group(ProcessGroup 实例)内部知道自己属于哪个 backend 类型(NCCL/Gloo/MPI);
-
调用正确的 C++ 方法
- 比如
group.allreduce()对应的是:ProcessGroupNCCL::allreduce或ProcessGroupGloo::allreduce或ProcessGroupMPI::allreduce或
- 比如
6.2 用 all_reduce 举一个完整路由过程
python
dist.all_reduce(tensor)
路由过程:
-
Python 层:
- 如果没指定 group → 从
_World中取default_pg;
- 如果没指定 group → 从
-
default_pg是某种具体 ProcessGroup 实现:ProcessGroupNCCL或Gloo或MPI;
-
Python 调:
pythonwork = default_pg.allreduce([tensor], op) -
实际执行的是对应 C++ 子类的重写函数。
6.3 总结
- 包装 :提供友好的 Python API(参数、默认值、
async_op等)。 - 路由 :
- 选对 ProcessGroup(group);
- 选对 backend;
- 调对 C++ 实现方法。
核心:Python 不做任何通信,只负责"选对对象 + 调对函数"。
第 7 章:Python → C++ → Backend 的完整调用链
以 all_reduce 为例。
7.1 顶层: Python 代码
python
import torch.distributed as dist
dist.all_reduce(tensor)
7.2 Python 层:distributed_c10d.py
python
def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
if group is None:
group = _get_default_group()
work = group.allreduce([tensor], op=op)
if async_op:
return work
else:
work.wait()
return work
_get_default_group()从_World拿到default_pg。group.allreduce是 C++ 层方法。
7.3 C++ 层:ProcessGroup 抽象类
在 C++ 中有抽象类:
cpp
class ProcessGroup {
public:
struct Work {
virtual bool isCompleted() = 0;
virtual bool wait() = 0;
virtual ~Work() {}
};
virtual std::shared_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) = 0;
virtual std::shared_ptr<Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts) = 0;
virtual std::shared_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts) = 0;
virtual std::shared_ptr<Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) = 0;
virtual std::shared_ptr<Work> barrier(
const BarrierOptions& opts) = 0;
virtual ~ProcessGroup() {}
};
ProcessGroup是虚类;ProcessGroupNCCL,ProcessGroupGloo,ProcessGroupMPI继承并实现上述接口。
7.4 C++ 层:具体后端实现(以 NCCL 为例)
cpp
std::shared_ptr<Work> ProcessGroupNCCL::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
// 1. 检查参数
// 2. 找到/创建 NCCL communicator
// 3. 为每个 tensor 选择 CUDA stream
// 4. 调用 ncclAllReduce(...)
// 5. 创建 WorkNCCL 对象(内部有 CUDA event)
// 6. 返回 WorkNCCL
}
7.5 底层库:NCCL / Gloo / MPI
例如 NCCL:
ncclAllReduce(...)在 GPU 上启动 ring/tree 算法的 kernel;- 使用 NVLink / PCIe / IB 进行 GPU-GPU 通信;
- 用 CUDA event 通知任务完成。
7.6 Work 的回流
- C++ 返回
std::shared_ptr<Work>(比如WorkNCCL)给 Python; - Python 封装成 Python 对象;
work.wait()调到 C++ 的Work::wait()(例如等待 CUDA event)。
7.7 总结调用链
text
你的代码:
dist.all_reduce(t)
↓ Python 包装 + 路由
distributed_c10d.py:
group = _get_default_group()
work = group.allreduce([t], op)
work.wait()
↓ C++ 调度层
ProcessGroup::allreduce (虚函数)
↓ C++ 后端实现
ProcessGroupNCCL::allreduce
→ ncclAllReduce(...)
→ 创建 WorkNCCL
↓ 底层库 + 硬件
NCCL + CUDA + NVLink/PCIe/IB
Python 层的角色:全程只是路由和 API 封装,没有一行真正的通信逻辑。
第 8 章:ProcessGroup 抽象类与 Work 接口(自定义后端必须实现什么)
当你要写一个自定义后端(比如 Fake Backend)时,必须遵守 ProcessGroup 抽象类的接口。
8.1 ProcessGroup 抽象类(C++)
见上一章 7.3 中的定义,这里再强调几点:
- 每一个 collective 接口(allreduce/broadcast/allgather/reduce_scatter/barrier)都必须返回
std::shared_ptr<Work>; Work对象提供isCompleted()和wait()两个接口;- 上层 Python 假设这些接口语义正确(
wait()返回时代表该 collective 已经"完成")。
8.2 自定义后端需要实现collective函数
最通用的是:
allreducebroadcastallgatherreduce_scatterbarrier
如果你的训练脚本没有用到 point-to-point(send/recv),可以先不实现 send/recv。
8.3 Work 接口
每个 collective 调用都必须返回一个 Work 对象,支持:
cpp
class Work {
public:
virtual bool isCompleted() = 0; // 是否完成
virtual bool wait() = 0; // 阻塞等待完成
virtual ~Work() {}
};
8.4 ProcessGroup 自身也要保存基本信息
rank_:当前进程在该 group 内的 rank;size_:该 group 的 world_size;store_:关联的 Store(可能是 PrefixStore)。
这些通常在构造函数里被设置。
8.5 Backend 的最小实现需求
要让 PyTorch 正常工作,你至少要做到:
-
定义
ProcessGroupfake类,继承ProcessGroup; -
实现上面提到的关键 collective 函数(即便只是 no-op);
-
定义
Work类,实现wait()/isCompleted(); -
提供一个 C++ 构造函数
new_process_group(...); -
在 Python 中注册:
python_register_backend("fake", _default_pg_fake_constructor)并在构造器中调用 C++ 的
new_process_group。 -
确保
new_group也能创建 PG,并更新_World中的pg_map/pg_group_ranks。
总结
- Backend 注册系统 :后端是插件,统一挂到
_backend字典。 - init_process_group :负责 rendezvous、创建 Store、构造默认
ProcessGroup、更新_World。 - Collective 包装函数:Python 只做包装 + 路由,真实通信在 C++ backend。
- new_group + PrefixStore:让你可以在一个进程中管理多个通信域(world/TP/PP/MoE)。
- _World :是所有
ProcessGroup/ Store / backend 的管理类。 - Routing:决定"这次调用用哪个 group / backend / C++ 实现",Python 不做通信。
- Python → C++ → Backend 链路 :你可以精准地从
dist.all_reduce追踪到 NCCL kernel。 - ProcessGroup 抽象类 & Work:定义了所有后端必须实现的统一接口。