使用pytorch保存和加载预训练的模型方法

需要使用到的函数

在 PyTorch 中,torch.save()torch.load() 是用于保存和加载模型的核心函数。

torch.save() 函数

  • 主要用途:将模型或模型的状态字典(state_dict)保存到文件中。

  • 语法

python 复制代码
torch.save(obj, f, pickle_module=pickle, pickle_protocol=None, _use_new_zipfile_serialization=True)
  • obj: 要保存的对象,可以是整个模型(nn.Module)或模型的状态字典(state_dict)。

  • f: 保存文件的路径。可以是一个字符串路径(如 'model.pth''model.pkl')或一个打开的文件对象。

  • pickle_module: 默认是 pickle,用于序列化对象。你可以使用其他兼容的序列化模块。

  • pickle_protocol: pickle 协议版本。默认值为 None,表示使用最高可用协议版本。

  • _use_new_zipfile_serialization: 默认值为 True,控制是否使用新的序列化格式(推荐使用)。

python 复制代码
# 保存整个模型
torch.save(model, 'model.pth')

# 保存模型的状态字典
torch.save(model.state_dict(), 'model_state_dict.pth')

torch.load() 函数

  • 主要用途:从文件中加载保存的模型或模型的状态字典。

  • 语法

python 复制代码
torch.load(f, map_location=None, pickle_module=pickle)
  • f: 要加载的文件路径。可以是一个字符串路径或一个打开的文件对象。

  • map_location: 控制如何将存储位置映射到当前设备。例如,map_location='cuda:0' 表示将模型加载到 GPU 上。

  • pickle_module: 默认是 pickle,用于反序列化对象。

python 复制代码
# 加载整个模型
model = torch.load('model.pth', map_location='cpu')  # 加载到 CPU

# 加载模型的状态字典
model_state_dict = torch.load('model_state_dict.pth', map_location='cuda:0')  # 加载到 GPU

加载状态字典到模型

  • 加载状态字典后,通常需要将其加载到一个已经实例化的模型中。可以使用 model.load_state_dict() 方法:

  • 语法

python 复制代码
model.load_state_dict(state_dict, strict=True)
  • state_dict: 从文件中加载的模型状态字典。

  • strict: 默认为 True,表示严格加载状态字典中的所有键。如果设置为 False,可以忽略不匹配的键。

python 复制代码
# 实例化模型
model = SimpleModel()

# 加载状态字典
model_state_dict = torch.load('model_state_dict.pth')

# 将状态字典加载到模型中
model.load_state_dict(model_state_dict)

注意事项

  • 设备映射 :使用 torch.load() 时,可以指定 map_location 参数来控制模型加载到的设备(如 CPU 或 GPU)。

  • 自定义类:保存和加载整个模型时,需要确保自定义的模型类在加载代码中已经定义,否则会报错。

  • 兼容性torch.save()torch.load() 使用 pickle 序列化,可能会受到 Python 版本和 PyTorch 版本的影响。建议使用相同版本的 PyTorch 和 Python 进行保存和加载。

  • 推荐使用状态字典 :保存和加载状态字典(state_dict)比保存整个模型更灵活和可移植。这样可以避免保存自定义类的依赖关系。

通过以上方法,你可以灵活地保存和加载 PyTorch 模型,无论是 .pth 还是 .pkl 格式,都可以根据需要选择合适的保存方式。

保存和读取.pth格式的预训练模型

保存

python 复制代码
import torch
import torch.nn as nn

# 定义一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        x = self.fc(x)
        return x

# 创建模型实例
model = SimpleModel()

# 假设已经训练了模型,这里只是演示保存
# 保存整个模型
torch.save(model, 'model.pth')
# 或者只保存模型的参数
torch.save(model.state_dict(), 'model_state_dict.pth')

读取

python 复制代码
# 如果保存的是整个模型
loaded_model = torch.load('model.pth')
# 如果保存的是模型参数
model_load = SimpleModel()  # 先实例化模型结构
model_load.load_state_dict(torch.load('model_state_dict.pth'))
###########################################################################
# 检查 GPU 是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载预训练模型
model = SimpleModel()
model.load_state_dict(torch.load('model_state_dict.pth', map_location=device))

# 将模型转移到 GPU
model.to(device)

# 示例输入数据
input_data = torch.randn(1, 10).to(device)  # 确保输入数据也在 GPU 上

# 前向传播
output = model(input_data)
print(output)

在使用 model.load_state_dict(torch.load('model_state_dict.pth', map_location=device)) 读取模型时,已经指定了 map_location=device,这确保了模型的参数(张量)被加载到指定的设备上。但是,是否还需要调用 model.to(device) 取决于具体的情况。

详细分析

  1. map_location=device 的作用

    • map_location=device 参数用于指定加载的张量应该被放置到哪个设备上。当你加载模型的状态字典时,这个参数确保所有张量(如模型的权重和偏置)被加载到指定的设备(CPU 或 GPU)。

    • 这个参数主要用于处理加载时的设备映射,特别是在加载存储在不同设备上的模型时(例如,从 GPU 上保存的模型加载到 CPU 上或反之)。

2 .model.to(device) 的作用

  • model.to(device) 用于将整个模型(包括模型的参数、缓冲区等)转移到指定的设备上。这是一个递归操作,会遍历模型的所有子模块并将其转移到目标设备。

  • 如果模型在加载时已经将所有张量加载到了正确的设备上(通过 map_location=device),那么调用 model.to(device) 是冗余的,但它不会产生负面影响。

具体情况分析

  • 加载到 CPU : -你在 如果 CPU 上加载模型,并且使用 map_location='cpu',那么模型的张量已经被加载到 CPU 上。在这种情况下,调用 model.to('cpu') 是不必要的,因为模型已经在 CPU 上了。

  • 加载到 GPU

    • 如果你在 GPU 上加载模型,并且使用 map_location='cuda'map_location=device(其中 device 是 GPU),那么模型的张量已经被加载到 GPU 上。但是,模型对象本身(如模型的结构)可能仍然在 CPU 上。

    • 此,调用 model.to(device) 可以确保模型的所有部分(包括模型的结构和参数)都正确地在 GPU 上。

推荐做法

为了确保模型及其所有组成部分都在正确的设备上,建议在加载模型后调用 model.to(device)。这样可以避免潜在的设备不一致问题。

保存和读取.pkl格式的预训练模型

保存

python 复制代码
import torch
import torch.nn as nn

# 义一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        x = self.fc(x)
        return x

# 创建模型实例
model = SimpleModel()

# 保存整个模型
with open('model.pkl', 'wb') as f:
    torch.save(model, f)
# 或者只保存模型的参数
with open('model_state_dict.pkl', 'wb') as f:
    torch.save(model.state_dict(), f)

读取

python 复制代码
# 如果保存的是整个模型
with open('model.pkl', 'rb') as f:
    loaded_model = torch.load(f)
# 如果保存的是模型参数
model_load = SimpleModel()  # 先实例化模型结构
with open('model_state_dict.pkl', 'rb') as f:
    model_load.load_state_dict(torch.load(f))

两种格式的区别

  • pth 格式

    • 是 PyTorch 推荐的模型保存格式。它使用 Python 的 pickle 模块来序列化模型对象。对于模型的存储来说,它能够较好地保存和加载模型的结构以及参数。当你想要完整地保存和恢复一个模型的训练状态(包括模型结构、参数、优化器等时),使用.pth 格式很方便。
  • pkl 格式

    • 本质上也是使用 pickle 序列化对象。它是一种通用的 Python 对象序列化格式。在 PyTorch 的早期版本中,pkl 格式被广泛用于保存模型。但是使用 pkl 格式时,可能会受到 Python 版本的限制。因为不同 Python 版本之间,pickle 序列化后的对象在反序列化时可能会出现兼容性问题。例如,你在 Python 3.7 环境下用 pickle 保存了一个模型,然后在 Python 3.8 环境下尝试加载时,可能会因为 pickle 协议版本或者对象结构差异等原因导致加载失败。而.pth 格式会更好地处理这些兼容性问题。

注意事项

  • 当保存整个模型时,如果自定义了模型类,加载模型时也需要提供相同的自定义类定义。否则加载时会出现错误,因为无法识别自定义类的结构。

  • 如果只保存模型参数(state_dict),在加载时必须先实例化一个与保存时相同的模型结构,然后将保存的参数加载到这个结构中。这样可以避免保存自定义类的依赖关系,增加模型的可移植性,但前提是你要清楚地知道模型的结构。

相关推荐
fanstuck34 分钟前
从知识图谱到精准决策:基于MCP的招投标货物比对溯源系统实践
人工智能·知识图谱
dqsh0638 分钟前
树莓派5+Ubuntu24.04 LTS串口通信 保姆级教程
人工智能·python·物联网·ubuntu·机器人
打小就很皮...2 小时前
编写大模型Prompt提示词方法
人工智能·语言模型·prompt
Aliano2172 小时前
Prompt(提示词)工程师,“跟AI聊天”
人工智能·prompt
weixin_445238122 小时前
第R8周:RNN实现阿尔兹海默病诊断(pytorch)
人工智能·pytorch·rnn
KingDol_MIni2 小时前
ResNet残差神经网络的模型结构定义(pytorch实现)
人工智能·pytorch·神经网络
sunshineine2 小时前
jupyter notebook运行简单程序
linux·windows·python
方博士AI机器人2 小时前
Python 3.x 内置装饰器 (4) - @dataclass
开发语言·python
万能程序员-传康Kk2 小时前
中国邮政物流管理系统(Django+mysql)
python·mysql·django
Logintern093 小时前
【每天学习一点点】使用Python的pathlib模块分割文件路径
开发语言·python·学习