计算3D目标框的NMS

3D障碍物目标框(中心点坐标XYZ、长宽高lwh、朝向角theta)的非极大值抑制

cpp 复制代码
#include <iostream>
#include <vector>
#include <algorithm>
#include <opencv2/opencv.hpp>

// 定义3D目标框的结构体
struct BoundingBox3D
{
    double centerX, centerY, centerZ; // 中心点坐标
    double length, width, height;     // 长宽高
    double theta;                     // 朝向角
    double score;                     // 目标框得分

    BoundingBox3D(double x, double y, double z, double l, double w, double h, double t, double s)
        : centerX(x), centerY(y), centerZ(z), length(l), width(w), height(h), theta(t), score(s) {}
};

class NMS3D
{
public:
    // 构造函数,传入IoU阈值
    NMS3D(double iouThreshold) : iouThreshold_(iouThreshold) {}

    // 执行NMS
    std::vector<BoundingBox3D> executeNMS(const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> resultBoxes;

        // 按得分降序排序
        std::vector<BoundingBox3D> sortedBoxes = sortBoxesByScore(boxes);

        // 遍历排序后的框
        while (!sortedBoxes.empty())
        {
            // 保留得分最高的框
            BoundingBox3D topBox = sortedBoxes[0];
            resultBoxes.push_back(topBox);

            // 移除与当前框IoU大于阈值的框
            sortedBoxes.erase(sortedBoxes.begin());
            sortedBoxes = removeOverlappingBoxes(topBox, sortedBoxes);
        }

        return resultBoxes;
    }

private:
    // 按得分降序排序
    std::vector<BoundingBox3D> sortBoxesByScore(const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> sortedBoxes = boxes;
        std::sort(sortedBoxes.begin(), sortedBoxes.end(),
                  [](const BoundingBox3D &a, const BoundingBox3D &b)
                  {
                      return a.score > b.score;
                  });
        return sortedBoxes;
    }

    // 移除与指定框IoU大于阈值的框
    std::vector<BoundingBox3D> removeOverlappingBoxes(const BoundingBox3D &box,
                                                      const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> filteredBoxes;
        for (const auto &b : boxes)
        {
            if (calculateIoU(box, b) < iouThreshold_)
            {
                filteredBoxes.push_back(b);
            }
        }
        return filteredBoxes;
    }

    // 计算两个框的IoU(Intersection over Union)
    double calculateIoU(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        // 计算两个框的相交部分的体积
        double intersectionVolume = calculateIntersectionVolume(box1, box2);

        // 计算两个框的并集部分的体积
        double unionVolume = box1.length * box1.width * box1.height +
                             box2.length * box2.width * box2.height -
                             intersectionVolume;

        // 计算IoU
        return intersectionVolume / unionVolume;
    }

    // 计算两个框的相交部分的体积
    double calculateIntersectionVolume(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        // 计算平面重叠面积
        double intersectArea = calIntersectionArea(box1, box2);
        
        double intersectHeight = calculateOverlap(box1.centerZ, box1.height, box2.centerZ, box2.height);

        // 计算相交部分的体积
        return intersectArea * intersectHeight;
    }

    cv::Point rotatePoint(const cv::Point &point, double angle)
    {
        double rotatedX = point.x * cos(angle) - point.y * sin(angle);
        double rotatedY = point.x * sin(angle) + point.y * cos(angle);

        return cv::Point(rotatedX, rotatedY);
    }

    double calIntersectionArea(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        cv::RotatedRect rect1(cv::Point2f(box1.centerX,box1.centerY),cv::Size2f(box1.width,box1.height),box1.theta);
        cv::RotatedRect rect2(cv::Point2f(box2.centerX,box2.centerY),cv::Size2f(box2.width,box2.height),box2.theta);

        std::vector<cv::Point2f> intersection;
        cv::rotatedRectangleIntersection(rect1,rect2, intersection);
        // std::cout <<rect1.center<< " "<<rect2.center<<std::endl;
        // std::cout <<rect1.size<< " "<<rect2.size<<std::endl;
        // std::cout << "intersection area:"<<intersection.size()<<std::endl;

        double union_area = cv::contourArea(intersection);
        // std::cout << "intersection area:"<<union_area<<std::endl;
        return union_area;
    }

    // 计算两个轴上的重叠部分长度
    double calculateOverlap(double center1, double size1, double center2, double size2)
    {
        double halfSize1 = size1 / 2;
        double halfSize2 = size2 / 2;
        double min1 = center1 - halfSize1;
        double max1 = center1 + halfSize1;
        double min2 = center2 - halfSize2;
        double max2 = center2 + halfSize2;

        // 计算重叠部分长度
        return std::max(0.0, std::min(max1, max2) - std::max(min1, min2));
    }

    double iouThreshold_; // IoU阈值
};

int main()
{
    std::vector<BoundingBox3D> inputBoxes;
    inputBoxes.push_back(BoundingBox3D(0.0, 0.0, 0.0, 200.0,200.0, 200.0, 45, 0.9));
    inputBoxes.push_back(BoundingBox3D(100,100, 10, 200.0, 200.0, 200.0, -45, 0.8));
    //inputBoxes.push_back(BoundingBox3D(2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 0, 0.7));

    double iouThreshold = 0.5; // 可根据实际情况调整IoU阈值
    NMS3D nms(iouThreshold);
    std::vector<BoundingBox3D> resultBoxes = nms.executeNMS(inputBoxes);

    // 输出结果框
    for (const auto &box : resultBoxes)
    {
        std::cout << "Center: (" << box.centerX << ", " << box.centerY << ", " << box.centerZ << "), "
                  << "Dimensions: (" << box.length << ", " << box.width << ", " << box.height << "), "
                  << "Theta: " << box.theta << ", "
                  << "Score: " << box.score << std::endl;
    }
    return 0;
}

关于cv::contourArea可能计算不准的问题,是由于传入的点没有按照一定的顺序排列(顺时针或逆时针)。参考解决博客

相关推荐
图扑可视化12 小时前
水墨国风智慧大坝 3D 可视化系统技术实现
3d·数字孪生·智慧水利·水利发电
普密斯科技13 小时前
齿轮平面度与正反面智能检测方案:3D视觉技术破解精密制造品控难题
人工智能·计算机视觉·平面·3d·自动化·视觉检测
丷丩16 小时前
第 2 篇:入门实操|3dtubetilecreater 环境搭建全教程(零踩坑版)
3d·gis·postgis·管线·自动建模·管网
丷丩19 小时前
第3篇:技术拆解|3dtubetilecreater 前后端架构全解析(Vue+Express+PostGIS)
vue.js·3d·架构
一碗白开水一21 小时前
【论文解读】PETRv2: AUnified Framework for 3D Perception from Multi-Camera Images
3d
大模型实验室Lab4AI1 天前
ICLR 2026|上海交通提出 π³,突破参考视图束缚,提升 3D 几何重建鲁棒性
3d
threelab1 天前
Vue3 + Trilab:打造高扩展性三维可视化插件化框架实战指南
javascript·3d·webgl
AGV算法笔记1 天前
最新感知算法论文分析:RaCFormer 如何提升雷达相机 3D 目标检测性能?
数码相机·算法·3d·自动驾驶·机器人视觉·3d目标检测·感知算法
MIXLLRED1 天前
随笔——dddmr_navigation开源3D导航栈介绍与分析
3d·开源·navigation·dddmr
Coovally AI模型快速验证1 天前
无人机拍叶片→AI找缺陷:CEA-DETR改进RT-DETR做风电叶片表面缺陷检测,mAP50达89.4%
人工智能·3d·视觉检测·无人机·异常检测·工业质检