网络模型的保存和读取

1.保存

python 复制代码
import torch
import torchvision
from torch import nn

# 加载一个内置的 VGG16 模型结构(不带预训练权重)
vgg16 = torchvision.models.vgg16(pretrained=False)

# --- 保存方式 1:保存整个模型结构 + 模型参数 ---
# 这种方式不仅保存了权重,还保存了模型的定义。
# 优点:调用极其方便,加载时不需要重新定义模型类。
torch.save(vgg16, "vgg16_method1.pth")

# --- 保存方式 2:仅保存模型参数(官方推荐) ---
# .state_dict() 会将模型各层的权重和偏置提取为一个字典(OrderedDict)。
# 优点:文件体积小,更灵活,且在跨环境加载时更稳定。
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

# --- 陷阱示例:自定义模型 ---
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

tudui = Tudui()
# 使用方式 1 保存自定义模型
torch.save(tudui, "tudui_method1.pth")
1. 方式 1:整体保存(Model + Params)
  • 原理 :它是基于 Python 的 pickle 序列化机制。它尝试把整个对象丢进一个文件里。

  • 局限性 :虽然看似省事,但在不同的 Python 文件中恢复(load)这个模型时,如果当前命名空间里没有 vgg16 的类定义,程序会因为找不到对应的类而报错。

2. 方式 2:字典保存(State Dict)------ 官方推荐
  • 原理:它只记录每个神经元的"记忆"(权重数值),而不关心房子的"图纸"(网络结构)。

  • 加载逻辑

    1. 你得先亲手盖一个一样的"房子"(实例化模型)。

    2. 然后把"记忆"灌进去:model.load_state_dict(torch.load("path"))

  • 优势:它是工业界的标准做法。当你将模型从研究环境部署到生产环境(如 C++ 环境)时,这种纯字典格式是最安全的。

3. 关于那个"陷阱"的详解

代码末尾提到的 tudui_method1.pth 有一个隐患:

场景模拟 :如果你在 train.py 中运行了上面的代码,然后想在 test.py 中直接使用 torch.load("tudui_method1.pth") 而不在 test.py 中重新写一遍 class Tudui,你会得到一个 AttributeError

原因torch.save(model, path) 并没有把源代码存进去,它只存了一个"指向该类的引用"。如果加载文件时 Python 找不到 Tudui 这个类,它就不知道该如何还原这个对象。

2 加载

python 复制代码
import torch
from model_save import * # 核心:解决陷阱的关键
# 方式1-》保存方式1,加载模型
import torchvision
from torch import nn

# 加载方式1:直接加载整个模型对象
model = torch.load("vgg16_method1.pth")
# print(model) # 取消注释可查看 VGG16 结构

# 方式2:加载方式2,即仅加载参数
# 第一步:先创建一个一模一样的模型结构
vgg16 = torchvision.models.vgg16(pretrained=False)
# 第二步:将保存的参数字典加载到模型结构中
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# 注意:不能直接 model = torch.load("vgg16_method2.pth"),因为那样加载出来只是一个字典
# print(vgg16)

# --- 陷阱1 及其 解决方案 ---
# 当加载自定义模型(如 Tudui)时,必须让当前程序知道 Tudui 类的定义。

# 虽然这里把类定义注释掉了:
# class Tudui(nn.Module):
#     ...

# 但因为在代码最上方执行了 `from model_save import *`
# 只要 model_save.py 中有 Tudui 的定义,下面的加载就不会报错。
model = torch.load('tudui_method1.pth')
print(model)
1. 关键的导入
  • from model_save import * :这是这段代码能跑通的"大功臣"。在使用 torch.save(model, path)(方式1)保存自定义模型时,模型文件并不包含源代码。如果你在另一个文件中加载它,Python 必须知道这个类的结构。通过这一行,你把上一节课定义的 Tudui 类引入了当前空间。
2. 加载方式 1 (全量加载)
  • model = torch.load("vgg16_method1.pth")

    • 操作:一步到位。

    • 结果model 直接就是一个完整的、带权重的网络对象。

    • 前提:对于官方模型(如 VGG16),PyTorch 内置了定义,所以不需要额外 import 就能加载。

3. 加载方式 2 (参数加载)
  • vgg16 = torchvision.models.vgg16(...):先造一个"空壳"模型。

  • vgg16.load_state_dict(...)

    • 操作 :把 torch.load 出来的"权重字典"灌进"空壳"里。

    • 核心:这是最安全、最推荐的加载方式,因为它不依赖于复杂的对象序列化,只依赖于数值。

4. 攻克"陷阱"
  • model = torch.load('tudui_method1.pth')

    • 这是针对自定义模型的全量加载。

    • 如果你把代码第一行的 from model_save import * 删掉,这一行会立刻报错:AttributeError: Can't get attribute 'Tudui' on <module '__main__'>

    • 结论:方式 1 虽然看起来简单(不用手动实例化模型),但它要求你在加载时,必须保证类定义在当前环境下是可用的。

相关推荐
牧子川6 分钟前
009-Transformer-Architecture
人工智能·深度学习·transformer
covco25 分钟前
矩阵管理系统指南:拆解星链引擎的架构设计与全链路落地实践
大数据·人工智能·矩阵
沪漂阿龙29 分钟前
AI大模型面试题:支持向量机是什么?间隔最大化、软间隔、核函数、LinearSVC 全面拆解
人工智能·算法·支持向量机
lifewange30 分钟前
AI编写测试用例工具介绍
人工智能·测试用例
陕西字符33 分钟前
2026 西安 豆包获客优化技术深度解析:企来客科技 AI 全域获客系统测评
大数据·人工智能
掘金安东尼36 分钟前
GGUF、GPTQ、AWQ、EXL2、MLX、VMLX...运行大模型,为什么会有这么多格式?
人工智能
新知图书37 分钟前
市场分析报告自动化生成(使用千问)
人工智能·ai助手·千问·高效办公
无心水39 分钟前
【Hermes:安全、权限与生产环境】38、Hermes Agent 安全四层纵深:最小权限原则从理论到落地的完全指南
人工智能·安全·mcp协议·openclaw·养龙虾·hermes·honcho
旦莫1 小时前
AI驱动的纯视觉自动化测试:知识库里应该积累什么知识内容
人工智能·python·测试开发·pytest·ai测试
dfsj660111 小时前
第四章:深度学习革命
人工智能·深度学习