sahi目标检测java实现

SAHI(Slicing Aided Hyper Inference)采用了切片辅助推理和微调技术,可提高小目标对象的检测精度。

1 图像进行切片分割

private static void sahiImg(Bitmap tBitmap, int sWh, int boxWh){

int dImgW = tBitmap.getWidth();

int dImgH = tBitmap.getHeight();

int cNum = (int) Math.ceil((float) dImgW / sWh);

int rNum = (int) Math.ceil((float) dImgH / sWh);

Log.d("testWH", dImgW + "," + dImgH+ "," +rNum+ "," +cNum+ "," +sWh);

for(int i=0; i<rNum; i++) {

for (int j = 0; j < cNum; j++) {

int bX = sWh * j;

int bY = sWh * i;

if( bX >=dImgW-boxWh){

bX = dImgW-boxWh;//break;

j = cNum -1;

}

if( bY >=dImgH-boxWh){

bY = dImgH-boxWh;

i = rNum -1;

}

Bitmap part1bmap = Bitmap.createBitmap(tBitmap, bX, bY, boxWh, boxWh);

Log.d("testBmap", bX + "," + bY+ "," +boxWh+ "," +i+ ",a" +j);

if(CONST.decGpu==1){ isGpu = true; }

YoloV5Ncnn.dObj[] yoloObj = CONST.yolov5ncnn.Detect(part1bmap, isGpu);

for (YoloV5Ncnn.dObj dObj : yoloObj) {

Float[] tBoxArr = new Float[4];

tBoxArr[0] = bX + dObj.x;

tBoxArr[1] = bY + dObj.y;

tBoxArr[2] = bX + dObj.x + dObj.w;

tBoxArr[3] = bY + dObj.y + dObj.h;

boxList.add(tBoxArr);

Log.d("testXY", i+ "|" + j + "|a" + dObj.prob + CONST.yPestArr[dObj.label]);

pestIdxList.add(dObj.label);//

confidList.add(dObj.prob);

}

}

}

}

2 nms非极大值抑制

public static List<Integer> non_max_suppression(Float[][] box2Arr, List<Integer> pestIdxLs, List<Float> conFLs) {//single_class_

if (box2Arr.length == 0)

return null;

List<Integer> confIdxLs = new ArrayList<>();//保存置信度大于CONF_THRESH的元素的下标

List<Float> nConfLs = new ArrayList<>();//保存置信度大于CONF_THRESH的元素的值

List<Integer> nPestIdxLs = new ArrayList<>();//pest name

for (int i = 0; i < box2Arr.length; i++) {//confidences.size()

float confVal = conFLs.get(i);

Log.d("box",i +","+confVal);

if (confVal > Float.parseFloat(CONST.conStr)){

confIdxLs.add(i);

nConfLs.add(confVal);

nPestIdxLs.add(pestIdxLs.get(i));

}

}

if (confIdxLs.isEmpty())

return null;

int aliveIdxSize = confIdxLs.size();

List<Idxs> idxsList = new ArrayList<>();//将置信度与下标对应

for (int i = 0; i < aliveIdxSize; i++) {

//Idxs idxs = new Idxs(confIdxLs.get(i), nPestIdxLs.get(i), nConfLs.get(i));

idxsList.add(new Idxs(confIdxLs.get(i), nPestIdxLs.get(i), nConfLs.get(i)));

}

Collections.sort(idxsList);//按score升序排列

for (int i = 0; i < aliveIdxSize; i++) {

Log.d("idxNum",i +","+idxsList.get(i).getIndex()+","+ idxsList.get(i).getPestIdx()+","+idxsList.get(i).getConfVal());

}

float ovXmin, ovYmin, ovXmax, ovYmax;

float ovW, ovH, overArea, ovRatio;

//取出得分最高的bbox,计算剩下的bbox与它的交并比iou,去掉大于iou_thresh的bbox

List<Integer> pickList = new ArrayList<>();

while (idxsList.size() > 0) {

sleep(20);

int lastN = idxsList.size() - 1;

if(pickList.size() >= CONST.numDetect)//取置信度最高的NUM_DETECTIONS个结果

break;

int lastIdx = idxsList.get(lastN).getIndex();

Log.d("idx",lastIdx+","+idxsList.get(lastN).getPestIdx()+","+idxsList.get(lastN).getConfVal()+"");

float last_area = (box2Arr[lastIdx][2] -box2Arr[lastIdx][0]) * (box2Arr[lastIdx][3] -box2Arr[lastIdx][1]);//area=(xmax-xmin)*(ymax-ymin)

pickList.add(lastIdx);

List<Idxs> idxs_to_remove = new ArrayList<>();//交并比过大需要移除的bbox

for (int i = 0; i < lastN; i++) {

int iIdx = idxsList.get(i).getIndex();

ovXmin = Math.max(box2Arr[lastIdx][0], box2Arr[iIdx][0]);

ovYmin = Math.max(box2Arr[lastIdx][1], box2Arr[iIdx][1]);

ovXmax = Math.min(box2Arr[lastIdx][2], box2Arr[iIdx][2]);

ovYmax = Math.min(box2Arr[lastIdx][3], box2Arr[iIdx][3]);

ovW = Math.max(0, ovXmax - ovXmin);

ovH = Math.max(0, ovYmax - ovYmin);

overArea = ovW * ovH;

float i_area = (box2Arr[iIdx][2] -box2Arr[iIdx][0]) * (box2Arr[iIdx][3] -box2Arr[iIdx][1]);

ovRatio = overArea / ( last_area + i_area - overArea);//IoU

if (ovRatio > (float)CONST.iouThresh/100)

idxs_to_remove.add(idxsList.get(i));

}

idxs_to_remove.add(idxsList.get(lastN));

Log.d("testIdx",idxs_to_remove.size()+"||"+idxsList.size());

idxsList.removeAll(idxs_to_remove);

}

return pickList;

}

3 检测图

private static void drawImg(Bitmap mBmap,Float[][] bboxes, List<Integer> pestIdxLs, List<Float> conFLs,List<Integer> pickIdxLs) {

copyBMap = mBmap.copy(Bitmap.Config.ARGB_8888, true);

final int[] colors = new int[] {

Color.rgb( 54, 67, 244),

Color.rgb( 99, 30, 233),

Color.rgb(176, 39, 156),

Color.rgb(183, 58, 103),

Color.rgb(181, 81, 63),

Color.rgb(243, 150, 33),

Color.rgb(244, 169, 3),

Color.rgb(212, 188, 0),

Color.rgb(136, 150, 0),

Color.rgb( 80, 175, 76),

Color.rgb( 74, 195, 139),

Color.rgb( 57, 220, 205),

Color.rgb( 59, 235, 255),

Color.rgb( 7, 193, 255),

Color.rgb( 0, 152, 255),

Color.rgb( 34, 87, 255),

Color.rgb( 72, 85, 121),

Color.rgb(158, 158, 158),

Color.rgb(139, 125, 96)

};

Canvas canvas = new Canvas(copyBMap);

Paint paint = new Paint();

paint.setStyle(Paint.Style.STROKE);

paint.setStrokeWidth(4);

Paint textbgpaint = new Paint();

textbgpaint.setColor(Color.WHITE);

textbgpaint.setStyle(Paint.Style.FILL);

Paint textpaint = new Paint();

textpaint.setColor(Color.BLACK);

textpaint.setTextSize(26);

textpaint.setTextAlign(Paint.Align.LEFT);

for (int i = 0; i < pickIdxLs.size(); i++) { //if(yObj[i].prob>= Float.parseFloat(CONST.conStr)) {//高于置信限值 CONST.conStr

paint.setColor(colors[i % 19]);

float rectX1 = bboxes[pickIdxLs.get(i)][0];

float rectY1 = bboxes[pickIdxLs.get(i)][1];

float rectX2 = bboxes[pickIdxLs.get(i)][2];

float rectY2 = bboxes[pickIdxLs.get(i)][3];

canvas.drawRect(rectX1, rectY1, rectX2, rectY2, paint);

{// draw filled text inside image

String text = CONST.yPestArr[pestIdxLs.get(pickIdxLs.get(i))] + " = " + String.format("%.1f", conFLs.get(pickIdxLs.get(i)) * 100) + "%";

float text_width = textpaint.measureText(text);

float text_height = -textpaint.ascent() + textpaint.descent();

float lX1 = bboxes[pickIdxLs.get(i)][0];

float lY1 = bboxes[pickIdxLs.get(i)][1] - text_height;

if (lY1 < 0)

lY1 = 0;

if (lX1 + text_width > copyBMap.getWidth())

lX1 = copyBMap.getWidth() - text_width;

canvas.drawRect(lX1, lY1, lX1 + text_width, lY1 + text_height, textbgpaint);

canvas.drawText(text, lX1, lY1 - textpaint.ascent(), textpaint);

CONST.rectStr += (int) Math.round(rectX1) + "," + (int) Math.round(rectY1) + "," + (int) Math.round(rectX2) + "," + (int) Math.round(rectY2) + ";";

CONST.detectStr += CONST.yPestArr[pestIdxLs.get(pickIdxLs.get(i))] + ","

  • String.format("%.1f", conFLs.get(pickIdxLs.get(i)) * 100) + ";";

} //}

}

canvas.save();

canvas.restore();

}

相关推荐
励志成为嵌入式工程师14 分钟前
c语言简单编程练习9
c语言·开发语言·算法·vim
捕鲸叉44 分钟前
创建线程时传递参数给线程
开发语言·c++·算法
A charmer1 小时前
【C++】vector 类深度解析:探索动态数组的奥秘
开发语言·c++·算法
Yaml41 小时前
Spring Boot 与 Vue 共筑二手书籍交易卓越平台
java·spring boot·后端·mysql·spring·vue·二手书籍
小小小妮子~1 小时前
Spring Boot详解:从入门到精通
java·spring boot·后端
hong1616881 小时前
Spring Boot中实现多数据源连接和切换的方案
java·spring boot·后端
wheeldown1 小时前
【数据结构】选择排序
数据结构·算法·排序算法
aloha_7892 小时前
从零记录搭建一个干净的mybatis环境
java·笔记·spring·spring cloud·maven·mybatis·springboot
记录成长java2 小时前
ServletContext,Cookie,HttpSession的使用
java·开发语言·servlet
睡觉谁叫~~~2 小时前
一文解秘Rust如何与Java互操作
java·开发语言·后端·rust