1.保存
python
import torch
import torchvision
from torch import nn
# 加载一个内置的 VGG16 模型结构(不带预训练权重)
vgg16 = torchvision.models.vgg16(pretrained=False)
# --- 保存方式 1:保存整个模型结构 + 模型参数 ---
# 这种方式不仅保存了权重,还保存了模型的定义。
# 优点:调用极其方便,加载时不需要重新定义模型类。
torch.save(vgg16, "vgg16_method1.pth")
# --- 保存方式 2:仅保存模型参数(官方推荐) ---
# .state_dict() 会将模型各层的权重和偏置提取为一个字典(OrderedDict)。
# 优点:文件体积小,更灵活,且在跨环境加载时更稳定。
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
# --- 陷阱示例:自定义模型 ---
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
tudui = Tudui()
# 使用方式 1 保存自定义模型
torch.save(tudui, "tudui_method1.pth")
1. 方式 1:整体保存(Model + Params)
-
原理 :它是基于 Python 的
pickle序列化机制。它尝试把整个对象丢进一个文件里。 -
局限性 :虽然看似省事,但在不同的 Python 文件中恢复(load)这个模型时,如果当前命名空间里没有
vgg16的类定义,程序会因为找不到对应的类而报错。
2. 方式 2:字典保存(State Dict)------ 官方推荐
-
原理:它只记录每个神经元的"记忆"(权重数值),而不关心房子的"图纸"(网络结构)。
-
加载逻辑:
-
你得先亲手盖一个一样的"房子"(实例化模型)。
-
然后把"记忆"灌进去:
model.load_state_dict(torch.load("path"))。
-
-
优势:它是工业界的标准做法。当你将模型从研究环境部署到生产环境(如 C++ 环境)时,这种纯字典格式是最安全的。
3. 关于那个"陷阱"的详解
代码末尾提到的 tudui_method1.pth 有一个隐患:
场景模拟 :如果你在
train.py中运行了上面的代码,然后想在test.py中直接使用torch.load("tudui_method1.pth")而不在test.py中重新写一遍class Tudui,你会得到一个 AttributeError。
原因 : torch.save(model, path) 并没有把源代码存进去,它只存了一个"指向该类的引用"。如果加载文件时 Python 找不到 Tudui 这个类,它就不知道该如何还原这个对象。
2 加载
python
import torch
from model_save import * # 核心:解决陷阱的关键
# 方式1-》保存方式1,加载模型
import torchvision
from torch import nn
# 加载方式1:直接加载整个模型对象
model = torch.load("vgg16_method1.pth")
# print(model) # 取消注释可查看 VGG16 结构
# 方式2:加载方式2,即仅加载参数
# 第一步:先创建一个一模一样的模型结构
vgg16 = torchvision.models.vgg16(pretrained=False)
# 第二步:将保存的参数字典加载到模型结构中
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# 注意:不能直接 model = torch.load("vgg16_method2.pth"),因为那样加载出来只是一个字典
# print(vgg16)
# --- 陷阱1 及其 解决方案 ---
# 当加载自定义模型(如 Tudui)时,必须让当前程序知道 Tudui 类的定义。
# 虽然这里把类定义注释掉了:
# class Tudui(nn.Module):
# ...
# 但因为在代码最上方执行了 `from model_save import *`
# 只要 model_save.py 中有 Tudui 的定义,下面的加载就不会报错。
model = torch.load('tudui_method1.pth')
print(model)
1. 关键的导入
from model_save import *:这是这段代码能跑通的"大功臣"。在使用torch.save(model, path)(方式1)保存自定义模型时,模型文件并不包含源代码。如果你在另一个文件中加载它,Python 必须知道这个类的结构。通过这一行,你把上一节课定义的Tudui类引入了当前空间。
2. 加载方式 1 (全量加载)
-
model = torch.load("vgg16_method1.pth"):-
操作:一步到位。
-
结果 :
model直接就是一个完整的、带权重的网络对象。 -
前提:对于官方模型(如 VGG16),PyTorch 内置了定义,所以不需要额外 import 就能加载。
-
3. 加载方式 2 (参数加载)
-
vgg16 = torchvision.models.vgg16(...):先造一个"空壳"模型。 -
vgg16.load_state_dict(...):-
操作 :把
torch.load出来的"权重字典"灌进"空壳"里。 -
核心:这是最安全、最推荐的加载方式,因为它不依赖于复杂的对象序列化,只依赖于数值。
-
4. 攻克"陷阱"
-
model = torch.load('tudui_method1.pth'):-
这是针对自定义模型的全量加载。
-
如果你把代码第一行的
from model_save import *删掉,这一行会立刻报错:AttributeError: Can't get attribute 'Tudui' on <module '__main__'>。 -
结论:方式 1 虽然看起来简单(不用手动实例化模型),但它要求你在加载时,必须保证类定义在当前环境下是可用的。
-