网络模型的保存和读取

1.保存

python 复制代码
import torch
import torchvision
from torch import nn

# 加载一个内置的 VGG16 模型结构(不带预训练权重)
vgg16 = torchvision.models.vgg16(pretrained=False)

# --- 保存方式 1:保存整个模型结构 + 模型参数 ---
# 这种方式不仅保存了权重,还保存了模型的定义。
# 优点:调用极其方便,加载时不需要重新定义模型类。
torch.save(vgg16, "vgg16_method1.pth")

# --- 保存方式 2:仅保存模型参数(官方推荐) ---
# .state_dict() 会将模型各层的权重和偏置提取为一个字典(OrderedDict)。
# 优点:文件体积小,更灵活,且在跨环境加载时更稳定。
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

# --- 陷阱示例:自定义模型 ---
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

tudui = Tudui()
# 使用方式 1 保存自定义模型
torch.save(tudui, "tudui_method1.pth")
1. 方式 1:整体保存(Model + Params)
  • 原理 :它是基于 Python 的 pickle 序列化机制。它尝试把整个对象丢进一个文件里。

  • 局限性 :虽然看似省事,但在不同的 Python 文件中恢复(load)这个模型时,如果当前命名空间里没有 vgg16 的类定义,程序会因为找不到对应的类而报错。

2. 方式 2:字典保存(State Dict)------ 官方推荐
  • 原理:它只记录每个神经元的"记忆"(权重数值),而不关心房子的"图纸"(网络结构)。

  • 加载逻辑

    1. 你得先亲手盖一个一样的"房子"(实例化模型)。

    2. 然后把"记忆"灌进去:model.load_state_dict(torch.load("path"))

  • 优势:它是工业界的标准做法。当你将模型从研究环境部署到生产环境(如 C++ 环境)时,这种纯字典格式是最安全的。

3. 关于那个"陷阱"的详解

代码末尾提到的 tudui_method1.pth 有一个隐患:

场景模拟 :如果你在 train.py 中运行了上面的代码,然后想在 test.py 中直接使用 torch.load("tudui_method1.pth") 而不在 test.py 中重新写一遍 class Tudui,你会得到一个 AttributeError

原因torch.save(model, path) 并没有把源代码存进去,它只存了一个"指向该类的引用"。如果加载文件时 Python 找不到 Tudui 这个类,它就不知道该如何还原这个对象。

2 加载

python 复制代码
import torch
from model_save import * # 核心:解决陷阱的关键
# 方式1-》保存方式1,加载模型
import torchvision
from torch import nn

# 加载方式1:直接加载整个模型对象
model = torch.load("vgg16_method1.pth")
# print(model) # 取消注释可查看 VGG16 结构

# 方式2:加载方式2,即仅加载参数
# 第一步:先创建一个一模一样的模型结构
vgg16 = torchvision.models.vgg16(pretrained=False)
# 第二步:将保存的参数字典加载到模型结构中
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# 注意:不能直接 model = torch.load("vgg16_method2.pth"),因为那样加载出来只是一个字典
# print(vgg16)

# --- 陷阱1 及其 解决方案 ---
# 当加载自定义模型(如 Tudui)时,必须让当前程序知道 Tudui 类的定义。

# 虽然这里把类定义注释掉了:
# class Tudui(nn.Module):
#     ...

# 但因为在代码最上方执行了 `from model_save import *`
# 只要 model_save.py 中有 Tudui 的定义,下面的加载就不会报错。
model = torch.load('tudui_method1.pth')
print(model)
1. 关键的导入
  • from model_save import * :这是这段代码能跑通的"大功臣"。在使用 torch.save(model, path)(方式1)保存自定义模型时,模型文件并不包含源代码。如果你在另一个文件中加载它,Python 必须知道这个类的结构。通过这一行,你把上一节课定义的 Tudui 类引入了当前空间。
2. 加载方式 1 (全量加载)
  • model = torch.load("vgg16_method1.pth")

    • 操作:一步到位。

    • 结果model 直接就是一个完整的、带权重的网络对象。

    • 前提:对于官方模型(如 VGG16),PyTorch 内置了定义,所以不需要额外 import 就能加载。

3. 加载方式 2 (参数加载)
  • vgg16 = torchvision.models.vgg16(...):先造一个"空壳"模型。

  • vgg16.load_state_dict(...)

    • 操作 :把 torch.load 出来的"权重字典"灌进"空壳"里。

    • 核心:这是最安全、最推荐的加载方式,因为它不依赖于复杂的对象序列化,只依赖于数值。

4. 攻克"陷阱"
  • model = torch.load('tudui_method1.pth')

    • 这是针对自定义模型的全量加载。

    • 如果你把代码第一行的 from model_save import * 删掉,这一行会立刻报错:AttributeError: Can't get attribute 'Tudui' on <module '__main__'>

    • 结论:方式 1 虽然看起来简单(不用手动实例化模型),但它要求你在加载时,必须保证类定义在当前环境下是可用的。

相关推荐
AI 编程助手GPT3 小时前
【实战】Codex 接管电脑 + Claude Routines 云端值守:一次 Bug 排查的“无人化”闭环
人工智能·gpt·ai·chatgpt·bug
UltraLAB-F3 小时前
有限元分析内存需求深度解析:刚度矩阵、求解器与硬件配置
人工智能·ai·硬件架构
MediaTea3 小时前
Scikit-learn:特征矩阵与目标变量
人工智能·python·机器学习·矩阵·scikit-learn
qyr67893 小时前
全球AI服务器DAC线缆市场发展趋势与未来趋势展望
大数据·人工智能·数据分析·汽车·ai服务器·ai服务器dac线缆
郝学胜-神的一滴3 小时前
深度学习入门:极简神经网络搭建与参数计算全攻略
人工智能·pytorch·python·深度学习·神经网络·机器学习
重生之我要成为代码大佬3 小时前
pytorch与视觉检测
人工智能·pytorch·深度学习·大模型·视觉检测
初圣魔门首席弟子3 小时前
深度学习:学习率(Learning Rate)超通俗讲解
人工智能
LONGZETECH3 小时前
破解汽车实训难题!龙泽科技仿真软件,助力院校教学与大赛备赛
人工智能·科技·架构·汽车·汽车仿真教学软件
workflower3 小时前
机器人城市应用-室外总坪清扫
运维·人工智能·机器人·集成测试·人机交互·软件需求
前端不太难3 小时前
鸿蒙游戏 + AI:自动测试与自动发布
人工智能·游戏·harmonyos