【pytorch】多GPU同时训练模型

文章目录

  • [1. 基本原理](#1. 基本原理)
  • [2. Pytorch进行单机多卡训练步骤](#2. Pytorch进行单机多卡训练步骤)
    • [1. 指定GPU](#1. 指定GPU)
    • [2. 更改模型训练方式](#2. 更改模型训练方式)
    • [3. 更改权重保存方式](#3. 更改权重保存方式)

摘要:多GPU同时训练,能够解决单张GPU显存不足问题,同时加快模型训练。

1. 基本原理

单机多卡训练教程------DP模式

(1)将模型复制到各个GPU中,并将一个batch的数据 划分成mini_batch(平均分配) 并分发给每个GPU;
注意:这里的batch_size要大于device数。

(2)各个GPU独自完成mini_batch的前向传播,并把获得的output传递给GPU_0(主GPU) ;

(3) GPU_0整合各个GPU传递过来的output,并计算loss。此时GPU_0可以对这些loss进行一些聚合操作;

(4) GPU_0归并loss之后,并进行后向传播以及梯度下降从而完成模型参数的更新(此时只有GPU_0上的模型参数得到了更新),GPU_0将更新好的模型参数又传递给其余GPU;

以上就是DP模式下多卡GPU进行训练的方式。其实可以看到GPU_0不仅承担了前向传播的任务,还承担了收集loss,并进行梯度下降。因此在使用DP模式进行单机多卡GPU训练的时候会有一张卡的显存利用会比其他卡更多,那就是你设置的GPU_0。

2. Pytorch进行单机多卡训练步骤

只需要在你的代码中改三个地方就可实现

1. 指定GPU

如上所示,在导入各种库下面使用os.environ["CUDA_VISIBLE_DEVICES"]来指定可识别的GPU,该语句在程序开始前使用。

代码如下:

python 复制代码
import torch.nn as nn
import os
os.environ["CUDA_VISIBLE_DEVICES"]= 2,3,1'#指定该程序可以识别的物理GPU编号,这里的你主机上的2号GPU就是训练程序中的主GPUO,这里最好---定要自己指定你自己可以用的gpu号。

2. 更改模型训练方式

平常的模型训练方式只需要model.cuda()语句即可,在单机多卡训练中,只需要在该语句下面添加一行nn.DataParallel语句即可。

代码如下

python 复制代码
model.cuda()
model = nn.DataParallel(model,devise =[0,1,2])#在执行该语句之前最好加上model.cuda(),保证你的模型存在GPU上即可

3. 更改权重保存方式

对于数据,我们只需要按照平常的方式使用.cuda()放置在GPU上即可,内部batch的拆分已经被封装在了DataPanallel模块中。要注意的是,由于我们的model被nn.DataPanallel()包裹住了,所以如果想要储存模型的参数,需要使用:model.module.state_dict()的方式才能取出(不能直接是model.state_dict()

代码如下:

python 复制代码
'''
使用单机多卡训练的模型权重保存方式
'''
torch.save(model.module.state_dict(),f'best.pth')  

作为参考,将平常的权重保存方式也写上:

python 复制代码
'''
平常的权重保存方式
'''
torch.save(model.state_dict(),f'best.pth')  
相关推荐
CoovallyAIHub10 分钟前
一夜之间,大模型处理长文本的难题被DeepSeek新模型彻底颠覆!
深度学习·算法·计算机视觉
智驱力人工智能12 分钟前
疲劳驾驶检测提升驾驶安全 疲劳行为检测 驾驶员疲劳检测系统 疲劳检测系统价格
人工智能·安全·目标检测·目标跟踪·视觉检测
墨香幽梦客16 分钟前
塑胶制造生产ERP:有哪些系统值得关注
大数据·人工智能·制造
说私域16 分钟前
开源AI大模型、AI智能名片与S2B2C商城系统:个体IP打造与价值赋能的新范式
人工智能·tcp/ip·开源
北京耐用通信20 分钟前
打破协议壁垒:耐达讯自动化Modbus转Profinet网关实现光伏逆变器全数据采集
运维·人工智能·物联网·网络安全·自动化·信息与通信
信息快讯32 分钟前
【机器学习在智能水泥基复合材料中的应用与实践】
人工智能·机器学习·材料工程·复合材料·水泥基复合材料
AI technophile1 小时前
OpenCV计算机视觉实战(27)——深度学习与卷积神经网络
深度学习·opencv·计算机视觉
技术闲聊DD1 小时前
深度学习(10)-PyTorch 卷积神经网络
pytorch·深度学习·cnn
Juchecar1 小时前
如何理解“AI token 大宗商品化”?
人工智能
文火冰糖的硅基工坊1 小时前
[人工智能-大模型-29]:大模型应用层技术栈 - 第二层:Prompt 编排层(Prompt Orchestration)
人工智能·大模型·prompt·copilot