1. 环境准备
确保已安装 TensorFlow:
bash
pip install tensorflow
2. 训练一个简单的分类模型
(1) 导入库 & 加载数据
python
import tensorflow as tf
# 加载 MNIST 数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 数据预处理:归一化像素值到 [0, 1]
train_images = train_images / 255.0
test_images = test_images / 255.0
(2) 构建模型
python
# 定义一个简单的 Sequential 模型
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)), # 将 28x28 图像展平
tf.keras.layers.Dense(128, activation='relu'), # 全连接层,128个神经元
tf.keras.layers.Dense(10) # 输出层,10个神经元对应10个数字类别 (0-9)
])
(3) 编译模型
python
model.compile(
optimizer='adam', # 使用 Adam 优化器
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 损失函数
metrics=['accuracy'] # 评估指标:准确率
)
(4) 训练模型
python
# 训练 5 个 epochs
model.fit(train_images, train_labels, epochs=5)
(5) 评估模型
python
# 在测试集上评估
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"\n测试准确率: {test_acc}")
(6) 保存模型 (用于后续部署)
python
# 保存整个模型 (架构 + 权重 + 优化器状态)
model.save('my_mnist_model.keras') # 或者使用 .h5 格式 (model.save('my_model.h5'))
3. 模型部署示例 (使用 TensorFlow Lite)
(1) 转换模型为 TensorFlow Lite 格式
python
# 加载之前保存的 Keras 模型
loaded_model = tf.keras.models.load_model('my_mnist_model.keras')
# 创建一个转换器
converter = tf.lite.TFLiteConverter.from_keras_model(loaded_model)
# 转换模型
tflite_model = converter.convert()
# 保存转换后的模型
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
(2) 在 Python 中使用 TFLite 模型进行推理
python
# 加载 TFLite 模型并分配张量
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()
# 获取输入和输出张量详情
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# 选择一个测试样本 (例如第一个样本)
input_data = test_images[0:1] # 保持 batch 维度 (1, 28, 28)
input_data = input_data.astype(np.float32) # 确保数据类型匹配
# 设置输入张量
interpreter.set_tensor(input_details[0]['index'], input_data)
# 运行推理
interpreter.invoke()
# 获取输出结果
output_data = interpreter.get_tensor(output_details[0]['index'])
prediction = np.argmax(output_data) # 取概率最大的索引作为预测结果
print(f"预测数字: {prediction}, 真实标签: {test_labels[0]}")
说明
- 训练:示例展示了加载数据、构建模型、编译、训练、评估和保存模型的基本流程。
- 部署 :展示了将 Keras 模型转换为轻量级的 TensorFlow Lite (
.tflite) 格式,并在 Python 环境中加载该模型进行单样本推理的过程。TFLite 模型特别适合在移动端 (Android, iOS) 和嵌入式设备上部署。 - 实际部署 :在移动设备上使用 TFLite 模型通常需要调用相应平台 (Android 使用
InterpreterAPI, iOS 使用TFLInterpreter) 的接口。部署时还需考虑性能优化 (如使用 GPU/Hexagon 委托) 和模型量化压缩。