需要使用到的函数
在 PyTorch 中,torch.save()
和 torch.load()
是用于保存和加载模型的核心函数。
torch.save() 函数
-
主要用途:将模型或模型的状态字典(state_dict)保存到文件中。
-
语法:
python
torch.save(obj, f, pickle_module=pickle, pickle_protocol=None, _use_new_zipfile_serialization=True)
-
obj
: 要保存的对象,可以是整个模型(nn.Module
)或模型的状态字典(state_dict
)。 -
f
: 保存文件的路径。可以是一个字符串路径(如'model.pth'
或'model.pkl'
)或一个打开的文件对象。 -
pickle_module
: 默认是pickle
,用于序列化对象。你可以使用其他兼容的序列化模块。 -
pickle_protocol
:pickle
协议版本。默认值为None
,表示使用最高可用协议版本。 -
_use_new_zipfile_serialization
: 默认值为True
,控制是否使用新的序列化格式(推荐使用)。
python
# 保存整个模型
torch.save(model, 'model.pth')
# 保存模型的状态字典
torch.save(model.state_dict(), 'model_state_dict.pth')
torch.load() 函数
-
主要用途:从文件中加载保存的模型或模型的状态字典。
-
语法:
python
torch.load(f, map_location=None, pickle_module=pickle)
-
f
: 要加载的文件路径。可以是一个字符串路径或一个打开的文件对象。 -
map_location
: 控制如何将存储位置映射到当前设备。例如,map_location='cuda:0'
表示将模型加载到 GPU 上。 -
pickle_module
: 默认是pickle
,用于反序列化对象。
python
# 加载整个模型
model = torch.load('model.pth', map_location='cpu') # 加载到 CPU
# 加载模型的状态字典
model_state_dict = torch.load('model_state_dict.pth', map_location='cuda:0') # 加载到 GPU
加载状态字典到模型
-
加载状态字典后,通常需要将其加载到一个已经实例化的模型中。可以使用
model.load_state_dict()
方法: -
语法:
python
model.load_state_dict(state_dict, strict=True)
-
state_dict
: 从文件中加载的模型状态字典。 -
strict
: 默认为True
,表示严格加载状态字典中的所有键。如果设置为False
,可以忽略不匹配的键。
python
# 实例化模型
model = SimpleModel()
# 加载状态字典
model_state_dict = torch.load('model_state_dict.pth')
# 将状态字典加载到模型中
model.load_state_dict(model_state_dict)
注意事项
-
设备映射 :使用
torch.load()
时,可以指定map_location
参数来控制模型加载到的设备(如 CPU 或 GPU)。 -
自定义类:保存和加载整个模型时,需要确保自定义的模型类在加载代码中已经定义,否则会报错。
-
兼容性 :
torch.save()
和torch.load()
使用pickle
序列化,可能会受到 Python 版本和 PyTorch 版本的影响。建议使用相同版本的 PyTorch 和 Python 进行保存和加载。 -
推荐使用状态字典 :保存和加载状态字典(
state_dict
)比保存整个模型更灵活和可移植。这样可以避免保存自定义类的依赖关系。
通过以上方法,你可以灵活地保存和加载 PyTorch 模型,无论是 .pth
还是 .pkl
格式,都可以根据需要选择合适的保存方式。
保存和读取.pth格式的预训练模型
保存
python
import torch
import torch.nn as nn
# 定义一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
x = self.fc(x)
return x
# 创建模型实例
model = SimpleModel()
# 假设已经训练了模型,这里只是演示保存
# 保存整个模型
torch.save(model, 'model.pth')
# 或者只保存模型的参数
torch.save(model.state_dict(), 'model_state_dict.pth')
读取
python
# 如果保存的是整个模型
loaded_model = torch.load('model.pth')
# 如果保存的是模型参数
model_load = SimpleModel() # 先实例化模型结构
model_load.load_state_dict(torch.load('model_state_dict.pth'))
###########################################################################
# 检查 GPU 是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载预训练模型
model = SimpleModel()
model.load_state_dict(torch.load('model_state_dict.pth', map_location=device))
# 将模型转移到 GPU
model.to(device)
# 示例输入数据
input_data = torch.randn(1, 10).to(device) # 确保输入数据也在 GPU 上
# 前向传播
output = model(input_data)
print(output)
在使用 model.load_state_dict(torch.load('model_state_dict.pth', map_location=device))
读取模型时,已经指定了 map_location=device
,这确保了模型的参数(张量)被加载到指定的设备上。但是,是否还需要调用 model.to(device)
取决于具体的情况。
详细分析
-
map_location=device
的作用:-
map_location=device
参数用于指定加载的张量应该被放置到哪个设备上。当你加载模型的状态字典时,这个参数确保所有张量(如模型的权重和偏置)被加载到指定的设备(CPU 或 GPU)。 -
这个参数主要用于处理加载时的设备映射,特别是在加载存储在不同设备上的模型时(例如,从 GPU 上保存的模型加载到 CPU 上或反之)。
-
2 .model.to(device)
的作用:
-
model.to(device)
用于将整个模型(包括模型的参数、缓冲区等)转移到指定的设备上。这是一个递归操作,会遍历模型的所有子模块并将其转移到目标设备。 -
如果模型在加载时已经将所有张量加载到了正确的设备上(通过
map_location=device
),那么调用model.to(device)
是冗余的,但它不会产生负面影响。
具体情况分析
-
加载到 CPU : -你在 如果 CPU 上加载模型,并且使用
map_location='cpu'
,那么模型的张量已经被加载到 CPU 上。在这种情况下,调用model.to('cpu')
是不必要的,因为模型已经在 CPU 上了。 -
加载到 GPU:
-
如果你在 GPU 上加载模型,并且使用
map_location='cuda'
或map_location=device
(其中device
是 GPU),那么模型的张量已经被加载到 GPU 上。但是,模型对象本身(如模型的结构)可能仍然在 CPU 上。 -
此,调用
model.to(device)
可以确保模型的所有部分(包括模型的结构和参数)都正确地在 GPU 上。
-
推荐做法
为了确保模型及其所有组成部分都在正确的设备上,建议在加载模型后调用 model.to(device)
。这样可以避免潜在的设备不一致问题。
保存和读取.pkl格式的预训练模型
保存
python
import torch
import torch.nn as nn
# 义一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
x = self.fc(x)
return x
# 创建模型实例
model = SimpleModel()
# 保存整个模型
with open('model.pkl', 'wb') as f:
torch.save(model, f)
# 或者只保存模型的参数
with open('model_state_dict.pkl', 'wb') as f:
torch.save(model.state_dict(), f)
读取
python
# 如果保存的是整个模型
with open('model.pkl', 'rb') as f:
loaded_model = torch.load(f)
# 如果保存的是模型参数
model_load = SimpleModel() # 先实例化模型结构
with open('model_state_dict.pkl', 'rb') as f:
model_load.load_state_dict(torch.load(f))
两种格式的区别 :
-
pth 格式 :
- 是 PyTorch 推荐的模型保存格式。它使用 Python 的 pickle 模块来序列化模型对象。对于模型的存储来说,它能够较好地保存和加载模型的结构以及参数。当你想要完整地保存和恢复一个模型的训练状态(包括模型结构、参数、优化器等时),使用.pth 格式很方便。
-
pkl 格式 :
- 本质上也是使用 pickle 序列化对象。它是一种通用的 Python 对象序列化格式。在 PyTorch 的早期版本中,pkl 格式被广泛用于保存模型。但是使用 pkl 格式时,可能会受到 Python 版本的限制。因为不同 Python 版本之间,pickle 序列化后的对象在反序列化时可能会出现兼容性问题。例如,你在 Python 3.7 环境下用 pickle 保存了一个模型,然后在 Python 3.8 环境下尝试加载时,可能会因为 pickle 协议版本或者对象结构差异等原因导致加载失败。而.pth 格式会更好地处理这些兼容性问题。
注意事项 :
-
当保存整个模型时,如果自定义了模型类,加载模型时也需要提供相同的自定义类定义。否则加载时会出现错误,因为无法识别自定义类的结构。
-
如果只保存模型参数(state_dict),在加载时必须先实例化一个与保存时相同的模型结构,然后将保存的参数加载到这个结构中。这样可以避免保存自定义类的依赖关系,增加模型的可移植性,但前提是你要清楚地知道模型的结构。