TensorFlow 模型训练和简单部署示例

1. 环境准备

确保已安装 TensorFlow:

bash 复制代码
pip install tensorflow

2. 训练一个简单的分类模型

(1) 导入库 & 加载数据
python 复制代码
import tensorflow as tf

# 加载 MNIST 数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 数据预处理:归一化像素值到 [0, 1]
train_images = train_images / 255.0
test_images = test_images / 255.0
(2) 构建模型
python 复制代码
# 定义一个简单的 Sequential 模型
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)), # 将 28x28 图像展平
    tf.keras.layers.Dense(128, activation='relu'), # 全连接层,128个神经元
    tf.keras.layers.Dense(10) # 输出层,10个神经元对应10个数字类别 (0-9)
])
(3) 编译模型
python 复制代码
model.compile(
    optimizer='adam', # 使用 Adam 优化器
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 损失函数
    metrics=['accuracy'] # 评估指标:准确率
)
(4) 训练模型
python 复制代码
# 训练 5 个 epochs
model.fit(train_images, train_labels, epochs=5)
(5) 评估模型
python 复制代码
# 在测试集上评估
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"\n测试准确率: {test_acc}")
(6) 保存模型 (用于后续部署)
python 复制代码
# 保存整个模型 (架构 + 权重 + 优化器状态)
model.save('my_mnist_model.keras') # 或者使用 .h5 格式 (model.save('my_model.h5'))

3. 模型部署示例 (使用 TensorFlow Lite)

(1) 转换模型为 TensorFlow Lite 格式
python 复制代码
# 加载之前保存的 Keras 模型
loaded_model = tf.keras.models.load_model('my_mnist_model.keras')

# 创建一个转换器
converter = tf.lite.TFLiteConverter.from_keras_model(loaded_model)

# 转换模型
tflite_model = converter.convert()

# 保存转换后的模型
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)
(2) 在 Python 中使用 TFLite 模型进行推理
python 复制代码
# 加载 TFLite 模型并分配张量
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()

# 获取输入和输出张量详情
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# 选择一个测试样本 (例如第一个样本)
input_data = test_images[0:1] # 保持 batch 维度 (1, 28, 28)
input_data = input_data.astype(np.float32) # 确保数据类型匹配

# 设置输入张量
interpreter.set_tensor(input_details[0]['index'], input_data)

# 运行推理
interpreter.invoke()

# 获取输出结果
output_data = interpreter.get_tensor(output_details[0]['index'])
prediction = np.argmax(output_data) # 取概率最大的索引作为预测结果

print(f"预测数字: {prediction}, 真实标签: {test_labels[0]}")

说明

  1. 训练:示例展示了加载数据、构建模型、编译、训练、评估和保存模型的基本流程。
  2. 部署 :展示了将 Keras 模型转换为轻量级的 TensorFlow Lite (.tflite) 格式,并在 Python 环境中加载该模型进行单样本推理的过程。TFLite 模型特别适合在移动端 (Android, iOS) 和嵌入式设备上部署。
  3. 实际部署 :在移动设备上使用 TFLite 模型通常需要调用相应平台 (Android 使用 Interpreter API, iOS 使用 TFLInterpreter) 的接口。部署时还需考虑性能优化 (如使用 GPU/Hexagon 委托) 和模型量化压缩。
相关推荐
love530love6 小时前
LiveTalking 数字人项目 Windows 部署完全指南(EPGF 架构)
人工智能·windows·python·架构·livetalking·epgf
遇事不決洛必達6 小时前
【Python基础】GIL 锁是什么及其对爬虫的影响
爬虫·python·线程·进程·gil锁
CryptoPP7 小时前
快速对接东京证券交易所API数据:实战指南与代码示例
开发语言·人工智能·windows·python·信息可视化·区块链
探物 AI8 小时前
把 MambaOut 塞进 YOLOv11:会有什么样的反应
python·yolo·计算机视觉
如竟没有火炬8 小时前
最大矩阵——单调栈
数据结构·python·线性代数·算法·leetcode·矩阵
阳区欠8 小时前
【LangChain】LLM基础介绍
开发语言·python·langchain
Cosolar8 小时前
保姆级 CrewAI 教程:从零构建多智能体协作系统
人工智能·python·架构
GDAL8 小时前
使用 uv 管理 Python 版本
python·uv·版本
真实的菜8 小时前
Redis 从入门到精通(十二):典型业务场景实战 —— 排行榜、限流器、秒杀系统、Session 共享
数据库·redis·python
cup119 小时前
[开源] Meta Assistant / 告别命令行,我为一堆 Python 脚本做了一个 Windows 任务栏的“家”
windows·python·工具·nuitka·脚本运行