PyTorch DDP多GPU训练实践问题总结

1 背景介绍

1.1 PyTorch

PyTorch 是目前全球最受欢迎的开源深度学习框架之一,由 Meta(原 Facebook)的人工智能研究团队(FAIR)于 2016 年推出。PyTorch 最显著的特点是采用动态计算图(Dynamic Computation Graph),凭借灵活的设计和直观的 Python 风格语法,已经成为学术研究和工业界开发的主流选择,是 NeurIPS 等顶级会议论文复现的首选。

PyTorch的核心组件:

1)torch.Tensor/torch.utils.data:张量(基本数据载体)和数据集(加载和预处理)。

2)torch.nn:神经网络。包含各种层(如线性层/卷积层)、损失函数和激活函数等构建模型的基础组件。本文介绍的 用于多GPU分布式训练的DDP模块即是torch.nnD的子模块之一。

3)torch.autograd/torch.optim:自动计算梯度和优化器。

1.2 DDP

DistributedDataParallel (DDP) 是PyTorch AI训练框架**torch.nn组件** 的一个子模块(torch.nn.parallel.DistributedDataParallel ,用于实现分布式数据并行训练,它通过多进程架构和高效的 AllReduce 通信算法,解决了传统单进程方案的瓶颈,能够在单机多卡甚至多机集群上实现接近线性的加速比。

DDP 的核心设计理念可以概括为**"分而治之,高效同步"**。

1)多进程架构 :与 DataParallel(单进程多线程)不同,DDP 为每个 GPU 启动一个独立的 Python 进程。这意味着每个进程拥有独立的解释器,完美避开了 Python 的全局解释器锁(GIL)限制。

2)去中心化通信:DDP 所有进程地位平等,通过高效的通信后端(如 NCCL)进行协作。

3)分桶机制 :DDP 会将模型参数分桶(Bucket),当一个桶内的梯度计算完成后,立即启动 AllReduce 通信。这种计算与通信重叠的策略,进一步隐藏了通信延迟。

4)Batch Normalization (BN) 同步 :DDP 支持 SyncBatchNorm,它能跨 GPU 同步统计信息(均值和方差),保证 BN 层的计算效果等同于在大 batch size 下训练。

1.3 本文目的

笔者之前用于AI去噪训练的代码框架只能支持单GPU模式的训练,最近花了点时间改造为基于DDP支持分布式多GPU模式,调测过程中碰到过一些问题,这里作为总结输出。

2 实践内容

训练代码中新增DDP分布式逻辑的过程相对简单,这里简单罗列相关过程。

2.1 环境初始化

1. 初始化分布式环境

def setup(rank, world_size):

os.environ["MASTER_ADDR"] = "localhost"

os.environ["MASTER_PORT"] = "12355"

初始化进程组

dist.init_process_group("nccl", rank=rank, world_size=world_size)

2. 清理分布式环境

def cleanup():

dist.destroy_process_group()

2.2 数据加载

创建数据集和数据加载器

train_dataset=AudioDataset(... ...)

#DistributedSampler确保每个GPU获得不同的数据子集,丢弃最后不足一个 batch

sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True,drop_last=True)

train_data_loader= DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=sampler,pin_memory=True)

此外,训练循环中调用 sampler.set_epoch(epoch)确保每个epoch有不同的数据shuffle

2.3 执行流程

1)主进程

def main():

parser = get_parser()

args = parser.parse_args()

hps = utils.get_hparams(init=True, exp=args.exp_dir, cfg=args.config)

#检测可用GPU数量,为每个GPU启动一个独立的进程

world_size = torch.cuda.device_count()

if world_size > 1:

mp.spawn(run, args=(args, hps), nprocs=world_size, join=True)

else:

run(rank=0, args=args, hps=hps)

2)子进程:

def run(rank, args,hps):

... ... ... ...

world_size = args.world_size

#初始化分布式环境

if world_size > 1:

setup_dist(rank, world_size,args.master_port)

创建模型,省略

... ... ... ...

if world_size > 1:

logging.info(f"Using DDP at rank {rank}")

model = DDP(model, device_ids=[rank], find_unused_parameters=True)

定义损失函数和优化器,省略

创建数据集和数据加载器 ,省略

训练循环

for epoch in range(...):

sampler.set_epoch(epoch) # 确保每个epoch的shuffle不同

#训练循环中处理数据并更新模型 ,省略

... ... ... ...

#训练完后清理分布式环境

if world_size > 1:

torch.distributed.barrier()

cleanup_dist()

3 问题总结

3.1 DataLoader 中 sampler 和 shuffle=True互斥

1)问题现象:初次调测中出现如下错误告警

File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 215, in join raise ProcessRaisedException(msg, error_index, failed_process.pid) torch.multiprocessing.spawn.ProcessRaisedException: -- Process 0 terminated with the following error: Traceback (most recent call last): File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 90, in _wrap fn(i, *args) File "/home/tigerp/gtcrn-mo/train.py", line 286, in run train_data_loader= DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=sampler,shuffle=True,pin_memory=True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 359, in init raise ValueError("sampler option is mutually exclusive with shuffle") ValueError: sampler option is mutually exclusive with shuffle

2)原因分析:在使用 DataLoader 时同时指定了 sampler 和 shuffle=True,这两个选项是互斥的,因为 DistributedSampler 本身已经包含了数据打乱的功能。

3)解决方案:将上述代码"train_data_loader= DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=sampler,shuffle=True,pin_memory=True)"中的shuffle参数设置为False即可。

3.2 训练中进程异常"terminated with signal SIGKILL"

1)问题现象:分布式训练到第二轮时总是出现如下进程异常信息

Traceback (most recent call last):

File "<string>", line 1, in <module>

File "/usr/local/anaconda3/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main

exitcode = _main(fd, parent_sentinel)

^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/anaconda3/lib/python3.12/multiprocessing/spawn.py", line 132, in _main

self = reduction.pickle.load(from_parent)

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

_pickle.UnpicklingError: pickle data was truncated

W1210 11:00:22.364000 3629 site-packages/torch/multiprocessing/spawn.py:169] Terminating process 3704 via signal SIGTERM

Traceback (most recent call last):

File "/home/tigerp/gtcrn-mo/train.py", line 402, in <module>

main()

File "/home/tigerp/gtcrn-mo/train.py", line 394, in main

mp.spawn(run, args=(args, hps), nprocs=world_size, join=True)

File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 340, in spawn

return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 296, in start_processes

while not context.join():

^^^^^^^^^^^^^^

File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 196, in join

raise ProcessExitedException(
torch.multiprocessing.spawn.ProcessExitedException: process 1 terminated with signal SIGKILL

(base) root@ubun:gtcrn-mo# /usr/local/anaconda3/lib/python3.12/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 22 leaked semaphore objects to clean up at shutdown

warnings.warn("resource_tracker: There appear to be %d "

2)原因分析:刚开始看到"pickle data was truncated"的告警信息(多进程间传输大型对象时,数据包被截断),一直以为是GPU内存不足或内存泄露导致,尝试对GPU的内存使用优化调整了半天,问题却一直重新;后来尝试同样的程序使用单GPU模型则能正常训练,可见训练代码逻辑应该没有问题;再后来观察到主机CPU内存似乎在随着训练时间推移而不断减少,查看系统dmesg日志,发现实际原因为主机内存不足导致训练进程被系统oom-kill:

373070.636083\] \[ 124320\] 0 124320 72383204 14070046 13860849 27504 181693 119156736 192672 0 python \[373070.636086\] \[ 124321\] 0 124321 72371030 14107376 13876045 28224 203107 119349248 182160 0 python \[373070.636089\] \[ 136041\] 0 136041 3508343 1871825 1870673 1152 0 17018880 0 0 python \[373070.636091\] \[ 136042\] 0 136042 3348094 1711384 1710376 1008 0 15736832 0 0 python \[373070.636093\] oom-kill:constraint=CONSTRAINT_NONE,nodemask=(null),cpuset=/,mems_allowed=0,global_oom,task_memcg=/user.slice/user-0.slice/session-8.scope,task=python,pid=124321,uid=0 \[373070.636361\] Out of memory: Killed process 124321 (python) total-vm:289484120kB, anon-rss:55504180kB, file-rss:112896kB, shmem-rss:812428kB, UID:0 pgtables:116552kB oom_score_adj:0 \[373073.843173\] oom_reaper: reaped process 124321 (python), now anon-rss:0kB, file-rss:0kB, shmem-rss:27672kB

虽然单GPU训练正常,但使用两块GPU的分布式训练会同时启用两个独立进程同时训练(加载两倍的数据量),后者的内存开销是前者的两倍以上(还有跨GPU的通信开销)。

3)解决方案:通过减小单次加载的训练数据量来降低系统内存开销,问题解决。

3.3 模型训练时出现 training_loss: nan

1)问题现象:模型开始训练时training_loss值还正常,几个轮次之后就出现"training_loss: nan"

2)原因分析:"training_loss: nan" 意味着损失函数在计算过程中产生了非数值(Not a Number)错误,表明训练过程出现了数值不稳定现象。进一步分析,是因为前面3.2部分为解决分布式训练的GPU内存占用高问题而启用了半精度(FP16)混合训练机制,FP16导致loss函数的数值范围精度不足导致溢出。

for input in data_loader:

try:

with torch_autocast(dtype=torch.float16, enabled=hps_train.fp16_run):

3)解决办法:关闭FP16训练模型(将hps_train.fp16_run配置为False),问题解决。

3.4 模型加载提示"Missing key(s) in state_dict"

1)问题现象:分布式训练出来的模型,推理加载模型参数时提示如下错误:

Traceback (most recent call last): File "/home/tigerp/gtcrn-mo/blindtest.py", line 95, in <module> main(args) File "/home/tigerp/gtcrn-mo/blindtest.py", line 42, in main model.load_state_dict(torch.load(model_path, map_location=device)) File "/usr/local/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2593, in load_state_dict raise RuntimeError( RuntimeError: Error(s) in loading state_dict for GTMOD: Missing key(s) in state_dict: "erb.erb_fc.weight", "erb.ierb_fc.weight", "encoder.en_convs.0.conv.weight", "encoder.en_convs.0.conv.bias", "encoder.en_convs.0.bn.weight", "encoder.en_convs.0.bn.bias" ...... ......

2)原因分析:问题出在加载模型状态字典时,某些键在保存的模型中缺失。这通常是因为在分布式训练过程中,模型的状态字典可能包含了一些额外的前缀(如 module.),而在单 GPU 推理时,这些前缀不存在。

3)解决办法:在加载模型状态字典时,移除 module. 前缀。

base_dir = os.getcwd()

model_path = os.path.join(base_dir, f"{exp}/ckp", pth)

state_dict = torch.load(model_path, map_location=device)

if args.world_size > 1:

去除module.前缀

new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

else:

new_state_dict = state_dict

model.load_state_dict(new_state_dict)

相关推荐
破烂pan2 小时前
2025年下半年AI应用架构演进:从RAG到Agent再到MCP的生态跃迁
人工智能·架构·ai应用
9527(●—●)2 小时前
windows系统python开发pip命令使用(菜鸟学习)
开发语言·windows·python·学习·pip
数字会议深科技2 小时前
深科技 | 高端会议室效率升级指南:无纸化会议系统的演进与价值
大数据·人工智能·会议系统·无纸化·会议系统品牌·综合型系统集成商·会议室
曦云沐2 小时前
轻量却强大:Fun-ASR-Nano-2512 语音识别模型上手指南
人工智能·语音识别·asr·fun-asr-nano
森叶2 小时前
手搓一个 Windows 注册表清理器:从开发到 EXE 打包全流程
windows·python
少年白char2 小时前
【AI漫剧】开源自动化AI漫剧生成工具 - 从文字到影像:AI故事视频创作的全新可能
运维·人工智能·自动化
容智信息2 小时前
容智Report Agent智能体驱动财务自动化,从核算迈向价值创造
大数据·运维·人工智能·自然语言处理·自动化·政务
Allen正心正念20253 小时前
AWS专家Greg Coquillo提出的8层Agentic AI架构分析
人工智能·架构·aws
JoannaJuanCV3 小时前
自动驾驶—CARLA仿真(25)synchronous_mode demo
人工智能·机器学习·自动驾驶·carla