pytorch & tensorflow 保存和加载模型

1. Pytorch

1.1.1 save网络结构和参数:

注意最后一行为"self.state_dict()"

python 复制代码
    def save(self,t):
        current_path = os.path.dirname(os.path.abspath(__file__))
        model_path = 'model/2E_model_' + t + '_'+self.name+'/'

        save_path = os.path.join(current_path,model_path)
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        save_file_path=os.path.join(save_path, 'model.pth')

        torch.save(self.state_dict(),save_file_path)

1.1.2 对应的加载模型参数:

注意对应"agent.load_state_dict(checkpoint)"

python 复制代码
    def load(self,agent,model_path):
        model_pth = 'model.pth'
        model_path = os.path.join(model_path,model_pth)
        checkpoint = torch.load(model_path)
        agent.load_state_dict(checkpoint)
        agent.eval()

1.2.1 保存整个模型

注意为"torch.save(self.model,save_file_path)"

python 复制代码
    def save(self,t):
        current_path = os.path.dirname(os.path.abspath(__file__))
        model_path = 'model/model_' + t + '_'+self.name+'/'

        save_path = os.path.join(current_path,model_path)
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        save_file_path=os.path.join(save_path, 'model.pth')

        torch.save(self.model,save_file_path)

1.2.2 加载整个模型

注意"self.model = torch.load(model_path)"

python 复制代码
    def load(self,model_path):
        model_pth = 'model.pth'
        model_path = os.path.join(model_path,model_pth)
        self.model = torch.load(model_path)
        self.model.eval()

如果没对应上会报错:torch.nn.modules.module.ModuleAttributeError: object has no attribute 'copy',参考此链接

2. Tensorflow

2.1 保存模型

python 复制代码
    def save(self,time):
        current_path = os.path.dirname(os.path.abspath(__file__))
        model_path='model/model_'+time+'_'+self.name+'/weights_'+self.name
        save_path = os.path.join(current_path,model_path)
        if not os.path.exists(save_path):os.makedirs(save_path)
        self.saver.save(self.sess,save_path)

2.2 加载模型

python 复制代码
    def load(self,model_path):
        meta_path = 'weights_'+self.name+'.meta'

        mata_path_dir = os.path.join(model_path,meta_path)

        self.saver = tf.compat.v1.train.import_meta_graph(mata_path_dir)
        a=model_path+'/'
        self.saver.restore(self.sess, tf.train.latest_checkpoint(a))
相关推荐
大师兄带你刨AI11 分钟前
「AI产业」| 《2025中国低空经济商业洞察报告(商业无人机应用篇)》
大数据·人工智能
lul~11 分钟前
[科研理论]无人机底层控制算法PID、LQR、MPC解析
c++·人工智能·无人机
摆烂z15 分钟前
机器学习-黑马笔记
人工智能·笔记·机器学习
硅谷秋水22 分钟前
TASTE-Rob:推进面向任务的手-目标交互视频生成,实现可通用的机器人操作
人工智能·深度学习·机器学习·计算机视觉·机器人·交互
yzx99101325 分钟前
柑橘检测模型
服务器·人工智能·深度学习·算法
啊哈哈哈哈哈啊哈哈1 小时前
G1周打卡——GAN入门
pytorch·深度学习·生成对抗网络
神齐的小马1 小时前
机器学习 [白板推导](六)[核方法、指数族分布]
人工智能·机器学习
孚为智能科技1 小时前
集装箱残损识别系统如何检测残损?它的识别率能达到多少?
大数据·图像处理·人工智能·计算机视觉·视觉检测
小白学大数据1 小时前
爬取汽车之家评论并利用NLP进行关键词提取
人工智能·自然语言处理·汽车
biubiubiu07061 小时前
AI中的Prompt
人工智能·prompt