深度学习-Pytorch如何保存和加载模型

深度学习-Pytorch如何保存和加载模型

用pytorch构建模型,并训练模型,得到一个优化的模型,那么如何保存模型?然后如何又加载模型呢?

pytorch 目前在深度学习具有重要的地位,比起早先的caffe,tensorflow,keras越来越受到欢迎,其他的深度学习框架越来越显得小众。

数据分析

数据分析-Pandas如何转换产生新列

数据分析-Pandas如何统计数据概况

数据分析-Pandas如何轻松处理时间序列数据

数据分析-Pandas如何选择数据子集

数据分析-Pandas如何重塑数据表-CSDN博客

经典算法

经典算法-遗传算法的python实现

经典算法-模拟退火算法的python实现

经典算法-粒子群算法的python实现-CSDN博客

LLM应用

大模型查询工具助手之股票免费查询接口

Python技巧-终端屏幕打印光标和文字控制

如何保存模型

用pytorch构建模型,并训练模型,得到一个优化的模型,那么如何保存模型?

通常模型的信息很多,有些对使用没有用处,只需要保存感兴趣的参数部分状态就行,第二个参数就是模型路径和名称。

python 复制代码
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

如何加载模型

保存好模型后,如何加载模型,进行使用呢?

这里,需要加载模型的网络结构,当然也需要加载参数部分状态。

python 复制代码
model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth"))

加载模型后,如何使用呢?

这个模型是分类模型,把衣服分为10类,预测也是10类,如下,类似与前文的模型测试部分。

读者可以自行比较下:

[深度学习-Pytorch如何构建和训练模型-CSDN博客]

python 复制代码
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

Predicted: "Ankle boot", Actual: "Ankle boot"

以上代码只是一个简单示例,示例代码中的表达式可以根据实际问题进行修改。

觉得有用 收藏 收藏 收藏

点个赞 点个赞 点个赞

End


DeepLearning文章:

深度学习-Pytorch数据集构造和分批加载-CSDN博客

深度学习-Pytorch如何构建和训练模型-CSDN博客

GPT专栏文章:

GPT实战系列-ChatGLM3本地部署CUDA11+1080Ti+显卡24G实战方案

GPT实战系列-LangChain + ChatGLM3构建天气查询助手

大模型查询工具助手之股票免费查询接口

GPT实战系列-简单聊聊LangChain

GPT实战系列-大模型为我所用之借用ChatGLM3构建查询助手

GPT实战系列-P-Tuning本地化训练ChatGLM2等LLM模型,到底做了什么?(二)

GPT实战系列-P-Tuning本地化训练ChatGLM2等LLM模型,到底做了什么?(一)

GPT实战系列-ChatGLM2模型的微调训练参数解读

GPT实战系列-如何用自己数据微调ChatGLM2模型训练

GPT实战系列-ChatGLM2部署Ubuntu+Cuda11+显存24G实战方案

GPT实战系列-Baichuan2本地化部署实战方案

GPT实战系列-Baichuan2等大模型的计算精度与量化

GPT实战系列-GPT训练的Pretraining,SFT,Reward Modeling,RLHF

GPT实战系列-探究GPT等大模型的文本生成-CSDN博客

相关推荐
GocNeverGiveUp4 分钟前
机器学习2-NumPy
人工智能·机器学习·numpy
B站计算机毕业设计超人1 小时前
计算机毕业设计PySpark+Hadoop中国城市交通分析与预测 Python交通预测 Python交通可视化 客流量预测 交通大数据 机器学习 深度学习
大数据·人工智能·爬虫·python·机器学习·课程设计·数据可视化
学术头条1 小时前
清华、智谱团队:探索 RLHF 的 scaling laws
人工智能·深度学习·算法·机器学习·语言模型·计算语言学
18号房客1 小时前
一个简单的机器学习实战例程,使用Scikit-Learn库来完成一个常见的分类任务——**鸢尾花数据集(Iris Dataset)**的分类
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理·sklearn
feifeikon1 小时前
机器学习DAY3 : 线性回归与最小二乘法与sklearn实现 (线性回归完)
人工智能·机器学习·线性回归
游客5201 小时前
opencv中的常用的100个API
图像处理·人工智能·python·opencv·计算机视觉
古希腊掌管学习的神1 小时前
[机器学习]sklearn入门指南(2)
人工智能·机器学习·sklearn
Ven%1 小时前
如何在防火墙上指定ip访问服务器上任何端口呢
linux·服务器·网络·深度学习·tcp/ip
凡人的AI工具箱2 小时前
每天40分玩转Django:Django国际化
数据库·人工智能·后端·python·django·sqlite
IT猿手2 小时前
最新高性能多目标优化算法:多目标麋鹿优化算法(MOEHO)求解TP1-TP10及工程应用---盘式制动器设计,提供完整MATLAB代码
开发语言·深度学习·算法·机器学习·matlab·多目标算法