pytorch使用DataParallel并行化保存和加载模型(单卡、多卡各种情况讲解)

话不多说,直接进入正题。

!!!不过要注意一点,本文保存模型采用的都是只保存模型参数的情况,而不是保存整个模型的情况。一定要看清楚再用啊!

1 单卡训练,单卡加载

复制代码
#保存模型
torch.save(model.state_dict(),'model.pt')

#加载模型
model=MyModel()#MyModel()是你定义的创建模型的函数,就是先初始化得到一个模型实例,之后再将模型参数加载到该实例上
model.load_state_dict(torch.load('model.pt'))

2 单卡训练,多卡加载

保存模型的过程同第一种情况一样,但是要注意,多卡加载模型时, 是先加载模型参数,再对模型做并行化处理。

复制代码
#保存模型
torch.save(model.state_dict(),'model.pt')


#加载模型
model=MyModel()
model.load_state_dict(torch.load('model.pt'))

model=nn.DataParallel(model)#将模型进行并行化处理

3 多卡保存,单卡加载

方法一:

考虑到之后可能需要单卡加载你多卡训练的模型,所以建议在保存的时候,要去除模型参数字典里面的module,即使用model.module.state_dict()代替model.state_dict()来进行去除。

因为是单卡加载,所以还是要先加载 模型参数,再对模型做并行化处理。

复制代码
#保存模型
torch.save(model.module.state_dict(),'modle.pt')


#加载模型
model=MyModel()
model.load_state_dict(torch.load('model.pt'))

model=nn.DataParallel(model)

方法二:

仍然使用model.state_dict()保存,但是单卡加载的时候,要把模型做并行化(在单卡上并行),加载的时候要注意:由于我们保存到 方式是以多卡方式保存的,所以无论加载之后的模型是 在答案卡上运行还是在多卡上运行,都要先把模型并行化处理,然后再去加载模型。

复制代码
#保存模型
torch.save(model.state_dict(),'model.pt')


#加载模型
model=MyModel()

model=nn.DataParallel(model)

model.load_state_dict(torch.load('model.pt'))

4 多卡保存,多卡加载

这里保存模型采用"多卡保存,单卡加载"的第二种方法,加载的时候,要先把模型做并行化(在多卡上并行),然后再加载。

复制代码
#保存模型
torch.save(model.state_dict(),'model.pt')

#加载模型
model=MyModel()

model=nn.DataParallel(model)

model.load_state_dict(torch.load('model.pt'))

希望以上内容能够帮助到你,这里是希望你能越来越好的 小白冲鸭 ~~~

相关推荐
耿雨飞2 分钟前
二、The Power of LLM Function Calling
人工智能·大模型
金能电力5 分钟前
金能电力领跑京东工业安全工器具赛道 2025年首季度数据诠释“头部效应”
人工智能·安全·金能电力安全工器具
程丞Q香6 分钟前
python——学生管理系统
开发语言·python·pycharm
WSSWWWSSW9 分钟前
神经网络如何表示数据
人工智能·深度学习·神经网络
多吃轻食37 分钟前
Jieba分词的原理及应用(三)
人工智能·深度学习·自然语言处理·中文分词·分词·jieba·隐马尔可夫
dragon_perfect1 小时前
ubuntu22.04上设定Service程序自启动,自动运行Conda环境下的Python脚本(亲测)
开发语言·人工智能·python·conda
明月看潮生1 小时前
青少年编程与数学 02-016 Python数据结构与算法 15课题、字符串匹配
python·算法·青少年编程·编程与数学
by————组态2 小时前
低代码 Web 组态
前端·人工智能·物联网·低代码·数学建模·组态
凡人的AI工具箱2 小时前
PyTorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(一)
人工智能·pytorch·python·深度学习·学习·生成对抗网络