[Pytorch] 保存模型与加载模型

1、保存模型

python 复制代码
# 定义模型
model = BPNetModel(n_feature=n_feature,n_hidden=n_hidden,n_output=n_output) #调用网络

# 保存模型
torch.save(model, 'BPNetModel0.pth')

2、加载模型

python 复制代码
import torch

## 读取模型
model = torch.load('BPNetModel0.pth')

3、保存模型参数

python 复制代码
 #调用网络
model = BPNetModel(n_feature=n_feature,n_hidden=n_hidden,n_output=n_output)

# 保存模型
torch.save({'model': model.state_dict()}, 'BPNetModel0.pth')

4、加载参数

python 复制代码
# 读取模型
state_dict = torch.load('model_name.pth')
model.load_state_dict(state_dict['model'])
相关推荐
大任视点18 小时前
可梦AI获首批企业好评,蜜糖网络入驻共启AI短剧工业化
人工智能
鸢尾掠地平18 小时前
Python中常用内置函数上【含代码理解】
开发语言·python
萧鼎18 小时前
Python 图像处理利器:Pillow 深度详解与实战应用
图像处理·python·pillow
高洁0118 小时前
大模型-详解 Vision Transformer (ViT)
人工智能·python·深度学习·算法·transformer
科技峰行者18 小时前
亚马逊云科技与OpenAI战略合作深度分析:算力联盟重塑AI产业格局
人工智能
说私域19 小时前
O2O行业风口下的运营策略与定制开发AI智能名片S2B2C商城小程序的应用研究
人工智能·小程序
慕慕涵雪月光白19 小时前
在Ubuntu系统上安装英伟达(NVIDIA)RTX 3070 Ti的驱动程序
linux·运维·人工智能·ubuntu
柳鲲鹏19 小时前
OpenCV:BGR/RGB转I420(颜色失真),再转NV12
人工智能·opencv·计算机视觉
无风听海19 小时前
神经网络之线性变换
人工智能·深度学习·神经网络
陈果然DeepVersion19 小时前
Java大厂面试真题:Spring Boot+Kafka+AI智能客服场景全流程解析(九)
java·人工智能·spring boot·微服务·kafka·面试题·rag