Pytorch笔记一之 cpu模型保存、加载与推理

Pytorch笔记一之 cpu模型保存、加载与推理

1.保存模型

首先,在加载模型之前,我们需要了解如何保存模型。PyTorch 提供了两种保存模型的方法:保存整个模型和仅保存模型的状态字典(state dict)。推荐使用第二种方式,因为它更灵活且体积较小。

python 复制代码
import torch
import torch.nn as nn

# 定义一个简单的神经网络
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# 实例化模型并进行训练
model = SimpleNN()
# 模型训练过程(省略)

# 保存模型的状态字典
torch.save(model.state_dict(), 'simple_nn.pth')

2. 加载模型

一旦你保存了模型,接下来就可以加载它。在加载过程中,确保模型的架构与训练时一致。以下是加载模型的步骤:

  • 1.创建一个模型实例
  • 2.调用 load_state_dict() 方法加载状态字典

代码示例如下:

python 复制代码
# 重新定义模型架构
model = SimpleNN()

# 加载模型状态字典
model.load_state_dict(torch.load('simple_nn.pth', map_location=torch.device('cpu')))

3. 在 CPU 上进行推理

完成模型加载后,接下来就可以使用模型进行推理。以下是一个简单的示例:

python 复制代码
# 模拟输入数据
input_data = torch.randn(1, 10)

# 在 CPU 上进行推理
with torch.no_grad():  # 禁用梯度计算,节省内存
    output = model(input_data)

print(output)
相关推荐
猫猫与橙子10 小时前
记录使用AI工具来完成中文形近字识别
人工智能
Eric.Lee202110 小时前
机器人:sim2real 技术必要性
人工智能·深度学习·机器人·机器人仿真·mujoco·sim2real
江上鹤.14810 小时前
Day 49 预训练模型
人工智能·深度学习·机器学习
zuozewei10 小时前
7D-AI系列:Transformer 与深度学习核心概念
人工智能·深度学习·transformer
QT 小鲜肉11 小时前
【Linux命令大全】001.文件管理之mattrib命令(实操篇)
linux·运维·服务器·chrome·笔记
乐迪信息11 小时前
乐迪信息:异物入侵识别算法上线,AI摄像机保障智慧煤矿生产稳定
大数据·运维·人工智能·物联网·安全
CareyWYR11 小时前
每周AI论文速递(251222-251226)
人工智能
玄同76511 小时前
Python 真零基础入门:从 “什么是编程” 到 LLM Prompt 模板生成
人工智能·python·语言模型·自然语言处理·llm·nlp·prompt
虹科网络安全11 小时前
艾体宝洞察 | 生成式AI上线倒计时:Redis如何把“延迟”与“幻觉”挡在生产线之外?
数据库·人工智能·redis
Java后端的Ai之路11 小时前
【神经网络基础】-深度学习框架学习指南
人工智能·深度学习·神经网络·机器学习