替换SlowFast中Detectron2为Yolov8

一 需求

复制代码
FaceBookReserch中SlowFast源码中检测框是用Detectron2进行目标检测,本文想实现用yolov8替换detectron2

二 实施方案

首先,yolov8 支持有自定义库ultralytics(仅支持yolov8),安装对应库

bash 复制代码
pip install ultralytics

源码中slowfast/visualization.py 43行中

python 复制代码
if cfg.DETECTION.ENABLE:
       self.object_detector = Detectron2Predictor(cfg, gpu_id=self.gpu_id)

根据ultralytics文档进行定义

创建对应YOLOPredictor类(加入了检测框及其标签,具体见前一篇文章)

python 复制代码
class YOLOPredictor:

    def __init__(self, cfg, gpu_id=None):
        # 加载预训练的 YOLOv8n 模型
        self.model = YOLO('/root/autodl-tmp/data/runs/detect/train/weights/best.pt')
        self.detect_names, _, _ = get_class_names(cfg.DEMO.Detect_File_Path, None, None)

    def __call__(self, task):
        """
        Return bounding boxes predictions as a tensor.
        Args:
            task (TaskInfo object): task object that contain
                the necessary information for action prediction. (e.g. frames)
        Returns:
            task (TaskInfo object): the same task info object but filled with
                prediction values (a tensor) and the corresponding boxes for
                action detection task.
        """
        # """得到预测置信度"""
        # scores = outputs["instances"].scores[mask].tolist()
        # """获取类别标签"""
        # pred_labels = outputs["instances"].pred_classes[mask]
        # pred_labels = pred_labels.tolist()
        # """进行标签匹配"""
        # for i in range(len(pred_labels)):
        #     pred_labels[i] = self.detect_names[pred_labels[i]]
        # preds = [
        #     "[{:.4f}] {}".format(s, labels) for s, labels in zip(scores, pred_labels)
        # ]
        # """加入预测标签"""
        # task.add_detect_preds(preds)
        # task.add_bboxes(pred_boxes)
        middle_frame = task.frames[len(task.frames) // 2]
        outputs = self.model(middle_frame)
        boxes = outputs[0].boxes
        mask = boxes.conf >= 0.5
        pred_boxes = boxes.xyxy[mask]
        scores = boxes.conf[mask].tolist()
        pred_labels = boxes.cls[mask].to(torch.int)
        pred_labels = pred_labels.tolist()
        for i in range(len(pred_labels)):
            pred_labels[i] = self.detect_names[pred_labels[i]]
        preds = [
            "[{:.4f}] {}".format(s, labels) for s, labels in zip(scores, pred_labels)
        ]
        """加入预测标签"""
        task.add_detect_preds(preds)
        task.add_bboxes(pred_boxes)

        return task
相关推荐
boooo_hhh2 小时前
第28周——InceptionV1实现猴痘识别
python·深度学习·机器学习
白熊1882 小时前
【计算机视觉】OpenCV实战项目:基于OpenCV与face_recognition的实时人脸识别系统深度解析
人工智能·opencv·计算机视觉
闭月之泪舞2 小时前
OpenCv高阶(4.0)——案例:海报的透视变换
人工智能·opencv·计算机视觉
AI technophile3 小时前
OpenCV计算机视觉实战(5)——图像基础操作全解析
python·opencv·计算机视觉
九章云极AladdinEdu4 小时前
GPU SIMT架构的极限压榨:PTX汇编指令级并行优化实践
汇编·人工智能·pytorch·python·深度学习·架构·gpu算力
kyle~5 小时前
深度学习框架---TensorFlow概览
人工智能·深度学习·tensorflow
电鱼智能的电小鱼5 小时前
产线视觉检测设备技术方案:基于EFISH-SCB-RK3588/SAIL-RK3588的国产化替代赛扬N100/N150全场景技术解析
linux·人工智能·嵌入式硬件·计算机视觉·视觉检测·实时音视频
妄想成为master5 小时前
计算机视觉----基于锚点的车道线检测、从Line-CNN到CLRNet到CLRKDNet 本文所提算法Line-CNN 后续会更新以下全部算法
人工智能·计算机视觉·车道线检测
夜幕龙5 小时前
LeRobot 项目部署运行逻辑(七)—— ACT 在 Mobile ALOHA 训练与部署
人工智能·深度学习·机器学习
Echo``6 小时前
40:相机与镜头选型
开发语言·人工智能·深度学习·计算机视觉·视觉检测