介绍
TensorFlow Lite 是 Google 开发的一个用于在移动设备和嵌入式系统上进行机器学习推断的轻量级框架。它是 TensorFlow 的一个变种,专门针对资源受限的环境进行了优化,以便在手机、物联网设备、嵌入式系统和边缘设备上运行深度学习模型。
TensorFlow Lite 支持各种硬件平台,包括 Android 和 iOS 设备、树莓派和其他单板计算机,以及各种嵌入式系统。它提供了用于将 TensorFlow 模型转换为适用于移动设备的轻量级模型格式的工具,并提供了针对移动设备的高效推断运行时。这使得开发者可以将经过训练的神经网络模型部署到移动设备上,以进行实时推断,例如图像分类、目标检测和语音识别等应用。
TensorFlow Lite 有助于在资源有限的环境中实现高性能的机器学习应用,同时减少了模型大小和内存占用。这对于移动应用程序和嵌入式系统非常有用,因为它们通常具有有限的计算和存储资源。
tfLite模型的生成过程
- 训练模型: 首先,你需要训练一个深度学习模型,通常使用 TensorFlow 或其他深度学习框架。这个模型可以是用于图像分类、目标检测、自然语言处理等各种任务的模型。
- 转换模型: 一旦训练好了模型,接下来需要将它转换为 TensorFlow Lite 格式。
- 元数据添加(可选): 如果需要,在转换后,你可以添加模型的元数据,这可以通过创建一个 FlatBuffer 格式的元数据文件来完成,其中包括有关模型的详细信息。
端模型推断
查看官方文档
www.tensorflow.org/lite/guide?...
可以看到主要分为两类:
1.不包含元数据的模型
使用Interpreter api
可以看到既可以使用多维数组,也可以使用ByteBuffer作为输入
并且数据类型只能使用基元类型,不能使用Integer、Float 等包装数据类型
float
int
long
byte
那么如何确定数组的维数呢?如图所示,一般在下载完毕tflite文件之后,都会带有一个描述
这里安利一个下载网址 tfhub.dev/ 里面有很多已经训练好了的模型可以供我们使用 Text/Image/Video/Audio 等等领域的模型,但是要注意是否支持tflite 格式,只有这个格式才能在端上运行
这里以这个图片分类模型位例子
可以看到这是一个图像分类的库,输入一张 宽 224 高 224 的 RGB图片,输出
输入类型为1 * 224 * 224 * 3 的一个四维数组 (表示输入为224* * 224 的一张RGB图片)
输出类型为 1 * 1001 的一个二维数组 (表示该类别一共能识别1001种类型)
step1 加载模型
首先先引入依赖
erlang
// tensorflow lite 依赖
implementation ("org.tensorflow:tensorflow-lite:2.8.0")
implementation ("org.tensorflow:tensorflow-lite-metadata:0.1.0")
implementation("org.tensorflow:tensorflow-lite-task-vision:0.4.0")
implementation("org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.0")
implementation("org.tensorflow:tensorflow-lite-gpu:2.9.0")
下载模型放在asset 目录下之后进行加载(也可以在线下载,减少apk的体积)
kotlin
@Throws(IOException::class)
private fun loadModelFile(assetManager: AssetManager): ByteBuffer {
val fileDescriptor = assetManager.openFd(MODEL_FILE)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
step 2 初始化Interpreter
kotlin
private fun initializeInterpreter() {
// Load the TF Lite model
val assetManager = context.assets
val model = loadModelFile(assetManager)
interpreter= Interpreter(model)
val metadataExtractor = MetadataExtractor(model)
// Read input shape from model file
val inputShape = interpreter.getInputTensor(0).shape()
val outputShape = interpreter.getOutputTensor(0).shape()
val inputStream=metadataExtractor.getAssociatedFile("labels.txt")
label= FileUtil.loadLabels(inputStream, Charsets.UTF_8)
Log.d("label", "initializeInterpreter:${label.size} ")
Log.d("Shape", "Initialized TFLite interpreter. inputShape:${Arrays.toString(inputShape)}, outputShape:${Arrays.toString(outputShape)}")
}
可以看到如果一个模型不知到其输入和输出也可以通过interpreter.getInputTensor(0).shape()/interpreter.getOutputTensor(0).shape()来获取
或者在模型包含MataData的时候也可以通过 MetadataExtractor获取,因为元数据里面就包含输入输出
同时因为这个是个图像识别模型,我们这matadata中读取了labels.txt文件
step 3 构建 输入/输出对象
输出对象
scss
val result = Array(1) {
FloatArray(1001)
}
输入对象
我们获取到原始图片的bitmap是ARGB四通道,需要转化为RGB三通道
使用右移操作符
shr
和按位与操作符and
,我们提取了红色分量(pixel shr 16 and 0xFF
)、绿色分量(pixel shr 8 and 0xFF
)和蓝色分量(pixel and 0xFF
)。最后,将分离出的 RGB 分量值进行标准化处理,将像素值范围从 0-255 缩放到 0-1(通道度)。通过除以 255f,我们将颜色值转换为浮点数,并存储在
result
数组的相应位置上。
scss
fun bitmapToFloatArray(bitmap: Bitmap): Array<Array<Array<FloatArray>>> {
val height = bitmap.height
val width = bitmap.width
// 初始化一个float数组
val result = Array(1) {
Array(height) {
Array(width) {
FloatArray(3)
}
}
}
for (i in 0 until height) {
for (j in 0 until width) {
// 获取像素值
val pixel = bitmap.getPixel(j, i)
// 将RGB值分离并进行标准化(将颜色值标准化到0-1之间)
result[0][i][j][0] = ((pixel shr 16 and 0xFF) / 255.0f)
result[0][i][j][1] = ((pixel shr 8 and 0xFF) / 255.0f)
result[0][i][j][2] = ((pixel and 0xFF) / 255.0f)
}
}
return result
}
step4 运算推理
scss
val inputArray=bitmapToFloatArray(bitmap)//输入转化
interpreter.run(inputArray,result)//运算推理
val results= ArrayList<Classification>()
//获取结果后对分数低于0.5的对象进行过滤
result[0].forEachIndexed{index, score ->
if (score>=0.5){
results.add(Classification(label[index],score))
}
}
推理结果
2包含元数据的模型(简单易用)
如果模型包含了元数据,可以用 task库和support库开箱即用的api
例如
在task库中提供了ImageClassifier
类用于图像分类任务,ObjectDetector
类用于目标检测任务
并且对输出的检测结果进行了封装
step1 加载模型并初始化
scss
private fun setupClassifier() {
val baseOptions = BaseOptions.builder()
.setNumThreads(2)
.build()
val options = ImageClassifier.ImageClassifierOptions.builder()
.setBaseOptions(baseOptions)
.setMaxResults(maxResults)
.setScoreThreshold(threshold)
.build()
try {
classifier = ImageClassifier.createFromFileAndOptions(
context,
"mobilenet_v1_1.0_192_1_metadata_1.tflite",
options
)
} catch (e: IllegalStateException) {
e.printStackTrace()
}
}
这里都是调用了task库中的api方法,不仅包括了加载模型,还可以设置最大的输出结果(最多多少个),得分阈值(只有得分达到才可以获取)
step2 推理运算
kotlin
override fun classify(bitmap: Bitmap, rotation: Int): List<Classification> {
if(classifier == null) {
setupClassifier()
}
val imageProcessor = ImageProcessor.Builder().build()
val tensorImage = imageProcessor.process(TensorImage.fromBitmap(bitmap))
val imageProcessingOptions = ImageProcessingOptions.builder()
.setOrientation(getOrientationFromRotation(rotation))
.build()
val results = classifier?.classify(tensorImage, imageProcessingOptions)
if (results != null) {
Log.d("label", "analyze: ${results.toList()}")
}
return results?.flatMap { classications ->
classications.categories.map { category ->
Classification(
name = category.displayName,
score = category.score
)
}
}?.distinctBy { it.name } ?: emptyList()
}
可以看到,我们不需要对bitmap单独处理了,只用转化为tensorImage就行
分类也可以直接得到结果
使用mlkit
developers.google.com/ml-kit/lang...
如果上面的都还算复杂,google 的mikit 那就是开箱即用
不仅简化了输入输出,就连模型都是官方训练好的
这是官网首页的截图
例如试一下ocr文本识别功能
Step1 依赖
mlkit 提供了两种依赖方式
- 直接捆绑依赖:模型直接依赖在apk里面
- 动态依赖:模型可以在应用安装之后动态的从GooglePlay 下法,好处是apk比较小,坏处是国内无法使用
这里选择第一种
arduino
implementation 'com.google.mlkit:text-recognition-chinese:16.0.0'
implementation ("com.google.android.gms:play-services-tasks:18.0.2")
第二个依赖用于在 Android 应用程序中处理异步任务。Tasks 库提供了一种方便的方式来处理和管理后台任务,包括异步操作、并发执行和任务链。
step 2 分析
kotlin
//1. 创建 TextRecognizer 实例
// When using Chinese script library
val recognizer = TextRecognition.getClient(ChineseTextRecognizerOptions.Builder().build())
@OptIn(ExperimentalGetImage::class) override fun analyze(image: ImageProxy) {
if (frameSkipCounter % 60 == 0) {
//2.输入
val inputImage=InputImage.fromMediaImage(image.image!!,image.imageInfo.rotationDegrees)
//3.运行模型
val result = recognizer.process(inputImage).addOnSuccessListener {
text=it.text
Log.d("text",it.text)
}
}
frameSkipCounter++
image.close()
}
使用起来非常简单
对于图像的输入,官方也提供了很多的类型,这里选择了camerax 进行获取图像,简化了对旋转角度的处理
效果如下