计算3D目标框的NMS

3D障碍物目标框(中心点坐标XYZ、长宽高lwh、朝向角theta)的非极大值抑制

cpp 复制代码
#include <iostream>
#include <vector>
#include <algorithm>
#include <opencv2/opencv.hpp>

// 定义3D目标框的结构体
struct BoundingBox3D
{
    double centerX, centerY, centerZ; // 中心点坐标
    double length, width, height;     // 长宽高
    double theta;                     // 朝向角
    double score;                     // 目标框得分

    BoundingBox3D(double x, double y, double z, double l, double w, double h, double t, double s)
        : centerX(x), centerY(y), centerZ(z), length(l), width(w), height(h), theta(t), score(s) {}
};

class NMS3D
{
public:
    // 构造函数,传入IoU阈值
    NMS3D(double iouThreshold) : iouThreshold_(iouThreshold) {}

    // 执行NMS
    std::vector<BoundingBox3D> executeNMS(const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> resultBoxes;

        // 按得分降序排序
        std::vector<BoundingBox3D> sortedBoxes = sortBoxesByScore(boxes);

        // 遍历排序后的框
        while (!sortedBoxes.empty())
        {
            // 保留得分最高的框
            BoundingBox3D topBox = sortedBoxes[0];
            resultBoxes.push_back(topBox);

            // 移除与当前框IoU大于阈值的框
            sortedBoxes.erase(sortedBoxes.begin());
            sortedBoxes = removeOverlappingBoxes(topBox, sortedBoxes);
        }

        return resultBoxes;
    }

private:
    // 按得分降序排序
    std::vector<BoundingBox3D> sortBoxesByScore(const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> sortedBoxes = boxes;
        std::sort(sortedBoxes.begin(), sortedBoxes.end(),
                  [](const BoundingBox3D &a, const BoundingBox3D &b)
                  {
                      return a.score > b.score;
                  });
        return sortedBoxes;
    }

    // 移除与指定框IoU大于阈值的框
    std::vector<BoundingBox3D> removeOverlappingBoxes(const BoundingBox3D &box,
                                                      const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> filteredBoxes;
        for (const auto &b : boxes)
        {
            if (calculateIoU(box, b) < iouThreshold_)
            {
                filteredBoxes.push_back(b);
            }
        }
        return filteredBoxes;
    }

    // 计算两个框的IoU(Intersection over Union)
    double calculateIoU(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        // 计算两个框的相交部分的体积
        double intersectionVolume = calculateIntersectionVolume(box1, box2);

        // 计算两个框的并集部分的体积
        double unionVolume = box1.length * box1.width * box1.height +
                             box2.length * box2.width * box2.height -
                             intersectionVolume;

        // 计算IoU
        return intersectionVolume / unionVolume;
    }

    // 计算两个框的相交部分的体积
    double calculateIntersectionVolume(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        // 计算平面重叠面积
        double intersectArea = calIntersectionArea(box1, box2);
        
        double intersectHeight = calculateOverlap(box1.centerZ, box1.height, box2.centerZ, box2.height);

        // 计算相交部分的体积
        return intersectArea * intersectHeight;
    }

    cv::Point rotatePoint(const cv::Point &point, double angle)
    {
        double rotatedX = point.x * cos(angle) - point.y * sin(angle);
        double rotatedY = point.x * sin(angle) + point.y * cos(angle);

        return cv::Point(rotatedX, rotatedY);
    }

    double calIntersectionArea(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        cv::RotatedRect rect1(cv::Point2f(box1.centerX,box1.centerY),cv::Size2f(box1.width,box1.height),box1.theta);
        cv::RotatedRect rect2(cv::Point2f(box2.centerX,box2.centerY),cv::Size2f(box2.width,box2.height),box2.theta);

        std::vector<cv::Point2f> intersection;
        cv::rotatedRectangleIntersection(rect1,rect2, intersection);
        // std::cout <<rect1.center<< " "<<rect2.center<<std::endl;
        // std::cout <<rect1.size<< " "<<rect2.size<<std::endl;
        // std::cout << "intersection area:"<<intersection.size()<<std::endl;

        double union_area = cv::contourArea(intersection);
        // std::cout << "intersection area:"<<union_area<<std::endl;
        return union_area;
    }

    // 计算两个轴上的重叠部分长度
    double calculateOverlap(double center1, double size1, double center2, double size2)
    {
        double halfSize1 = size1 / 2;
        double halfSize2 = size2 / 2;
        double min1 = center1 - halfSize1;
        double max1 = center1 + halfSize1;
        double min2 = center2 - halfSize2;
        double max2 = center2 + halfSize2;

        // 计算重叠部分长度
        return std::max(0.0, std::min(max1, max2) - std::max(min1, min2));
    }

    double iouThreshold_; // IoU阈值
};

int main()
{
    std::vector<BoundingBox3D> inputBoxes;
    inputBoxes.push_back(BoundingBox3D(0.0, 0.0, 0.0, 200.0,200.0, 200.0, 45, 0.9));
    inputBoxes.push_back(BoundingBox3D(100,100, 10, 200.0, 200.0, 200.0, -45, 0.8));
    //inputBoxes.push_back(BoundingBox3D(2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 0, 0.7));

    double iouThreshold = 0.5; // 可根据实际情况调整IoU阈值
    NMS3D nms(iouThreshold);
    std::vector<BoundingBox3D> resultBoxes = nms.executeNMS(inputBoxes);

    // 输出结果框
    for (const auto &box : resultBoxes)
    {
        std::cout << "Center: (" << box.centerX << ", " << box.centerY << ", " << box.centerZ << "), "
                  << "Dimensions: (" << box.length << ", " << box.width << ", " << box.height << "), "
                  << "Theta: " << box.theta << ", "
                  << "Score: " << box.score << std::endl;
    }
    return 0;
}

关于cv::contourArea可能计算不准的问题,是由于传入的点没有按照一定的顺序排列(顺时针或逆时针)。参考解决博客

相关推荐
天人合一peng8 分钟前
Unity 3D 电脑端和手机端都实现画线与清除功能
3d·unity·智能手机
kobesdu16 分钟前
当几何失效时:3D激光SLAM退化场景的本质与应对策略
人工智能·机器学习·3d
木斯佳33 分钟前
HarmonyOS 数据可视化实战:封装一个可复用的 3D 热点词球卡片组件
3d·信息可视化·harmonyos
代数狂人20 小时前
《深入浅出Godot 4与C# 3D游戏开发》第二章:编辑器导航
3d·编辑器·游戏引擎·godot
VBsemi-专注于MOSFET研发定制20 小时前
高端汽车零部件尺寸3D检测设备功率MOSFET选型方案:精密高效运动与成像电源驱动系统适配指南
3d·汽车
threelab1 天前
从工厂模式到简化封装:三维引擎架构演进之路 threejs设计
javascript·3d·架构·webgl
三毛的二哥1 天前
BEV:MapTR
人工智能·算法·计算机视觉·3d
地知通1 天前
地下管网AR可视化更新:支持AR拍照、智能巡检、3dtiles等4项功能
3d·ar·数字孪生·地下管网·移动巡检·参数化建模
林恒smileZAZ2 天前
Three.js实现更真实的3D地球[特殊字符]动态昼夜交替
开发语言·javascript·3d
三维频道2 天前
不止于精度:汽车精密锻铸件质检的“数据降维”与方案重构
3d·重构·汽车·工业数字化·蓝光三维扫描仪·汽车供应链·汽车智能制造