Pytorch分布式训练,其他GPU进程占用GPU0的原因

问题

最近跑师兄21年的论文代码,代码里使用了Pytorch分布式训练,在单机8卡的情况下,运行代码,出现如下问题。

也就是说GPU(1..7)上的进程占用了GPU0,这导致GPU0占的显存太多,以至于我的batchsize不能和原论文保持一致。

解决方法

我一点一点进行debug。

首先,在数据加载部分,由于没有将local_rankworld_size传入get_cifar_iter函数,导致后续使用DALI创建pipeline时使用了默认的local_rank=0,因此会在GPU0上多出该GPU下的进程

其次,在使用torch.load加载模型权重时,没有设置map_location,于是会默认加载到GPU0上,下图我选择将模型权重加载到cpu。虽然,这会使训练速度变慢,但为了和论文的batchsize保持一致也不得不这样做了。-.-

参考文献

  1. nn.parallel.DistributedDataParallel多卡训练,第一张卡会多出进程?
相关推荐
七颗糖很甜1 分钟前
开源雷达NEXRAD Level 3 数据完整获取与 Python 处理教程
大数据·python·算法
SuAluvfy1 分钟前
PyTorch 基础:数据操作与数据预处理
人工智能·pytorch·python
ydmy10 分钟前
Embedding层(个人理解)
python·深度学习·embedding
qq_3300379913 分钟前
mysql在高并发下如何优化索引更新_mysql锁策略与调整
jvm·数据库·python
u01091476015 分钟前
如何排查SQL存储过程内存溢出_优化大数据量临时表使用
jvm·数据库·python
2301_7735536216 分钟前
mysql如何优化mysql在多核CPU下的性能_调整线程并发数
jvm·数据库·python
源码之家20 分钟前
计算机毕业设计:Python股票智能分析预测平台 Flask框架 数据分析 可视化 机器学习 随机森林 大数据(建议收藏)✅
python·机器学习·数据分析·django·flask·课程设计
a95114164220 分钟前
PHP如何批量处理AI请求_队列系统搭建【技巧】
jvm·数据库·python
sinat_3834373621 分钟前
如何实现SQL简单数据的映射查询_使用CASE表达式替换
jvm·数据库·python
2401_8359568121 分钟前
JavaScript 中实现基于分组的前端产品筛选功能
jvm·数据库·python