PyTorch 模型保存与加载的三种常用方式

在深度学习的训练过程中,我们不可避免地要保存模型,这是一个非常好的习惯。接下来,文章将通过一个简单的神经网络模型,带你了解 PyTorch 中主要的模型保存与加载方式。

文章目录

为什么保存和加载模型很重要?

训练一个神经网络可能需要数小时甚至数天的时间,你需要认知到一点:时间是非常宝贵的,目前3090云服务器租赁一天的价格为 37.92 元。如果你的代码没有保存模型的模块,那就先不要开始,因为不保存基本等于没跑,你的效果再好也没有办法直接呈现给别人。如果你保存了模型,你就可以做到以下的事情:

  • 继续训练:通过保存检查点(checkpoint),你可以在意外中断后继续训练你的模型,这一点可能会节省你大量的时间。
  • 模型部署:训练好的模型可以被部署到生产环境中进行推理,比如 LLM,LoRA 等。
  • 分享模型:将训练好的模型分享给实验室其他成员或开源社区,以便进一步研究或复现结果。

代码示例

模型准备

为了演示,我们先定义一个简单的神经网络模型:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)  # 输入层到隐藏层
        self.fc2 = nn.Linear(128, 64)   # 隐藏层到隐藏层
        self.fc3 = nn.Linear(64, 10)    # 隐藏层到输出层

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 实例化模型和优化器
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

方法一:保存和加载整个模型

保存模型

python 复制代码
torch.save(model, 'model.pth')

加载模型

python 复制代码
model = torch.load('model.pth')
print(model)

输出

Net(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

这种方法非常简单直观,因为它保存了模型的整个结构和参数。

方法二:只保存模型的状态字典(state_dict)

保存模型状态字典

python 复制代码
torch.save(model.state_dict(), 'model_state_dict.pth')

加载模型状态字典

需要注意的是,加载state_dict时你需要手动重新实例化模型。

python 复制代码
model = Net()  # 你需要先定义好模型架构
model.load_state_dict(torch.load('model_state_dict.pth'))
print(model)

输出

Net(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

与保存整个模型相比,保存 state_dict 更加灵活,它只包含模型的参数,而不依赖于完整的模型定义,这意味着你可以在不同的项目中加载模型参数,甚至只加载部分模型的权重。举个例子,对于分类模型,即便你保存的是完整的网络参数,也可以仅导入特征提取层部分,当然,直接导入完整模型再拆分实际上是一样的。对于不完全匹配的模型,加载时可以通过设置 strict=False 来忽略某些不匹配的键:

python 复制代码
model.load_state_dict(torch.load('model_state_dict.pth'), strict=False)

这样,你可以灵活地只加载模型的某些部分。

使用 strict=False 加载模型

假设我们在原来的 Net 模型中新增了一个全连接层(fc4),此时如果我们直接加载之前保存的 state_dict,会因为 state_dict 中没有 fc4 的权重信息而导致报错。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

# 修改后的模型,新增了一层 fc4
class ModifiedNet(nn.Module):
    def __init__(self):
        super(ModifiedNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.fc4 = nn.Linear(10, 5)  # 新增的全连接层

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

# 实例化模型
modified_model = ModifiedNet()

# 尝试加载之前保存的 state_dict,但忽略不匹配的层
modified_model.load_state_dict(torch.load('model_state_dict.pth'), strict=False)

# 输出模型结构
print(modified_model)

输出

python 复制代码
ModifiedNet(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
  (fc4): Linear(in_features=10, out_features=5, bias=True)
)

如果不设置 strict=False,将会报错,提示缺少 fc4 的权重:

python 复制代码
RuntimeError: Error(s) in loading state_dict for ModifiedNet: Missing key(s) in state_dict: "fc4.weight", "fc4.bias". 

注意,减少层也可以使用 strict=False。例如,如果修改后的网络只保留前两层,仍然可以成功加载原始的 state_dict,并跳过缺失的部分。

方法三:保存完整的训练状态(checkpoint)

有时候,你可能不仅仅需要保存模型参数,还需要保存训练进度,比如当前的轮数、优化器状态等。此时可以使用检查点保存更多信息。

保存检查点

python 复制代码
torch.save({
    'epoch': 100,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': 0.01,
}, 'checkpoint.pth')

加载检查点

python 复制代码
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f"Epoch: {epoch}, Loss: {loss}")

输出:

python 复制代码
Epoch: 100, Loss: 0.01

这种方式适合长时间训练时,可以从中断的地方继续训练 。但文件体积相比前面会更大,具体原因见《7. 探究模型参数与显存的关系以及不同精度造成的影响》,加载过程也稍微复杂一些,我们可以写一个函数来打包这个过程。

定义 checkpont 保存和加载的函数

python 复制代码
def save_checkpoint(model, optimizer, epoch, loss, filepath='checkpoint.pth'):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, filepath)

def load_checkpoint(filepath, model, optimizer):
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['loss']

# 保存
save_checkpoint(model, optimizer, 100, 0.01)

# 加载
epoch, loss = load_checkpoint('checkpoint.pth', model, optimizer)
print(f"Loaded checkpoint at epoch {epoch} with loss {loss}")
相关推荐
_.Switch13 分钟前
高级Python自动化运维:容器安全与网络策略的深度解析
运维·网络·python·安全·自动化·devops
AI极客菌43 分钟前
Controlnet作者新作IC-light V2:基于FLUX训练,支持处理风格化图像,细节远高于SD1.5。
人工智能·计算机视觉·ai作画·stable diffusion·aigc·flux·人工智能作画
阿_旭1 小时前
一文读懂| 自注意力与交叉注意力机制在计算机视觉中作用与基本原理
人工智能·深度学习·计算机视觉·cross-attention·self-attention
王哈哈^_^1 小时前
【数据集】【YOLO】【目标检测】交通事故识别数据集 8939 张,YOLO道路事故目标检测实战训练教程!
前端·人工智能·深度学习·yolo·目标检测·计算机视觉·pyqt
测开小菜鸟1 小时前
使用python向钉钉群聊发送消息
java·python·钉钉
Power20246662 小时前
NLP论文速读|LongReward:基于AI反馈来提升长上下文大语言模型
人工智能·深度学习·机器学习·自然语言处理·nlp
数据猎手小k2 小时前
AIDOVECL数据集:包含超过15000张AI生成的车辆图像数据集,目的解决旨在解决眼水平分类和定位问题。
人工智能·分类·数据挖掘
好奇龙猫2 小时前
【学习AI-相关路程-mnist手写数字分类-win-硬件:windows-自我学习AI-实验步骤-全连接神经网络(BPnetwork)-操作流程(3) 】
人工智能·算法
沉下心来学鲁班2 小时前
复现LLM:带你从零认识语言模型
人工智能·语言模型
数据猎手小k2 小时前
AndroidLab:一个系统化的Android代理框架,包含操作环境和可复现的基准测试,支持大型语言模型和多模态模型。
android·人工智能·机器学习·语言模型