Pytorch分布式训练,其他GPU进程占用GPU0的原因

问题

最近跑师兄21年的论文代码,代码里使用了Pytorch分布式训练,在单机8卡的情况下,运行代码,出现如下问题。

也就是说GPU(1..7)上的进程占用了GPU0,这导致GPU0占的显存太多,以至于我的batchsize不能和原论文保持一致。

解决方法

我一点一点进行debug。

首先,在数据加载部分,由于没有将local_rankworld_size传入get_cifar_iter函数,导致后续使用DALI创建pipeline时使用了默认的local_rank=0,因此会在GPU0上多出该GPU下的进程

其次,在使用torch.load加载模型权重时,没有设置map_location,于是会默认加载到GPU0上,下图我选择将模型权重加载到cpu。虽然,这会使训练速度变慢,但为了和论文的batchsize保持一致也不得不这样做了。-.-

参考文献

  1. nn.parallel.DistributedDataParallel多卡训练,第一张卡会多出进程?
相关推荐
独行soc1 小时前
2025年渗透测试面试题总结-275(题目+回答)
网络·python·安全·web安全·网络安全·渗透测试·安全狮
番石榴AI3 小时前
java版的ocr推荐引擎——JiaJiaOCR 2.0重磅升级!纯Java CPU推理,新增手写OCR与表格识别
java·python·ocr
时光轻浅,半夏挽歌3 小时前
python不同格式文件的读写方式(json等)
python·json
测试人社区-千羽4 小时前
边缘计算场景下的智能测试挑战
人工智能·python·安全·开源·智能合约·边缘计算·分布式账本
抽象带篮子4 小时前
Pytorch Lightning 框架运行顺序
人工智能·pytorch·python
哇哈哈&4 小时前
安装wxWidgets3.2.0(编译高版本erlang的时候用,不如用rpm包),而且还需要高版本的gcc++19以上,已基本舍弃
linux·数据库·python
luod5 小时前
pymysql执行DDL语句
python
song5015 小时前
鸿蒙 Flutter 图像识别进阶:物体分类与花卉识别(含离线模型)
人工智能·分布式·python·flutter·3d·华为·分类
Mqh1807625 小时前
day 35 文件的拆分和使用
python
虚假程序设计6 小时前
pythonnet 调用C接口
c语言·python