Android-机器学习组件-自定义模型

前言

通过 Android 机器学习组件-图像标签初探 - 掘金 (juejin.cn) 我们了解了如何使用基础模型实现图像标签功能。但是,基础模型所能识别的标签是有限的,是基于特定的数据集进行训练的。而实际业务场景中,需要我们结合自身的数据,需要量身定制的模型,以便模型推理的结果更符合我们实际的需求,下面就来了解一下如何使用自定义模型。

自定义模型

相比于标准模型,使用自定义模型有更多的灵活性,可以通过参数定制更多的内容。

使用自定义模型获取图像标签

  • 首先,使用自定义模型时需要依赖支持自定义模型的组件 com.google.mlkit:image-labeling-custom
  • 将模型文件放在项目代码 assets 目录下,模型文件(通常以 .tflite 或 .lite 结尾)
  • 完成 ImageLabeler 的初始化
kotlin 复制代码
    private fun initCustomLabeler() {
        val localModel = LocalModel.Builder().setAssetFilePath("model.tflite").build()
        val customImageLabelerOptions =
            CustomImageLabelerOptions.Builder(localModel).setConfidenceThreshold(0.5f)
                .setMaxResultCount(15).build()
        labeler = ImageLabeling.getClient(customImageLabelerOptions)
    }
  • 这里 setConfidenceThreshold 顾名思义,就是配置可信度最小值,小于这个可信度的结果是不会返回的
  • setMaxResultCount 返回标签个数集合的最大值

创建好 ImageLabeler 了之后,具体使用方式已经在Android 机器学习组件-图像标签初探 - 掘金 (juejin.cn)中说过了,我们用相同的图片再试一下。这里使用的自定义模型是有 Google 训练的 mobilenet_v3, 是一个专门用于图像分类相关任务的模型。

shell 复制代码
17:07:37.195 ImageLabelHelper         I  text=stage     ,confidence=8.178474  ,index=820 ,uri=content://media/external/images/media/101508
17:07:37.195 ImageLabelHelper         I  text=volcano   ,confidence=7.372629  ,index=981 ,uri=content://media/external/images/media/101508
17:07:37.195 ImageLabelHelper         I  text=spotlight ,confidence=5.266957  ,index=819 ,uri=content://media/external/images/media/101508
17:07:37.196 ImageLabelHelper         I  text=alp       ,confidence=4.521148  ,index=971 ,uri=content://media/external/images/media/101508
17:07:37.196 ImageLabelHelper         I  text=electric guitar,confidence=4.404194  ,index=547 ,uri=content://media/external/images/media/101508
17:07:37.197 ImageLabelHelper         I  text=geyser    ,confidence=4.381730  ,index=975 ,uri=content://media/external/images/media/101508
17:07:37.197 ImageLabelHelper         I  text=maypole   ,confidence=4.266276  ,index=646 ,uri=content://media/external/images/media/101508
17:07:37.197 ImageLabelHelper         I  text=planetarium,confidence=4.118463  ,index=728 ,uri=content://media/external/images/media/101508
17:07:37.197 ImageLabelHelper         I  text=ballplayer,confidence=4.013018  ,index=982 ,uri=content://media/external/images/media/101508
17:07:37.198 ImageLabelHelper         I  text=torch     ,confidence=3.877848  ,index=863 ,uri=content://media/external/images/media/101508
17:07:37.198 ImageLabelHelper         I  text=steam locomotive,confidence=3.836026  ,index=821 ,uri=content://media/external/images/media/101508
17:07:37.198 ImageLabelHelper         I  text=fountain  ,confidence=3.698731  ,index=563 ,uri=content://media/external/images/media/101508
17:07:37.199 ImageLabelHelper         I  text=unicycle  ,confidence=3.564799  ,index=881 ,uri=content://media/external/images/media/101508
17:07:37.199 ImageLabelHelper         I  text=jigsaw puzzle,confidence=3.548051  ,index=612 ,uri=content://media/external/images/media/101508
17:07:37.199 ImageLabelHelper         I  text=crash helmet,confidence=3.526286  ,index=519 ,uri=content://media/external/images/media/101508
shell 复制代码
17:09:00.118 ImageLabelHelper         I  text=basketball,confidence=10.063719 ,index=431 ,uri=content://media/external/images/media/101507
17:09:00.119 ImageLabelHelper         I  text=unicycle  ,confidence=5.737387  ,index=881 ,uri=content://media/external/images/media/101507
17:09:00.119 ImageLabelHelper         I  text=volleyball,confidence=5.126340  ,index=891 ,uri=content://media/external/images/media/101507
17:09:00.119 ImageLabelHelper         I  text=mountain bike,confidence=5.038583  ,index=672 ,uri=content://media/external/images/media/101507
17:09:00.120 ImageLabelHelper         I  text=bow       ,confidence=4.851575  ,index=457 ,uri=content://media/external/images/media/101507
17:09:00.120 ImageLabelHelper         I  text=comic book,confidence=4.779339  ,index=918 ,uri=content://media/external/images/media/101507
17:09:00.120 ImageLabelHelper         I  text=toilet seat,confidence=4.755874  ,index=862 ,uri=content://media/external/images/media/101507
17:09:00.120 ImageLabelHelper         I  text=racket    ,confidence=4.598331  ,index=753 ,uri=content://media/external/images/media/101507
17:09:00.121 ImageLabelHelper         I  text=chain saw ,confidence=4.398981  ,index=492 ,uri=content://media/external/images/media/101507
17:09:00.121 ImageLabelHelper         I  text=drum      ,confidence=4.151714  ,index=542 ,uri=content://media/external/images/media/101507
17:09:00.121 ImageLabelHelper         I  text=shield    ,confidence=3.973551  ,index=788 ,uri=content://media/external/images/media/101507
17:09:00.122 ImageLabelHelper         I  text=toyshop   ,confidence=3.807693  ,index=866 ,uri=content://media/external/images/media/101507
17:09:00.122 ImageLabelHelper         I  text=bobsled   ,confidence=3.746702  ,index=451 ,uri=content://media/external/images/media/101507
17:09:00.122 ImageLabelHelper         I  text=tricycle  ,confidence=3.694458  ,index=871 ,uri=content://media/external/images/media/101507
17:09:00.122 ImageLabelHelper         I  text=balance beam,confidence=3.685299  ,index=417 ,uri=content://media/external/images/media/101507
shell 复制代码
17:10:03.945 ImageLabelHelper         I  text=seashore  ,confidence=7.280358  ,index=979 ,uri=content://media/external/images/media/3719
17:10:03.946 ImageLabelHelper         I  text=aircraft carrier,confidence=5.773122  ,index=404 ,uri=content://media/external/images/media/3719
17:10:03.946 ImageLabelHelper         I  text=alp       ,confidence=5.731852  ,index=971 ,uri=content://media/external/images/media/3719
17:10:03.946 ImageLabelHelper         I  text=obelisk   ,confidence=5.652296  ,index=683 ,uri=content://media/external/images/media/3719
17:10:03.947 ImageLabelHelper         I  text=breakwater,confidence=5.417479  ,index=461 ,uri=content://media/external/images/media/3719
17:10:03.947 ImageLabelHelper         I  text=traffic light,confidence=5.388855  ,index=921 ,uri=content://media/external/images/media/3719
17:10:03.947 ImageLabelHelper         I  text=maze      ,confidence=5.183353  ,index=647 ,uri=content://media/external/images/media/3719
17:10:03.948 ImageLabelHelper         I  text=go-kart   ,confidence=5.009349  ,index=574 ,uri=content://media/external/images/media/3719
17:10:03.948 ImageLabelHelper         I  text=lakeside  ,confidence=4.928476  ,index=976 ,uri=content://media/external/images/media/3719
17:10:03.948 ImageLabelHelper         I  text=unicycle  ,confidence=4.758537  ,index=881 ,uri=content://media/external/images/media/3719
17:10:03.948 ImageLabelHelper         I  text=racer     ,confidence=4.605724  ,index=752 ,uri=content://media/external/images/media/3719
17:10:03.949 ImageLabelHelper         I  text=triumphal arch,confidence=4.504304  ,index=874 ,uri=content://media/external/images/media/3719
17:10:03.949 ImageLabelHelper         I  text=promontory,confidence=4.499630  ,index=977 ,uri=content://media/external/images/media/3719
17:10:03.949 ImageLabelHelper         I  text=crane     ,confidence=4.449341  ,index=518 ,uri=content://media/external/images/media/3719
17:10:03.949 ImageLabelHelper         I  text=trailer truck,confidence=4.262432  ,index=868 ,uri=content://media/external/images/media/3719

可以看到使用这个自定义模型,相比与基础模型,返回的图像标签类型更丰富了。同时标签类型索引也和之前的不同了,毕竟这个模型更大了,会有新的标签映射关系。

如何获取自定义模型

如果要使用自定义模型组件,就需要使用自己的模型。那么模型文件从哪里来呢?这个一般有两种方式

  1. 从网上获取别人训练好模型,比如直接从 TenserFlow Hub 下载别人训练好的模型。

这里需要注意的是,从网络上获取模型时要符合 TensorFlow Lite 的规范

选择模式时,从这里选择正确的模型,然后下载即可使用。

  1. 自己进行训练,这种一般是训练通用模型,也就是 PC 端可用的模型,然后转换为 TensorFlow Lite 类型的模型,以便移动端进行使用。许多深度学习框架 PyTorch,Paddle 训练之后的模型,都可以转换为符合 TensorFlow Lite 规范的模型。

    可以参考详细拆解YOLO的导出原理,以tflite格式为例实现Android端的调用

上述相关完整代码可以参考 Matisse

参考

相关推荐
喵~来学编程啦7 分钟前
【论文精读】LPT: Long-tailed prompt tuning for image classification
人工智能·深度学习·机器学习·计算机视觉·论文笔记
深圳市青牛科技实业有限公司20 分钟前
【青牛科技】应用方案|D2587A高压大电流DC-DC
人工智能·科技·单片机·嵌入式硬件·机器人·安防监控
水豚AI课代表40 分钟前
分析报告、调研报告、工作方案等的提示词
大数据·人工智能·学习·chatgpt·aigc
几两春秋梦_41 分钟前
符号回归概念
人工智能·数据挖掘·回归
用户691581141652 小时前
Ascend Extension for PyTorch的源码解析
人工智能
Chef_Chen2 小时前
从0开始学习机器学习--Day13--神经网络如何处理复杂非线性函数
神经网络·学习·机器学习
Troc_wangpeng2 小时前
R language 关于二维平面直角坐标系的制作
开发语言·机器学习
用户691581141652 小时前
Ascend C的编程模型
人工智能
-Nemophilist-2 小时前
机器学习与深度学习-1-线性回归从零开始实现
深度学习·机器学习·线性回归
神仙别闹2 小时前
基于tensorflow和flask的本地图片库web图片搜索引擎
前端·flask·tensorflow