PyTorch|保存与加载自己的模型

训练好一个模型之后,我们往往要对其进行保存,除非下次用时想再次训练一遍。

下面以一个简单的回归任务来详细讲解模型的保存和加载。

来看这样一组数据:

复制代码
x=torch.linspace(-1,1,50)x=x.view(50,1)y=x.pow(2)+0.3*torch.rand(50).view(50,1)

画图:

复制代码
plt.scatter(x.numpy(),y.numpy())

很显然,x与y基本呈二次函数关系 ,那么接下来我们就来拟合整个函数

复制代码
import torchimport matplotlib.pyplot as pltimport torch.nn as nnimport torch.optim as optim

x=torch.linspace(-1,1,50)x=x.view(50,1)y=x.pow(2)+0.3*torch.rand(50).view(50,1)net1=nn.Sequential(nn.Linear(1,10),                  nn.ReLU(),                  nn.Linear(10,1))criterion=nn.MSELoss()optimizer=optim.SGD(net1.parameters(),lr=0.2)#训练模型for i in range(1000):    pred=net1(x)    loss=criterion(pred,y)    optimizer.zero_grad()    loss.backward()    optimizer.step()
#测试模型net1.eval()with torch.no_grad():    y1=net1(x)    plt.plot(x.numpy(),y1.numpy(),'r-')    plt.scatter(x.numpy(),y.numpy())

结果似乎不错!

这里我们得到了一个网络net1,它可以被当作一个二次函数,用于描述之前的x,y数据的关系

得到这个网络后,我们想保存它,主要有两种方式

1,保存整个网络,包括训练后的各个层的参数

​​​​​​​

复制代码
#保存整个网络,包括训练后的各个层的参数torch.save(net1,'net1weight.pkl')

2,只保存训练好的网络的参数,速度更快

​​​​​​​

复制代码
#只保存训练好的网络的参数,速度更快torch.save(net1.state_dict(),'net1_params.pkl')

假设我们按第一种方式保存,那么下次想要使用次网络时需要这样做:

复制代码
network=torch.load('net1weight.pkl')

#测试模型network.eval()with torch.no_grad():    y1=network(x)    plt.plot(x.numpy(),y1.numpy(),'b-')    plt.scatter(x.numpy(),y.numpy())

假设我们按第二种方式保存,那么下次想要使用次网络时需要这样做:

复制代码
network=nn.Sequential(nn.Linear(1,10),                  nn.ReLU(),                  nn.Linear(10,1))network.load_state_dict(torch.load('net1_params.pkl'))​​​​​​​

#测试模型network.eval()with torch.no_grad():    y1=network(x)    plt.plot(x.numpy(),y1.numpy(),'g-')    plt.scatter(x.numpy(),y.numpy())

可以看出,第二次首先需要构造出一个一模一样的模型,接着再导入参数即可。当然,这只是个简单的回归模型,其它模型保存与加载同样如此。

总结一下:

模型保存与导入有两种方式:

方式一:​​​​​​​

复制代码
#模型保存torch.save(net1,'net1weight.pkl')#模型导入network=torch.load('net1weight.pkl')

方式二:​​​​​​​

复制代码
#模型保存torch.save(net1.state_dict(),'net1_params.pkl')#模型导入network.load_state_dict(torch.load('net1_params.pkl'))
相关推荐
产品何同学12 小时前
情绪经济下的AI应用怎么设计?6个APP原型设计案例拆解
人工智能·ai·产品经理·交友·ai应用·ai伴侣·情绪经济
猿小猴子12 小时前
主流 AI 大模型开源平台社区之一的 ModelScope (魔搭社区) 介绍
人工智能
LCHub低代码社区13 小时前
ima Copilot任务模式PPT生成功能全流程评测
人工智能·ai智能体·大禹智库
珑墨13 小时前
【大语言模型】从历史到未来
前端·人工智能·后端·ai·语言模型·自然语言处理·chatgpt
通义灵码13 小时前
使用Qoder开发一个AI皮肤分析小程序
人工智能·小程序
罗政13 小时前
【Excel批处理】一键批量AI提取身份证信息到excel表格,数据安全,支持断网使用
人工智能·excel
致Great13 小时前
使用 GRPO 和 OpenEnv 微调小型语言模型实现浏览器控制
数据库·人工智能·深度学习·语言模型·自然语言处理·agent·智能体
最新快讯13 小时前
快讯 | AI教父辛顿:人工智能明年将对大量工作岗位产生实质性冲击;Copilot整合效果不佳
人工智能·科技
说私域13 小时前
分享经济:智能名片链动2+1模式商城小程序驱动下的可持续增长引擎
大数据·人工智能·小程序