基于onnxruntime结合PyQt快速搭建视觉原型Demo

我在日常工作中经常使用PyQt和onnxruntime来快速生产demo软件,用于展示和测试,这里,我将以Yolov12为例,展示一下我的方案。

首先我们需要使用Yolov12训练一个模型,并export出Onnx文件,这个部分网络上有很多内容,可以使用ultralytics框架做这个事情,我在这里就不在赘述了,接下来的步骤直接从onnxruntime开始。

此处你需要针对你的模型去编写一个基于Onnxruntime的推理类,包括前处理,后处理,可视化等部分,做到输入是图片,输出是结果,要保证算法代码和软件代码的独立性。这里是我参考ultralytics框架写的一个yolov12的推理类,用于实例分割的推理。

复制代码
import cv2
import numpy as np
import onnxruntime as ort
import torch
import yoloSeg.utils.ops as ops
from yoloSeg.utils.results import Results


class YOLOv12Seg:
    """
    YOLOv12 segmentation model for performing instance segmentation using ONNX Runtime.

    This class implements a YOLOv12 instance segmentation model using ONNX Runtime for inference. It handles
    preprocessing of input images, running inference with the ONNX model, and postprocessing the results to
    generate bounding boxes and segmentation masks.
    """

    def __init__(self, onnx_model, classes, conf=0.25, iou=0.7, imgsz=640):
        """
        Initialize the instance segmentation model using an ONNX model.
        """
        self.session = ort.InferenceSession(
            onnx_model,
            providers=["CPUExecutionProvider"]
            # if torch.cuda.is_available()
            # else ["CPUExecutionProvider"],
        )
        self.imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
        self.classes = classes
        self.conf = conf
        self.iou = iou

    def __call__(self, img):
        """
        Run inference on the input image using the ONNX model.
        """
        prep_img = self.preprocess(img, self.imgsz)
        outs = self.session.run(None, {self.session.get_inputs()[0].name: prep_img})
        return self.postprocess(img, prep_img, outs)

    def letterbox(self, img, new_shape=(640, 640)):
        """
        Resize and pad image while maintaining aspect ratio.
        """
        shape = img.shape[:2]  # current shape [height, width]
        # Scale ratio (new / old)
        r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
        # Compute padding
        new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
        dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2  # wh padding
        if shape[::-1] != new_unpad:  # resize
            img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
        top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
        left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
        img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
        return img

    def preprocess(self, img, new_shape):
        """
        Preprocess the input image before feeding it into the model.
        """
        img = self.letterbox(img, new_shape)
        img = img[..., ::-1].transpose([2, 0, 1])[None]  # BGR to RGB, BHWC to BCHW
        img = np.ascontiguousarray(img)
        img = img.astype(np.float32) / 255  # Normalize to [0, 1]
        return img

    def postprocess(self, img, prep_img, outs):
        """
        Post-process model predictions to extract meaningful results.
        """
        preds, protos = [torch.from_numpy(p) for p in outs]
        preds = ops.non_max_suppression(preds, self.conf, self.iou, nc=len(self.classes))
        results = []
        for i, pred in enumerate(preds):
            pred[:, :4] = ops.scale_boxes(prep_img.shape[2:], pred[:, :4], img.shape)
            masks = self.process_mask(protos[i], pred[:, 6:], pred[:, :4], img.shape[:2])
            results.append(Results(img, path="", names=self.classes, boxes=pred[:, :6], masks=masks))
        return results

    def process_mask(self, protos, masks_in, bboxes, shape):
        """
        Process prototype masks with predicted mask coefficients to generate instance segmentation masks.
        """
        c, mh, mw = protos.shape  # CHW
        masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)  # Matrix multiplication
        masks = ops.scale_masks(masks[None], shape)[0]  # Scale masks to original image size
        masks = ops.crop_mask(masks, bboxes)  # Crop masks to bounding boxes
        return masks.gt_(0.0)  # Convert to binary masks

    def visualize_segmentation(self, image, results, alpha=0.95):
        # 创建图像副本
        visualization = image.copy()
        # 获取预测结果
        if isinstance(results, list):
            result = results[0]  # 只取第一个结果
        else:
            result = results
        # 检查是否有分割掩码
        if hasattr(result, 'masks') and result.masks is not None:
            # 获取边界框、类别和置信度
            boxes = result.boxes.cpu().numpy()
            masks = result.masks.data.cpu().numpy()
            # 生成随机颜色
            num_instances = len(boxes)
            colors = np.random.randint(0, 255, size=(num_instances, 3), dtype=np.uint8)
            # 遍历每个实例
            for i in range(num_instances):
                confidence = boxes.conf[i]
                class_id = int(boxes.cls[i])
                # 获取掩码并调整大小以匹配原始图像
                mask = masks[i]
                if mask.shape[:2] != image.shape[:2]:
                    mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
                # 创建着色掩码
                color_mask = np.zeros_like(image)
                mask_bool = (mask > 0).astype(bool)
                color = colors[i].tolist()
                color_mask[mask_bool] = color
                # 将掩码与原始图像混合
                visualization = cv2.addWeighted(visualization, 1.0, color_mask, alpha, 0)
                # # 绘制边界框
                # x1, y1, x2, y2 = map(int, boxes.xyxy[i])
                # cv2.rectangle(visualization, (x1, y1), (x2, y2), color, 2)
                # # 获取类别名称
                # class_name = self.classes[class_id]
                # # 显示类别名称和置信度
                # label = f"{class_name}: {confidence:.2f}"
                # text_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
                # cv2.rectangle(visualization, (x1, y1 - text_size[1] - 5), (x1 + text_size[0], y1), color, -1)
                # cv2.putText(visualization, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
        return visualization

完成算法模块的独立后,接下来就是基于PyQt的软件模块开发,此处之所以强调两者互相独立,也是为了提高软件代码后续的重复使用率,以后即使换了一个算法,只需要修改几行代码就可以快速的实现第二个demo软件的开发。那么因为我们是视觉项目,所以整个软件需要强调的是原图,结果图,以及一些相关的功能按钮,接下来我将提供一个软件模板和对应的注释,方便大家使用。

复制代码
import sys
import os
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, 
                            QPushButton, QLabel, QFileDialog, QTabWidget, QScrollArea, 
                            QSplitter, QMessageBox)
from PyQt5.QtGui import QPixmap, QImage
from PyQt5.QtCore import Qt, pyqtSignal, QThread
import cv2
import numpy as np
from yoloSeg.model import YOLOv12Seg as SegAI


class ImageProcessor(QThread):
    """处理图像的线程,避免UI阻塞"""
    detection_completed = pyqtSignal(np.ndarray)
    
    def __init__(self):
        super().__init__()
        self.image = None
        self.ai = SegAI("weights/best.onnx", ["goldLine"], conf=0.2, iou=0.8)
        
    def set_images(self, image):
        self.image = image
        
    def run(self):
        result_image = self.detection_algorithm()
        self.detection_completed.emit(result_image)
    
    def detection_algorithm(self):
        results = self.ai(self.image)
        visualization = self.ai.visualize_segmentation(self.image, results)
        return visualization


class ImageViewer(QLabel):
    """图像查看器组件"""
    def __init__(self):
        super().__init__()
        self.setAlignment(Qt.AlignCenter)
        self.setMinimumSize(300, 300)
        self.setStyleSheet("border: 1px solid gray; background-color: #f0f0f0;")
        self.setScaledContents(False)
        self.current_pixmap = None
        
    def setImage(self, image):
        if isinstance(image, np.ndarray):
            # 将OpenCV图像转换为QPixmap
            height, width, channel = image.shape
            bytesPerLine = 3 * width
            qImg = QImage(image.data, width, height, bytesPerLine, QImage.Format_RGB888).rgbSwapped()
            pixmap = QPixmap.fromImage(qImg)
        else:
            pixmap = QPixmap(image)
            
        if pixmap.isNull():
            return
        self.current_pixmap = pixmap
        self.updatePixmap()
            
    def updatePixmap(self):
        if self.current_pixmap:
            # 保持纵横比缩放到标签大小
            scaled_pixmap = self.current_pixmap.scaled(
                self.width(), self.height(), 
                Qt.KeepAspectRatio, Qt.SmoothTransformation
            )
            super().setPixmap(scaled_pixmap)
    
    def resizeEvent(self, event):
        self.updatePixmap()
        super().resizeEvent(event)


class ImageProcessingApp(QMainWindow):
    def __init__(self):
        super().__init__()
        self.image = None
        # 组件
        self.initUI()
        self.image_processor = ImageProcessor()
        self.image_processor.detection_completed.connect(self.on_detection_completed)
        
    def initUI(self):
        # 设置窗口标题和大小
        self.setWindowTitle('检测软件')
        self.setGeometry(100, 100, 1200, 700)
        # 创建中央部件和总体布局
        central_widget = QWidget()
        self.setCentralWidget(central_widget)
        main_layout = QVBoxLayout(central_widget)
        # 创建顶部按钮布局
        button_layout = QHBoxLayout()
        # 添加按钮
        self.btn_open = QPushButton('打开图片文件')
        self.btn_detection = QPushButton('开始检测')
        self.btn_save= QPushButton('保存图片')
        # 添加按钮到布局
        button_layout.addWidget(self.btn_open)
        button_layout.addWidget(self.btn_detection)
        button_layout.addWidget(self.btn_save)
        # 连接按钮信号
        self.btn_open.clicked.connect(self.open_file)
        self.btn_detection.clicked.connect(self.detect_image)
        self.btn_save.clicked.connect(self.save_image)
        # 禁用未加载图片前的按钮
        self.btn_save.setEnabled(False)
        # 添加分隔器分割左右区域
        splitter = QSplitter(Qt.Horizontal)
        # 左侧:图片展示区
        self.image_tabs = QTabWidget()
        self.image_viewer = ImageViewer()
        self.image_tabs.addTab(self.image_viewer, "图片队列")
        splitter.addWidget(self.image_tabs)
        # 右侧:结果显示区(使用标签页)
        self.result_tabs = QTabWidget()
        # 添加标签页
        self.detection_tab = ImageViewer()
        self.result_tabs.addTab(self.detection_tab, "图片检测")
        splitter.addWidget(self.result_tabs)
        # 设置分隔器比例
        splitter.setSizes([600, 600])
        # 添加按钮区域和分隔器到主布局
        main_layout.addLayout(button_layout)
        main_layout.addWidget(splitter, 1)
        # 添加状态栏
        self.statusBar().showMessage('就绪')
        
    def open_file(self):
        """打开文件夹并加载图像"""
        file_path, filetype = QFileDialog.getOpenFileName(self,  
                                    "选取文件",  
                                    os.getcwd(), # 起始路径 
                                    "Image Files (*.jpg *.jpeg *.png *.bmp *.tif *.tiff)")
        if file_path:
            self.image = cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), -1)
            # 显示图片
            if self.image.any():
                self.image_viewer.setImage(self.image)
            else:
                QMessageBox.warning(self, '警告', '无法加载任何图像文件!')
    
    def detect_image(self):
        """检测图片"""
        self.statusBar().showMessage('正在进行图片检测...')
        # 使用线程处理图片检测
        self.image_processor.set_images(self.image)
        self.image_processor.start()
    
    def on_detection_completed(self, result_image):
        """检测完成后的回调"""
        self.detection_tab.setImage(result_image)
        self.result_tabs.setCurrentIndex(1)  # 切换到图片检测标签页
        # 保存检测结果
        self.detection_result = result_image
        self.statusBar().showMessage('图片检测完成')
        self.btn_save.setEnabled(True) # 启用保存按钮

    def save_image(self):
        """保存最新阶段的图片"""
        cv2.imwrite("result.jpg", self.detection_result)
        self.statusBar().showMessage('图片保存完成')


if __name__ == "__main__":
    app = QApplication(sys.argv)
    ex = ImageProcessingApp()
    ex.show()
    sys.exit(app.exec_())

这里注意ImageProcessor类其实就是算法类在软件中的代理类,类似于一个协议,后续只需要按需求修改这个类和算法类即可。

相关推荐
来自天蝎座的孙孙8 分钟前
洛谷P1595讲解(加强版)+错排讲解
python·算法
张子夜 iiii1 小时前
机器学习算法系列专栏:主成分分析(PCA)降维算法(初学者)
人工智能·python·算法·机器学习
weixin_456904271 小时前
一文讲清楚Pytorch 张量、链式求导、正向传播、反向求导、计算图等基础知识
人工智能·pytorch·学习
跟橙姐学代码2 小时前
学Python像学做人:从基础语法到人生哲理的成长之路
前端·python
Keying,,,,3 小时前
力扣hot100 | 矩阵 | 73. 矩阵置零、54. 螺旋矩阵、48. 旋转图像、240. 搜索二维矩阵 II
python·算法·leetcode·矩阵
桃源学社(接毕设)3 小时前
基于人工智能和物联网融合跌倒监控系统(LW+源码+讲解+部署)
人工智能·python·单片机·yolov8
yunhuibin3 小时前
pycharm2025导入anaconda创建的各个AI环境
人工智能·python
杨荧3 小时前
基于Python的电影评论数据分析系统 Python+Django+Vue.js
大数据·前端·vue.js·python
python-行者4 小时前
akamai鼠标轨迹
爬虫·python·计算机外设·akamai
jndingxin4 小时前
OpenCV图像注册模块
人工智能·opencv·计算机视觉