1. Pytorch
1.1.1 save网络结构和参数:
注意最后一行为"self.state_dict()"
python
def save(self,t):
current_path = os.path.dirname(os.path.abspath(__file__))
model_path = 'model/2E_model_' + t + '_'+self.name+'/'
save_path = os.path.join(current_path,model_path)
if not os.path.exists(save_path):
os.makedirs(save_path)
save_file_path=os.path.join(save_path, 'model.pth')
torch.save(self.state_dict(),save_file_path)
1.1.2 对应的加载模型参数:
注意对应"agent.load_state_dict(checkpoint)"
python
def load(self,agent,model_path):
model_pth = 'model.pth'
model_path = os.path.join(model_path,model_pth)
checkpoint = torch.load(model_path)
agent.load_state_dict(checkpoint)
agent.eval()
1.2.1 保存整个模型
注意为"torch.save(self.model,save_file_path)"
python
def save(self,t):
current_path = os.path.dirname(os.path.abspath(__file__))
model_path = 'model/model_' + t + '_'+self.name+'/'
save_path = os.path.join(current_path,model_path)
if not os.path.exists(save_path):
os.makedirs(save_path)
save_file_path=os.path.join(save_path, 'model.pth')
torch.save(self.model,save_file_path)
1.2.2 加载整个模型
注意"self.model = torch.load(model_path)"
python
def load(self,model_path):
model_pth = 'model.pth'
model_path = os.path.join(model_path,model_pth)
self.model = torch.load(model_path)
self.model.eval()
如果没对应上会报错:torch.nn.modules.module.ModuleAttributeError: object has no attribute 'copy',参考此链接
2. Tensorflow
2.1 保存模型
python
def save(self,time):
current_path = os.path.dirname(os.path.abspath(__file__))
model_path='model/model_'+time+'_'+self.name+'/weights_'+self.name
save_path = os.path.join(current_path,model_path)
if not os.path.exists(save_path):os.makedirs(save_path)
self.saver.save(self.sess,save_path)
2.2 加载模型
python
def load(self,model_path):
meta_path = 'weights_'+self.name+'.meta'
mata_path_dir = os.path.join(model_path,meta_path)
self.saver = tf.compat.v1.train.import_meta_graph(mata_path_dir)
a=model_path+'/'
self.saver.restore(self.sess, tf.train.latest_checkpoint(a))