使用C++版本的opencv dnn 部署onnx模型

使用OpenCV的DNN模块在C++中部署ONNX模型涉及几个步骤,包括加载模型、预处理输入数据、进行推理以及处理输出。

构建了yolo类,方便调用

yolo.h 文件

cpp 复制代码
#ifndef YOLO_H
#define YOLO_H
#include <fstream>
#include <sstream>
#include <iostream>
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>

struct yoloDetectionResult_detection_thread
{
    cv::Point2f DetectionResultLocation; // 目标中心点像素位置
    cv::Point2d DetectionResultClassAndConf; //类型、置信度
    cv::Rect DetectionResultRect; //目标矩形框
    cv::Mat DetectionResultIMG;   //目标像素
    unsigned char object_no = -1;  //目标序号
    unsigned char object_mission = -1; //目标任务状态
    int frame_no;  //图像帧号
};
class detect_result
{
public:
    int classId;
    float confidence;
    cv::Rect_<float> box;

};

class YOLO
{
public:
    YOLO();
    ~YOLO();
    void init(std::string onnxpath);
    void detect(cv::Mat& frame, std::vector<detect_result>& result);
    void draw_frame(cv::Mat& frame, std::vector<detect_result>& results);

private:
    cv::dnn::Net net;

    const float confidence_threshold_ = 0.4f;
    const float nms_threshold_ = 0.4f;
    const int model_input_width_ = 640;
    const int model_input_height_ = 640;
    double HighWidthHeightRatio = 25;
    double LowWidthHeightRatio = 0.05;
};

#endif // !YOLO_H

yolo.cpp

cpp 复制代码
#include "yolo.h"

YOLO::YOLO()
{

}

YOLO::~YOLO()
{

}

void YOLO::init(std::string onnxpath)
{

    this->net = cv::dnn::readNetFromONNX(onnxpath);
}

void YOLO::detect(cv::Mat& frame, std::vector<detect_result>& results)
{
    int w = frame.cols;
    int h = frame.rows;
    int _max = std::max(h, w);
    cv::Mat image = cv::Mat::zeros(cv::Size(_max, _max), CV_8UC3);
    if(frame.channels()==1){
        cv::cvtColor(frame, frame, cv::COLOR_GRAY2BGR);
    }  
    cv::Rect roi(0, 0, w, h);
    frame.copyTo(image(cv::Rect(0, 0, w, h)));

    float x_factor = static_cast<float>(image.cols) / model_input_width_;
    float y_factor = static_cast<float>(image.rows) / model_input_height_;
    cv::Mat blob = cv::dnn::blobFromImage(image, 1 / 255.0, cv::Size(model_input_width_, model_input_height_), cv::Scalar(0, 0, 0), true, false);
    this->net.setInput(blob);
    cv::Mat preds = this->net.forward("output0");
    //outputname,使用Netron看一下输出的名字,一般为output0或者output
    cv::Mat det_output(preds.size[1], preds.size[2], CV_32F, preds.ptr<float>());

    std::vector<cv::Rect> boxes;
    std::vector<int> classIds;
    std::vector<float> confidences;
    for (int i = 0; i < det_output.rows; i++)
    {
        float box_conf = det_output.at<float>(i, 4);
        if (box_conf < nms_threshold_)
        {
            continue;
        }

        cv::Mat classes_confidences = det_output.row(i).colRange(5, 6);
        cv::Point classIdPoint;
        double cls_conf;
        cv::minMaxLoc(classes_confidences, 0, &cls_conf, 0, &classIdPoint);


        if (cls_conf > confidence_threshold_)
        {
            float cx = det_output.at<float>(i, 0);
            float cy = det_output.at<float>(i, 1);
            float ow = det_output.at<float>(i, 2);
            float oh = det_output.at<float>(i, 3);
            int x = static_cast<int>((cx - 0.5 * ow) * x_factor);
            int y = static_cast<int>((cy - 0.5 * oh) * y_factor);
            int width = static_cast<int>(ow * x_factor);
            int height = static_cast<int>(oh * y_factor);
            cv::Rect box;
            box.x = x;
            box.y = y;
            box.width = width;
            box.height = height;

            boxes.push_back(box);
            classIds.push_back(classIdPoint.x);
            confidences.push_back(cls_conf * box_conf);
        }
    }

    std::vector<int> indexes;
    cv::dnn::NMSBoxes(boxes, confidences, confidence_threshold_, nms_threshold_, indexes);
    for (size_t i = 0; i < indexes.size(); i++)
    {
        detect_result dr;
        int index = indexes[i];
        int idx = classIds[index];
        dr.box = boxes[index];
        dr.classId = idx;
        dr.confidence = confidences[index];
        results.push_back(dr);
    }
    std::vector<cv::Rect>().swap(boxes);
    std::vector<int>().swap( classIds);
    std::vector<float>().swap( confidences);
    std::vector<int>().swap( indexes);
}

void YOLO::draw_frame(cv::Mat& frame, std::vector<detect_result>& results)
{
    for (auto dr : results)
    {

        cv::rectangle(frame, dr.box, cv::Scalar(0, 0, 255), 2, 8);
        cv::rectangle(frame, cv::Point(dr.box.tl().x, dr.box.tl().y - 20), cv::Point(dr.box.br().x, dr.box.tl().y), cv::Scalar(255, 0, 0), -1);

        std::string label = cv::format("%.2f", dr.confidence);
        label = dr.classId + ":" + label;

        cv::putText(frame, label, cv::Point(dr.box.x, dr.box.y + 6), 1, 2, cv::Scalar(0, 255, 0), 2);
    }


}

下面是调用函数编写部分

cpp 复制代码
#include<string>
#include"yolo.h"
#include<opencv2\opencv.hpp>
#include<iostream>
int main(){
    YOLO* yolo = new YOLO;
	std::string modelPath =  "C:\\Resource\\model\\XXX.onnx";//模型的地址
    std::string imgPath=  "C:\\Resource\\model\\XXX.jpg";//模型的地址
	//clock_t start_times{},end_times{};
	yolo->init(modelPath);
	std::vector<detect_result> output;
    cv::Mat yoloImages = cv::imread(imgPath);
    if(!yoloImages.empty()){  
        //start_times= clock();
        yolo->detect(yoloImages, output);
    	yolo->draw_frame(yoloImages, output);
    	//end_times = clock();
    	//double FPS = 1 / ((double)(end_times - start_times) / CLOCKS_PER_SEC);
    	cv::imshow("images", yoloImages);
    	cv::waitKey(1);

        std::vector<detect_result>().swap(output);
        std::string().swap(model);
        if(yolo!=NULL){
            delete yolo;
            yolo =NULL;
        }
    }
}	
相关推荐
游客5201 小时前
opencv中的各种滤波器简介
图像处理·人工智能·python·opencv·计算机视觉
小俊俊的博客1 小时前
海康RGBD相机使用C++和Opencv采集图像记录
c++·opencv·海康·rgbd相机
7yewh1 小时前
嵌入式Linux QT+OpenCV基于人脸识别的考勤系统 项目
linux·开发语言·arm开发·驱动开发·qt·opencv·嵌入式linux
Kai HVZ1 小时前
《OpenCV计算机视觉》--介绍及基础操作
人工智能·opencv·计算机视觉
_WndProc1 小时前
C++ 日志输出
开发语言·c++·算法
薄荷故人_1 小时前
从零开始的C++之旅——红黑树及其实现
数据结构·c++
m0_748240021 小时前
Chromium 中chrome.webRequest扩展接口定义c++
网络·c++·chrome
biter00881 小时前
opencv(15) OpenCV背景减除器(Background Subtractors)学习
人工智能·opencv·学习
qq_433554541 小时前
C++ 面向对象编程:+号运算符重载,左移运算符重载
开发语言·c++
吃个糖糖1 小时前
35 Opencv 亚像素角点检测
人工智能·opencv·计算机视觉