pth.tar的保存和读取

一、简介

在PyTorch中,.pt、.pth和.pth.tar都是保存训练好的模型的文件格式。主要区别在于:

  1. .pt是PyTorch1.6及以上版本中引入的保存格式,可以保存整个模型,包括模型结构、模型参数以及优化器状态等信息,是一个二进制文件。
  2. .pth是PyTorch旧版本中使用的模型文件格式,只保存了模型参数,没有保存模型和其他相关信息,是一个二进制文件。
  3. .pth.tar包括.pth文件以及其他信息,比如模型结构、优化器状态、超参数信息。

二、保存

使用torch.save进行保存,保存时传入保存的状态,名称

python 复制代码
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

三、读取

1、通过torch.load()函数加载

python 复制代码
checkpoint_path = "/home/user/msh/Project/SimCLR-master_old/runs/Jan03_19-04-59_user-X10DRi/checkpoint_0100.pth.tar"
checkpoint = torch.load(checkpoint_path)
print(checkpoint.keys())

运行结果如下:

python 复制代码
dict_keys(['epoch', 'arch', 'state_dict', 'optimizer'])

2、epoch存放的是训练的轮次,arch存放的是模型的名称,optimizer存放是优化器具体的参数,

python 复制代码
epoch = checkpoint['epoch']
print(epoch)
arch = checkpoint['arch']
print(arch)
optimizer = checkpoint['optimizer']

运行结果:

python 复制代码
100
resnet18

3、state_dict.keys()存放的是模型每一层结构的名称

python 复制代码
state_dict = checkpoint['state_dict']
print(state_dict.keys())

4、使用:先初始化模型,创建一个对象,然后使用load_state_dict()函数加载参数

python 复制代码
model = ResNetSimCLR(arch,160)
model.load_state_dict(state_dict)
相关推荐
databook12 分钟前
Manim实现闪光轨迹特效
后端·python·动效
Juchecar1 小时前
解惑:NumPy 中 ndarray.ndim 到底是什么?
python
用户8356290780512 小时前
Python 删除 Excel 工作表中的空白行列
后端·python
Json_2 小时前
使用python-fastApi框架开发一个学校宿舍管理系统-前后端分离项目
后端·python·fastapi
CoovallyAIHub4 小时前
开源的消逝与新生:从 TensorFlow 的落幕到开源生态的蜕变
pytorch·深度学习·llm
数据智能老司机8 小时前
精通 Python 设计模式——分布式系统模式
python·设计模式·架构
数据智能老司机9 小时前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言
数据智能老司机9 小时前
精通 Python 设计模式——测试模式
python·设计模式·架构
数据智能老司机9 小时前
精通 Python 设计模式——性能模式
python·设计模式·架构
c8i10 小时前
drf初步梳理
python·django