深度学习-Pytorch如何保存和加载模型
用pytorch构建模型,并训练模型,得到一个优化的模型,那么如何保存模型?然后如何又加载模型呢?
pytorch 目前在深度学习具有重要的地位,比起早先的caffe,tensorflow,keras越来越受到欢迎,其他的深度学习框架越来越显得小众。
数据分析
经典算法
LLM应用
如何保存模型
用pytorch构建模型,并训练模型,得到一个优化的模型,那么如何保存模型?
通常模型的信息很多,有些对使用没有用处,只需要保存感兴趣的参数部分状态就行,第二个参数就是模型路径和名称。
python
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")
如何加载模型
保存好模型后,如何加载模型,进行使用呢?
这里,需要加载模型的网络结构,当然也需要加载参数部分状态。
python
model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth"))
加载模型后,如何使用呢?
这个模型是分类模型,把衣服分为10类,预测也是10类,如下,类似与前文的模型测试部分。
读者可以自行比较下:
[深度学习-Pytorch如何构建和训练模型-CSDN博客]
python
classes = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
]
model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
x = x.to(device)
pred = model(x)
predicted, actual = classes[pred[0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')
Predicted: "Ankle boot", Actual: "Ankle boot"
以上代码只是一个简单示例,示例代码中的表达式可以根据实际问题进行修改。
觉得有用 收藏 收藏 收藏
点个赞 点个赞 点个赞
End
DeepLearning文章:
GPT专栏文章:
GPT实战系列-ChatGLM3本地部署CUDA11+1080Ti+显卡24G实战方案
GPT实战系列-LangChain + ChatGLM3构建天气查询助手
GPT实战系列-大模型为我所用之借用ChatGLM3构建查询助手
GPT实战系列-P-Tuning本地化训练ChatGLM2等LLM模型,到底做了什么?(二)
GPT实战系列-P-Tuning本地化训练ChatGLM2等LLM模型,到底做了什么?(一)
GPT实战系列-ChatGLM2部署Ubuntu+Cuda11+显存24G实战方案