1. 代码用法
python
def dist_init():
"""
初始化分布式训练环境 (Distributed Data Parallel, DDP)。
该函数负责:
1. 从环境变量中读取分布式训练的配置信息(rank, world_size 等)。
2. 设置当前进程使用的 GPU 设备。
3. 初始化 PyTorch 的分布式进程组 (process group),通常使用 NCCL 后端。
4. 设置随机种子以保证结果的可复现性,且不同进程的种子不同。
Returns:
dict: 包含分布式环境信息的字典,包括 rank, world_size, local_rank 等。
"""
# 从环境变量中获取本地 rank (当前节点上的 GPU 编号)
# 这些环境变量通常由 torchrun 或类似的启动脚本自动设置
local_rank = int(os.environ["LOCAL_RANK"])
# 获取当前节点上的进程数量 (GPU 数量)
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
# 设置当前进程使用的 CUDA 设备
# 这是非常关键的一步,确保每个进程只使用分配给它的那个 GPU
torch.cuda.set_device(local_rank)
# 初始化进程组
# "nccl" 是 NVIDIA GPU 上推荐的分布式后端,性能最好
# timeout 设置超时时间,防止初始化卡死
dist.init_process_group("nccl", timeout=timedelta(seconds=3600))
# 获取全局 rank (所有节点中所有进程的唯一 ID,从 0 开始)
rank = dist.get_rank()
# 获取全局 world size (所有节点中所有进程的总数)
world_size = dist.get_world_size()
# 计算当前节点在所有节点中的索引
node_idx = rank // local_world_size
# 计算总节点数
num_nodes = world_size // local_world_size
# 设置随机种子
# 关键点:使用 '42 + rank' 确保每个进程使用不同的种子
# 这对于数据加载器的数据增强等随机操作很重要,避免所有 GPU 训练完全相同的数据
seed = 42 + rank
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
# dist.barrier() # 这是一个同步点,用于等待所有进程都到达这里。这里被注释掉了,通常在初始化后立即同步是个好习惯。
return dict(
rank=rank,
world_size=world_size,
local_rank=local_rank,
local_world_size=local_world_size,
node_idx=node_idx,
num_nodes=num_nodes
)
2. 下面解释一下代码中相关变量的含义:
在分布式训练(Distributed Data Parallel, DDP)中,这几个概念是核心,理解它们对于配置多机多卡训练至关重要。我们可以把分布式训练想象成一个"大型工厂",里面有多个"车间"(节点 Node),每个车间里有多个"工人"(GPU/进程 Process)。
以下是这些术语的详细解释:
2.1. 核心概念
| 术语 | 英文全称 | 含义 (工厂比喻) | 实际含义 | 范围 |
|---|---|---|---|---|
| Rank | Global Rank | 工人的全局工号 | 全局所有进程的唯一 ID | [0, World Size - 1] |
| World Size | World Size | 工厂的总工人数 | 全局所有进程(GPU)的总数 | ≥1\ge 1≥1 |
| Local Rank | Local Rank | 工人在本车间的工号 | 当前节点(机器)内的进程 ID | [0, Local World Size - 1] |
| Local World Size | Local World Size | 一个车间的工人数 | 当前节点(机器)内的进程(GPU)总数 | ≥1\ge 1≥1 |
| Node Index | Node Index | 车间编号 | 当前机器在集群中的编号 | [0, Num Nodes - 1] |
| Num Nodes | Number of Nodes | 车间总数 | 集群中参与训练的机器总数 | ≥1\ge 1≥1 |
2.2. 实际场景举例
假设我们要训练一个大模型,使用了 2 台服务器(节点) ,每台服务器上有 4 张 GPU。
- 机器 A (Node 0) : IP
192.168.1.100, 有 GPU0, 1, 2, 3 - 机器 B (Node 1) : IP
192.168.1.101, 有 GPU0, 1, 2, 3
总共参与训练的 GPU 数量 = 2×4=82 \times 4 = 82×4=8。
在这种情况下,每个进程的参数值如下:
机器 A (Node 0) 上的进程
| GPU (物理卡号) | Rank (全局ID) | Local Rank (本地ID) | World Size | Local World Size | Node Index | Num Nodes |
|---|---|---|---|---|---|---|
| GPU 0 | 0 | 0 | 8 | 4 | 0 | 2 |
| GPU 1 | 1 | 1 | 8 | 4 | 0 | 2 |
| GPU 2 | 2 | 2 | 8 | 4 | 0 | 2 |
| GPU 3 | 3 | 3 | 8 | 4 | 0 | 2 |
机器 B (Node 1) 上的进程
| GPU (物理卡号) | Rank (全局ID) | Local Rank (本地ID) | World Size | Local World Size | Node Index | Num Nodes |
|---|---|---|---|---|---|---|
| GPU 0 | 4 | 0 | 8 | 4 | 1 | 2 |
| GPU 1 | 5 | 1 | 8 | 4 | 1 | 2 |
| GPU 2 | 2 | 2 | 8 | 4 | 1 | 2 |
| GPU 3 | 7 | 3 | 8 | 4 | 1 | 2 |
2.3. 详细解释与用途
1. rank (Global Rank)
- 含义 : 全局唯一的 ID,用于标识整个分布式任务中的某一个进程。通常 Rank 0 被称为 Master 或 Chief 进程。
- 用途 :
- 打印日志 : 通常只让
rank == 0的进程打印日志、保存模型 checkpoint,避免 8 个进程同时打印导致控制台混乱或文件写入冲突。 - 数据切分 :
DistributedSampler会根据 rank 决定当前进程应该读取数据集的哪一部分,确保不同 GPU 训练的数据不重复。 - 通信 : 在
dist.send(tensor, dst=rank)点对点通信时,需要指定目标 rank。
- 打印日志 : 通常只让
2. world_size
- 含义 : 整个任务中总的并行进程数。通常等于
总机器数 * 每台机器的GPU数。 - 用途 :
- 平均 Loss : 计算 loss 时,如果用了
reduce='mean',底层通信库需要知道总共有多少个进程来做除法。 - 学习率调整 : 有时会根据
Batch Size * World Size来线性缩放学习率。
- 平均 Loss : 计算 loss 时,如果用了
3. local_rank
- 含义: 当前进程在当前机器上的序号。
- 用途 :
- 指定 GPU 设备 : 这是最重要的用途 。代码中
torch.cuda.set_device(local_rank)就是告诉当前进程:"你是这台机器上的第 0 号进程,你就去用第 0 号 GPU;你是第 1 号进程,就用第 1 号 GPU"。 - 如果没有这个设置,所有进程都会默认抢占 GPU 0,导致 Out Of Memory (OOM)。
- 指定 GPU 设备 : 这是最重要的用途 。代码中
4. local_world_size
- 含义 : 当前机器上启动了多少个进程。通常等于当前机器的 GPU 数量(
torch.cuda.device_count())。 - 用途 :
- 在一些异构集群中(有的机器 8 卡,有的 4 卡),这个值可能不同。但大多数情况是每台机器配置相同。
5. node_idx
- 含义: 当前机器的 ID。
- 计算公式 :
rank // local_world_size(假设每台机器 GPU 数相同)。 - 用途: 用于区分不同的物理节点,较少直接使用。
6. num_nodes
- 含义: 总共有多少台物理机器。
- 计算公式 :
world_size // local_world_size。
2.4. 代码中的体现
回到你提供的代码片段:
python
def dist_init():
# 1. 获取本地信息,这通常由 torchrun 自动注入环境变量
local_rank = int(os.environ["LOCAL_RANK"])
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
# 2. 关键:绑定当前进程到指定的 GPU
torch.cuda.set_device(local_rank)
# 3. 初始化通信后端
dist.init_process_group("nccl", ...)
# 4. 获取全局信息
rank = dist.get_rank()
world_size = dist.get_world_size()
# 5. 推算节点信息(假设每台机器 GPU 数一致)
node_idx = rank // local_world_size
num_nodes = world_size // local_world_size
# ...
这段代码是一个非常标准的模板,它不仅适用于单机多卡(单机 8 卡时,Node=1, Rank=0~7, Local Rank=0~7),也完美适配多机多卡场景。
3. 使用方法
比如:torchrun --nproc_per_node=1 train.py
torchrun 是 PyTorch 提供的用于启动分布式训练(Distributed Data Parallel, DDP)的命令行工具。
--nproc_per_node=1 :
- 含义 :指定 每个节点(机器)上启动的进程数量 。