SAHI(Slicing Aided Hyper Inference)采用了切片辅助推理和微调技术,可提高小目标对象的检测精度。
1 图像进行切片分割
private static void sahiImg(Bitmap tBitmap, int sWh, int boxWh){
int dImgW = tBitmap.getWidth();
int dImgH = tBitmap.getHeight();
int cNum = (int) Math.ceil((float) dImgW / sWh);
int rNum = (int) Math.ceil((float) dImgH / sWh);
Log.d("testWH", dImgW + "," + dImgH+ "," +rNum+ "," +cNum+ "," +sWh);
for(int i=0; i<rNum; i++) {
for (int j = 0; j < cNum; j++) {
int bX = sWh * j;
int bY = sWh * i;
if( bX >=dImgW-boxWh){
bX = dImgW-boxWh;//break;
j = cNum -1;
}
if( bY >=dImgH-boxWh){
bY = dImgH-boxWh;
i = rNum -1;
}
Bitmap part1bmap = Bitmap.createBitmap(tBitmap, bX, bY, boxWh, boxWh);
Log.d("testBmap", bX + "," + bY+ "," +boxWh+ "," +i+ ",a" +j);
if(CONST.decGpu==1){ isGpu = true; }
YoloV5Ncnn.dObj[] yoloObj = CONST.yolov5ncnn.Detect(part1bmap, isGpu);
for (YoloV5Ncnn.dObj dObj : yoloObj) {
Float[] tBoxArr = new Float[4];
tBoxArr[0] = bX + dObj.x;
tBoxArr[1] = bY + dObj.y;
tBoxArr[2] = bX + dObj.x + dObj.w;
tBoxArr[3] = bY + dObj.y + dObj.h;
boxList.add(tBoxArr);
Log.d("testXY", i+ "|" + j + "|a" + dObj.prob + CONST.yPestArr[dObj.label]);
pestIdxList.add(dObj.label);//
confidList.add(dObj.prob);
}
}
}
}
2 nms非极大值抑制
public static List<Integer> non_max_suppression(Float[][] box2Arr, List<Integer> pestIdxLs, List<Float> conFLs) {//single_class_
if (box2Arr.length == 0)
return null;
List<Integer> confIdxLs = new ArrayList<>();//保存置信度大于CONF_THRESH的元素的下标
List<Float> nConfLs = new ArrayList<>();//保存置信度大于CONF_THRESH的元素的值
List<Integer> nPestIdxLs = new ArrayList<>();//pest name
for (int i = 0; i < box2Arr.length; i++) {//confidences.size()
float confVal = conFLs.get(i);
Log.d("box",i +","+confVal);
if (confVal > Float.parseFloat(CONST.conStr)){
confIdxLs.add(i);
nConfLs.add(confVal);
nPestIdxLs.add(pestIdxLs.get(i));
}
}
if (confIdxLs.isEmpty())
return null;
int aliveIdxSize = confIdxLs.size();
List<Idxs> idxsList = new ArrayList<>();//将置信度与下标对应
for (int i = 0; i < aliveIdxSize; i++) {
//Idxs idxs = new Idxs(confIdxLs.get(i), nPestIdxLs.get(i), nConfLs.get(i));
idxsList.add(new Idxs(confIdxLs.get(i), nPestIdxLs.get(i), nConfLs.get(i)));
}
Collections.sort(idxsList);//按score升序排列
for (int i = 0; i < aliveIdxSize; i++) {
Log.d("idxNum",i +","+idxsList.get(i).getIndex()+","+ idxsList.get(i).getPestIdx()+","+idxsList.get(i).getConfVal());
}
float ovXmin, ovYmin, ovXmax, ovYmax;
float ovW, ovH, overArea, ovRatio;
//取出得分最高的bbox,计算剩下的bbox与它的交并比iou,去掉大于iou_thresh的bbox
List<Integer> pickList = new ArrayList<>();
while (idxsList.size() > 0) {
sleep(20);
int lastN = idxsList.size() - 1;
if(pickList.size() >= CONST.numDetect)//取置信度最高的NUM_DETECTIONS个结果
break;
int lastIdx = idxsList.get(lastN).getIndex();
Log.d("idx",lastIdx+","+idxsList.get(lastN).getPestIdx()+","+idxsList.get(lastN).getConfVal()+"");
float last_area = (box2Arr[lastIdx][2] -box2Arr[lastIdx][0]) * (box2Arr[lastIdx][3] -box2Arr[lastIdx][1]);//area=(xmax-xmin)*(ymax-ymin)
pickList.add(lastIdx);
List<Idxs> idxs_to_remove = new ArrayList<>();//交并比过大需要移除的bbox
for (int i = 0; i < lastN; i++) {
int iIdx = idxsList.get(i).getIndex();
ovXmin = Math.max(box2Arr[lastIdx][0], box2Arr[iIdx][0]);
ovYmin = Math.max(box2Arr[lastIdx][1], box2Arr[iIdx][1]);
ovXmax = Math.min(box2Arr[lastIdx][2], box2Arr[iIdx][2]);
ovYmax = Math.min(box2Arr[lastIdx][3], box2Arr[iIdx][3]);
ovW = Math.max(0, ovXmax - ovXmin);
ovH = Math.max(0, ovYmax - ovYmin);
overArea = ovW * ovH;
float i_area = (box2Arr[iIdx][2] -box2Arr[iIdx][0]) * (box2Arr[iIdx][3] -box2Arr[iIdx][1]);
ovRatio = overArea / ( last_area + i_area - overArea);//IoU
if (ovRatio > (float)CONST.iouThresh/100)
idxs_to_remove.add(idxsList.get(i));
}
idxs_to_remove.add(idxsList.get(lastN));
Log.d("testIdx",idxs_to_remove.size()+"||"+idxsList.size());
idxsList.removeAll(idxs_to_remove);
}
return pickList;
}
3 检测图
private static void drawImg(Bitmap mBmap,Float[][] bboxes, List<Integer> pestIdxLs, List<Float> conFLs,List<Integer> pickIdxLs) {
copyBMap = mBmap.copy(Bitmap.Config.ARGB_8888, true);
final int[] colors = new int[] {
Color.rgb( 54, 67, 244),
Color.rgb( 99, 30, 233),
Color.rgb(176, 39, 156),
Color.rgb(183, 58, 103),
Color.rgb(181, 81, 63),
Color.rgb(243, 150, 33),
Color.rgb(244, 169, 3),
Color.rgb(212, 188, 0),
Color.rgb(136, 150, 0),
Color.rgb( 80, 175, 76),
Color.rgb( 74, 195, 139),
Color.rgb( 57, 220, 205),
Color.rgb( 59, 235, 255),
Color.rgb( 7, 193, 255),
Color.rgb( 0, 152, 255),
Color.rgb( 34, 87, 255),
Color.rgb( 72, 85, 121),
Color.rgb(158, 158, 158),
Color.rgb(139, 125, 96)
};
Canvas canvas = new Canvas(copyBMap);
Paint paint = new Paint();
paint.setStyle(Paint.Style.STROKE);
paint.setStrokeWidth(4);
Paint textbgpaint = new Paint();
textbgpaint.setColor(Color.WHITE);
textbgpaint.setStyle(Paint.Style.FILL);
Paint textpaint = new Paint();
textpaint.setColor(Color.BLACK);
textpaint.setTextSize(26);
textpaint.setTextAlign(Paint.Align.LEFT);
for (int i = 0; i < pickIdxLs.size(); i++) { //if(yObj[i].prob>= Float.parseFloat(CONST.conStr)) {//高于置信限值 CONST.conStr
paint.setColor(colors[i % 19]);
float rectX1 = bboxes[pickIdxLs.get(i)][0];
float rectY1 = bboxes[pickIdxLs.get(i)][1];
float rectX2 = bboxes[pickIdxLs.get(i)][2];
float rectY2 = bboxes[pickIdxLs.get(i)][3];
canvas.drawRect(rectX1, rectY1, rectX2, rectY2, paint);
{// draw filled text inside image
String text = CONST.yPestArr[pestIdxLs.get(pickIdxLs.get(i))] + " = " + String.format("%.1f", conFLs.get(pickIdxLs.get(i)) * 100) + "%";
float text_width = textpaint.measureText(text);
float text_height = -textpaint.ascent() + textpaint.descent();
float lX1 = bboxes[pickIdxLs.get(i)][0];
float lY1 = bboxes[pickIdxLs.get(i)][1] - text_height;
if (lY1 < 0)
lY1 = 0;
if (lX1 + text_width > copyBMap.getWidth())
lX1 = copyBMap.getWidth() - text_width;
canvas.drawRect(lX1, lY1, lX1 + text_width, lY1 + text_height, textbgpaint);
canvas.drawText(text, lX1, lY1 - textpaint.ascent(), textpaint);
CONST.rectStr += (int) Math.round(rectX1) + "," + (int) Math.round(rectY1) + "," + (int) Math.round(rectX2) + "," + (int) Math.round(rectY2) + ";";
CONST.detectStr += CONST.yPestArr[pestIdxLs.get(pickIdxLs.get(i))] + ","
- String.format("%.1f", conFLs.get(pickIdxLs.get(i)) * 100) + ";";
} //}
}
canvas.save();
canvas.restore();
}