114_PyTorch 进阶:模型保存与读取的两大方式及“陷阱”避坑指南

在 PyTorch 中,保存和加载模型主要有两种方式。虽然都能达到目的,但它们在存储内容和使用便捷性上有着本质的区别。本文将结合实战代码,带你掌握这些核心技巧。

1. 方式一:保存整个网络模型结构与参数

这种方式最为直观,它会把模型的结构定义权重参数全部打包保存到一个文件中。

代码实现:

Python

复制代码
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式一:保存模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

# 加载方式一:
model = torch.load("vgg16_method1.pth")
print(model)
  • 优点:加载方便,一行代码搞定。
  • 缺点:文件体积较大;存在"陷阱"(见下文)。

2. 方式二:只保存模型参数(官方推荐)

这种方式只保存模型中每个层的权重和偏置(即 state_dict)。加载时需要先创建一个模型实例,再将参数"填"进去。

代码实现:

Python

复制代码
# 保存方式二:将模型参数保存为字典格式(官方推荐,体积小)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

# 加载方式二:
# 1. 必须先新建一个网络模型实例
vgg16_new = torchvision.models.vgg16(pretrained=False)
# 2. 将保存的参数字典加载到模型中
vgg16_new.load_state_dict(torch.load("vgg16_method2.pth"))

print(vgg16_new)
  • 优点:更安全、更节省空间,是工业界和官方推荐的标准做法。

3. 核心避坑指南:自定义模型的加载"陷阱"

文件特别提到了一个新手常犯的错误:当你用"方式一"加载自定义模型时,如果当前环境没有该模型的类定义,程序会报错。

错误场景:

如果你在 A.py 中定义并保存了模型,在 B.py 中直接 torch.load,会提示 AttributeError

解决方案:

在加载代码前,必须确保模型类的定义是可访问的。

Python

复制代码
import torch

# 陷阱:必须把模型的定义复制过来,或者从其他文件 import 进来
class Tudui(torch.nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, 3)

    def forward(self, x):
        return self.conv1(x)

# 这样才能成功加载方式一保存的文件
model = torch.load("tudui_method1.pth")
print(model)

或者使用更优雅的导入方式:

Python

复制代码
from model_save import * # 确保 Tudui 类被正确引入
model = torch.load("tudui_method1.pth")

4. 总结:如何选择保存方式?

  • 方式一:适合快速实验,想省去重新定义结构的麻烦(但要记住加载时引入类定义)。
  • 方式二:适合生产环境,不仅体积小,且能更灵活地将参数加载到结构略有差异的模型中。

💡 学习小结

掌握了模型的保存与读取,你就真正拥有了"复用"训练成果的能力。无论是模型的版本控制,还是将模型部署到云端或移动端,这都是必经之路。

相关推荐
CoovallyAIHub1 小时前
把 Whisper、Moonshine、SenseVoice 统统装进手机:sherpa-onnx 离线语音部署框架,GitHub 10.9K Star
人工智能·架构
一只叫煤球的猫1 小时前
RAG 如何落地?从原理解释到工程实现
人工智能·后端·ai编程
AI营销快线1 小时前
AI营销获客难?原圈科技深度解析SaaS系统增长之道
大数据·人工智能
南滑散修2 小时前
机器学习(四):混合高斯模型GMM
人工智能·机器学习
柯儿的天空2 小时前
Mem0深度解析:给你的ai agent加上长期记忆,让ai从“健忘“到“过目不忘“
人工智能·gpt·自然语言处理·ai作画·aigc·ai编程·agi
FluxMelodySun2 小时前
机器学习(二十五) 降维:主成分分析(PCA)及特征值分解
人工智能·算法·机器学习
Cosolar2 小时前
Transformer训练与生成背后的数学基础
人工智能·后端·开源
GoCoding2 小时前
Triton + RISC-V
pytorch·openai·编译器