前言
学习和了解 Google 官方提供的 LiteRT 在 Android 的部署,以经典的手写数字识别模型为例,逐步了解 LiteRT 的使用方法。
什么是 LiteRT
LiteRT(简称 Lite Runtime,以前称为 TensorFlow Lite)是 Google 面向设备端 AI 的高性能运行时。
主要特性
-
针对设备端机器学习进行了优化:LiteRT 解决了五项关键的 ODML 约束条件:延迟时间(无需往返服务器)、隐私性(没有个人数据离开设备)、连接性(无需连接到互联网)、大小(缩减了模型和二进制文件大小)和功耗(高效推理和缺少网络连接)。
-
支持多平台:与 Android 和 iOS 设备、嵌入式 Linux 和微控制器兼容。
-
多框架模型选项:AI Edge 提供了一些工具,可将 TensorFlow、PyTorch 和 JAX 模型转换为 FlatBuffers 格式 (.tflite),让您能够在 LiteRT 上使用各种先进的模型。您还可以使用可处理量化和元数据的模型优化工具。
-
支持多种语言:包括适用于 Java/Kotlin、Swift、Objective-C、C++ 和 Python 的 SDK。
-
高性能:通过 GPU 和 iOS Core ML 等专用代理实现硬件加速。
可以看到 LiteRT 还是很有前景的。
为什么选择 LiteRT
之前一直在使用 PyTorchLite 作为 Android 端推理模型的框架,但是依赖 SDK 后包体积爆炸式增长。日常开发的时候,编译和安装都太耗时。 再有 LiteRT 支持将 PyTorch 模型转换为自身可用的格式,而且是 Google 官方支持的亲儿子,因此选择 LiteRT 似乎是自然而然的事情。
可以看到 PyTorchLite 这个 SDK 一点也不 Lite, 其 so
文件的体积居然有 57.6 MB ,比模型文件还大。而 TensonFlowLite 即 LiteRT 需要的 so
只有 536 KB,两者之间相差近乎 110 倍,因此选择 LiterRT 势在必行。其实另一个重要原因是 PyTorch 官方已经不再维护 PyTorchLite 了,pytorch/android-demo-app 这个仓库在 2024 年 9 月已经声明不再维护,官方推荐 ExecuTorch 。
LiteRT 入门
添加依赖
gradle
// LiteRT dependencies for Google Play services
implementation 'com.google.android.gms:play-services-tflite-java:16.4.0'
// Optional: include LiteRT Support Library
implementation 'com.google.android.gms:play-services-tflite-support:16.4.0'
初始化模型及 InterpreterApi
SDK 初始化
kotlin
fun init(context: Context) {
TfLite.initialize(context).addOnSuccessListener {
Log.d(TAG, "ver ${TensorFlowLite.schemaVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)}")
Log.d(TAG, "ver ${TensorFlowLite.runtimeVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)}")
}.addOnFailureListener {
Log.d(TAG, "Initialized TFLite fail.")
}
}
SDK 初始化完成之后,就可以执行模型的初始化
模型及推理接口的初始化
kotlin
fun createInterpreterApi(context: Context, modelName: String): InterpreterApi? {
val model = loadModelFile(context.assets, modelName)
val interpreterOption = InterpreterApi.Options()
.setRuntime(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)
interpreter = InterpreterApi.create(model, interpreterOption)
return interpreter
}
InterpreterApi 就是 SDK 提供的统一接口,方便获取推理模型的信息,执行推理操作。
kotlin
interpreter?.let { inter ->
val inputShape = inter.getInputTensor(0).shape()
inputImageWidth = inputShape[1]
inputImageHeight = inputShape[2]
modelInputSize = 4 * inputImageWidth * inputImageHeight
}
通过 interpreter
我们可以获取到推理模型各方面的信息
kotlin
val inputShape = inter.getInputTensor(0).shape()
Log.d(TAG, "input shape = ${inputShape.contentToString()}")
Log.d(TAG, "elem shape = ${inter.getInputTensor(0).numElements()}")
Log.d(TAG, "output shape = ${inter.getOutputTensor(0).shape().contentToString()}")
shell
DigitClassifier D input shape = [1, 28, 28, 1]
DigitClassifier D elem shape = 784
DigitClassifier D output shape = [1, 10]
可以看到当前推理模型如下信息
- 输入是 1x28x28x1 的 4 维数组。如果你对神经网路稍有了解的话,应该不陌生。这是手写数字识别时,MNIST 数据集提供的标准输入,单通道 28x28 像素的手写数字图像。
- 输入元素的总个数
784 = 1*28*28*1
- 模型输出是 1x10 的数组,恰好是 0~9 是 10 个数字的概率。

图中就是 MNIST 数据集的示例,图中每一个手写数字的图片都是 28x28 像素的单通道黑白图片
执行推理操作
使用推理模型进行运算,需要做的事情和之前一样,给模型提供合适的输入数据,根据返回结果的数据类型展现我们需要的信息。对于手写数字识别的场景,输入就是图像信息,而输出则是模型对于 0~9 这 10 个数字的判断概率。
kotlin
private fun classify(bitmap: Bitmap) {
val resizedImage = bitmap.scale(inputImageWidth, inputImageHeight)
val byteBuffer = convertBitmapToByteBuffer(resizedImage)
val output = Array(1) { FloatArray(10) }
interpreter?.run(byteBuffer, output)
val result = output[0]
Log.d(TAG, "result = ${result.contentToString()}")
val maxIndex = result.indices.maxBy { result[it] }
}
- 首先,进行图片缩放。由于模型预期输入的图片大小是 28x28 像素,因此需要基于输入的 Bitmap 进行裁剪
- 其次,将输入 ARGB 格式的 Bitmap 转换适用于 TensorFlow 的类型
- 定义模型的输出大小,即用于存放 0~9 这个几个数字概率的一维数组即可
- 最后,执行推理,获取返回值中概率最大的那个值。

shell
result = [5.4527607E-8, 4.035255E-9, 7.9443926E-5, 3.279577E-4, 0.0015575942, 0.99790514, 5.318362E-6, 3.163248E-5, 1.296262E-5, 7.9982296E-5]
根据打印的日志,返回概率和我们实际输入的图像信息是一致的,模型准确的推理出了数字 5 的概率最大。
可以看到,这个模型整体在表现还是不错的,可以非常精确的识别出 0~9 这几个手写数字的内容。
如果你之前使用过 TensorFlow Lite ,官方也提供了迁移文档
小结
无论是 PyTorch 还是 TensorFlow ,运用已经训练好的模型进行推理操作,其原理和执行步骤本质是相同的,都是依据模型输入输入的格式提供数据,将模型输出的数据用合适的方式展现出来。无论是在 PC 还是移动端,都是类似的做法。无非就是各个平台提供的接口有些许差异,需要我们进行适配。