Google 端侧 AI 框架 LiteRT 初探

前言

学习和了解 Google 官方提供的 LiteRT 在 Android 的部署,以经典的手写数字识别模型为例,逐步了解 LiteRT 的使用方法。

什么是 LiteRT

LiteRT(简称 Lite Runtime,以前称为 TensorFlow Lite)是 Google 面向设备端 AI 的高性能运行时。

主要特性

  • 针对设备端机器学习进行了优化:LiteRT 解决了五项关键的 ODML 约束条件:延迟时间(无需往返服务器)、隐私性(没有个人数据离开设备)、连接性(无需连接到互联网)、大小(缩减了模型和二进制文件大小)和功耗(高效推理和缺少网络连接)。

  • 支持多平台:与 Android 和 iOS 设备、嵌入式 Linux 和微控制器兼容。

  • 多框架模型选项:AI Edge 提供了一些工具,可将 TensorFlow、PyTorch 和 JAX 模型转换为 FlatBuffers 格式 (.tflite),让您能够在 LiteRT 上使用各种先进的模型。您还可以使用可处理量化和元数据的模型优化工具。

  • 支持多种语言:包括适用于 Java/Kotlin、Swift、Objective-C、C++ 和 Python 的 SDK。

  • 高性能:通过 GPU 和 iOS Core ML 等专用代理实现硬件加速。

可以看到 LiteRT 还是很有前景的。

为什么选择 LiteRT

之前一直在使用 PyTorchLite 作为 Android 端推理模型的框架,但是依赖 SDK 后包体积爆炸式增长。日常开发的时候,编译和安装都太耗时。 再有 LiteRT 支持将 PyTorch 模型转换为自身可用的格式,而且是 Google 官方支持的亲儿子,因此选择 LiteRT 似乎是自然而然的事情。

可以看到 PyTorchLite 这个 SDK 一点也不 Lite, 其 so 文件的体积居然有 57.6 MB ,比模型文件还大。而 TensonFlowLite 即 LiteRT 需要的 so 只有 536 KB,两者之间相差近乎 110 倍,因此选择 LiterRT 势在必行。其实另一个重要原因是 PyTorch 官方已经不再维护 PyTorchLite 了,pytorch/android-demo-app 这个仓库在 2024 年 9 月已经声明不再维护,官方推荐 ExecuTorch

LiteRT 入门

添加依赖

gradle 复制代码
    // LiteRT dependencies for Google Play services
    implementation 'com.google.android.gms:play-services-tflite-java:16.4.0'
    // Optional: include LiteRT Support Library
    implementation 'com.google.android.gms:play-services-tflite-support:16.4.0'

初始化模型及 InterpreterApi

SDK 初始化
kotlin 复制代码
    fun init(context: Context) {
        TfLite.initialize(context).addOnSuccessListener {
            Log.d(TAG, "ver ${TensorFlowLite.schemaVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)}")
            Log.d(TAG, "ver ${TensorFlowLite.runtimeVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)}")
        }.addOnFailureListener {
            Log.d(TAG, "Initialized TFLite fail.")
            
        }
    }

SDK 初始化完成之后,就可以执行模型的初始化

模型及推理接口的初始化
kotlin 复制代码
    fun createInterpreterApi(context: Context, modelName: String): InterpreterApi? {
        val model = loadModelFile(context.assets, modelName)
        val interpreterOption = InterpreterApi.Options()
            .setRuntime(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)
        interpreter = InterpreterApi.create(model, interpreterOption)
        return interpreter
    }

InterpreterApi 就是 SDK 提供的统一接口,方便获取推理模型的信息,执行推理操作。

kotlin 复制代码
                interpreter?.let { inter ->
                    val inputShape = inter.getInputTensor(0).shape()
                    inputImageWidth = inputShape[1]
                    inputImageHeight = inputShape[2]
                    modelInputSize = 4 * inputImageWidth * inputImageHeight 
                }

通过 interpreter 我们可以获取到推理模型各方面的信息

kotlin 复制代码
val inputShape = inter.getInputTensor(0).shape()
Log.d(TAG, "input  shape = ${inputShape.contentToString()}")
Log.d(TAG, "elem   shape = ${inter.getInputTensor(0).numElements()}")
Log.d(TAG, "output shape = ${inter.getOutputTensor(0).shape().contentToString()}")    
shell 复制代码
DigitClassifier          D  input  shape = [1, 28, 28, 1]
DigitClassifier          D  elem   shape = 784
DigitClassifier          D  output shape = [1, 10]

可以看到当前推理模型如下信息

  • 输入是 1x28x28x1 的 4 维数组。如果你对神经网路稍有了解的话,应该不陌生。这是手写数字识别时,MNIST 数据集提供的标准输入,单通道 28x28 像素的手写数字图像。
  • 输入元素的总个数 784 = 1*28*28*1
  • 模型输出是 1x10 的数组,恰好是 0~9 是 10 个数字的概率。

图中就是 MNIST 数据集的示例,图中每一个手写数字的图片都是 28x28 像素的单通道黑白图片

执行推理操作

使用推理模型进行运算,需要做的事情和之前一样,给模型提供合适的输入数据,根据返回结果的数据类型展现我们需要的信息。对于手写数字识别的场景,输入就是图像信息,而输出则是模型对于 0~9 这 10 个数字的判断概率。

kotlin 复制代码
    private fun classify(bitmap: Bitmap) {
        
        val resizedImage = bitmap.scale(inputImageWidth, inputImageHeight)
        
        val byteBuffer = convertBitmapToByteBuffer(resizedImage)

    
        val output = Array(1) { FloatArray(10) }

        
        interpreter?.run(byteBuffer, output)
        
        val result = output[0]
        Log.d(TAG, "result = ${result.contentToString()}")
        val maxIndex = result.indices.maxBy { result[it] }
        
    }
  • 首先,进行图片缩放。由于模型预期输入的图片大小是 28x28 像素,因此需要基于输入的 Bitmap 进行裁剪
  • 其次,将输入 ARGB 格式的 Bitmap 转换适用于 TensorFlow 的类型
  • 定义模型的输出大小,即用于存放 0~9 这个几个数字概率的一维数组即可
  • 最后,执行推理,获取返回值中概率最大的那个值。
shell 复制代码
result = [5.4527607E-8, 4.035255E-9, 7.9443926E-5, 3.279577E-4, 0.0015575942, 0.99790514, 5.318362E-6, 3.163248E-5, 1.296262E-5, 7.9982296E-5]

根据打印的日志,返回概率和我们实际输入的图像信息是一致的,模型准确的推理出了数字 5 的概率最大。

可以看到,这个模型整体在表现还是不错的,可以非常精确的识别出 0~9 这几个手写数字的内容。

如果你之前使用过 TensorFlow Lite ,官方也提供了迁移文档

小结

无论是 PyTorch 还是 TensorFlow ,运用已经训练好的模型进行推理操作,其原理和执行步骤本质是相同的,都是依据模型输入输入的格式提供数据,将模型输出的数据用合适的方式展现出来。无论是在 PC 还是移动端,都是类似的做法。无非就是各个平台提供的接口有些许差异,需要我们进行适配。

参考文档

LiteRT 概览

相关推荐
go54631584658 分钟前
基于深度学习的食管癌右喉返神经旁淋巴结预测系统研究
图像处理·人工智能·深度学习·神经网络·算法
Blossom.11810 分钟前
基于深度学习的图像分类:使用Capsule Networks实现高效分类
人工智能·python·深度学习·神经网络·机器学习·分类·数据挖掘
宇称不守恒4.013 分钟前
2025暑期—05神经网络-卷积神经网络
深度学习·神经网络·cnn
格林威1 小时前
Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型实现沙滩小人检测识别(C#代码UI界面版)
人工智能·深度学习·数码相机·yolo·计算机视觉
kerli2 小时前
Android 嵌套滑动设计思想
android·客户端
巫婆理发2222 小时前
神经网络(多层感知机)(第二课第二周)
人工智能·深度学习·神经网络
xw33734095642 小时前
彩色转灰度的核心逻辑:三种经典方法及原理对比
人工智能·python·深度学习·opencv·计算机视觉
恣艺3 小时前
LeetCode 854:相似度为 K 的字符串
android·算法·leetcode
贝塔西塔3 小时前
PytorchLightning最佳实践基础篇
pytorch·深度学习·lightning·编程框架
阿华的代码王国3 小时前
【Android】相对布局应用-登录界面
android·xml·java