在 PyTorch 中,保存和加载模型主要有两种方式。虽然都能达到目的,但它们在存储内容和使用便捷性上有着本质的区别。本文将结合实战代码,带你掌握这些核心技巧。
1. 方式一:保存整个网络模型结构与参数
这种方式最为直观,它会把模型的结构定义 和权重参数全部打包保存到一个文件中。
代码实现:
Python
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式一:保存模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")
# 加载方式一:
model = torch.load("vgg16_method1.pth")
print(model)
- 优点:加载方便,一行代码搞定。
- 缺点:文件体积较大;存在"陷阱"(见下文)。
2. 方式二:只保存模型参数(官方推荐)
这种方式只保存模型中每个层的权重和偏置(即 state_dict)。加载时需要先创建一个模型实例,再将参数"填"进去。
代码实现:
Python
# 保存方式二:将模型参数保存为字典格式(官方推荐,体积小)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
# 加载方式二:
# 1. 必须先新建一个网络模型实例
vgg16_new = torchvision.models.vgg16(pretrained=False)
# 2. 将保存的参数字典加载到模型中
vgg16_new.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16_new)
- 优点:更安全、更节省空间,是工业界和官方推荐的标准做法。
3. 核心避坑指南:自定义模型的加载"陷阱"
文件特别提到了一个新手常犯的错误:当你用"方式一"加载自定义模型时,如果当前环境没有该模型的类定义,程序会报错。
错误场景:
如果你在 A.py 中定义并保存了模型,在 B.py 中直接 torch.load,会提示 AttributeError。
解决方案:
在加载代码前,必须确保模型类的定义是可访问的。
Python
import torch
# 陷阱:必须把模型的定义复制过来,或者从其他文件 import 进来
class Tudui(torch.nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 64, 3)
def forward(self, x):
return self.conv1(x)
# 这样才能成功加载方式一保存的文件
model = torch.load("tudui_method1.pth")
print(model)
或者使用更优雅的导入方式:
Python
from model_save import * # 确保 Tudui 类被正确引入
model = torch.load("tudui_method1.pth")
4. 总结:如何选择保存方式?
- 方式一:适合快速实验,想省去重新定义结构的麻烦(但要记住加载时引入类定义)。
- 方式二:适合生产环境,不仅体积小,且能更灵活地将参数加载到结构略有差异的模型中。
💡 学习小结
掌握了模型的保存与读取,你就真正拥有了"复用"训练成果的能力。无论是模型的版本控制,还是将模型部署到云端或移动端,这都是必经之路。