计算3D目标框的NMS

3D障碍物目标框(中心点坐标XYZ、长宽高lwh、朝向角theta)的非极大值抑制

cpp 复制代码
#include <iostream>
#include <vector>
#include <algorithm>
#include <opencv2/opencv.hpp>

// 定义3D目标框的结构体
struct BoundingBox3D
{
    double centerX, centerY, centerZ; // 中心点坐标
    double length, width, height;     // 长宽高
    double theta;                     // 朝向角
    double score;                     // 目标框得分

    BoundingBox3D(double x, double y, double z, double l, double w, double h, double t, double s)
        : centerX(x), centerY(y), centerZ(z), length(l), width(w), height(h), theta(t), score(s) {}
};

class NMS3D
{
public:
    // 构造函数,传入IoU阈值
    NMS3D(double iouThreshold) : iouThreshold_(iouThreshold) {}

    // 执行NMS
    std::vector<BoundingBox3D> executeNMS(const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> resultBoxes;

        // 按得分降序排序
        std::vector<BoundingBox3D> sortedBoxes = sortBoxesByScore(boxes);

        // 遍历排序后的框
        while (!sortedBoxes.empty())
        {
            // 保留得分最高的框
            BoundingBox3D topBox = sortedBoxes[0];
            resultBoxes.push_back(topBox);

            // 移除与当前框IoU大于阈值的框
            sortedBoxes.erase(sortedBoxes.begin());
            sortedBoxes = removeOverlappingBoxes(topBox, sortedBoxes);
        }

        return resultBoxes;
    }

private:
    // 按得分降序排序
    std::vector<BoundingBox3D> sortBoxesByScore(const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> sortedBoxes = boxes;
        std::sort(sortedBoxes.begin(), sortedBoxes.end(),
                  [](const BoundingBox3D &a, const BoundingBox3D &b)
                  {
                      return a.score > b.score;
                  });
        return sortedBoxes;
    }

    // 移除与指定框IoU大于阈值的框
    std::vector<BoundingBox3D> removeOverlappingBoxes(const BoundingBox3D &box,
                                                      const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> filteredBoxes;
        for (const auto &b : boxes)
        {
            if (calculateIoU(box, b) < iouThreshold_)
            {
                filteredBoxes.push_back(b);
            }
        }
        return filteredBoxes;
    }

    // 计算两个框的IoU(Intersection over Union)
    double calculateIoU(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        // 计算两个框的相交部分的体积
        double intersectionVolume = calculateIntersectionVolume(box1, box2);

        // 计算两个框的并集部分的体积
        double unionVolume = box1.length * box1.width * box1.height +
                             box2.length * box2.width * box2.height -
                             intersectionVolume;

        // 计算IoU
        return intersectionVolume / unionVolume;
    }

    // 计算两个框的相交部分的体积
    double calculateIntersectionVolume(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        // 计算平面重叠面积
        double intersectArea = calIntersectionArea(box1, box2);
        
        double intersectHeight = calculateOverlap(box1.centerZ, box1.height, box2.centerZ, box2.height);

        // 计算相交部分的体积
        return intersectArea * intersectHeight;
    }

    cv::Point rotatePoint(const cv::Point &point, double angle)
    {
        double rotatedX = point.x * cos(angle) - point.y * sin(angle);
        double rotatedY = point.x * sin(angle) + point.y * cos(angle);

        return cv::Point(rotatedX, rotatedY);
    }

    double calIntersectionArea(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        cv::RotatedRect rect1(cv::Point2f(box1.centerX,box1.centerY),cv::Size2f(box1.width,box1.height),box1.theta);
        cv::RotatedRect rect2(cv::Point2f(box2.centerX,box2.centerY),cv::Size2f(box2.width,box2.height),box2.theta);

        std::vector<cv::Point2f> intersection;
        cv::rotatedRectangleIntersection(rect1,rect2, intersection);
        // std::cout <<rect1.center<< " "<<rect2.center<<std::endl;
        // std::cout <<rect1.size<< " "<<rect2.size<<std::endl;
        // std::cout << "intersection area:"<<intersection.size()<<std::endl;

        double union_area = cv::contourArea(intersection);
        // std::cout << "intersection area:"<<union_area<<std::endl;
        return union_area;
    }

    // 计算两个轴上的重叠部分长度
    double calculateOverlap(double center1, double size1, double center2, double size2)
    {
        double halfSize1 = size1 / 2;
        double halfSize2 = size2 / 2;
        double min1 = center1 - halfSize1;
        double max1 = center1 + halfSize1;
        double min2 = center2 - halfSize2;
        double max2 = center2 + halfSize2;

        // 计算重叠部分长度
        return std::max(0.0, std::min(max1, max2) - std::max(min1, min2));
    }

    double iouThreshold_; // IoU阈值
};

int main()
{
    std::vector<BoundingBox3D> inputBoxes;
    inputBoxes.push_back(BoundingBox3D(0.0, 0.0, 0.0, 200.0,200.0, 200.0, 45, 0.9));
    inputBoxes.push_back(BoundingBox3D(100,100, 10, 200.0, 200.0, 200.0, -45, 0.8));
    //inputBoxes.push_back(BoundingBox3D(2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 0, 0.7));

    double iouThreshold = 0.5; // 可根据实际情况调整IoU阈值
    NMS3D nms(iouThreshold);
    std::vector<BoundingBox3D> resultBoxes = nms.executeNMS(inputBoxes);

    // 输出结果框
    for (const auto &box : resultBoxes)
    {
        std::cout << "Center: (" << box.centerX << ", " << box.centerY << ", " << box.centerZ << "), "
                  << "Dimensions: (" << box.length << ", " << box.width << ", " << box.height << "), "
                  << "Theta: " << box.theta << ", "
                  << "Score: " << box.score << std::endl;
    }
    return 0;
}

关于cv::contourArea可能计算不准的问题,是由于传入的点没有按照一定的顺序排列(顺时针或逆时针)。参考解决博客

相关推荐
温轻舟13 小时前
3D词云图
前端·javascript·3d·交互·词云图·温轻舟
在下胡三汉14 小时前
粗略地看一下 glTF 2.0 的所有标准属性(顺便说一下,还有 .glb 的结构)
3d
zhongqu_3dnest17 小时前
3D可视化:开启多维洞察新时代
3d·3d建模·空间计算·3d可视化·三维空间·沉浸式体验
试着1 天前
【数据标注师】3D标注
3d·数据标注师·3d标注
工业3D_大熊10 天前
3D模式格式转换工具HOOPS Exchange如何将3D PDF转换为STEP格式?
3d·pdf·3d格式转换·3d模型格式转换·cad格式转换·cad数据格式转换·3d模型可视化
广州华锐视点11 天前
浅议 3D 展示技术为线上车展新体验带来的助力
3d
大霸王龙12 天前
AR眼镜与3D建模社区建设
3d·ar
杀生丸学AI12 天前
【物理重建】SPLART:基于3D高斯泼溅的铰链估计与部件级重建
3d·aigc·三维重建·视觉大模型·世界模型·空间智能·动态重建
Love__Tay13 天前
【Python小练习】3D散点图
开发语言·python·3d