浅谈一谈pytorch中模型的几种保存方式、以及如何从中止的地方继续开始训练;

一、本文总共介绍3中pytorch模型的保存方式:1.保存整个模型;2.只保存模型参数;3.保存模型参数、优化器、学习率、epoch和其它的所有命令行相关参数以方便从上次中止训练的地方重新启动训练过程。

1.保存整个模型。这种保存方式最简单,保存内容包括模型结构、模型参数以及其它相关信息。代码如下:

python 复制代码
# 保存模型,PATH为模型的保存路径及模型命名
import torch
torch.save(model,PATH)

# 加载模型
model = torch.load(PATH)
  1. 只保存模型参数,不保存模型结构和其它相关信息。这种方式保存的模型,在加载模型前需要构建相同的模型结构,然后再将加载的模型参数赋值给对应的层。代码如下:
python 复制代码
# 只保存模型参数
torch.save(model.state_dict(), PATH)

# 创建相同结构的模型,然后加载模型参数
model = Model()   # 调用Model类实例化模型
model_dict = torch.load(PATH)
model.load_state_dict(model_dict) #加载模型参数

如果进行模型加载前,创建的模型结构发生了改变,和原来预训练的模型的结构不同,则需要遍历模型参数进行选择性赋值,例如下面的代码:

python 复制代码
from collections import OrderedDict

model = Unet()  # 实例化Unet模型
model_dict = torch.load(pretrained_pth, map_location="cpu")  # 加载模型时将参数映射到CPU上
new_state_dict = OrderedDict()  # 新建一个字典类型用来存储新的模型参数
# 改变模型结构名称,如果有,就去掉backbone.前缀
for k, v in model_dict["state_dict"].items():
    new_state_dict[k.replace("backbone.", "")] = v

model.load_state_dict(new_state_dict)  # 加载模型参数

注意上述代码中,有一个参数 map_location="cpu",这个参数是指定将模型参数映射到CPU上,这个参数一般在一下情况下比较适用:1. 当你在CPU上训练了一个模型,并且想将其加载到CPU上进行推断或者继续训练时,使用map_location="cpu"可以确保模型参数被正确地映射到CPU上;2.如果你的预训练模型是在GPU上训练的,但是你在没有GPU的环境中加载模型时,使用这个参数可以避免找不到GPU而导致的错误。 而如果你的代码没有指定map_location参数,则默认情况下pytorch会尝试将模型加载到当前可用设备上(通常是GPU)

  1. 保存模型必要参数,使下次训练可以从模型训练停止的地方继续训练,代码如下:
python 复制代码
# 将需要保存的参数打包成字典类型
save_file = {"model": model.state_dict(),
             "optimizer": optimizer.state_dict(),
             "lr_scheduler": lr_scheduler.state_dict(),
             "epoch": epoch,
             "args": args}     

# 保存模型和其它参数    
torch.save(save_file, "save_weights/model.pth")
    
python 复制代码
# 加载模型和必要的参数
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])  # 加载模型参数
optimizer.load_state_dict(checkpoint['optimizer'])  # 加载模型优化器
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])  # 加载模型学习策略
args.start_epoch = checkpoint['epoch'] + 1  # 加载模型训练epoch停止数

如果仅是进行模型推理,则只用加载模型参数即可,不用加载其它的东西。

相关推荐
区块block5 分钟前
DeFi中的自主代理:用AI重塑金融
人工智能·金融
数据科学作家10 分钟前
如何入门python机器学习?金融从业人员如何快速学习Python、机器学习?机器学习、数据科学如何进阶成为大神?
大数据·开发语言·人工智能·python·机器学习·数据分析·统计分析
GJGCY12 分钟前
金融智能体技术解读:十大应用场景与AI Agent架构设计思路
人工智能·经验分享·ai·金融·自动化
孤客网络科技工作室14 分钟前
Python - 100天从新手到大师:第五十八天 Python中的并发编程(1-3)
开发语言·python
文火冰糖的硅基工坊17 分钟前
[人工智能-大模型-57]:模型层技术 - 软件开发的不同层面(如底层系统、中间件、应用层等),算法的类型、设计目标和实现方式存在显著差异。
人工智能·算法·中间件
Coovally AI模型快速验证22 分钟前
突破性开源模型DepthLM问世:视觉语言模型首次实现精准三维空间理解
人工智能·语言模型·自然语言处理·ocr·音视频·ai编程
芯片SIPI设计33 分钟前
面向3D IC AI芯片中UCIe 电源传输与电源完整性的系统分析挑战与解决方案
人工智能·3d
计算衎38 分钟前
Jenkins上实现CI集成软件信息Teams群通知案例实现。
python·jenkins·1024程序员节·microsoft azure·teams消息群通知·微软 graph api
浆果020739 分钟前
【图像超分】论文复现:轻量化超分 | RLFN的Pytorch源码复现,跑通源码,整合到EDSR-PyTorch中进行训练、测试
人工智能·python·深度学习·超分辨率重建·1024程序员节
CV实验室1 小时前
TPAMI 2025 | 从分离到融合:新一代3D场景技术实现双重能力提升!
人工智能·计算机视觉·3d