TensorFlow花卉图片分类器模型训练

花卉图片分类器:Keras 训练并导出 TFLite

TensorFlow Lite Model Maker由于依赖库与新版本的Python不兼容的问题,我们将方案转为用 TensorFlow/Keras 训练一个花卉图片分类模型,并把训练好的模型转换为 TensorFlow Lite 的 .tflite 文件。

它不依赖 tflite-model-maker,因此可以避开 scann、旧版 TensorFlow、旧版 Python 之间常见的安装冲突。整体流程是:下载/读取图片数据集 -> 构建 Keras 模型 -> 训练与评估 -> 保存 .keras 模型 -> 转换并导出 .tflite 模型 -> 简单测试 TFLite 推理结果。

1. 安装依赖

如果当前 Jupyter kernel 里还没有安装 TensorFlow,请先行安装。可以使用Anaconda创建一个新的环境进行安装。

建议使用 Python 3.10 或 3.11或更高版本。这个版本只需要 TensorFlow,不需要安装 tflite-model-maker

python 复制代码
# 可选:如果当前 notebook 环境还没有安装依赖,取消下一行开头的 # 后运行。
# %pip install -r requirements-modern.txt

requirements-modern.txt的内容

tensorflow>=2.15

matplotlib>=3.7

numpy>=1.23

2. 导入库并设置参数

这里会导入训练、数据读取、模型转换需要的库。FLOWER_URL 是 TensorFlow 官方示例花卉数据集的下载地址。

如果你不指定自己的图片目录,程序会自动下载这个数据集,并缓存到用户目录下的 Keras 数据集缓存位置,例如 Windows 上通常是 C:\Users\你的用户名\.keras\datasets\

python 复制代码
import tarfile
from pathlib import Path

import numpy as np
import tensorflow as tf

# TensorFlow 官方花卉数据集。第一次运行时会自动下载,之后会复用本地缓存。
FLOWER_URL = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"

print("TensorFlow 版本:", tf.__version__)
python 复制代码
# 数据目录配置:
# - DATA_DIR = None:自动下载并使用 TensorFlow 官方 flowers 数据集。
# - DATA_DIR = r"D:\path\to\my_images":使用你自己的图片分类目录。
#
# 自定义图片目录需要按类别分文件夹,例如:
# my_images/
#   daisy/
#     1.jpg
#   roses/
#     2.jpg
DATA_DIR = None

# 导出目录。训练完成后会在这里生成 model.tflite、labels.txt 和 flower_classifier.keras。
EXPORT_DIR = "exported_flower_model"

# 训练参数。教程演示可以先用 3 到 5 个 epoch;如果使用自己的数据,可以适当增加。
EPOCHS = 5
BATCH_SIZE = 32
IMAGE_SIZE = 224
LEARNING_RATE = 1e-3

# TFLite 量化方式:
# - "dynamic":默认推荐,模型更小,通常最容易成功。
# - "float16":适合部分支持 float16 的设备。
# - "int8":体积更小,但需要代表性数据集,转换要求更严格。
# - "none":不量化,保留浮点模型。
QUANTIZATION = "dynamic"

# 固定随机种子,方便训练/验证划分尽量可复现。
SEED = 123

3. 读取并划分数据集

load_flower_datasets 完成三件事:

  1. 如果 DATA_DIRNone,自动下载并解压官方花卉数据集。
  2. 使用 image_dataset_from_directory 按文件夹名生成分类标签。
  3. 将数据划分为训练集、验证集和测试集,并开启缓存、打乱与预取,加快训练过程。
python 复制代码
def load_flower_datasets(data_dir, image_size, batch_size, seed):
    # 如果没有传入自定义数据目录,就下载 TensorFlow 官方 flower_photos 数据集。
    if data_dir is None:
        archive_path = tf.keras.utils.get_file(
            "flower_photos.tgz",
            FLOWER_URL,
            extract=False,
        )
        archive_path = Path(archive_path)

        # Keras 可能已经缓存了解压后的目录;先检查常见位置,避免重复解压。
        candidates = [
            archive_path.parent / "flower_photos",
            archive_path.parent / "flower_photos_extracted" / "flower_photos",
        ]
        data_dir = next((path for path in candidates if path.exists()), None)
        if data_dir is None:
            with tarfile.open(archive_path, "r:gz") as tar:
                tar.extractall(archive_path.parent / "flower_photos_extracted")
            data_dir = archive_path.parent / "flower_photos_extracted" / "flower_photos"
    else:
        data_dir = Path(data_dir)

    # 从目录读取图片。目录下的每个子文件夹会被当作一个类别。
    train_ds = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        seed=seed,
        image_size=(image_size, image_size),
        batch_size=batch_size,
    )
    val_ds = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        seed=seed,
        image_size=(image_size, image_size),
        batch_size=batch_size,
    )
    class_names = train_ds.class_names

    # 原始 validation 部分再拆成验证集和测试集:验证集用于训练过程中观察效果,测试集用于最后评估。
    val_batches = int(tf.data.experimental.cardinality(val_ds).numpy())
    test_ds = val_ds.take(val_batches // 2)
    val_ds = val_ds.skip(val_batches // 2)

    # cache/prefetch 可以减少数据读取等待;shuffle 只用于训练集。
    autotune = tf.data.AUTOTUNE
    train_ds = train_ds.cache().shuffle(1000, seed=seed).prefetch(autotune)
    val_ds = val_ds.cache().prefetch(autotune)
    test_ds = test_ds.cache().prefetch(autotune)
    return train_ds, val_ds, test_ds, class_names
python 复制代码
# 加载数据集并查看类别名称。
train_ds, val_ds, test_ds, class_names = load_flower_datasets(
    DATA_DIR,
    IMAGE_SIZE,
    BATCH_SIZE,
    SEED,
)

print("类别数量:", len(class_names))
print("类别名称:", class_names)

4. 构建并训练 Keras 模型

这里使用迁移学习:底座模型是 ImageNet 预训练的 MobileNetV2,它已经学过很多通用图像特征。我们冻结底座模型,只训练最后新增的分类层。

这样做的好处是训练速度快、需要的数据量少,也更适合后续转换为移动端可用的 TFLite 模型。

python 复制代码
def build_model(num_classes, image_size, learning_rate):
    # 输入图片尺寸固定为 IMAGE_SIZE x IMAGE_SIZE x 3。
    inputs = tf.keras.Input(shape=(image_size, image_size, 3), name="image")

    # MobileNetV2 有自己的预处理方式,这里把像素值转换到模型期望的范围。
    x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs)

    # include_top=False 表示不要 ImageNet 原始的 1000 类分类头,只保留特征提取部分。
    base_model = tf.keras.applications.MobileNetV2(
        input_shape=(image_size, image_size, 3),
        include_top=False,
        weights="imagenet",
        pooling="avg",
    )

    # 冻结预训练模型参数,只训练后面的 Dense 分类层。
    base_model.trainable = False
    x = base_model(x, training=False)
    x = tf.keras.layers.Dropout(0.2)(x)

    # 输出维度等于类别数量,softmax 输出每个类别的概率。
    outputs = tf.keras.layers.Dense(num_classes, activation="softmax", name="predictions")(x)
    model = tf.keras.Model(inputs, outputs)

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=["accuracy"],
    )
    return model
python 复制代码
# 创建模型并打印结构。第一次运行会下载 MobileNetV2 的 ImageNet 预训练权重。
model = build_model(len(class_names), IMAGE_SIZE, LEARNING_RATE)
model.summary()
python 复制代码
# 开始训练。history 中会保存每个 epoch 的 loss、accuracy、val_loss、val_accuracy。
history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)
python 复制代码
# 使用测试集评估模型。测试集没有参与训练,用于更客观地观察最终效果。
loss, accuracy = model.evaluate(test_ds)
print(f"test_loss={loss:.4f}, test_accuracy={accuracy:.4f}")

5. 转换为 TensorFlow Lite 模型

训练完成后,先把 Keras 模型保存为 .keras 文件,再使用 tf.lite.TFLiteConverter.from_keras_model(model) 转换为 .tflite

本 notebook 默认使用动态范围量化 dynamic,通常可以减小模型体积,并且不需要额外准备复杂的校准数据。

python 复制代码
def convert_to_tflite(model, quantization, representative_ds):
    # 从 Keras 模型创建 TFLite 转换器。
    converter = tf.lite.TFLiteConverter.from_keras_model(model)

    if quantization == "dynamic":
        # 动态范围量化:最常用、最容易成功的压缩方式。
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
    elif quantization == "float16":
        # float16 量化:权重使用半精度浮点数,适合部分移动端/GPU 场景。
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.target_spec.supported_types = [tf.float16]
    elif quantization == "int8":
        # int8 全整数量化:体积更小,但需要代表性数据集校准输入分布。
        converter.optimizations = [tf.lite.Optimize.DEFAULT]

        def representative_data_gen():
            for images, _ in representative_ds.take(100):
                for image in images:
                    yield [tf.expand_dims(tf.cast(image, tf.float32), 0)]

        converter.representative_dataset = representative_data_gen
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        converter.inference_input_type = tf.uint8
        converter.inference_output_type = tf.uint8
    elif quantization != "none":
        raise ValueError(f"Unsupported quantization mode: {quantization}")

    return converter.convert()
python 复制代码
# 创建导出目录。
export_dir = Path(EXPORT_DIR)
export_dir.mkdir(parents=True, exist_ok=True)

# 保存标签文件。部署时需要 labels.txt 把模型输出编号映射回类别名称。
labels_path = export_dir / "labels.txt"
labels_path.write_text("\n".join(class_names) + "\n", encoding="utf-8")

# 保存 Keras 原始模型,便于以后继续训练或重新转换。
keras_path = export_dir / "flower_classifier.keras"
model.save(keras_path)

# 转换并保存 TFLite 模型。
tflite_model = convert_to_tflite(model, QUANTIZATION, train_ds)
tflite_path = export_dir / "model.tflite"
tflite_path.write_bytes(tflite_model)

print(f"已保存 Keras 模型: {keras_path}")
print(f"已保存 TFLite 模型: {tflite_path}")
print(f"已保存标签文件: {labels_path}")

6. 简单测试导出的 TFLite 模型

最后用 tf.lite.Interpreter 加载刚导出的 .tflite 文件,取几张测试图片做推理,确认模型文件可以正常运行。

这里的 smoke test 不是完整评估,只是快速检查:模型能否加载、输入输出张量是否正常、预测流程是否能跑通。

python 复制代码
def smoke_test_tflite(tflite_path, test_ds, class_names):
    # 加载 TFLite 模型并分配张量内存。
    interpreter = tf.lite.Interpreter(model_path=str(tflite_path))
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]

    # 从测试集中取 8 张图片做快速推理。
    images, labels = next(iter(test_ds.unbatch().batch(8)))
    input_data = tf.cast(images, input_details["dtype"]).numpy()

    # 如果模型是 uint8 输入,需要按照量化参数把图片转换到对应范围。
    if input_details["dtype"] == np.uint8:
        scale, zero_point = input_details["quantization"]
        if scale:
            input_data = images.numpy() / scale + zero_point
            input_data = np.clip(input_data, 0, 255).astype(np.uint8)

    predictions = []
    for image in input_data:
        interpreter.set_tensor(input_details["index"], np.expand_dims(image, 0))
        interpreter.invoke()
        predictions.append(interpreter.get_tensor(output_details["index"])[0])

    predicted_ids = np.argmax(np.asarray(predictions), axis=1)
    for expected, predicted in zip(labels.numpy()[:5], predicted_ids[:5]):
        print(f"真实类别={class_names[expected]}, 预测类别={class_names[predicted]}")
python 复制代码
# 运行 TFLite 快速测试。
smoke_test_tflite(tflite_path, test_ds, class_names)

参考文献

相关推荐
ZC跨境爬虫20 小时前
跟着 MDN 学CSS day_40:(Flexbox实战技能测试)
前端·css·ui·html·tensorflow
ZC跨境爬虫1 天前
跟着 MDN 学CSS day_39:(Flexbox 弹性盒子核心机制)
前端·css·ui·html·tensorflow
ZC跨境爬虫3 天前
跟着 MDN 学CSS day_34:(CSS 布局全面解析)
前端·css·ui·html·tensorflow
ZC跨境爬虫3 天前
跟着 MDN 学CSS day_35:浮动布局完全指南
前端·css·ui·html·tensorflow
ZC跨境爬虫4 天前
跟着 MDN 学CSS day_29:(掌握文本与字体样式的核心艺术)
前端·css·ui·html·tensorflow
ZC跨境爬虫4 天前
跟着 MDN 学CSS day_30:(玩转列表样式,从基础到进阶)
前端·css·html·tensorflow·媒体
ZC跨境爬虫5 天前
跟着 MDN 学CSS day_25:(高级区块效果)
前端·css·html·tensorflow·媒体
weixin_468466855 天前
PyTorch 与 TensorFlow 实战选型与应用场景指南
人工智能·pytorch·深度学习·算法·机器学习·tensorflow·深度学习框架
之歆5 天前
Day22_CSS 函数完全指南:从变量到数学计算的现代样式编程
开发语言·前端·javascript·css·tensorflow·less