114_PyTorch 进阶:模型保存与读取的两大方式及“陷阱”避坑指南

在 PyTorch 中,保存和加载模型主要有两种方式。虽然都能达到目的,但它们在存储内容和使用便捷性上有着本质的区别。本文将结合实战代码,带你掌握这些核心技巧。

1. 方式一:保存整个网络模型结构与参数

这种方式最为直观,它会把模型的结构定义权重参数全部打包保存到一个文件中。

代码实现:

Python

复制代码
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式一:保存模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

# 加载方式一:
model = torch.load("vgg16_method1.pth")
print(model)
  • 优点:加载方便,一行代码搞定。
  • 缺点:文件体积较大;存在"陷阱"(见下文)。

2. 方式二:只保存模型参数(官方推荐)

这种方式只保存模型中每个层的权重和偏置(即 state_dict)。加载时需要先创建一个模型实例,再将参数"填"进去。

代码实现:

Python

复制代码
# 保存方式二:将模型参数保存为字典格式(官方推荐,体积小)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

# 加载方式二:
# 1. 必须先新建一个网络模型实例
vgg16_new = torchvision.models.vgg16(pretrained=False)
# 2. 将保存的参数字典加载到模型中
vgg16_new.load_state_dict(torch.load("vgg16_method2.pth"))

print(vgg16_new)
  • 优点:更安全、更节省空间,是工业界和官方推荐的标准做法。

3. 核心避坑指南:自定义模型的加载"陷阱"

文件特别提到了一个新手常犯的错误:当你用"方式一"加载自定义模型时,如果当前环境没有该模型的类定义,程序会报错。

错误场景:

如果你在 A.py 中定义并保存了模型,在 B.py 中直接 torch.load,会提示 AttributeError

解决方案:

在加载代码前,必须确保模型类的定义是可访问的。

Python

复制代码
import torch

# 陷阱:必须把模型的定义复制过来,或者从其他文件 import 进来
class Tudui(torch.nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, 3)

    def forward(self, x):
        return self.conv1(x)

# 这样才能成功加载方式一保存的文件
model = torch.load("tudui_method1.pth")
print(model)

或者使用更优雅的导入方式:

Python

复制代码
from model_save import * # 确保 Tudui 类被正确引入
model = torch.load("tudui_method1.pth")

4. 总结:如何选择保存方式?

  • 方式一:适合快速实验,想省去重新定义结构的麻烦(但要记住加载时引入类定义)。
  • 方式二:适合生产环境,不仅体积小,且能更灵活地将参数加载到结构略有差异的模型中。

💡 学习小结

掌握了模型的保存与读取,你就真正拥有了"复用"训练成果的能力。无论是模型的版本控制,还是将模型部署到云端或移动端,这都是必经之路。

相关推荐
新知图书24 分钟前
LangGraph中的记忆存储
人工智能·langgraph·智能体设计·多智能体设计
冬奇Lab1 小时前
Claude Code 实战经验分享(上篇):从启动到并发协同
人工智能·ai编程·claude
minhuan1 小时前
多SKILL协同推理:双慢病联合决策:SKILL架构下糖尿病与高血压的协同诊疗体系.147
人工智能·慢病管理智能体·多skill协同推理·skill架构分析·双慢病决策
我叫张土豆1 小时前
从 SSE 到 Streamable HTTP:AI 时代的协议演进之路
人工智能·网络协议·http
冬奇Lab1 小时前
一天一个开源项目(第75篇):Hermes Agent - Nous Research 开源的自我进化 AI Agent
人工智能·开源·资讯
普密斯科技1 小时前
齿轮平面度与正反面智能检测方案:3D视觉技术破解精密制造品控难题
人工智能·计算机视觉·平面·3d·自动化·视觉检测
米猴设计师2 小时前
PS图案融合到褶皱布料上怎么弄?贴图教程
图像处理·人工智能·贴图·ps·nanobanana
123_不打狼2 小时前
基于UNET的语义分割
人工智能·语义分割
实在智能RPA2 小时前
Agent 如何处理流程中的异常情况?2026年AI Agent架构工程与自愈机制深度拆解
人工智能·ai·架构
十铭忘2 小时前
局部重绘3——FLUX-Fill的Lora训练
人工智能·深度学习·机器学习