android + tflite 分类APP开发-2

APP开发

build.gradle导入库

//implementation 'org.tensorflow:tensorflow-android:+'

implementation 'org.tensorflow:tensorflow-lite:2.4.0'

implementation 'org.tensorflow:tensorflow-lite-support:0.3.1

implementation 'org.tensorflow:tensorflow-lite-metadata:0.3.1'

加载模型

try {

tfLiteClassificationUtil = new TFLiteClassificationUtil(CONST.downPath + "/zjym.tflite");

Toast.makeText(MainActTflite.this, "模型加载成功!", Toast.LENGTH_SHORT).show();

} catch (Exception e) {

Toast.makeText(MainActTflite.this, "模型加载失败!", Toast.LENGTH_SHORT).show();

e.printStackTrace();

finish();

}

模型一般在assets目录下,在编译时会集成到APP中,不利于模型的迭代,这里模型保存在内部存储目录下。

分类预测

try { // 预测图像

FileInputStream fis = new FileInputStream(image_path);

imageView.setImageBitmap(BitmapFactory.decodeStream(fis));

long start = System.currentTimeMillis();

int\[\]\[\] res2Arr = tfLiteClassificationUtil.predictImage(image_path);

long end = System.currentTimeMillis();

String show_text = "预测结果标签:" + (int) res2Arrres2Arr.length-10 +

"\n名称:" + classNames.get((int) res2Arrres2Arr.length-10) +"概率:" + (float) res2Arrres2Arr.length - 11 / 256 +

"\n名称:" + classNames.get((int) res2Arrres2Arr.length-20) +"概率:" + (float) res2Arrres2Arr.length - 21 / 256 +

"\n名称:" + classNames.get((int) res2Arrres2Arr.length-30) +"概率:" + (float) res2Arrres2Arr.length - 31 / 256 +

"\n时间:" + (end - start) + "ms";

textView.setText(show_text);

} catch (Exception e) {

e.printStackTrace();

}

res2Arrres2Arr.length - 11 / 256,两个整数相除显示为0,添加(float)显示字符串

TFLiteClassificationUtil类功能模块

public TFLiteClassificationUtil(String modelPath) throws Exception {

File file = new File(modelPath);

if (!file.exists()) {

throw new Exception("model file is not exists!");

}

try {

Interpreter.Options options = new Interpreter.Options();

options.setNumThreads(NUM_THREADS);// 使用多线程预测

NnApiDelegate delegate = new NnApiDelegate();// 使用Android自带的API或者GPU加速

// GpuDelegate delegate = new GpuDelegate();

options.addDelegate(delegate);

tflite = new Interpreter(file, options);

// 获取输入,shape为{1, height, width, 3}

int\[\] imageShape = tflite.getInputTensor(tflite.getInputIndex("input_1")).shape();

DataType imageDataType = tflite.getInputTensor(tflite.getInputIndex("input_1")).dataType();

inputImageBuffer = new TensorImage(imageDataType);

// 获取输入,shape为{1, NUM_CLASSES}

int\[\] probabilityShape = tflite.getOutputTensor(tflite.getOutputIndex("Identity")).shape();

DataType probabilityDataType = tflite.getOutputTensor(tflite.getOutputIndex("Identity")).dataType();

//outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);

outputProbabilityBuffer = TensorBuffer.createFixedSize(tflite.getOutputTensor(0).shape(), DataType.UINT8);

// 添加图像预处理方式

imageProcessor = new ImageProcessor.Builder()

.add(new ResizeOp(224, 224, ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))

.add(new NormalizeOp(new float\[\] {0.0f}, new float\[\] {255.0f}))

.add(new QuantizeOp(0f, 0.003921569f))

.add(new CastOp(DataType.UINT8))

.build();

TensorProcessor probabilityPostProcessor = new TensorProcessor.Builder()

.add(new DequantizeOp((float) 0, (float) 0.00390625))

.add(new NormalizeOp(new float\[\]{0.0f}, new float\[\]{1.0f}))

.build();

} catch (Exception e) {

e.printStackTrace();

throw new Exception("load model fail!");

}

}

public int\[\]\[\] predictImage(String image_path) throws Exception {

if (!new File(image_path).exists()) {

throw new Exception("image file is not exists!");

}

FileInputStream fis = new FileInputStream(image_path);

Bitmap bitmap = BitmapFactory.decodeStream(fis);

int\[\]\[\] result = predictImage(bitmap);

if (bitmap.isRecycled()) {

bitmap.recycle();

}

return result;

}

// 重载方法,直接使用Bitmap预测

public int\[\]\[\] predictImage(Bitmap bitmap) throws Exception {

return predict(bitmap);

}

private int\[\]\[\] predict(Bitmap bmp) throws Exception {

inputImageBuffer = loadImage(bmp);

try {

tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());

} catch (Exception e) {

throw new Exception("predict image fail! log:" + e);

}

int\[\] results = outputProbabilityBuffer.getIntArray();

Log.d("results", Arrays.toString(results));

int\[\]\[\] arr = new intresults.length2;

for (int i=0;i<results.length;i++) {

arri0 = i;

arri1 = resultsi;

}

Arrays.sort(arr, Comparator.comparingInt(e -> e1));

//int l = getMaxResult(results);

return arr;//new float\[\]{l, resultsl};

}

tflite默认保存格式为UINT8,如果不加add(new CastOp(DataType.UINT8))可能显示

Cannot copy to a TensorFlowLite tensor (input_1) with 150528 bytes from a Java Buffer with 602112 bytes

默认的预训练模型是 EfficientNet-Lite0,如果为其他模型,其输入参数等也要修改。可通过下述方法查看。

Android Studio ->File ->open ->other ->tflite,打开tflite模型,build ->Make Project 会自动生成模型接口类,并移动模型到ml目录,查看类中模型参数。

相关推荐
方白羽5 小时前
Android Gradle 缓存与文件目录深度解析
android·gradle·android studio
曲幽9 小时前
Termux里的二进制和脚本,到底怎么运行才不踩坑?Termux-service 保活妙招!
android·termux·nohup·services·wake-lock
plainGeekDev10 小时前
单例模式 → object 声明
android·java·kotlin
程序员陆业聪10 小时前
读者点单·03|Compose 与传统 View 混用的 12 个真实坑
android
程序员陆业聪11 小时前
读者点单·02|Android 启动优化实战:Trace 抓取→Application 编排→冷启动全流程拆解
android
Coffeeee11 小时前
帮你快速理解AI Agent之我想招个Android实习生
android·人工智能·agent
恋猫de小郭12 小时前
苹果 AirPods 协议,Android 也可以使用完整版 AirPods 能力
android·前端·flutter
黄林晴12 小时前
告别无效重建:Gradle 9.6.0 解决 CI 构建缓存失效痛点告别无效重建:Gradle 9.6.0 解决 CI 建筑缓存失效痛点
android·gradle
张风捷特烈13 小时前
Flutter 类库大揭秘#01 | path_provider架构与设计
android·flutter
_阿南_1 天前
Android文件读写和分享总结
android