文章目录
- [一、状态字典(State Dictionary)](#一、状态字典(State Dictionary))
- [二、序列化模型(Serialized Model)](#二、序列化模型(Serialized Model))
- 三、示例代码
一、状态字典(State Dictionary)
这种保存形式将模型的参数保存为一个字典,其中包含了所有模型的权重和偏置等参数。状态字典保存了模型在训练过程中学到的参数值,而不包含模型的结构。可以使用这个字典来加载模型的参数,并将其应用于相同结构的模型。
在 PyTorch 中,您可以使用 torch.save() 函数将模型的状态字典保存到文件中,例如:
python
torch.save(model.state_dict(), 'model.pth')
然后,可以使用 torch.load() 函数加载状态字典并将其应用于相同结构的模型:
python
model = MyModel() # 创建模型对象
model.load_state_dict(torch.load('model.pth'))
这种保存形式非常适用于仅保存和加载模型的参数,而不需要保存和加载模型的结构。
二、序列化模型(Serialized Model)
这种保存形式将整个模型(包括模型的结构、参数等)保存为一个文件。序列化模型保存了模型的完整信息,可以完全恢复模型的状态,包括模型的结构、权重、偏置以及其他相关参数。
在 PyTorch 中,您可以使用 torch.save() 函数直接保存整个模型对象,例如:
python
torch.save(model, 'model.pth')
然后,您可以使用 torch.load() 函数加载整个序列化模型:
python
model = torch.load('model.pth')
这种保存形式适用于需要保存和加载完整模型信息的情况,包括模型的结构和参数。
三、示例代码
python
import torch
class LinearNet(torch.nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(in_features=input_size, out_features= 5, bias=True),
torch.nn.Sigmoid(),
torch.nn.Linear(in_features= 5, out_features=5, bias=True),
torch.nn.Sigmoid(),
torch.nn.Linear(in_features=5, out_features=output_size, bias=True)
)
def forward(self,x):
return self.net(x)
square_net = LinearNet(1,1)
# square_net.load_state_dict(torch.load('weight.pth')) #直接加载已经训练好的权重
if __name__ == '__main__':
# print(square_net(torch.tensor([3.16],dtype=torch.float32)))
# save 方式1
torch.save(square_net.state_dict(), "./w1.pth")
my_state_dict = torch.load("./w1.pth")
print("纯state_dict:\n", my_state_dict)
print("type:", type(my_state_dict))
# save 方式2
torch.save(square_net, "./w2.pth")
my_state_dict = torch.load("./w2.pth")
print("\n\n模型结构:\n", my_state_dict)
print("type:", type(my_state_dict))
# 执行结果
'''
纯state_dict:
OrderedDict([('net.0.weight', tensor([[ 0.0820],
[-0.6923],
[ 0.5066],
[-0.8931],
[ 0.0460]])), ('net.0.bias', tensor([ 0.1455, 0.5106, 0.2347, 0.4903, -0.6838])), ('net.2.weight', tensor([[-0.4055, -0.2721, 0.3770, -0.2285, 0.3025],
[-0.0416, 0.0133, -0.3834, -0.2151, 0.1454],
[ 0.0749, -0.3664, -0.1901, -0.2829, 0.3957],
[-0.3567, 0.2668, 0.3343, -0.3351, -0.3808],
[ 0.4375, 0.1000, 0.1185, 0.2295, -0.3997]])), ('net.2.bias', tensor([-0.2405, -0.2751, 0.1928, 0.3970, -0.0005])), ('net.4.weight', tensor([[-0.4388, -0.2654, 0.3038, 0.2008, 0.0381]])), ('net.4.bias', tensor([0.1847]))])
模型结构:
LinearNet(
(net): Sequential(
(0): Linear(in_features=1, out_features=5, bias=True)
(1): Sigmoid()
(2): Linear(in_features=5, out_features=5, bias=True)
(3): Sigmoid()
(4): Linear(in_features=5, out_features=1, bias=True)
)
)
'''