Python调用onnx模型

概述:

核心功能

高性能推理:使用 onnxruntime 引擎,支持 GPU (CUDA)、CPU 加速,相比原生模型推理更轻量。

异步回调系统:识别结果通过 ThreadPoolExecutor 异步推送,检测主流程不会因处理逻辑(如写入数据库、发送邮件)而产生卡顿。

动态任务管理:支持在程序运行期间动态地 增加 或 删除 检测源,无需重启程序。

自动重连机制:针对网络摄像头等不稳定源,内置了自动断线重连逻辑。

频率控制:可自定义 interval(检测间隔),在保证监控质量的同时极大降低 CPU/GPU 负载。

一、环境

java 复制代码
创建:conda create -n onnx python=3.9
激活:conda activate onnx

二、依赖

java 复制代码
创建requirements.txt文件
写入内容如下:
    # ONNX RUN requirements
    # Usage: pip install -r requirements.txt

    # Base ------------------------------------------------------------------------
    numpy>=1.26.4
    # CPU版本
    onnxruntime>=1.19.2

    # CUDA版本
    #onnxruntime-gpu>=1.19.2

    opencv-python>=4.9.0.80

执行安装依赖:pip install -r requirements.txt

三、创建yolo_detector.py(核心工具类)

python 复制代码
import cv2
import numpy as np
import onnxruntime as ort
import json
import time
import threading
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor

class YOLOv5Detector:
    def __init__(self, model_path, names=None, conf_thres=0.25, iou_thres=0.45):
        """
        :param model_path: ONNX模型路径
        :param names: 类别字典, 如 {0: 'person', 1: 'car'}
        :param conf_thres: 置信度阈值(过滤虚警)
        :param iou_thres: 交并比阈值(消除重复的框) 
        """
        self.conf_thres = conf_thres
        self.iou_thres = iou_thres
        self.names = names if names else {}
        
        # 硬件加速适配
        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if \
                    'CUDAExecutionProvider' in ort.get_available_providers() else ['CPUExecutionProvider']
        
        self.session = ort.InferenceSession(model_path, providers=providers)
        self.input_name = self.session.get_inputs()[0].name
        self.img_size = self.session.get_inputs()[0].shape[2]
        self.callbacks = []
        self.executor = ThreadPoolExecutor(max_workers=4)

    def add_callback(self, func):
        """注册回调函数。只要识别到目标,就会调用此函数"""
        self.callbacks.append(func)

    def _trigger_callbacks(self, data):
        """使用守护线程异步执行回调,确保不卡住识别主流程"""
        payload = json.dumps(data, ensure_ascii=False)
        for func in self.callbacks:
            self.executor.submit(func, payload)

    def _preprocess(self, img):
        h, w = img.shape[:2]
        r = self.img_size / max(h, w)
        new_unpad = (int(w * r), int(h * r))
        resized_img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
        canvas = np.full((self.img_size, self.img_size, 3), 114, dtype=np.uint8)
        canvas[:new_unpad[1], :new_unpad[0], :] = resized_img
        img_in = canvas.transpose((2, 0, 1))[::-1]
        img_in = np.ascontiguousarray(img_in).astype(np.float32) / 255.0
        return img_in[None, :], r

    def _postprocess(self, outputs, ratio):
        predictions = outputs[0][0]
        mask = predictions[:, 4] > self.conf_thres
        predictions = predictions[mask]
        if len(predictions) == 0: return []

        boxes, scores, class_ids = [], [], []
        for pred in predictions:
            score = pred[4]
            class_score = pred[5:]
            class_id = np.argmax(class_score)
            conf = float(score * class_score[class_id])
            if conf > self.conf_thres:
                x, y, w, h = pred[:4]
                boxes.append([(x - w/2)/ratio, (y - h/2)/ratio, w/ratio, h/ratio])
                scores.append(conf)
                class_ids.append(int(class_id))
        
        indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf_thres, self.iou_thres)
        results = []
        if len(indices) > 0:
            for i in indices.flatten():
                cid = class_ids[i]
                results.append({
                    "class_id": cid,
                    "name": self.names.get(cid, str(cid)),
                    "score": round(scores[i], 3),
                    "box": [int(v) for v in boxes[i]]
                })
        return results

    def detect_frame(self, frame):
        """单帧检测接口"""
        if frame is None: return []
        img_in, ratio = self._preprocess(frame)
        outputs = self.session.run(None, {self.input_name: img_in})
        return self._postprocess(outputs, ratio)

    def start_polling(self, source, stop_event, interval=1.0):
        """
        视频流/长任务轮询
        :param stop_event: threading.Event 对象,用于外部停止线程
        """
        last_emit_time = 0
        while not stop_event.is_set():
            cap = cv2.VideoCapture(source)
            if not cap.isOpened():
                # 若连接失败,循环检查停止信号,避免死等
                for _ in range(10): 
                    if stop_event.is_set(): break
                    time.sleep(0.5)
                continue
            
            print(f"[Info] 开始轮询源: {source}")
            while cap.isOpened() and not stop_event.is_set():
                ret, frame = cap.read()
                if not ret: break
                
                now = time.time()
                if now - last_emit_time > interval:
                    detections = self.detect_frame(frame)
                    if detections:
                        msg = {
                            "time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                            "source": str(source),
                            "detections": detections
                        }
                        self._trigger_callbacks(msg)
                    last_emit_time = now
            
            cap.release()
            if stop_event.is_set(): break
            time.sleep(1)
        print(f"[Info] 任务已移除,线程退出: {source}")

class DetectionManager:
    """动态任务管理器"""
    def __init__(self, detector):
        self.detector = detector
        self.active_tasks = {}   # 存储 {source: thread}
        self.stop_events = {}    # 存储 {source: stop_event}

    def add_task(self, source, interval=1.0):
        """动态新增任务"""
        
        if source in self.active_tasks and self.active_tasks[source].is_alive():
            print(f"[Manager] 任务 {source} 已经在运行中")
            return

        # 创建停止信号
        stop_event = threading.Event()
        self.stop_events[source] = stop_event
        
        t = threading.Thread(
            target=self.detector.start_polling, 
            args=(source, stop_event, interval),
            daemon=True
        )
        t.start()
        self.active_tasks[source] = t
        print(f"[Manager] 成功添加任务: {source}")

    def remove_task(self, source):
        """动态移除任务"""
        if source in self.stop_events:
            print(f"[Manager] 正在停止任务: {source}...")
            self.stop_events[source].set()
            
            # 等待线程结束(不阻塞主线程建议在后台清理)
            # self.active_tasks[source].join() 
            
            # 3. 清理字典引用
            del self.stop_events[source]
            del self.active_tasks[source]
            return True
        else:
            print(f"[Manager] 未找到任务: {source}")
            return False

    def list_tasks(self):
        """查看当前运行中的任务"""
        return list(self.active_tasks.keys())

四、测试类(main.py

python 复制代码
from yolo_detector import YOLOv5Detector, DetectionManager
import time

# 业务回调
def my_callback(msg_json):
    print(f"--- 识别到目标 ---\n{msg_json}\n")

# 初始化
detector = YOLOv5Detector(model_path="./best.onnx", names={0: "shou", 1: "jiao"})

# 注册回调函数
detector.add_callback(my_callback)

# 创建管理器
manager = DetectionManager(detector)

# --- 模拟业务场景 ---

# 第一次:传一张图片过来
manager.add_task("images/12d1090c-c151-4d59-a09c-3463537c5a9e.png", 10)

# 第二次:传一个 RTSP 流过来
time.sleep(5) 
manager.add_task("images/5d8e456c-ae87-4c40-8adc-c5ea81a99611.png", 10)

# 移除第一个检测源
# time.sleep(10)
# manager.remove_task("/Users/wangqingpan/Desktop/hmbaby/train/images/12d1090c-c151-4d59-a09c-3463537c5a9e.png")

# 移除第二个检测源
# time.sleep(10)
# manager.remove_task("/Users/wangqingpan/Desktop/hmbaby/train/images/f49a57d7-d31b-47e0-9440-caadfe335538.png")

# 第三次:传一个本地摄像头
# time.sleep(5)
# manager.add_task(0) 

# 主程序不能死
while True:
    time.sleep(1)

五、运行测试

python 复制代码
需要准备onnx文件和要识别的图片或流地址,准备好后修改main.py中的参数

执行:python main.py

六、输出JSON格式说明

| time | 识别发生的具体时间 |

| source | 视频源标识 |

| detections | 检测到的目标列表 |

| detections.name | 类别名称 |

| detections.score | 置信度 (0-1) |

| detections.box | 坐标 [x, y, w, h] |

七、Docker部署

(1) 创建Dockerfile
python 复制代码
# 使用轻量级的 Python 镜像
FROM python:3.9-slim

# 设置工作目录
WORKDIR /app

# 安装 OpenCV 所需的系统依赖 (libGL.so.1 等)
RUN apt-get update && apt-get install -y \
    libgl1 \
    libglib2.0-0 \
    && rm -rf /var/lib/apt/lists/*

# 复制依赖文件(如果没有 requirements.txt,可以直接 RUN pip)
RUN pip install --no-cache-dir \
    opencv-python-headless \
    numpy \
    onnxruntime \
    datetime

# 复制项目代码和模型
COPY . .

# 启动程序
CMD ["python", "main.py"]
(2) 构建镜像
python 复制代码
docker build -t 名称:版本 .
(3) 运行镜像

指定路径或摄像头需要在代码中做相应的指定

python 复制代码
   基础运行:docker run -d --name 名称 名称:版本
   映射本地图片路径:docker run -d \
                    --name 名称 \
                     -v 宿主机路径:/data/hmbaby \
                    名称:版本
   使用宿主机摄像头:docker run -d \
                    --name 名称 \
                    --device=/dev/video0:/dev/video0 \
                    名称:版本
相关推荐
Ava的硅谷新视界2 小时前
用了一天 Claude Opus 4.7,聊几点真实感受
开发语言·后端·编程
AC赳赳老秦2 小时前
OpenClaw生成博客封面图+标题,适配CSDN视觉搜索,提升点击量
运维·人工智能·python·自动化·php·deepseek·openclaw
浪客川3 小时前
【百例RUST - 010】字符串
开发语言·后端·rust
m0_493934533 小时前
如何监控AWR数据收集Job_DBA_SCHEDULER_JOBS中的BSLN_MAINTAIN_STATS
jvm·数据库·python
xiaotao1313 小时前
01-编程基础与数学基石:概率与统计
人工智能·python·numpy·pandas
赵侃侃爱分享3 小时前
学完Python第一次写程序写了这个简单的计算器
开发语言·python
a9511416424 小时前
Go语言如何操作OSS_Go语言阿里云OSS上传教程【完整】
jvm·数据库·python
2401_897190554 小时前
MySQL中如何利用LIMIT配合函数分页_MySQL分页查询优化
jvm·数据库·python
断眉的派大星4 小时前
# Python 魔术方法(魔法方法)超详细讲解
开发语言·python