本文结构:
- 分布式训练概述
- 环境设置
- 数据并行(DDP)
- 模型并行
- 启动训练
- 性能优化建议
- 示例代码
- 参考资料和相关问题
以下是为您整理的PyTorch分布式训练教程指南:
一、PyTorch分布式训练核心概念
-
数据并行:通过分割数据集实现多GPU并行训练,主流方法包括:
DistributedDataParallel
(DDP):官方推荐的分布式训练接口DataParallel
(DP):单机多卡方案(已逐步被DDP取代)
-
模型并行:
- 流水线并行:将模型按层拆分到不同设备
- 张量并行:拆分单个运算的矩阵维度
-
通信协议:
- NCCL:NVIDIA GPU专用通信库
- Gloo:支持CPU和GPU的跨平台协议
二、DDP实战步骤
1. 环境初始化
python
import torch.distributed as dist
def setup(rank, world_size):
dist.init_process_group(
backend='nccl', # GPU推荐NCCL
init_method='env://',
rank=rank,
world_size=world_size
)
torch.cuda.set_device(rank)
2. 数据分片
通过DistributedSampler
实现数据集自动划分:
python
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True
)
dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)
(数据划分原理见引用[3]中描述的补充采样机制)
3. 模型封装
python
model = NeuralNetwork().to(rank)
model = DDP(model, device_ids=[rank])
4. 训练循环
python
for epoch in range(epochs):
sampler.set_epoch(epoch) # 保证shuffle有效性
for batch in dataloader:
outputs = model(batch)
loss = criterion(outputs)
loss.backward()
optimizer.step()
三、多节点启动方法
引用[2]展示了多节点启动命令示例,推荐使用官方启动工具:
bash
# 单机多卡启动(4 GPU)
torchrun --nproc_per_node=4 train.py
# 多节点启动(需配置MASTER_ADDR)
torchrun --nnodes=2 --nproc_per_node=4 --master_addr=192.168.1.1 train.py
四、性能优化建议
-
通信优化:
- 使用梯度累积减少通信频率
- 设置
find_unused_parameters=False
(当模型有未使用参数时需设为True)
-
计算优化:
- 增大批次尺寸 B B B可提升计算效率(引用[4]中的 T c o m p T_{comp} Tcomp优化)
- 混合精度训练
pythonscaler = GradScaler() with autocast(): outputs = model(inputs) scaler.scale(loss).backward() scaler.step(optimizer)
-
网络优化:
- 使用RDMA高速网络降低 T s y n c T_{sync} Tsync(引用[4])
- 梯度压缩技术
五、完整代码示例
python
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def main(rank, world_size):
# 初始化进程组
setup(rank, world_size)
# 准备数据
dataset = MyDataset()
sampler = DistributedSampler(dataset, world_size, rank)
dataloader = DataLoader(dataset, sampler=sampler)
# 构建模型
model = DDP(MyModel().to(rank), device_ids=[rank])
# 训练循环
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(100):
sampler.set_epoch(epoch)
for batch in dataloader:
inputs = batch.to(rank)
outputs = model(inputs)
loss = outputs.mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(main, args=(world_size,), nprocs=world_size)