PyTorch DDP多GPU训练实践问题总结

1 背景介绍

1.1 PyTorch

PyTorch 是目前全球最受欢迎的开源深度学习框架之一,由 Meta(原 Facebook)的人工智能研究团队(FAIR)于 2016 年推出。PyTorch 最显著的特点是采用动态计算图(Dynamic Computation Graph),凭借灵活的设计和直观的 Python 风格语法,已经成为学术研究和工业界开发的主流选择,是 NeurIPS 等顶级会议论文复现的首选。

PyTorch的核心组件:

1)torch.Tensor/torch.utils.data:张量(基本数据载体)和数据集(加载和预处理)。

2)torch.nn:神经网络。包含各种层(如线性层/卷积层)、损失函数和激活函数等构建模型的基础组件。本文介绍的 用于多GPU分布式训练的DDP模块即是torch.nnD的子模块之一。

3)torch.autograd/torch.optim:自动计算梯度和优化器。

1.2 DDP

DistributedDataParallel (DDP) 是PyTorch AI训练框架**torch.nn组件** 的一个子模块(torch.nn.parallel.DistributedDataParallel ,用于实现分布式数据并行训练,它通过多进程架构和高效的 AllReduce 通信算法,解决了传统单进程方案的瓶颈,能够在单机多卡甚至多机集群上实现接近线性的加速比。

DDP 的核心设计理念可以概括为**"分而治之,高效同步"**。

1)多进程架构 :与 DataParallel(单进程多线程)不同,DDP 为每个 GPU 启动一个独立的 Python 进程。这意味着每个进程拥有独立的解释器,完美避开了 Python 的全局解释器锁(GIL)限制。

2)去中心化通信:DDP 所有进程地位平等,通过高效的通信后端(如 NCCL)进行协作。

3)分桶机制 :DDP 会将模型参数分桶(Bucket),当一个桶内的梯度计算完成后,立即启动 AllReduce 通信。这种计算与通信重叠的策略,进一步隐藏了通信延迟。

4)Batch Normalization (BN) 同步 :DDP 支持 SyncBatchNorm,它能跨 GPU 同步统计信息(均值和方差),保证 BN 层的计算效果等同于在大 batch size 下训练。

1.3 本文目的

笔者之前用于AI去噪训练的代码框架只能支持单GPU模式的训练,最近花了点时间改造为基于DDP支持分布式多GPU模式,调测过程中碰到过一些问题,这里作为总结输出。

2 实践内容

训练代码中新增DDP分布式逻辑的过程相对简单,这里简单罗列相关过程。

2.1 环境初始化

1. 初始化分布式环境

def setup(rank, world_size):

os.environ["MASTER_ADDR"] = "localhost"

os.environ["MASTER_PORT"] = "12355"

初始化进程组

dist.init_process_group("nccl", rank=rank, world_size=world_size)

2. 清理分布式环境

def cleanup():

dist.destroy_process_group()

2.2 数据加载

创建数据集和数据加载器

train_dataset=AudioDataset(... ...)

#DistributedSampler确保每个GPU获得不同的数据子集,丢弃最后不足一个 batch

sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True,drop_last=True)

train_data_loader= DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=sampler,pin_memory=True)

此外,训练循环中调用 sampler.set_epoch(epoch)确保每个epoch有不同的数据shuffle

2.3 执行流程

1)主进程

def main():

parser = get_parser()

args = parser.parse_args()

hps = utils.get_hparams(init=True, exp=args.exp_dir, cfg=args.config)

#检测可用GPU数量,为每个GPU启动一个独立的进程

world_size = torch.cuda.device_count()

if world_size > 1:

mp.spawn(run, args=(args, hps), nprocs=world_size, join=True)

else:

run(rank=0, args=args, hps=hps)

2)子进程:

def run(rank, args,hps):

... ... ... ...

world_size = args.world_size

#初始化分布式环境

if world_size > 1:

setup_dist(rank, world_size,args.master_port)

创建模型,省略

... ... ... ...

if world_size > 1:

logging.info(f"Using DDP at rank {rank}")

model = DDP(model, device_ids=[rank], find_unused_parameters=True)

定义损失函数和优化器,省略

创建数据集和数据加载器 ,省略

训练循环

for epoch in range(...):

sampler.set_epoch(epoch) # 确保每个epoch的shuffle不同

#训练循环中处理数据并更新模型 ,省略

... ... ... ...

#训练完后清理分布式环境

if world_size > 1:

torch.distributed.barrier()

cleanup_dist()

3 问题总结

3.1 DataLoader 中 sampler 和 shuffle=True互斥

1)问题现象:初次调测中出现如下错误告警

File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 215, in join raise ProcessRaisedException(msg, error_index, failed_process.pid) torch.multiprocessing.spawn.ProcessRaisedException: -- Process 0 terminated with the following error: Traceback (most recent call last): File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 90, in _wrap fn(i, *args) File "/home/tigerp/gtcrn-mo/train.py", line 286, in run train_data_loader= DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=sampler,shuffle=True,pin_memory=True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 359, in init raise ValueError("sampler option is mutually exclusive with shuffle") ValueError: sampler option is mutually exclusive with shuffle

2)原因分析:在使用 DataLoader 时同时指定了 sampler 和 shuffle=True,这两个选项是互斥的,因为 DistributedSampler 本身已经包含了数据打乱的功能。

3)解决方案:将上述代码"train_data_loader= DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=sampler,shuffle=True,pin_memory=True)"中的shuffle参数设置为False即可。

3.2 训练中进程异常"terminated with signal SIGKILL"

1)问题现象:分布式训练到第二轮时总是出现如下进程异常信息

Traceback (most recent call last):

File "<string>", line 1, in <module>

File "/usr/local/anaconda3/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main

exitcode = _main(fd, parent_sentinel)

^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/anaconda3/lib/python3.12/multiprocessing/spawn.py", line 132, in _main

self = reduction.pickle.load(from_parent)

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

_pickle.UnpicklingError: pickle data was truncated

W1210 11:00:22.364000 3629 site-packages/torch/multiprocessing/spawn.py:169] Terminating process 3704 via signal SIGTERM

Traceback (most recent call last):

File "/home/tigerp/gtcrn-mo/train.py", line 402, in <module>

main()

File "/home/tigerp/gtcrn-mo/train.py", line 394, in main

mp.spawn(run, args=(args, hps), nprocs=world_size, join=True)

File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 340, in spawn

return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 296, in start_processes

while not context.join():

^^^^^^^^^^^^^^

File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 196, in join

raise ProcessExitedException(
torch.multiprocessing.spawn.ProcessExitedException: process 1 terminated with signal SIGKILL

(base) root@ubun:gtcrn-mo# /usr/local/anaconda3/lib/python3.12/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 22 leaked semaphore objects to clean up at shutdown

warnings.warn("resource_tracker: There appear to be %d "

2)原因分析:刚开始看到"pickle data was truncated"的告警信息(多进程间传输大型对象时,数据包被截断),一直以为是GPU内存不足或内存泄露导致,尝试对GPU的内存使用优化调整了半天,问题却一直重新;后来尝试同样的程序使用单GPU模型则能正常训练,可见训练代码逻辑应该没有问题;再后来观察到主机CPU内存似乎在随着训练时间推移而不断减少,查看系统dmesg日志,发现实际原因为主机内存不足导致训练进程被系统oom-kill:

373070.636083\] \[ 124320\] 0 124320 72383204 14070046 13860849 27504 181693 119156736 192672 0 python \[373070.636086\] \[ 124321\] 0 124321 72371030 14107376 13876045 28224 203107 119349248 182160 0 python \[373070.636089\] \[ 136041\] 0 136041 3508343 1871825 1870673 1152 0 17018880 0 0 python \[373070.636091\] \[ 136042\] 0 136042 3348094 1711384 1710376 1008 0 15736832 0 0 python \[373070.636093\] oom-kill:constraint=CONSTRAINT_NONE,nodemask=(null),cpuset=/,mems_allowed=0,global_oom,task_memcg=/user.slice/user-0.slice/session-8.scope,task=python,pid=124321,uid=0 \[373070.636361\] Out of memory: Killed process 124321 (python) total-vm:289484120kB, anon-rss:55504180kB, file-rss:112896kB, shmem-rss:812428kB, UID:0 pgtables:116552kB oom_score_adj:0 \[373073.843173\] oom_reaper: reaped process 124321 (python), now anon-rss:0kB, file-rss:0kB, shmem-rss:27672kB

虽然单GPU训练正常,但使用两块GPU的分布式训练会同时启用两个独立进程同时训练(加载两倍的数据量),后者的内存开销是前者的两倍以上(还有跨GPU的通信开销)。

3)解决方案:通过减小单次加载的训练数据量来降低系统内存开销,问题解决。

3.3 模型训练时出现 training_loss: nan

1)问题现象:模型开始训练时training_loss值还正常,几个轮次之后就出现"training_loss: nan"

2)原因分析:"training_loss: nan" 意味着损失函数在计算过程中产生了非数值(Not a Number)错误,表明训练过程出现了数值不稳定现象。进一步分析,是因为前面3.2部分为解决分布式训练的GPU内存占用高问题而启用了半精度(FP16)混合训练机制,FP16导致loss函数的数值范围精度不足导致溢出。

for input in data_loader:

try:

with torch_autocast(dtype=torch.float16, enabled=hps_train.fp16_run):

3)解决办法:关闭FP16训练模型(将hps_train.fp16_run配置为False),问题解决。

3.4 模型加载提示"Missing key(s) in state_dict"

1)问题现象:分布式训练出来的模型,推理加载模型参数时提示如下错误:

Traceback (most recent call last): File "/home/tigerp/gtcrn-mo/blindtest.py", line 95, in <module> main(args) File "/home/tigerp/gtcrn-mo/blindtest.py", line 42, in main model.load_state_dict(torch.load(model_path, map_location=device)) File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2593, in load_state_dict raise RuntimeError( RuntimeError: Error(s) in loading state_dict for GTMOD: Missing key(s) in state_dict: "erb.erb_fc.weight", "erb.ierb_fc.weight", "encoder.en_convs.0.conv.weight", "encoder.en_convs.0.conv.bias", "encoder.en_convs.0.bn.weight", "encoder.en_convs.0.bn.bias" ...... ......

2)原因分析:问题出在加载模型状态字典时,某些键在保存的模型中缺失。这通常是因为在分布式训练过程中,模型的状态字典可能包含了一些额外的前缀(如 module.),而在单 GPU 推理时,这些前缀不存在。

3)解决办法:在加载模型状态字典时,移除 module. 前缀。

base_dir = os.getcwd()

model_path = os.path.join(base_dir, f"{exp}/ckp", pth)

state_dict = torch.load(model_path, map_location=device)

if args.world_size > 1:

去除module.前缀

new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

else:

new_state_dict = state_dict

model.load_state_dict(new_state_dict)

相关推荐
意疏3 分钟前
节点小宝4.0 正式发布:一键直达,重新定义远程控制!
人工智能
一个无名的炼丹师12 分钟前
GraphRAG深度解析:从原理到实战,重塑RAG检索增强生成的未来
人工智能·python·rag
Yan-英杰35 分钟前
BoostKit OmniAdaptor 源码深度解析
网络·人工智能·网络协议·tcp/ip·http
AI街潜水的八角40 分钟前
基于Pytorch深度学习神经网络MNIST手写数字识别系统源码(带界面和手写画板)
pytorch·深度学习·神经网络
用户8356290780511 小时前
用Python轻松管理Word页脚:批量处理与多节文档技巧
后端·python
用泥种荷花1 小时前
【LangChain学习笔记】Message
人工智能
阿里云大数据AI技术1 小时前
一套底座支撑多场景:高德地图基于 Paimon + StarRocks 轨迹服务实践
人工智能
云擎算力平台omniyq.com1 小时前
CES 2026观察:从“物理AI”愿景看行业算力基础设施演进
人工智能
进击的松鼠1 小时前
LangChain 实战 | 快速搭建 Python 开发环境
python·langchain·llm
小北方城市网1 小时前
第1课:架构设计核心认知|从0建立架构思维(架构系列入门课)
大数据·网络·数据结构·python·架构·数据库架构