PyTorch多GPU训练模型——使用单GPU或CPU进行推理的方法

文章目录

  • [1 问题描述](#1 问题描述)
  • [2 模型保存方式](#2 模型保存方式)
  • [3 单块GPU上加载模型](#3 单块GPU上加载模型)
  • [4 CPU上加载模型](#4 CPU上加载模型)
  • [5 总结](#5 总结)

1 问题描述

PyTorch提供了非常便捷的多GPU网络训练方法:DataParallelDistributedDataParallel。在涉及到一些复杂模型时,基本都是采用多个GPU并行训练并保存模型。但在推理阶段往往只采用单个GPU或者CPU运行。这时怎么将多GPU环境下保存的模型权重加载到单GPU/CPU运行环境下的模型上成了一个关键的问题。

如果忽视环境问题直接加载往往会出现两类问题:

1 出现错误:IndexError: list index out of range

出现这个错误的原因是:现有模型参数是在多GPU上获得并保存的,因此在读入时默认会保存至对应的GPU上,但是目前推理环境中只有一块GPU,所以导致那些本来在其它GPU上的参数找不到自己应该去的GPU编号,出现了一个溢出错误,本质是GPU编号溢出。

2 出现错误:Missing key(s) in state_dict:

出现这个错误的原因是:由于模型训练和推理的环境不同,导致一些参数丢失,因此报错。目前在网上的一些解决策略是忽视这些丢失的参数,例如使用命令:model.load_state_dict(torch.load('model.pth'), strict=False)

来成功导入模型。这条命令可以让程序不报错并看似成功的导入模型参数。但实际上这条命令的含义是在导入模型参数时通过设置 strict=False 来忽略丢失的参数,也就是说那些丢失参数地方的模型权重初仍为初始化随机状态,等同于没有进行训练和学习,何谈推理与验证!!!

2 模型保存方式

不论是用哪种方式进行推理,在训练的时候要保证程序保存模型的方式是这样的:

python 复制代码
torch.save(model.state_dict(), "model.pth")

3 单块GPU上加载模型

将多GPU训练的权重文件加载到单GPU上:

python 复制代码
# 1 加载模型
model = Model()
# 2 指定运行设备,这里为单块GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 3 将模型用DataParallel方法封装一次
model = torch.nn.DataParallel(model)
# 4 将模型读入到GPU设备上
model = model_E2E.to(device)
# 5 加载权重文件
model.load_state_dict(torch.load(weight_path, map_location=device))

通过上面的程序就可以实现将多块GPU上训练得到的权重文件加载到单块GPU环境下的模型中。这里有两点需要注意:

  • 在多GPU训练时,模型使用了 DataParallel
    DistributedDataParallel 方法,这两种并行化工具会修改模型的结构,将模型封装在一个新的模块中,通常名为:module。因此在权重文件中保存的模型是经过 DataParallel 封装后的结构。为了能够载入全部参数,需要通过步骤3使推理模型与原始多GPU训练模型在结构上保持一致。

  • 在步骤5加载模型参数时使用了map_location 参数。这个参数会告诉 PyTorch在加载模型时应该将张量放置在哪个设备上。设置map_location=device,那么无论模型原来是在哪个设备上训练的,现在都将放置在指定的设备device='cuda:0'上。

4 CPU上加载模型

在CPU上加载模型:

python 复制代码
from collections import OrderedDict

# 1 加载模型
model = Model()
# 2 指定设备CPU
device = "cpu"
# 3 读取权重文件
state_dict = torch.load(weight_path, map_location=device)
# 4 剥除权重文件中的module层
if next(iter(state_dict)).startswith("module."):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    state_dict = new_state_dict
# 5 加载权重文件
model.load_state_dict(state_dict)
# 6 将模型载入到CPU
model = model.to(device)

在CPU上加载模型的逻辑和GPU的差不多,核心都是因为原权重文件中的模型被封装成了module.Model,所以需要将这层外壳去掉,最后再进行读取并将模型加载到CPU上。

5 总结

在深度学习任务中训练与推理环境存在差异的情况十分常见 ,有差异的环境下实现网络权重文件的正确读取十分重要。实际操作中一定要确保正确的权重文件被读入,这是进行推理最基本的前提!最好在推理前做一些对比实验(例如:选取一部分数据,分别套用已有的程序进行训练和推理,对比二者的效果)来确保已经读入到正确的权重。

相关推荐
胡耀超2 小时前
DataOceanAI Dolphin(ffmpeg音频转化教程) 多语言(中国方言)语音识别系统部署与应用指南
python·深度学习·ffmpeg·音视频·语音识别·多模态·asr
HUIMU_2 小时前
DAY12&DAY13-新世纪DL(Deeplearning/深度学习)战士:破(改善神经网络)1
人工智能·深度学习
mit6.8243 小时前
[1Prompt1Story] 注意力机制增强 IPCA | 去噪神经网络 UNet | U型架构分步去噪
人工智能·深度学习·神经网络
Coovally AI模型快速验证4 小时前
YOLO、DarkNet和深度学习如何让自动驾驶看得清?
深度学习·算法·yolo·cnn·自动驾驶·transformer·无人机
科大饭桶4 小时前
昇腾AI自学Day2-- 深度学习基础工具与数学
人工智能·pytorch·python·深度学习·numpy
努力还债的学术吗喽5 小时前
2021 IEEE【论文精读】用GAN让音频隐写术骗过AI检测器 - 对抗深度学习的音频信息隐藏
人工智能·深度学习·生成对抗网络·密码学·音频·gan·隐写
weixin_507929916 小时前
第G7周:Semi-Supervised GAN 理论与实战
人工智能·pytorch·深度学习
AI波克布林8 小时前
发文暴论!线性注意力is all you need!
人工智能·深度学习·神经网络·机器学习·注意力机制·线性注意力
weixin_456904278 小时前
一文讲清楚Pytorch 张量、链式求导、正向传播、反向求导、计算图等基础知识
人工智能·pytorch·学习
Blossom.1189 小时前
把 AI 推理塞进「 8 位 MCU 」——0.5 KB RAM 跑通关键词唤醒的魔幻之旅
人工智能·笔记·单片机·嵌入式硬件·深度学习·机器学习·搜索引擎