tensorflow Keras 模型的保存与加载

这段内容的核心是讲解 Keras模型的保存与加载 ------Keras模型继承自tf.Module,因此完全兼容tf.saved_model.save(),但Keras封装了更简洁的model.save()方法,还额外支持保存「训练相关配置(损失、优化器、指标)」,加载也更便捷,无需依赖原始Python类就能恢复模型并复用。

下面结合你已学的tf.Module保存知识,逐点拆解Keras模型保存/加载的逻辑、警告含义和核心优势:

一、核心前提:Keras模型与tf.Module保存的关系

Keras模型(tf.keras.Model)继承自tf.keras.layers.Layer,而Layer又继承自tf.Module,因此:

  • ✅ Keras模型完全兼容tf.saved_model.save(model, path)(和之前tf.Module的SavedModel保存方式一致);
  • ✅ Keras额外提供model.save(path)------写法更简单,且可选择性保存训练配置(损失函数、优化器、评估指标);
  • ✅ 加载时用tf.keras.models.load_model(path),无需手动创建模型实例(比如不用先写MySequentialModel()),直接加载即可用。

二、逐行解析代码:保存→加载→使用

1. 保存Keras模型:my_sequential_model.save("exname_of_file")
python 复制代码
my_sequential_model.save("exname_of_file")
# 输出:INFO:tensorflow:Assets written to: exname_of_file/assets
  • 底层逻辑

    这行代码等价于tf.saved_model.save(my_sequential_model, "exname_of_file"),但做了增强:

    • 基础功能(和tf.Module一致):保存模型的计算图 (由call方法+@tf.function固化)和权重 (变量值,存在exname_of_file/variables/目录);
    • 增强功能(Keras独有):如果模型调用过model.compile()(编译,指定损失、优化器、指标),会额外保存「训练配置」(比如optimizer='adam'loss='mse')到SavedModel中。
  • 输出提示Assets written toassets目录用于保存模型的额外资源(比如NLP模型的词汇表、图像模型的标签文件),本例无额外资源,所以是空目录,可忽略。

  • 对比tf.Module的SavedModel保存:

    python 复制代码
    # tf.Module的保存方式(之前学的)
    tf.saved_model.save(my_module, "path")
    # Keras模型的简化方式(效果一致,且可保存训练配置)
    my_keras_model.save("path")
2. 加载Keras模型:reconstructed_model = tf.keras.models.load_model("exname_of_file")
python 复制代码
reconstructed_model = tf.keras.models.load_model("exname_of_file")
# 警告:WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
  • 核心作用 :直接从SavedModel文件加载出可调用的模型,无需依赖原始Python类(比如MySequentialModel) ------这是Keras加载的最大优势(对比tf.Module加载:需要先创建同结构的模型实例,再restore权重)。

  • 警告含义(关键)

    警告的原因是:原模型my_sequential_model没有调用model.compile()(编译),因此保存时没有记录"训练配置"(比如用什么优化器、什么损失函数)。

    • 影响:仅影响「训练」------加载后的模型无法直接调用model.fit()(训练),需要手动model.compile(optimizer='adam', loss='mse')
    • 不影响:完全不影响「预测/推断」(model(x)),因为预测只需要计算图和权重,不需要训练配置。
  • 对比tf.Module的加载:

    python 复制代码
    # tf.Module的加载方式(之前学的)
    new_module = MySequentialModule()  # 必须先创建同结构的实例
    checkpoint = tf.train.Checkpoint(model=new_module)
    checkpoint.restore("path")        # 再恢复权重
    
    # Keras模型的加载方式(无需原始类,直接加载)
    new_model = tf.keras.models.load_model("path")
3. 加载后模型的使用:结果与原模型完全一致
python 复制代码
reconstructed_model(tf.constant([[2.0, 2.0, 2.0]]))
# 输出:tf.Tensor([[ -7.7209134, -11.065    ]], dtype=float32)
  • 核心逻辑
    加载后的模型reconstructed_model和原模型my_sequential_model的「计算图+权重」完全一致,因此对相同输入的预测结果也完全相同。
    • 权重:加载时自动从exname_of_file/variables/恢复;
    • 计算图:加载时自动从exname_of_file/saved_model.pb恢复;
    • 无需手动恢复权重(对比tf.Module的checkpoint.restore()),Keras已自动完成。

三、Keras SavedModel的额外优势(对比tf.Module)

文中提到"Keras SavedModels还可以保存指标、损失和优化器状态",这是Keras的核心增强,举个例子:

python 复制代码
# 1. 原模型编译(指定训练配置)
my_sequential_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),  # 优化器
    loss=tf.keras.losses.MeanSquaredError(),                 # 损失函数
    metrics=[tf.keras.metrics.Accuracy()]                    # 评估指标
)

# 2. 保存模型(此时会同时保存计算图+权重+训练配置)
my_sequential_model.save("compiled_model")

# 3. 加载模型(无警告,模型已编译)
loaded_model = tf.keras.models.load_model("compiled_model")

# 4. 加载后可直接训练(无需手动compile)
# loaded_model.fit(x_train, y_train, epochs=10)
  • 此时加载模型不会有警告,因为保存了训练配置;
  • 甚至能保存优化器的「状态」(比如训练到第5个epoch中断,保存模型后加载,优化器的学习率、动量等状态会恢复,可继续从第5个epoch训练)------这是tf.Module的检查点也能做到的,但Keras整合得更简洁。

四、关键注意事项(避免踩坑)

  1. 自定义层的加载 :如果模型包含自定义层(比如FlexibleDense),加载时需要确保自定义层的代码在当前环境中(或通过custom_objects参数指定),否则会报错:

    python 复制代码
    # 加载含自定义层的模型,需指定custom_objects
    reconstructed_model = tf.keras.models.load_model(
        "exname_of_file",
        custom_objects={"FlexibleDense": FlexibleDense}
    )
  2. 保存格式 :Keras的model.save()默认保存为SavedModel格式(和tf.saved_model一致),也可指定保存为HDF5格式(model.save("model.h5")),但推荐SavedModel(跨环境兼容性更好);

  3. 预测不受编译影响 :无论模型是否编译,加载后都能正常model(x)做预测,只有训练需要编译。

总结

Keras模型的保存/加载核心优势是:

  1. 写法极简model.save(path)替代tf.saved_model.savetf.keras.models.load_model(path)替代"创建实例+恢复权重";
  2. 功能增强:可保存训练配置(损失、优化器、指标),支持续训;
  3. 脱离原始类 :加载时无需依赖自定义模型/层的Python代码(除非是自定义层,需指定custom_objects);
  4. 兼容底层 :完全继承tf.Module的SavedModel特性(计算图+权重),同时保留Keras的高层训练功能。

简单说:Keras把tf.Module的"保存计算图+权重"和"保存训练配置"整合为一步,加载时又把"恢复结构+恢复权重+恢复训练配置"整合为一步,大幅简化了模型的保存、共享和部署流程。

相关推荐
再__努力1点36 分钟前
【50】OpenCV背景减法技术解析与实现
开发语言·图像处理·人工智能·python·opencv·算法·计算机视觉
c骑着乌龟追兔子38 分钟前
Day 29 机器学习管道 pipeline
人工智能·机器学习
努力也学不会java39 分钟前
【docker】Docker Image(镜像)
java·运维·人工智能·机器学习·docker·容器
zhangfeng113340 分钟前
suppr.wilddata.cn 文献检索,用中文搜 PubMed 一种基于大语言模型的智能搜索引擎构建方法
人工智能·搜索引擎·语言模型
大千AI助手41 分钟前
高维空间中的高效导航者:球树(Ball Tree)算法深度解析
人工智能·算法·机器学习·数据挖掘·大千ai助手·球树·ball-tree
新知图书41 分钟前
使用FastGPT知识库构建智能客服的示例
人工智能·ai agent·智能体·大模型应用开发·大模型应用
生信大表哥44 分钟前
GPT-5-Codex VS Gemini 3 VS Claude Sonnet 4.5 新手小白入门学习教程
人工智能·gpt·学习·rstudio·数信院生信服务器
子午1 小时前
【植物识别系统】Python+TensorFlow+Django+人工智能+深度学习+卷积神经网络算法
人工智能·python·深度学习
ONLYOFFICE1 小时前
ONLYOFFICE 文档与桌面编辑器 9.2 版本更新说明
人工智能·编辑器·onlyoffice