114_PyTorch 进阶:模型保存与读取的两大方式及“陷阱”避坑指南

在 PyTorch 中,保存和加载模型主要有两种方式。虽然都能达到目的,但它们在存储内容和使用便捷性上有着本质的区别。本文将结合实战代码,带你掌握这些核心技巧。

1. 方式一:保存整个网络模型结构与参数

这种方式最为直观,它会把模型的结构定义权重参数全部打包保存到一个文件中。

代码实现:

Python

复制代码
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式一:保存模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

# 加载方式一:
model = torch.load("vgg16_method1.pth")
print(model)
  • 优点:加载方便,一行代码搞定。
  • 缺点:文件体积较大;存在"陷阱"(见下文)。

2. 方式二:只保存模型参数(官方推荐)

这种方式只保存模型中每个层的权重和偏置(即 state_dict)。加载时需要先创建一个模型实例,再将参数"填"进去。

代码实现:

Python

复制代码
# 保存方式二:将模型参数保存为字典格式(官方推荐,体积小)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

# 加载方式二:
# 1. 必须先新建一个网络模型实例
vgg16_new = torchvision.models.vgg16(pretrained=False)
# 2. 将保存的参数字典加载到模型中
vgg16_new.load_state_dict(torch.load("vgg16_method2.pth"))

print(vgg16_new)
  • 优点:更安全、更节省空间,是工业界和官方推荐的标准做法。

3. 核心避坑指南:自定义模型的加载"陷阱"

文件特别提到了一个新手常犯的错误:当你用"方式一"加载自定义模型时,如果当前环境没有该模型的类定义,程序会报错。

错误场景:

如果你在 A.py 中定义并保存了模型,在 B.py 中直接 torch.load,会提示 AttributeError

解决方案:

在加载代码前,必须确保模型类的定义是可访问的。

Python

复制代码
import torch

# 陷阱:必须把模型的定义复制过来,或者从其他文件 import 进来
class Tudui(torch.nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, 3)

    def forward(self, x):
        return self.conv1(x)

# 这样才能成功加载方式一保存的文件
model = torch.load("tudui_method1.pth")
print(model)

或者使用更优雅的导入方式:

Python

复制代码
from model_save import * # 确保 Tudui 类被正确引入
model = torch.load("tudui_method1.pth")

4. 总结:如何选择保存方式?

  • 方式一:适合快速实验,想省去重新定义结构的麻烦(但要记住加载时引入类定义)。
  • 方式二:适合生产环境,不仅体积小,且能更灵活地将参数加载到结构略有差异的模型中。

💡 学习小结

掌握了模型的保存与读取,你就真正拥有了"复用"训练成果的能力。无论是模型的版本控制,还是将模型部署到云端或移动端,这都是必经之路。

相关推荐
乐维_lwops15 小时前
从 “救火运维” 到 “自动驾驶”:运维智能体到底解决了什么?
运维·人工智能·运维智能体
TheRouter15 小时前
AI Agent 记忆体系建设实战:短期、长期与工作记忆的工程实现
数据库·人工智能·oracle
weixin_4684668515 小时前
MoneyPrinterTurbo 短视频自动化生产实战指南
运维·人工智能·自动化·大模型·音视频·moneyprinter
Omics Pro15 小时前
首个!外源天然产物综合性代谢图谱
数据库·人工智能·算法·机器学习·r语言
LilySesy15 小时前
【与AI+】英语day7——工作流与增强工具
人工智能·sap·abap·机器翻译
voidmort15 小时前
3. 微调(Fine-tuning)与强化学习(RL)的核心思想
python·深度学习·算法
彬鸿科技16 小时前
bhSDR Studio/Matlab入门指南(十一):AI数据集采集实验界面全解析
人工智能·matlab·软件定义无线电
云烟成雨TD16 小时前
Spring AI Alibaba 1.x 系列【63】AI Agent 长期记忆
java·人工智能·spring
武雄(小星Ai)16 小时前
2026年AI Agent框架选型指南:LangGraph vs CrewAI vs Claude SDK vs OpenAI SDK
人工智能·aigc·agent
狒狒热知识16 小时前
2026年AI传播新闻软文营销发布当下178软文网领衔发展路径
大数据·人工智能