pytorch & tensorflow 保存和加载模型

1. Pytorch

1.1.1 save网络结构和参数:

注意最后一行为"self.state_dict()"

python 复制代码
    def save(self,t):
        current_path = os.path.dirname(os.path.abspath(__file__))
        model_path = 'model/2E_model_' + t + '_'+self.name+'/'

        save_path = os.path.join(current_path,model_path)
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        save_file_path=os.path.join(save_path, 'model.pth')

        torch.save(self.state_dict(),save_file_path)

1.1.2 对应的加载模型参数:

注意对应"agent.load_state_dict(checkpoint)"

python 复制代码
    def load(self,agent,model_path):
        model_pth = 'model.pth'
        model_path = os.path.join(model_path,model_pth)
        checkpoint = torch.load(model_path)
        agent.load_state_dict(checkpoint)
        agent.eval()

1.2.1 保存整个模型

注意为"torch.save(self.model,save_file_path)"

python 复制代码
    def save(self,t):
        current_path = os.path.dirname(os.path.abspath(__file__))
        model_path = 'model/model_' + t + '_'+self.name+'/'

        save_path = os.path.join(current_path,model_path)
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        save_file_path=os.path.join(save_path, 'model.pth')

        torch.save(self.model,save_file_path)

1.2.2 加载整个模型

注意"self.model = torch.load(model_path)"

python 复制代码
    def load(self,model_path):
        model_pth = 'model.pth'
        model_path = os.path.join(model_path,model_pth)
        self.model = torch.load(model_path)
        self.model.eval()

如果没对应上会报错:torch.nn.modules.module.ModuleAttributeError: object has no attribute 'copy',参考此链接

2. Tensorflow

2.1 保存模型

python 复制代码
    def save(self,time):
        current_path = os.path.dirname(os.path.abspath(__file__))
        model_path='model/model_'+time+'_'+self.name+'/weights_'+self.name
        save_path = os.path.join(current_path,model_path)
        if not os.path.exists(save_path):os.makedirs(save_path)
        self.saver.save(self.sess,save_path)

2.2 加载模型

python 复制代码
    def load(self,model_path):
        meta_path = 'weights_'+self.name+'.meta'

        mata_path_dir = os.path.join(model_path,meta_path)

        self.saver = tf.compat.v1.train.import_meta_graph(mata_path_dir)
        a=model_path+'/'
        self.saver.restore(self.sess, tf.train.latest_checkpoint(a))
相关推荐
山烛10 小时前
深入解析 YOLO v2
人工智能·yolo·计算机视觉·目标跟踪·yolov2
GISer_Jing10 小时前
AI/CICD/Next/React Native&Taro内容
人工智能·react native·taro
声网10 小时前
阿里发布「夸克 AI 眼镜」:融合阿里购物、地图、支付生态;苹果拟收购计算机视觉初创 Prompt AI丨日报
人工智能·计算机视觉·prompt
IT_陈寒10 小时前
Java性能调优实战:7个让GC效率提升50%的关键参数设置
前端·人工智能·后端
爱看科技10 小时前
微美全息(NASDAQ:WIMI)融合区块链+AI+IoT 三大技术,解锁物联网入侵检测新范式
人工智能·物联网·区块链
华为云开发者联盟11 小时前
华为开发者空间携手乐知行:轻松实现智能网联小车数据可视化系
人工智能·华为开发者空间
云卓SKYDROID11 小时前
飞控信号模块技术要点与难点分析
人工智能·无人机·航电系统·高科技·云卓科技
文火冰糖的硅基工坊11 小时前
[嵌入式系统-101]:AIoT(人工智能物联网)开发板
人工智能·物联网·重构·架构
说私域12 小时前
开源AI智能名片链动2+1模式S2B2C商城小程序在个性化与小众化消费崛起中的营销宣传策略研究
人工智能·小程序
AI小云12 小时前
【Python与AI基础】Python编程基础:读写CSV文件
人工智能·python