计算3D目标框的NMS

3D障碍物目标框(中心点坐标XYZ、长宽高lwh、朝向角theta)的非极大值抑制

cpp 复制代码
#include <iostream>
#include <vector>
#include <algorithm>
#include <opencv2/opencv.hpp>

// 定义3D目标框的结构体
struct BoundingBox3D
{
    double centerX, centerY, centerZ; // 中心点坐标
    double length, width, height;     // 长宽高
    double theta;                     // 朝向角
    double score;                     // 目标框得分

    BoundingBox3D(double x, double y, double z, double l, double w, double h, double t, double s)
        : centerX(x), centerY(y), centerZ(z), length(l), width(w), height(h), theta(t), score(s) {}
};

class NMS3D
{
public:
    // 构造函数,传入IoU阈值
    NMS3D(double iouThreshold) : iouThreshold_(iouThreshold) {}

    // 执行NMS
    std::vector<BoundingBox3D> executeNMS(const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> resultBoxes;

        // 按得分降序排序
        std::vector<BoundingBox3D> sortedBoxes = sortBoxesByScore(boxes);

        // 遍历排序后的框
        while (!sortedBoxes.empty())
        {
            // 保留得分最高的框
            BoundingBox3D topBox = sortedBoxes[0];
            resultBoxes.push_back(topBox);

            // 移除与当前框IoU大于阈值的框
            sortedBoxes.erase(sortedBoxes.begin());
            sortedBoxes = removeOverlappingBoxes(topBox, sortedBoxes);
        }

        return resultBoxes;
    }

private:
    // 按得分降序排序
    std::vector<BoundingBox3D> sortBoxesByScore(const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> sortedBoxes = boxes;
        std::sort(sortedBoxes.begin(), sortedBoxes.end(),
                  [](const BoundingBox3D &a, const BoundingBox3D &b)
                  {
                      return a.score > b.score;
                  });
        return sortedBoxes;
    }

    // 移除与指定框IoU大于阈值的框
    std::vector<BoundingBox3D> removeOverlappingBoxes(const BoundingBox3D &box,
                                                      const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> filteredBoxes;
        for (const auto &b : boxes)
        {
            if (calculateIoU(box, b) < iouThreshold_)
            {
                filteredBoxes.push_back(b);
            }
        }
        return filteredBoxes;
    }

    // 计算两个框的IoU(Intersection over Union)
    double calculateIoU(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        // 计算两个框的相交部分的体积
        double intersectionVolume = calculateIntersectionVolume(box1, box2);

        // 计算两个框的并集部分的体积
        double unionVolume = box1.length * box1.width * box1.height +
                             box2.length * box2.width * box2.height -
                             intersectionVolume;

        // 计算IoU
        return intersectionVolume / unionVolume;
    }

    // 计算两个框的相交部分的体积
    double calculateIntersectionVolume(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        // 计算平面重叠面积
        double intersectArea = calIntersectionArea(box1, box2);
        
        double intersectHeight = calculateOverlap(box1.centerZ, box1.height, box2.centerZ, box2.height);

        // 计算相交部分的体积
        return intersectArea * intersectHeight;
    }

    cv::Point rotatePoint(const cv::Point &point, double angle)
    {
        double rotatedX = point.x * cos(angle) - point.y * sin(angle);
        double rotatedY = point.x * sin(angle) + point.y * cos(angle);

        return cv::Point(rotatedX, rotatedY);
    }

    double calIntersectionArea(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        cv::RotatedRect rect1(cv::Point2f(box1.centerX,box1.centerY),cv::Size2f(box1.width,box1.height),box1.theta);
        cv::RotatedRect rect2(cv::Point2f(box2.centerX,box2.centerY),cv::Size2f(box2.width,box2.height),box2.theta);

        std::vector<cv::Point2f> intersection;
        cv::rotatedRectangleIntersection(rect1,rect2, intersection);
        // std::cout <<rect1.center<< " "<<rect2.center<<std::endl;
        // std::cout <<rect1.size<< " "<<rect2.size<<std::endl;
        // std::cout << "intersection area:"<<intersection.size()<<std::endl;

        double union_area = cv::contourArea(intersection);
        // std::cout << "intersection area:"<<union_area<<std::endl;
        return union_area;
    }

    // 计算两个轴上的重叠部分长度
    double calculateOverlap(double center1, double size1, double center2, double size2)
    {
        double halfSize1 = size1 / 2;
        double halfSize2 = size2 / 2;
        double min1 = center1 - halfSize1;
        double max1 = center1 + halfSize1;
        double min2 = center2 - halfSize2;
        double max2 = center2 + halfSize2;

        // 计算重叠部分长度
        return std::max(0.0, std::min(max1, max2) - std::max(min1, min2));
    }

    double iouThreshold_; // IoU阈值
};

int main()
{
    std::vector<BoundingBox3D> inputBoxes;
    inputBoxes.push_back(BoundingBox3D(0.0, 0.0, 0.0, 200.0,200.0, 200.0, 45, 0.9));
    inputBoxes.push_back(BoundingBox3D(100,100, 10, 200.0, 200.0, 200.0, -45, 0.8));
    //inputBoxes.push_back(BoundingBox3D(2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 0, 0.7));

    double iouThreshold = 0.5; // 可根据实际情况调整IoU阈值
    NMS3D nms(iouThreshold);
    std::vector<BoundingBox3D> resultBoxes = nms.executeNMS(inputBoxes);

    // 输出结果框
    for (const auto &box : resultBoxes)
    {
        std::cout << "Center: (" << box.centerX << ", " << box.centerY << ", " << box.centerZ << "), "
                  << "Dimensions: (" << box.length << ", " << box.width << ", " << box.height << "), "
                  << "Theta: " << box.theta << ", "
                  << "Score: " << box.score << std::endl;
    }
    return 0;
}

关于cv::contourArea可能计算不准的问题,是由于传入的点没有按照一定的顺序排列(顺时针或逆时针)。参考解决博客

相关推荐
GIS数据转换器5 小时前
城市生命线安全保障:技术应用与策略创新
大数据·人工智能·安全·3d·智慧城市
m0_743106469 小时前
【论文笔记】MV-DUSt3R+:两秒重建一个3D场景
论文阅读·深度学习·计算机视觉·3d·几何学
m0_743106469 小时前
【论文笔记】TranSplat:深度refine的camera-required可泛化稀疏方法
论文阅读·深度学习·计算机视觉·3d·几何学
old_power16 小时前
【PCL】Segmentation 模块—— 基于图割算法的点云分割(Min-Cut Based Segmentation)
c++·算法·计算机视觉·3d
Thomas_YXQ17 小时前
Unity3D项目开发中的资源加密详解
游戏·3d·unity·unity3d·游戏开发
jimumeta20 小时前
不建模,无代码,如何构建一个3D虚拟展厅?
3d·虚拟展厅·3d展厅
菩提树下的凡夫1 天前
Halcon 3D基础知识及常用函数
3d
CASAIM1 天前
手持式三维激光扫描仪-3D扫描产品尺寸
3d·信息可视化
old_power2 天前
【PCL】Segmentation 模块—— 欧几里得聚类提取(Euclidean Cluster Extraction)
c++·计算机视觉·3d
3D小将3 天前
3D 模型格式转换之 STP 转 STL 深度解析
3d·建造者模式