目标检测之单类别NMS

long time no see!

在目标检测中,常见的是多类别NMS,也就是只对相同类别的boxes来计算IOU;但现实场景中经常遇到同一个物体被识别成2个类别,也就是模型认为它既是类别1也是类别2.这时候通过多类别nms就过滤不掉这种重叠的框。所以就需要进行单类别NMS:即把所有的boxes都认为是一个类别,然后再计算IOU来过滤。

这个函数的三个输入参数分别是:模型检测得到的框(x,y,w,h)、 每个框的得分、nms阈值

python 复制代码
def oneclass_nms(boxes, class_probs, nms_threshold):


    def get_iou(box1, box2):
        """
        计算两个边界框的IOU
        :param box1: 第一个边界框,格式为 [x1, y1, x2, y2]
        :param box2: 第二个边界框,格式为 [x1, y1, x2, y2]
        :return: IOU的值
        """
        x11, y11, x12, y12 = box1
        x21, y21, x22, y22 = box2

        # 计算边界框的交集
        inter_x1 = max(x11, x21)
        inter_y1 = max(y11, y21)
        inter_x2 = min(x12, x22)
        inter_y2 = min(y12, y22)

        # 计算交集面积
        inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)

        # 计算边界框的总面积
        box1_area = (x12 - x11) * (y12 - y11)
        box2_area = (x22 - x21) * (y22 - y21)

        # 计算并集面积
        union_area = box1_area + box2_area - inter_area

        # 计算IOU
        iou = inter_area / union_area
        return iou

    # 初始化一个空列表来存储保留的边界
    boxes_list = copy.deepcopy(boxes.tolist())
    boxes_list_copy = copy.deepcopy(boxes.tolist())
    box_save = set()

    while boxes_list:
        box_a = boxes_list.pop(0)
        for box_b in boxes_list:
            if get_iou(box_a, box_b) > 0.1:
                box_save.add(boxes_list_copy.index(box_a))

    all_index = set(list(range(len(boxes_list_copy))))
    # 获取all_index中不在keep中的索引
    diff = all_index - box_save
    diff = list(diff)
    diff = sorted(diff, key=lambda x: x)

    return diff

在官方的代码中已经有boxes, class_probs, nms_threshold这三个参数的输出,我们只需把它传入上面的函数就可以了。在官方yolo的基础上修改代码如下(注释掉的是官方原始的代码)

在non_max_suppression这个函数里插入我们的单类别nms函数即可。把官方的nms注释掉换成自定义的nms就OK了

相关推荐
欣赏你流浪^2 小时前
物联网智能感知进阶:基于YOLO的琏雾系统视频分析
物联网·yolo·音视频
cver1233 小时前
人脸情绪检测数据集-9,400 张图片 智能客服系统 在线教育平台 心理健康监测 人机交互优化 市场研究与广告 安全监控系统
人工智能·安全·yolo·计算机视觉·目标跟踪·机器人·人机交互
ChironW6 小时前
Ubuntu 22.04 离线环境下完整安装 Anaconda、CUDA 12.1、NVIDIA 驱动及 cuDNN 8.9.3 教程
linux·运维·人工智能·深度学习·yolo·ubuntu
雪可问春风10 小时前
YOLOv8 训练报错:PyTorch 2.6+ 模型加载兼容性问题解决
人工智能·pytorch·yolo
极智视界1 天前
目标检测数据集 - 自动驾驶场景道路异常检测数据集下载「包含VOC、COCO、YOLO三种格式」
yolo·自动驾驶·voc·coco·目标检测数据集·道路异常检测数据集·算法训练数据集
是Dream呀1 天前
YOLOv6深度解析:实时目标检测的新突破
人工智能·yolo·目标检测
程序猿小D1 天前
【完整源码+数据集+部署教程】植物生长阶段检测系统源码和数据集:改进yolo11-rmt
python·yolo·计算机视觉·目标跟踪·数据集·yolo11·植物生长阶段检测系统
新手村领路人3 天前
c++ opencv调用yolo onnx文件
c++·opencv·yolo
zhangxiaomm3 天前
Ubuntu 搭建 yolov5
linux·yolo·ubuntu
音视频牛哥3 天前
从 AI 到实时视频通道:基于模块化架构的低延迟直播全链路实践
人工智能·opencv·yolo·计算机视觉·音视频·大牛直播sdk·ai人工智能