我们以手机相册的 "智能分类" 功能(识别图片中的物体类型)为例,演示如何使用 TensorFlow Lite 框架将端侧模型部署到 Android 设备上。该场景通用且覆盖端侧部署的核心步骤:模型准备→环境配置→代码集成→硬件加速→业务调用。
一、场景需求与模型选择
需求:手机相册扫描本地图片,自动分类为 "风景""宠物""美食" 等类别(支持 1000 + 类)。
模型选择:使用轻量级图像分类模型 MobileNetV3-Small(TFLite 格式),体积仅 5.4MB,精度接近 ResNet-50,但推理速度更快(手机端 < 30ms)。
二、模型准备与项目配置
- 获取 TFLite 模型
从 TensorFlow 官方模型库下载预训练的 MobileNetV3-Small TFLite 模型:
下载地址(选择.tflite格式)。
将模型文件重命名为mobilenet_v3_small.tflite,放入 Android 项目的app/src/main/assets/models/目录。 - 配置 Android 项目依赖
在build.gradle中添加 TFLite 依赖(支持硬件加速和量化模型):
gradle
groovy
dependencies {
// TFLite基础库
implementation 'org.tensorflow:tensorflow-lite:2.15.0'
// GPU加速代理(可选,需设备支持)
implementation 'org.tensorflow:tensorflow-lite-gpu-delegate:2.15.0'
// NNAPI加速代理(可选,Android 8.1+)
implementation 'org.tensorflow:tensorflow-lite-nnapi:2.15.0'
}
三、核心代码实现(端侧部署)
- 模型加载与初始化(含硬件加速)
创建ImageClassifier类,封装模型加载、推理和资源释放逻辑,支持自动选择最优加速方案(NNAPI→GPU→CPU 多线程):
kotlin
groovy
```kotlin
import android.content.Context
import android.graphics.Bitmap
import android.os.Build
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.gpu.CompatibilityList
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.nnapi.NnApiDelegate
import java.nio.ByteBuffer
import java.nio.ByteOrder
class ImageClassifier(private val context: Context) {
private var interpreter: Interpreter? = null
private var nnApiDelegate: NnApiDelegate? = null
private var gpuDelegate: GpuDelegate? = null
// 模型输入参数(根据MobileNetV3配置)
private val inputWidth = 224
private val inputHeight = 224
private val inputChannel = 3
private val inputSize = inputWidth * inputHeight * inputChannel * 4 // FP32每通道4字节
init {
initializeModel()
}
/** 初始化模型(自动选择加速方案) */
private fun initializeModel() {
val options = Interpreter.Options().apply {
// 优先使用NNAPI(需Android 9+)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
nnApiDelegate = NnApiDelegate().also { addDelegate(it) }
}
// 次优使用GPU(需设备支持)
else if (CompatibilityList().isDelegateSupportedOnThisDevice) {
gpuDelegate = GpuDelegate(CompatibilityList().bestOptionsForThisDevice).also { addDelegate(it) }
}
// 兜底使用CPU多线程
else {
setNumThreads(4)
}
setAllowFp16PrecisionForFp32(true) // 允许FP16混合精度
}
// 从assets加载模型
val modelBuffer = context.assets.open("models/mobilenet_v3_small.tflite").use { inputStream ->
val buffer = ByteBuffer.allocateDirect(inputStream.available())
.order(ByteOrder.nativeOrder())
inputStream.read(buffer.array())
buffer
}
interpreter = Interpreter(modelBuffer, options)
}
/** 执行图像分类推理 */
fun classify(bitmap: Bitmap): List<Pair<String, Float>> {
// 1. 预处理图像(缩放+归一化)
val inputBuffer = bitmapToByteBuffer(bitmap)
// 2. 初始化输出缓冲区(1000类概率)
val output = Array(1) { FloatArray(1000) }
// 3. 运行推理
interpreter?.run(inputBuffer, output)
// 4. 后处理(取Top5类别)
return output[0].withIndex()
.sortedByDescending { it.value }
.take(5)
.map { "类别${it.index}" to it.value } // 实际需替换为真实类别名称(如"狗")
}
/** 释放模型资源 */
fun release() {
interpreter?.close()
nnApiDelegate?.close()
gpuDelegate?.close()
}
/** 图像转ByteBuffer(模型输入格式) */
private fun bitmapToByteBuffer(bitmap: Bitmap): ByteBuffer {
val resizedBitmap = Bitmap.createScaledBitmap(bitmap, inputWidth, inputHeight, true)
val buffer = ByteBuffer.allocateDirect(inputSize)
.order(ByteOrder.nativeOrder())
buffer.rewind()
val pixels = IntArray(inputWidth * inputHeight)
resizedBitmap.getPixels(pixels, 0, inputWidth, 0, 0, inputWidth, inputHeight)
for (pixel in pixels) {
// 归一化到[0,1](MobileNetV3输入范围)
buffer.putFloat((pixel shr 16 and 0xFF) / 255.0f) // R
buffer.putFloat((pixel shr 8 and 0xFF) / 255.0f) // G
buffer.putFloat((pixel and 0xFF) / 255.0f) // B
}
return buffer
}
}
四、业务方调用示例
在相册应用的图片浏览界面,调用ImageClassifier实现智能分类:
```kotlin
class PhotoDetailActivity : AppCompatActivity() {
private lateinit var classifier: ImageClassifier
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
classifier = ImageClassifier(applicationContext)
// 从Intent获取图片路径
val imagePath = intent.getStringExtra("image_path")
val bitmap = BitmapFactory.decodeFile(imagePath)
// 执行分类
val result = classifier.classify(bitmap)
showClassificationResult(result)
}
private fun showClassificationResult(result: List<Pair<String, Float>>) {
// 在界面显示Top5类别(如"狗: 92.3%", "宠物: 6.5%")
val resultText = result.joinToString("\n") { "${it.first}: ${(it.second * 100).format(1)}%" }
findViewById<TextView>(R.id.tv_result).text = "分类结果:\n$resultText"
}
override fun onDestroy() {
classifier.release() // 释放模型资源
super.onDestroy()
}
}
五、关键优化点说明
硬件加速:
NNAPI(Android 9+):调用设备 GPU/TPU 专用加速器,推理速度提升 30%-50%;
GPU Delegate:通过 OpenGL/Vulkan 调用 GPU,适合不支持 NNAPI 的旧设备;
CPU 多线程:默认 4 线程,平衡速度与功耗。
模型压缩:
使用量化模型(如 INT8):体积从 15MB(FP32)缩小至 4MB,推理速度提升 2 倍,精度损失 < 2%;
若需更高精度,可保留 FP16 混合精度(体积 8MB,速度接近 INT8)。
内存优化:
使用ByteBuffer直接内存(避免 Java 堆内存拷贝);
复用输入 / 输出缓冲区(减少 GC 频率)。
六、测试与验证
性能测试:使用TFLite Benchmark工具(官方文档)测量推理延迟(目标 < 50ms);
精度验证:用 COCO 数据集测试模型准确率(MobileNetV3-Small 在 ImageNet 上的 Top-1 准确率约 67%);
设备兼容性:在主流 Android 机型(如小米、三星、Pixel)上测试,确保 NNAPI/GPU 加速正常。
通过以上步骤,业务方只需集成ImageClassifier类,传入图片即可实现端侧智能分类,无需关注底层模型加载、加速配置等细节。该方案可扩展至其他端侧场景(如目标检测、文本分类),只需替换模型和输入输出处理逻辑即可。