TensorFlow 模型训练和简单部署示例

1. 环境准备

确保已安装 TensorFlow:

bash 复制代码
pip install tensorflow

2. 训练一个简单的分类模型

(1) 导入库 & 加载数据
python 复制代码
import tensorflow as tf

# 加载 MNIST 数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 数据预处理:归一化像素值到 [0, 1]
train_images = train_images / 255.0
test_images = test_images / 255.0
(2) 构建模型
python 复制代码
# 定义一个简单的 Sequential 模型
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)), # 将 28x28 图像展平
    tf.keras.layers.Dense(128, activation='relu'), # 全连接层,128个神经元
    tf.keras.layers.Dense(10) # 输出层,10个神经元对应10个数字类别 (0-9)
])
(3) 编译模型
python 复制代码
model.compile(
    optimizer='adam', # 使用 Adam 优化器
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 损失函数
    metrics=['accuracy'] # 评估指标:准确率
)
(4) 训练模型
python 复制代码
# 训练 5 个 epochs
model.fit(train_images, train_labels, epochs=5)
(5) 评估模型
python 复制代码
# 在测试集上评估
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"\n测试准确率: {test_acc}")
(6) 保存模型 (用于后续部署)
python 复制代码
# 保存整个模型 (架构 + 权重 + 优化器状态)
model.save('my_mnist_model.keras') # 或者使用 .h5 格式 (model.save('my_model.h5'))

3. 模型部署示例 (使用 TensorFlow Lite)

(1) 转换模型为 TensorFlow Lite 格式
python 复制代码
# 加载之前保存的 Keras 模型
loaded_model = tf.keras.models.load_model('my_mnist_model.keras')

# 创建一个转换器
converter = tf.lite.TFLiteConverter.from_keras_model(loaded_model)

# 转换模型
tflite_model = converter.convert()

# 保存转换后的模型
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)
(2) 在 Python 中使用 TFLite 模型进行推理
python 复制代码
# 加载 TFLite 模型并分配张量
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()

# 获取输入和输出张量详情
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# 选择一个测试样本 (例如第一个样本)
input_data = test_images[0:1] # 保持 batch 维度 (1, 28, 28)
input_data = input_data.astype(np.float32) # 确保数据类型匹配

# 设置输入张量
interpreter.set_tensor(input_details[0]['index'], input_data)

# 运行推理
interpreter.invoke()

# 获取输出结果
output_data = interpreter.get_tensor(output_details[0]['index'])
prediction = np.argmax(output_data) # 取概率最大的索引作为预测结果

print(f"预测数字: {prediction}, 真实标签: {test_labels[0]}")

说明

  1. 训练:示例展示了加载数据、构建模型、编译、训练、评估和保存模型的基本流程。
  2. 部署 :展示了将 Keras 模型转换为轻量级的 TensorFlow Lite (.tflite) 格式,并在 Python 环境中加载该模型进行单样本推理的过程。TFLite 模型特别适合在移动端 (Android, iOS) 和嵌入式设备上部署。
  3. 实际部署 :在移动设备上使用 TFLite 模型通常需要调用相应平台 (Android 使用 Interpreter API, iOS 使用 TFLInterpreter) 的接口。部署时还需考虑性能优化 (如使用 GPU/Hexagon 委托) 和模型量化压缩。
相关推荐
好运的阿财2 小时前
OpenClaw四种角色详解
人工智能·python·程序人生·microsoft·开源·ai编程
买大橘子也用券2 小时前
2026红明谷
python·web安全
李昊哲小课2 小时前
Python办公自动化教程 - 第2章 单元格样式魔法 - 让表格变得美观专业
开发语言·python·excel·openpyxl
tryCbest2 小时前
Pip生成requirements.txt文件
python·pip
橘子编程2 小时前
编程语言全指南:从C到Rust
java·c语言·开发语言·c++·python·rust·c#
ego.iblacat2 小时前
Flask 框架
后端·python·flask
我送炭你添花2 小时前
边走边聊 Python 3.8:Win7 从入门到高手(目录)
开发语言·python
w_t_y_y2 小时前
工具篇(一)机器学习常用的python包
开发语言·python·信息可视化
徒 花2 小时前
Python知识学习07
windows·python·学习