pytorch小记(七):pytorch中的保存/加载模型操作

pytorch小记(七):pytorch中的保存/加载模型操作

  • [1. 加载模型参数 (`state_dict`)](#1. 加载模型参数 (state_dict))
    • [1.1 保存模型参数](#1.1 保存模型参数)
    • [1.2 加载模型参数](#1.2 加载模型参数)
    • [1.3 常见变种](#1.3 常见变种)
      • [1.3.1 指定加载设备](#1.3.1 指定加载设备)
      • [1.3.2 非严格加载(跳过部分层)](#1.3.2 非严格加载(跳过部分层))
      • [1.3.3 打印加载的参数](#1.3.3 打印加载的参数)
  • [2. 加载整个模型](#2. 加载整个模型)
    • [2.1 保存整个模型](#2.1 保存整个模型)
    • [2.2 加载整个模型](#2.2 加载整个模型)
    • [2.3 注意事项](#2.3 注意事项)
  • [3. 总结](#3. 总结)
  • [4. 加载模型的完整代码示例](#4. 加载模型的完整代码示例)
    • [4.1 保存和加载参数](#4.1 保存和加载参数)
    • [4.2 保存和加载整个模型](#4.2 保存和加载整个模型)
    • [4.3 加载到不同设备](#4.3 加载到不同设备)
    • [4.4 忽略部分参数(非严格加载)](#4.4 忽略部分参数(非严格加载))
    • [5. 检查模型是否加载成功](#5. 检查模型是否加载成功)

在 PyTorch 中,加载模型通常分为两种情况:加载模型参数(state_dict)加载整个模型。以下是加载模型的所有相关操作及其详细步骤:


1. 加载模型参数 (state_dict)

当仅保存了模型的参数时(使用 model.state_dict() 保存),加载模型的步骤如下:

1.1 保存模型参数

python 复制代码
torch.save(model.state_dict(), 'model.pth')
  • 文件内容:只保存模型的参数(权重和偏置)。
  • 优点
    • 节省存储空间。
    • 灵活性更高,可以与不同的模型架构配合使用。
  • 缺点
    • 需要手动重新定义模型结构。

1.2 加载模型参数

  1. 重新定义模型架构:

    python 复制代码
    model = MyModel()  # 替换为你的模型类
  2. 加载参数:

    python 复制代码
    state_dict = torch.load('model.pth')  # 加载参数字典
    model.load_state_dict(state_dict)    # 加载参数到模型
  3. 选择运行设备:

    python 复制代码
    model.to('cuda')  # 如果需要运行在 GPU 上

1.3 常见变种

1.3.1 指定加载设备

  • 如果保存时模型在 GPU 上,而加载时在 CPU 环境中,可以使用 map_location

    python 复制代码
    state_dict = torch.load('model.pth', map_location='cpu')

1.3.2 非严格加载(跳过部分层)

  • 如果保存的参数与模型结构不完全匹配(例如额外的层或不同的顺序),可以使用 strict=False

    python 复制代码
    model.load_state_dict(state_dict, strict=False)

1.3.3 打印加载的参数

  • 可以检查参数字典的内容:

    python 复制代码
    print(state_dict.keys())

2. 加载整个模型

当模型是通过 torch.save(model) 保存时,文件包含了模型的结构和参数,加载更为简单。

2.1 保存整个模型

python 复制代码
torch.save(model, 'model_full.pth')
  • 文件内容:包含模型的架构和参数。
  • 优点
    • 无需重新定义模型结构。
    • 直接加载并使用。
  • 缺点
    • 文件依赖于保存时的代码版本(如模型定义)。
    • 文件体积较大。

2.2 加载整个模型

python 复制代码
model = torch.load('model_full.pth')
model.to('cuda')  # 如果需要在 GPU 上运行

2.3 注意事项

  • 动态定义的模型
    • 如果模型结构是动态定义的(如包含条件逻辑),保存和加载整个模型可能会依赖于代码的一致性。
    • 确保在加载时导入了与保存时相同的模型类。

3. 总结

操作 使用场景 优点 缺点
保存参数 (state_dict) 推荐大多数情况 文件小、灵活性高 需要手动定义模型架构
保存整个模型 模型复杂且固定时 不需要重新定义模型,直接加载 文件大、依赖保存时的代码版本

4. 加载模型的完整代码示例

4.1 保存和加载参数

python 复制代码
import torch
import torch.nn as nn

# 定义模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# 保存参数
model = MyModel()
torch.save(model.state_dict(), 'model.pth')

# 加载参数
model = MyModel()  # 重新定义模型
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
model.to('cuda')  # 运行在 GPU

4.2 保存和加载整个模型

python 复制代码
# 保存整个模型
torch.save(model, 'model_full.pth')

# 加载整个模型
model = torch.load('model_full.pth')
model.to('cuda')  # 运行在 GPU

4.3 加载到不同设备

python 复制代码
# 保存参数
torch.save(model.state_dict(), 'model.pth')

# 加载到 CPU
state_dict = torch.load('model.pth', map_location='cpu')
model.load_state_dict(state_dict)

# 加载到 GPU
model.to('cuda')

4.4 忽略部分参数(非严格加载)

python 复制代码
# 保存参数
torch.save(model.state_dict(), 'model.pth')

# 加载参数(非严格模式)
model = MyModel()
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict, strict=False)

5. 检查模型是否加载成功

  1. 验证权重是否加载

    python 复制代码
    for name, param in model.named_parameters():
        print(f"{name}: {param.data}")
  2. 进行推理验证

    python 复制代码
    x = torch.randn(1, 10).to('cuda')  # 假设输入维度为 10
    output = model(x)
    print(output)

通过以上操作,你可以灵活加载 PyTorch 模型,无论是仅加载参数还是加载整个模型结构和权重。

相关推荐
Coder_Boy_6 分钟前
Spring AI 源码核心分析
java·人工智能·spring
net3m337 分钟前
websocket下发mp3帧数据时一个包被分包为几个子包而导致mp3解码失败而播放卡顿有杂音或断播的解决方法
开发语言·数据库·python
java1234_小锋8 分钟前
[免费]基于Python的天气预报(天气预测分析)(Django+sklearn机器学习+selenium爬虫)可视化系统【论文+源码+SQL脚本】
爬虫·python·selenium·天气预报·天气预测
雪花desu8 分钟前
GraphRAG
人工智能
Qhumaing9 分钟前
解决因为jupyter notebook修改路径下没有c.NotebookApp.notebook_dir而无法修改目录问题
ide·python·jupyter
38242782710 分钟前
python3网络爬虫开发实战 第2版:使用aiohttp
开发语言·爬虫·python
云老大TG:@yunlaoda36012 分钟前
华为云国际站代理商MSGSMS的服务质量如何?
大数据·数据库·人工智能·华为云
Elaine33614 分钟前
基于 Qwen2.5 与 LLaMA-Factory 的 LoRA 微调实战
人工智能·lora·微调·llama·llama-factory
热爱专研AI的学妹23 分钟前
【高级教程】联网搜索网页阅读api使用cURL从接口调试到复杂场景实战
服务器·数据库·人工智能·搜索引擎
Yuer202523 分钟前
为什么要用rust做算子执行引擎
人工智能·算法·数据挖掘·rust