Tensorflow2保存和加载模型

1、model.save() and model.load()

此种方法可保存模型的结构、参数等内容。加载模型后无需设置即可使用!

保存模型:

python 复制代码
model.save('my_model.h5')

加载模型:

python 复制代码
# 加载整个模型
loaded_model = tf.keras.models.load_model('my_model.h5')

注意,创建的模型不能使用自定义的loss函数等方法,否则导入时会出错!

示例:

python 复制代码
model_file = "data/model/multi_labels_model.h5"    # 模型文件路径
def model_handle(x_train, y_train):
    if os.path.exists(model_file):
        print("---load the model---")
        model = tf.keras.models.load_model(model_file) # 导入已存在的模型
    else:
        # 模型构建
        model = tf.keras.Sequential([
            tf.keras.layers.LSTM(128),
            tf.keras.layers.Dense(class_num, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
        ])
        # 编译模型,不能使用自定义函数方法,否则导入模型会有问题
        model.compile(loss="BinaryCrossentropy", optimizer='adam', metrics=['accuracy'])
        
        history = model.fit(x_train, y_train, epochs=epoch_num, batch_size=1, verbose=1, 
                            callbacks=[PrintPredictionsCallback(x_train, y_train)])

        model.summary()
        model.save(model_file)
    return model

2、model.save_weight() and model.load_weight()

此方法只保存和加载模型的权重。

保存权重:

python 复制代码
# 只保存权重
model.save_weights('my_model_weights.h5')

加载权重:

python 复制代码
# 创建一个新的模型实例(确保架构与原始模型相同)
new_model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(10, activation='relu', input_shape=(32,)),
    tf.keras.layers.Dense(1)
])
# new_model.build(input_shape=x_train.shape) # 如果模型创建时没有规定input_shape,需要创建
# 加载权重到新模型
new_model.load_weights('my_model_weights.h5')

此方法的模型可以使用自定义的函数方法。

注意:以H5格式加载子类模型的参数时,需要提前建立模型,规定输入网络的shape,否则会报错!

python 复制代码
ValueError: Unable to load weights saved in HDF5 format into a subclassed Model which has not created its variables yet. Call the Model first, then load the weights.

示例:

python 复制代码
def model_handle(x_train, y_train):
    # 模型构建,多分类的激活函数使用sigmoid 或 softmax
    model = tf.keras.Sequential([
        tf.keras.layers.LSTM(128),
        tf.keras.layers.Dense(class_num, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
    ])
    if os.path.exists(model_file):
        print("-----load model weights-----")
        model.build(input_shape=x_train.shape)  # 以H5格式加载子类模型的参数时,需要提前建立模型,规定输入网络的shape,否则会报错
        model.load_weights(model_file)
    else:
        # 编译模型,使用自定义loss函数
        model.compile(loss=custom_loss, optimizer='adam', metrics=['accuracy'])
        # model.compile(loss="BinaryCrossentropy", optimizer='adam', metrics=['accuracy'])

        history = model.fit(x_train, y_train, epochs=epoch_num, batch_size=1, verbose=1, 
                            callbacks=[PrintPredictionsCallback(x_train, y_train)])

        model.summary()
        model.save_weights(model_file)

    return model

3、model.checkpoint

主要是用于模型的断点续训。用法参考如下:

python 复制代码
checkpoint_save_path = "./checkpoint/my_checkpoint.ckpt"

if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True,
                                                 monitor='val_loss')

history = model.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])

model.summary()
相关推荐
ZC跨境爬虫29 分钟前
跟着 MDN 学CSS day_45:媒体查询入门指南——从语法到移动优先实践
前端·css·ui·html·tensorflow·媒体
lqqjuly41 分钟前
推荐系统技术解析(Recommendation Systems)
深度学习·推荐算法
zhangfeng11331 小时前
超算中心 高性能计算 slurm的linux版本 centos7,如何安装docker,如何安装torch2.4
linux·运维·服务器·开发语言·人工智能·机器学习·docker
搞科研的小刘选手1 小时前
【重庆大学主办】第三届智能感知与模式识别国际学术会议(IPPR 2026)
物联网·机器学习·计算机视觉·机器人·人机交互·感知·传感
老鱼说AI1 小时前
统计学习方法第八章:Boosting
人工智能·深度学习·神经网络·机器学习·学习方法·集成学习·boosting
钓了猫的鱼儿1 小时前
基于深度学习+AI的无人机森林火灾目标检测与预警系统(Python源码+数据集+UI可视化界面+YOLOv11训练结果)
人工智能·深度学习·无人机
ZC跨境爬虫1 小时前
跟着 MDN 学CSS day_47:(移动优先实战——从手机到宽屏的响应式进化)
前端·css·html·tensorflow·媒体
ZC跨境爬虫1 小时前
跟着 MDN 学CSS day_46:(响应式实战——用媒体查询打造双列布局)
前端·css·ui·html·tensorflow·媒体
imDwAaY2 小时前
机器学习入门:从感知机到逻辑回归,理解线性分类器与Softmax CS188 Note20 学习笔记
人工智能·笔记·python·学习·机器学习·逻辑回归
无负今日_tq2 小时前
【无标题】
人工智能·深度学习·条纹