pth.tar的保存和读取

一、简介

在PyTorch中,.pt、.pth和.pth.tar都是保存训练好的模型的文件格式。主要区别在于:

  1. .pt是PyTorch1.6及以上版本中引入的保存格式,可以保存整个模型,包括模型结构、模型参数以及优化器状态等信息,是一个二进制文件。
  2. .pth是PyTorch旧版本中使用的模型文件格式,只保存了模型参数,没有保存模型和其他相关信息,是一个二进制文件。
  3. .pth.tar包括.pth文件以及其他信息,比如模型结构、优化器状态、超参数信息。

二、保存

使用torch.save进行保存,保存时传入保存的状态,名称

python 复制代码
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

三、读取

1、通过torch.load()函数加载

python 复制代码
checkpoint_path = "/home/user/msh/Project/SimCLR-master_old/runs/Jan03_19-04-59_user-X10DRi/checkpoint_0100.pth.tar"
checkpoint = torch.load(checkpoint_path)
print(checkpoint.keys())

运行结果如下:

python 复制代码
dict_keys(['epoch', 'arch', 'state_dict', 'optimizer'])

2、epoch存放的是训练的轮次,arch存放的是模型的名称,optimizer存放是优化器具体的参数,

python 复制代码
epoch = checkpoint['epoch']
print(epoch)
arch = checkpoint['arch']
print(arch)
optimizer = checkpoint['optimizer']

运行结果:

python 复制代码
100
resnet18

3、state_dict.keys()存放的是模型每一层结构的名称

python 复制代码
state_dict = checkpoint['state_dict']
print(state_dict.keys())

4、使用:先初始化模型,创建一个对象,然后使用load_state_dict()函数加载参数

python 复制代码
model = ResNetSimCLR(arch,160)
model.load_state_dict(state_dict)
相关推荐
C嘎嘎嵌入式开发2 小时前
(2)100天python从入门到拿捏
开发语言·python
Stanford_11062 小时前
如何利用Python进行数据分析与可视化的具体操作指南
开发语言·c++·python·微信小程序·微信公众平台·twitter·微信开放平台
white-persist4 小时前
Python实例方法与Python类的构造方法全解析
开发语言·前端·python·原型模式
Java 码农4 小时前
Centos7 maven 安装
java·python·centos·maven
lyx33136967594 小时前
#深度学习基础:神经网络基础与PyTorch
pytorch·深度学习·神经网络·参数初始化
倔强青铜三5 小时前
苦练Python第63天:零基础玩转TOML配置读写,tomllib模块实战
人工智能·python·面试
递归不收敛5 小时前
吴恩达机器学习课程(PyTorch 适配)学习笔记:3.3 推荐系统全面解析
pytorch·学习·机器学习
浔川python社5 小时前
《网络爬虫技术规范与应用指南系列》(xc—3):合规实操与场景落地
python
B站计算机毕业设计之家5 小时前
智慧交通项目:Python+YOLOv8 实时交通标志系统 深度学习实战(TT100K+PySide6 源码+文档)✅
人工智能·python·深度学习·yolo·计算机视觉·智慧交通·交通标志
IT森林里的程序猿6 小时前
基于机器学习方法的网球比赛胜负趋势预测
python·机器学习·django