YOLOv8目标跟踪model.track的封装

YOLOv8目标跟踪model.track的封装

flyfish

在使用目标跟踪时, 调用model.track整个步骤就完成,track封装了内部运行的步骤。这里主要说回调部分。

使用model.track

py 复制代码
import cv2

from ultralytics import YOLO
from collections import defaultdict
import numpy as np

track_history = defaultdict(lambda: [])
# Open the video file
video_path = "1.mp4"


model = YOLO("yolov8s.pt")

# Open the video file

cap = cv2.VideoCapture(video_path)


# Retrieve video properties: width, height, and frames per second
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))

# Initialize video writer to save the output video with the specified properties
out = cv2.VideoWriter("detection-object-tracking-bytetrack.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h))



# Loop through the video frames
while cap.isOpened():
    # Read a frame from the video
    success, frame = cap.read()

    if success:
        # Run YOLOv8 tracking on the frame, persisting tracks between frames
        results = model.track(frame, persist=True,tracker="bytetrack.yaml")
        # Get the boxes and track IDs
        boxes = results[0].boxes.xywh.cpu()
        track_ids = results[0].boxes.id.int().cpu().tolist()

        # Visualize the results on the frame
        annotated_frame = results[0].plot()
                # Plot the tracks
        for box, track_id in zip(boxes, track_ids):
            x, y, w, h = box
            track = track_history[track_id]
            track.append((float(x), float(y)))  # x, y center point
            if len(track) > 30:  # retain 90 tracks for 90 frames
                track.pop(0)

            # Draw the tracking lines
            points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
            cv2.polylines(
                annotated_frame,
                [points],
                isClosed=False,
                color=(230, 230, 230),
                thickness=10,
            ) 

        # Display the annotated frame
        out.write(annotated_frame)
        cv2.imshow("YOLOv8 Tracking", annotated_frame)

        # Break the loop if 'q' is pressed
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break
    else:
        # Break the loop if the end of the video is reached
        break

# Release the video capture object and close the display window
out.release()
cap.release()
cv2.destroyAllWindows()

model.track 背后就是回调

我们分析下它的回调代码
model.predict()方法会触发on_predict_starton_predict_postprocess_end事件

分析回调

py 复制代码
from functools import partial
from pathlib import Path

import torch

from ultralytics.utils import IterableSimpleNamespace, yaml_load
from ultralytics.utils.checks import check_yaml

from .bot_sort import BOTSORT
from .byte_tracker import BYTETracker

# A mapping of tracker types to corresponding tracker classes
TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}


def on_predict_start(predictor: object, persist: bool = False) -> None:
    """
    Initialize trackers for object tracking during prediction.

    Args:
        predictor (object): The predictor object to initialize trackers for.
        persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.

    Raises:
        AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
    """
    if hasattr(predictor, "trackers") and persist:
        return

    tracker = check_yaml(predictor.args.tracker)
    cfg = IterableSimpleNamespace(**yaml_load(tracker))

    if cfg.tracker_type not in {"bytetrack", "botsort"}:
        raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'")

    trackers = []
    for _ in range(predictor.dataset.bs):
        tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
        trackers.append(tracker)
        if predictor.dataset.mode != "stream":  # only need one tracker for other modes.
            break
    predictor.trackers = trackers
    predictor.vid_path = [None] * predictor.dataset.bs  # for determining when to reset tracker on new video


def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:
    """
    Postprocess detected boxes and update with object tracking.

    Args:
        predictor (object): The predictor object containing the predictions.
        persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
    """
    path, im0s = predictor.batch[:2]

    is_obb = predictor.args.task == "obb"
    is_stream = predictor.dataset.mode == "stream"
    for i in range(len(im0s)):
        tracker = predictor.trackers[i if is_stream else 0]
        vid_path = predictor.save_dir / Path(path[i]).name
        if not persist and predictor.vid_path[i if is_stream else 0] != vid_path:
            tracker.reset()
            predictor.vid_path[i if is_stream else 0] = vid_path

        det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy()
        if len(det) == 0:
            continue
        tracks = tracker.update(det, im0s[i])
        if len(tracks) == 0:
            continue
        idx = tracks[:, -1].astype(int)
        predictor.results[i] = predictor.results[i][idx]

        update_args = {"obb" if is_obb else "boxes": torch.as_tensor(tracks[:, :-1])}
        predictor.results[i].update(**update_args)


def register_tracker(model: object, persist: bool) -> None:
    """
    Register tracking callbacks to the model for object tracking during prediction.

    Args:
        model (object): The model object to register tracking callbacks for.
        persist (bool): Whether to persist the trackers if they already exist.
    """
    model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
    model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))

简单仿写,可以独立运行

py 复制代码
 def on_predict_start(predictor: object, persist: bool = False) -> None:
    # 回调函数代码
    print("on_predict_start")
    pass

def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:
    # 回调函数代码
    print("on_predict_postprocess_end")
    pass
from functools import partial

def register_tracker(model: object, persist: bool) -> None:
    model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
    model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))


from functools import partial

class Model:
    def __init__(self):
        self.callbacks = {"on_predict_start": [], "on_predict_postprocess_end": []}

    def add_callback(self, event, callback):
        if event in self.callbacks:
            self.callbacks[event].append(callback)

    def predict(self):
        # 触发'on_predict_start'事件
        for callback in self.callbacks["on_predict_start"]:
            callback(self)

        # 模拟预测过程
        print("Predicting...")

        # 触发'on_predict_postprocess_end'事件
        for callback in self.callbacks["on_predict_postprocess_end"]:
            callback(self)

# 使用例子 model.predict()方法会触发on_predict_start和on_predict_postprocess_end事件,调用已注册的回调函数。
model = Model()
register_tracker(model, persist=True)
model.predict()

输出

on_predict_start
Predicting...
on_predict_postprocess_end

partial应用在回调函数中

在回调函数的场景中,partial 特别有用,因为它允许预设某些参数,而不是在每次调用时都传入这些参数。

假设有一个回调函数,它需要两个参数,但是在注册回调函数时,只能传入一个参数:

py 复制代码
def callback_function(event, persist):
    print(f"Event: {event}, Persist: {persist}")

希望将这个函数作为回调函数,但是只希望在事件发生时传入 event 参数,而 persist 参数是预设好的。这时可以使用 partial:

py 复制代码
from functools import partial

# 预设 persist 参数
partial_callback = partial(callback_function, persist=True)

# 当事件发生时,只需要传入 event 参数
partial_callback(event="on_predict_start")  # 输出: Event: on_predict_start, Persist: True
相关推荐
CSDN云计算5 分钟前
如何以开源加速AI企业落地,红帽带来新解法
人工智能·开源·openshift·红帽·instructlab
艾派森16 分钟前
大数据分析案例-基于随机森林算法的智能手机价格预测模型
人工智能·python·随机森林·机器学习·数据挖掘
hairenjing112318 分钟前
在 Android 手机上从SD 卡恢复数据的 6 个有效应用程序
android·人工智能·windows·macos·智能手机
小蜗子22 分钟前
Multi‐modal knowledge graph inference via media convergenceand logic rule
人工智能·知识图谱
SpikeKing35 分钟前
LLM - 使用 LLaMA-Factory 微调大模型 环境配置与训练推理 教程 (1)
人工智能·llm·大语言模型·llama·环境配置·llamafactory·训练框架
黄焖鸡能干四碗1 小时前
信息化运维方案,实施方案,开发方案,信息中心安全运维资料(软件资料word)
大数据·人工智能·软件需求·设计规范·规格说明书
1 小时前
开源竞争-数据驱动成长-11/05-大专生的思考
人工智能·笔记·学习·算法·机器学习
ctrey_1 小时前
2024-11-4 学习人工智能的Day21 openCV(3)
人工智能·opencv·学习
攻城狮_Dream1 小时前
“探索未来医疗:生成式人工智能在医疗领域的革命性应用“
人工智能·设计·医疗·毕业
学习前端的小z2 小时前
【AIGC】如何通过ChatGPT轻松制作个性化GPTs应用
人工智能·chatgpt·aigc