TensorFlow2 Python深度学习 - 模型保存与加载

锋哥原创的TensorFlow2 Python深度学习视频教程:

https://www.bilibili.com/video/BV1X5xVz6E4w/

课程介绍

本课程主要讲解基于TensorFlow2的Python深度学习知识,包括深度学习概述,TensorFlow2框架入门知识,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。

TensorFlow2 Python深度学习 - 模型保存与加载

在 TensorFlow 2 中,模型的保存和加载是非常简便的操作,使用了 tf.keras API 进行高效管理。TensorFlow 支持保存和加载完整的模型,包括权重、优化器、训练配置等,使得模型可以在不同的环境中被复用、共享或部署。

TensorFlow2 keras提供保存和加载完整模型(包括模型架构、权重、训练配置等)的save()和load_model()方法来实现模型保存和加载。

我们看一个具体示例:

通过save()方法保存模型,注意保存的模型文件后缀是.keras

复制代码
import tensorflow as tf
from keras import Input, layers
from sklearn.datasets import load_iris
​
# 1,加载鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度
y = iris.target  # 标签:0-Setosa, 1-Versicolour, 2-Virginica
​
# 2,构建分类模型
model = tf.keras.models.Sequential([
    Input(shape=(X.shape[1],)),  # 输入层
    layers.Dense(16, activation='relu'),  # 隐藏层
    layers.Dense(3, activation='softmax')  # 输出层 3个神经元,对应3个类别
])
​
# 3,模型编译
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',  # 多分类交叉熵损失函数
    metrics=['accuracy']  # 评估指标:准确率
)
​
# 4,模型保存
model.save('iris_model.keras')

运行完成后,先生成iris_model.keras模型文件:

然后其他地方需要用到这个模型定义的时候,我们只需要使用keras提供的load_model()方法加载这个模型文件即可。

复制代码
from keras.src.saving import load_model
from sklearn.datasets import load_iris
​
# 1,加载鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度
y = iris.target  # 标签:0-Setosa, 1-Versicolour, 2-Virginica
​
# 2,加载模型
model = load_model('iris_model.keras')
​
# 4,模型训练
history = model.fit(X, y, epochs=100, batch_size=32, verbose=1)
print(f"最终损失: {history.history['loss'][-1]:.4f}, 最终准确率: {history.history['accuracy'][-1]:.4f}")

我们来运行测试下:

相关推荐
青青草原羊村懒大王11 小时前
python基础知识三
开发语言·python
傻啦嘿哟11 小时前
Python高效实现Word转HTML:从基础到进阶的全流程方案
人工智能·python·tensorflow
思通数科多模态大模型12 小时前
扑灭斗殴的火苗:AI智能守护如何为校园安全保驾护航
大数据·人工智能·深度学习·安全·目标检测·计算机视觉·数据挖掘
wu_jing_sheng012 小时前
深度学习入门:揭开神经网络的神秘面纱(附PyTorch实战)
python
Ace_317508877613 小时前
淘宝店铺全量商品接口实战:分类穿透采集与增量同步的技术方案
大数据·数据库·python
xixixi7777713 小时前
了解一下LSTM:长短期记忆网络(改进的RNN)
人工智能·深度学习·机器学习
能来帮帮蒟蒻吗13 小时前
深度学习(1)—— 基本概念
人工智能·深度学习
LeonDL16813 小时前
基于YOLO11深度学习的电动车头盔检测系统【Python源码+Pyqt5界面+数据集+安装使用教程+训练代码】【附下载链接】
人工智能·python·深度学习·pyqt5·yolo数据集·电动车头盔检测系统·yolo11深度学习
carver w13 小时前
彻底理解传统卷积,深度可分离卷积
人工智能·深度学习·计算机视觉
xier_ran13 小时前
深度学习:从零开始手搓一个浅层神经网络(Single Hidden Layer Neural Network)
人工智能·深度学习·神经网络