114_PyTorch 进阶:模型保存与读取的两大方式及“陷阱”避坑指南

在 PyTorch 中,保存和加载模型主要有两种方式。虽然都能达到目的,但它们在存储内容和使用便捷性上有着本质的区别。本文将结合实战代码,带你掌握这些核心技巧。

1. 方式一:保存整个网络模型结构与参数

这种方式最为直观,它会把模型的结构定义权重参数全部打包保存到一个文件中。

代码实现:

Python

复制代码
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式一:保存模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

# 加载方式一:
model = torch.load("vgg16_method1.pth")
print(model)
  • 优点:加载方便,一行代码搞定。
  • 缺点:文件体积较大;存在"陷阱"(见下文)。

2. 方式二:只保存模型参数(官方推荐)

这种方式只保存模型中每个层的权重和偏置(即 state_dict)。加载时需要先创建一个模型实例,再将参数"填"进去。

代码实现:

Python

复制代码
# 保存方式二:将模型参数保存为字典格式(官方推荐,体积小)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

# 加载方式二:
# 1. 必须先新建一个网络模型实例
vgg16_new = torchvision.models.vgg16(pretrained=False)
# 2. 将保存的参数字典加载到模型中
vgg16_new.load_state_dict(torch.load("vgg16_method2.pth"))

print(vgg16_new)
  • 优点:更安全、更节省空间,是工业界和官方推荐的标准做法。

3. 核心避坑指南:自定义模型的加载"陷阱"

文件特别提到了一个新手常犯的错误:当你用"方式一"加载自定义模型时,如果当前环境没有该模型的类定义,程序会报错。

错误场景:

如果你在 A.py 中定义并保存了模型,在 B.py 中直接 torch.load,会提示 AttributeError

解决方案:

在加载代码前,必须确保模型类的定义是可访问的。

Python

复制代码
import torch

# 陷阱:必须把模型的定义复制过来,或者从其他文件 import 进来
class Tudui(torch.nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, 3)

    def forward(self, x):
        return self.conv1(x)

# 这样才能成功加载方式一保存的文件
model = torch.load("tudui_method1.pth")
print(model)

或者使用更优雅的导入方式:

Python

复制代码
from model_save import * # 确保 Tudui 类被正确引入
model = torch.load("tudui_method1.pth")

4. 总结:如何选择保存方式?

  • 方式一:适合快速实验,想省去重新定义结构的麻烦(但要记住加载时引入类定义)。
  • 方式二:适合生产环境,不仅体积小,且能更灵活地将参数加载到结构略有差异的模型中。

💡 学习小结

掌握了模型的保存与读取,你就真正拥有了"复用"训练成果的能力。无论是模型的版本控制,还是将模型部署到云端或移动端,这都是必经之路。

相关推荐
混沌福王21 分钟前
Electron三端统一架构:运行时Adapter、IPC能力边界与分层设计
人工智能·agent·ai编程
说了很好21 分钟前
马尔可夫扩散链+损失函数推导,手把手实现原生Diffusion
人工智能
聂二AI落地内参24 分钟前
合同抽取别停在 JSON:标准规则和交易日历才是硬仗
人工智能
冬哥聊AI26 分钟前
滴滴Agent岗二面:RAG 系统的 LLM 幻觉怎么治?从两类根源讲到四道防线
人工智能
lyshlc32 分钟前
# AI Agent的推迟判定协议:不确定性下的最优策略
人工智能
用户3299016750536 分钟前
用zod在运行时兜住AI返回的JSON
人工智能
George37537 分钟前
第一章:本体论是什么(以及它不是什么)
人工智能
贵慜_Derek37 分钟前
《从零实现 Agent 系统》连载 32|闭集 IE 与小模型:分类、意图与字段抽取
人工智能·架构·agent
IT_陈寒1 小时前
Java 并行流把我坑惨了,这6小时加班值了
前端·人工智能·后端
火山引擎开发者社区2 小时前
告别长期密码:火山引擎云数据库 MySQL IAM 鉴权全解析
人工智能