基于onnxruntime结合PyQt快速搭建视觉原型Demo

我在日常工作中经常使用PyQt和onnxruntime来快速生产demo软件,用于展示和测试,这里,我将以Yolov12为例,展示一下我的方案。

首先我们需要使用Yolov12训练一个模型,并export出Onnx文件,这个部分网络上有很多内容,可以使用ultralytics框架做这个事情,我在这里就不在赘述了,接下来的步骤直接从onnxruntime开始。

此处你需要针对你的模型去编写一个基于Onnxruntime的推理类,包括前处理,后处理,可视化等部分,做到输入是图片,输出是结果,要保证算法代码和软件代码的独立性。这里是我参考ultralytics框架写的一个yolov12的推理类,用于实例分割的推理。

复制代码
import cv2
import numpy as np
import onnxruntime as ort
import torch
import yoloSeg.utils.ops as ops
from yoloSeg.utils.results import Results


class YOLOv12Seg:
    """
    YOLOv12 segmentation model for performing instance segmentation using ONNX Runtime.

    This class implements a YOLOv12 instance segmentation model using ONNX Runtime for inference. It handles
    preprocessing of input images, running inference with the ONNX model, and postprocessing the results to
    generate bounding boxes and segmentation masks.
    """

    def __init__(self, onnx_model, classes, conf=0.25, iou=0.7, imgsz=640):
        """
        Initialize the instance segmentation model using an ONNX model.
        """
        self.session = ort.InferenceSession(
            onnx_model,
            providers=["CPUExecutionProvider"]
            # if torch.cuda.is_available()
            # else ["CPUExecutionProvider"],
        )
        self.imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
        self.classes = classes
        self.conf = conf
        self.iou = iou

    def __call__(self, img):
        """
        Run inference on the input image using the ONNX model.
        """
        prep_img = self.preprocess(img, self.imgsz)
        outs = self.session.run(None, {self.session.get_inputs()[0].name: prep_img})
        return self.postprocess(img, prep_img, outs)

    def letterbox(self, img, new_shape=(640, 640)):
        """
        Resize and pad image while maintaining aspect ratio.
        """
        shape = img.shape[:2]  # current shape [height, width]
        # Scale ratio (new / old)
        r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
        # Compute padding
        new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
        dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2  # wh padding
        if shape[::-1] != new_unpad:  # resize
            img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
        top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
        left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
        img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
        return img

    def preprocess(self, img, new_shape):
        """
        Preprocess the input image before feeding it into the model.
        """
        img = self.letterbox(img, new_shape)
        img = img[..., ::-1].transpose([2, 0, 1])[None]  # BGR to RGB, BHWC to BCHW
        img = np.ascontiguousarray(img)
        img = img.astype(np.float32) / 255  # Normalize to [0, 1]
        return img

    def postprocess(self, img, prep_img, outs):
        """
        Post-process model predictions to extract meaningful results.
        """
        preds, protos = [torch.from_numpy(p) for p in outs]
        preds = ops.non_max_suppression(preds, self.conf, self.iou, nc=len(self.classes))
        results = []
        for i, pred in enumerate(preds):
            pred[:, :4] = ops.scale_boxes(prep_img.shape[2:], pred[:, :4], img.shape)
            masks = self.process_mask(protos[i], pred[:, 6:], pred[:, :4], img.shape[:2])
            results.append(Results(img, path="", names=self.classes, boxes=pred[:, :6], masks=masks))
        return results

    def process_mask(self, protos, masks_in, bboxes, shape):
        """
        Process prototype masks with predicted mask coefficients to generate instance segmentation masks.
        """
        c, mh, mw = protos.shape  # CHW
        masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)  # Matrix multiplication
        masks = ops.scale_masks(masks[None], shape)[0]  # Scale masks to original image size
        masks = ops.crop_mask(masks, bboxes)  # Crop masks to bounding boxes
        return masks.gt_(0.0)  # Convert to binary masks

    def visualize_segmentation(self, image, results, alpha=0.95):
        # 创建图像副本
        visualization = image.copy()
        # 获取预测结果
        if isinstance(results, list):
            result = results[0]  # 只取第一个结果
        else:
            result = results
        # 检查是否有分割掩码
        if hasattr(result, 'masks') and result.masks is not None:
            # 获取边界框、类别和置信度
            boxes = result.boxes.cpu().numpy()
            masks = result.masks.data.cpu().numpy()
            # 生成随机颜色
            num_instances = len(boxes)
            colors = np.random.randint(0, 255, size=(num_instances, 3), dtype=np.uint8)
            # 遍历每个实例
            for i in range(num_instances):
                confidence = boxes.conf[i]
                class_id = int(boxes.cls[i])
                # 获取掩码并调整大小以匹配原始图像
                mask = masks[i]
                if mask.shape[:2] != image.shape[:2]:
                    mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
                # 创建着色掩码
                color_mask = np.zeros_like(image)
                mask_bool = (mask > 0).astype(bool)
                color = colors[i].tolist()
                color_mask[mask_bool] = color
                # 将掩码与原始图像混合
                visualization = cv2.addWeighted(visualization, 1.0, color_mask, alpha, 0)
                # # 绘制边界框
                # x1, y1, x2, y2 = map(int, boxes.xyxy[i])
                # cv2.rectangle(visualization, (x1, y1), (x2, y2), color, 2)
                # # 获取类别名称
                # class_name = self.classes[class_id]
                # # 显示类别名称和置信度
                # label = f"{class_name}: {confidence:.2f}"
                # text_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
                # cv2.rectangle(visualization, (x1, y1 - text_size[1] - 5), (x1 + text_size[0], y1), color, -1)
                # cv2.putText(visualization, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
        return visualization

完成算法模块的独立后,接下来就是基于PyQt的软件模块开发,此处之所以强调两者互相独立,也是为了提高软件代码后续的重复使用率,以后即使换了一个算法,只需要修改几行代码就可以快速的实现第二个demo软件的开发。那么因为我们是视觉项目,所以整个软件需要强调的是原图,结果图,以及一些相关的功能按钮,接下来我将提供一个软件模板和对应的注释,方便大家使用。

复制代码
import sys
import os
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, 
                            QPushButton, QLabel, QFileDialog, QTabWidget, QScrollArea, 
                            QSplitter, QMessageBox)
from PyQt5.QtGui import QPixmap, QImage
from PyQt5.QtCore import Qt, pyqtSignal, QThread
import cv2
import numpy as np
from yoloSeg.model import YOLOv12Seg as SegAI


class ImageProcessor(QThread):
    """处理图像的线程,避免UI阻塞"""
    detection_completed = pyqtSignal(np.ndarray)
    
    def __init__(self):
        super().__init__()
        self.image = None
        self.ai = SegAI("weights/best.onnx", ["goldLine"], conf=0.2, iou=0.8)
        
    def set_images(self, image):
        self.image = image
        
    def run(self):
        result_image = self.detection_algorithm()
        self.detection_completed.emit(result_image)
    
    def detection_algorithm(self):
        results = self.ai(self.image)
        visualization = self.ai.visualize_segmentation(self.image, results)
        return visualization


class ImageViewer(QLabel):
    """图像查看器组件"""
    def __init__(self):
        super().__init__()
        self.setAlignment(Qt.AlignCenter)
        self.setMinimumSize(300, 300)
        self.setStyleSheet("border: 1px solid gray; background-color: #f0f0f0;")
        self.setScaledContents(False)
        self.current_pixmap = None
        
    def setImage(self, image):
        if isinstance(image, np.ndarray):
            # 将OpenCV图像转换为QPixmap
            height, width, channel = image.shape
            bytesPerLine = 3 * width
            qImg = QImage(image.data, width, height, bytesPerLine, QImage.Format_RGB888).rgbSwapped()
            pixmap = QPixmap.fromImage(qImg)
        else:
            pixmap = QPixmap(image)
            
        if pixmap.isNull():
            return
        self.current_pixmap = pixmap
        self.updatePixmap()
            
    def updatePixmap(self):
        if self.current_pixmap:
            # 保持纵横比缩放到标签大小
            scaled_pixmap = self.current_pixmap.scaled(
                self.width(), self.height(), 
                Qt.KeepAspectRatio, Qt.SmoothTransformation
            )
            super().setPixmap(scaled_pixmap)
    
    def resizeEvent(self, event):
        self.updatePixmap()
        super().resizeEvent(event)


class ImageProcessingApp(QMainWindow):
    def __init__(self):
        super().__init__()
        self.image = None
        # 组件
        self.initUI()
        self.image_processor = ImageProcessor()
        self.image_processor.detection_completed.connect(self.on_detection_completed)
        
    def initUI(self):
        # 设置窗口标题和大小
        self.setWindowTitle('检测软件')
        self.setGeometry(100, 100, 1200, 700)
        # 创建中央部件和总体布局
        central_widget = QWidget()
        self.setCentralWidget(central_widget)
        main_layout = QVBoxLayout(central_widget)
        # 创建顶部按钮布局
        button_layout = QHBoxLayout()
        # 添加按钮
        self.btn_open = QPushButton('打开图片文件')
        self.btn_detection = QPushButton('开始检测')
        self.btn_save= QPushButton('保存图片')
        # 添加按钮到布局
        button_layout.addWidget(self.btn_open)
        button_layout.addWidget(self.btn_detection)
        button_layout.addWidget(self.btn_save)
        # 连接按钮信号
        self.btn_open.clicked.connect(self.open_file)
        self.btn_detection.clicked.connect(self.detect_image)
        self.btn_save.clicked.connect(self.save_image)
        # 禁用未加载图片前的按钮
        self.btn_save.setEnabled(False)
        # 添加分隔器分割左右区域
        splitter = QSplitter(Qt.Horizontal)
        # 左侧:图片展示区
        self.image_tabs = QTabWidget()
        self.image_viewer = ImageViewer()
        self.image_tabs.addTab(self.image_viewer, "图片队列")
        splitter.addWidget(self.image_tabs)
        # 右侧:结果显示区(使用标签页)
        self.result_tabs = QTabWidget()
        # 添加标签页
        self.detection_tab = ImageViewer()
        self.result_tabs.addTab(self.detection_tab, "图片检测")
        splitter.addWidget(self.result_tabs)
        # 设置分隔器比例
        splitter.setSizes([600, 600])
        # 添加按钮区域和分隔器到主布局
        main_layout.addLayout(button_layout)
        main_layout.addWidget(splitter, 1)
        # 添加状态栏
        self.statusBar().showMessage('就绪')
        
    def open_file(self):
        """打开文件夹并加载图像"""
        file_path, filetype = QFileDialog.getOpenFileName(self,  
                                    "选取文件",  
                                    os.getcwd(), # 起始路径 
                                    "Image Files (*.jpg *.jpeg *.png *.bmp *.tif *.tiff)")
        if file_path:
            self.image = cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), -1)
            # 显示图片
            if self.image.any():
                self.image_viewer.setImage(self.image)
            else:
                QMessageBox.warning(self, '警告', '无法加载任何图像文件!')
    
    def detect_image(self):
        """检测图片"""
        self.statusBar().showMessage('正在进行图片检测...')
        # 使用线程处理图片检测
        self.image_processor.set_images(self.image)
        self.image_processor.start()
    
    def on_detection_completed(self, result_image):
        """检测完成后的回调"""
        self.detection_tab.setImage(result_image)
        self.result_tabs.setCurrentIndex(1)  # 切换到图片检测标签页
        # 保存检测结果
        self.detection_result = result_image
        self.statusBar().showMessage('图片检测完成')
        self.btn_save.setEnabled(True) # 启用保存按钮

    def save_image(self):
        """保存最新阶段的图片"""
        cv2.imwrite("result.jpg", self.detection_result)
        self.statusBar().showMessage('图片保存完成')


if __name__ == "__main__":
    app = QApplication(sys.argv)
    ex = ImageProcessingApp()
    ex.show()
    sys.exit(app.exec_())

这里注意ImageProcessor类其实就是算法类在软件中的代理类,类似于一个协议,后续只需要按需求修改这个类和算法类即可。

相关推荐
老朋友此林44 分钟前
MiniMind:3块钱成本 + 2小时!训练自己的0.02B的大模型。minimind源码解读、MOE架构
人工智能·python·nlp
多巴胺与内啡肽.2 小时前
Opencv进阶操作:图像拼接
人工智能·opencv·计算机视觉
宸汐Fish_Heart2 小时前
Python打卡训练营Day22
开发语言·python
小草cys2 小时前
查看YOLO版本的三种方法
人工智能·深度学习·yolo
伊织code2 小时前
PyTorch API 9 - masked, nested, 稀疏, 存储
pytorch·python·ai·api·-·9·masked
白熊1883 小时前
【计算机视觉】OpenCV实战项目:ETcTI_smart_parking智能停车系统深度解析
人工智能·opencv·计算机视觉
wxl7812273 小时前
基于flask+pandas+csv的报表实现
python·flask·pandas
鸡鸭扣4 小时前
DRF/Django+Vue项目线上部署:腾讯云+Centos7.6(github的SSH认证)
前端·vue.js·python·django·腾讯云·drf
钢铁男儿4 小时前
Python中的标识、相等性与别名:深入理解对象引用机制
java·网络·python
且慢.5894 小时前
Python_day22
python·机器学习