pytorch复现_NMS

NMS(非极大值抑制)阈值是用于控制在一组重叠的边界框中保留哪些边界框的参数。当检测或识别算法生成多个边界框可能涵盖相同物体时,NMS用于筛选出最相关的边界框,通常是根据它们的置信度分数。
具体来说,NMS的工作原理如下:
1.首先,算法对图像中的目标进行检测,并为每个检测到的目标生成一个边界框。每个边界框都伴随一个与目标相关的置信度分数。
2.接下来,NMS算法将所有边界框按照它们的置信度分数进行排序,通常按照分数降序排列。
3.然后,NMS算法从分数最高的边界框开始,将该边界框添加到最终保留的边界框列表中。
4.对于剩余的边界框,NMS会计算它们与当前保留的边界框的IoU(交并比)。如果IoU大于NMS阈值,这些边界框将被抑制(丢弃),只保留一个。
5.重复步骤3和4,直到遍历所有边界框。

python 复制代码
import numpy as np

import numpy as np
def compute_iou(boxA,boxB):
    # 计算相交区域的坐标
    xA=max(boxA[0],boxB[0])
    yA=max(boxA[1],boxB[1])
    xB=min(boxA[2],boxB[2])
    yB=min(boxA[3],boxB[3])

    # 计算相交区域,如果是负数一定是不相交
    interArea = max(0,xB-xA)*max(0,yB-yA)

    # 计算A和B的面积
    boxAArea=(boxA[3]-boxA[1])*(boxA[2]-boxA[0])
    boxBArea=(boxB[3]-boxB[1])*(boxB[2]-boxB[0])

    # 计算iou
    iou=interArea/(boxAArea+boxBArea-interArea)

    return iou

def nms(boxes, scores, threshold):
    # boxes: 边界框列表,每个框是一个格式为 [x1, y1, x2, y2] 的列表
    # scores: 每个边界框的得分列表
    # threshold: NMS的IoU阈值

    # 按得分升序排列边界框
    sorted_indices = np.argsort(scores)
    boxes = [boxes[i] for i in sorted_indices]
    scores = [scores[i] for i in sorted_indices]

    keep = []  # 保留的边界框的索引列表

    while boxes:
        # 取得分最高的边界框
        current_box = boxes.pop()
        current_score = scores.pop()

        keep.append(sorted_indices[-1])
        sorted_indices = sorted_indices[:-1]
        discard_indices = []  # 需要丢弃的边界框的索引列表

        for i, box in enumerate(boxes):
            # 计算与当前边界框的IoU
            iou = compute_iou(current_box, box)

            # 如果IoU超过阈值,标记该边界框为需要丢弃
            if iou > threshold:
                discard_indices.append(i)

        # 移除标记为需要丢弃的边界框。从后往前删,不然for循环会出错
        for i in sorted(discard_indices, reverse=True):
            boxes.pop(i)
            scores.pop(i)
            sorted_indices = np.delete(sorted_indices, i) # np与list的方法不同

    return keep


# test

# 模拟一组边界框和得分
boxes = [[1, 1, 3, 3], [2, 2, 4, 4], [4, 4, 6, 6], [5, 5, 7, 7], [10, 10, 12, 12]]
scores = [0.9, 0.8, 0.7, 0.75, 0.6]

# 设置NMS阈值
nms_threshold = 0.9

# 调用nms函数进行非极大值抑制
keep_indices = nms(boxes, scores, nms_threshold)

# 打印保留下来的边界框的索引
print("保留的边界框索引:", keep_indices)

# 打印保留下来的边界框的坐标和得分
print("保留的边界框坐标和得分:")
for idx in keep_indices:
    print("边界框坐标:", boxes[idx])
    print("得分:", scores[idx])
相关推荐
大千AI助手8 分钟前
概率单位回归(Probit Regression)详解
人工智能·机器学习·数据挖掘·回归·大千ai助手·概率单位回归·probit回归
狂炫冰美式38 分钟前
3天,1人,从0到付费产品:AI时代个人开发者的生存指南
前端·人工智能·后端
LCG元1 小时前
垂直Agent才是未来:详解让大模型"专业对口"的三大核心技术
人工智能
我不是QI1 小时前
周志华《机器学习—西瓜书》二
人工智能·安全·机器学习
BBB努力学习程序设计2 小时前
Python面向对象编程:从代码搬运工到架构师
python·pycharm
操练起来2 小时前
【昇腾CANN训练营·第八期】Ascend C生态兼容:基于PyTorch Adapter的自定义算子注册与自动微分实现
人工智能·pytorch·acl·昇腾·cann
rising start2 小时前
五、python正则表达式
python·正则表达式
KG_LLM图谱增强大模型2 小时前
[500页电子书]构建自主AI Agent系统的蓝图:谷歌重磅发布智能体设计模式指南
人工智能·大模型·知识图谱·智能体·知识图谱增强大模型·agenticai
声网2 小时前
活动推荐丨「实时互动 × 对话式 AI」主题有奖征文
大数据·人工智能·实时互动
caiyueloveclamp2 小时前
【功能介绍03】ChatPPT好不好用?如何用?用户操作手册来啦!——【AI溯源篇】
人工智能·信息可视化·powerpoint·ai生成ppt·aippt