pytorch & tensorflow 保存和加载模型

1. Pytorch

1.1.1 save网络结构和参数:

注意最后一行为"self.state_dict()"

python 复制代码
    def save(self,t):
        current_path = os.path.dirname(os.path.abspath(__file__))
        model_path = 'model/2E_model_' + t + '_'+self.name+'/'

        save_path = os.path.join(current_path,model_path)
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        save_file_path=os.path.join(save_path, 'model.pth')

        torch.save(self.state_dict(),save_file_path)

1.1.2 对应的加载模型参数:

注意对应"agent.load_state_dict(checkpoint)"

python 复制代码
    def load(self,agent,model_path):
        model_pth = 'model.pth'
        model_path = os.path.join(model_path,model_pth)
        checkpoint = torch.load(model_path)
        agent.load_state_dict(checkpoint)
        agent.eval()

1.2.1 保存整个模型

注意为"torch.save(self.model,save_file_path)"

python 复制代码
    def save(self,t):
        current_path = os.path.dirname(os.path.abspath(__file__))
        model_path = 'model/model_' + t + '_'+self.name+'/'

        save_path = os.path.join(current_path,model_path)
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        save_file_path=os.path.join(save_path, 'model.pth')

        torch.save(self.model,save_file_path)

1.2.2 加载整个模型

注意"self.model = torch.load(model_path)"

python 复制代码
    def load(self,model_path):
        model_pth = 'model.pth'
        model_path = os.path.join(model_path,model_pth)
        self.model = torch.load(model_path)
        self.model.eval()

如果没对应上会报错:torch.nn.modules.module.ModuleAttributeError: object has no attribute 'copy',参考此链接

2. Tensorflow

2.1 保存模型

python 复制代码
    def save(self,time):
        current_path = os.path.dirname(os.path.abspath(__file__))
        model_path='model/model_'+time+'_'+self.name+'/weights_'+self.name
        save_path = os.path.join(current_path,model_path)
        if not os.path.exists(save_path):os.makedirs(save_path)
        self.saver.save(self.sess,save_path)

2.2 加载模型

python 复制代码
    def load(self,model_path):
        meta_path = 'weights_'+self.name+'.meta'

        mata_path_dir = os.path.join(model_path,meta_path)

        self.saver = tf.compat.v1.train.import_meta_graph(mata_path_dir)
        a=model_path+'/'
        self.saver.restore(self.sess, tf.train.latest_checkpoint(a))
相关推荐
风铃喵游36 分钟前
让大模型调用MCP服务变得超级简单
前端·人工智能
booooooty1 小时前
基于Spring AI Alibaba的多智能体RAG应用
java·人工智能·spring·多智能体·rag·spring ai·ai alibaba
PyAIExplorer1 小时前
基于 OpenCV 的图像 ROI 切割实现
人工智能·opencv·计算机视觉
风口猪炒股指标1 小时前
技术分析、超短线打板模式与情绪周期理论,在市场共识的形成、分歧、瓦解过程中缘起性空的理解
人工智能·博弈论·群体博弈·人生哲学·自我引导觉醒
ai_xiaogui2 小时前
一键部署AI工具!用AIStarter快速安装ComfyUI与Stable Diffusion
人工智能·stable diffusion·部署ai工具·ai应用市场教程·sd快速部署·comfyui一键安装
聚客AI3 小时前
Embedding进化论:从Word2Vec到OpenAI三代模型技术跃迁
人工智能·llm·掘金·日新计划
weixin_387545643 小时前
深入解析 AI Gateway:新一代智能流量控制中枢
人工智能·gateway
聽雨2373 小时前
03每日简报20250705
人工智能·社交电子·娱乐·传媒·媒体
二川bro4 小时前
飞算智造JavaAI:智能编程革命——AI重构Java开发新范式
java·人工智能·重构
acstdm4 小时前
DAY 48 CBAM注意力
人工智能·深度学习·机器学习