YOLOv8目标跟踪model.track的封装
flyfish
在使用目标跟踪时, 调用model.track
整个步骤就完成,track封装了内部运行的步骤。这里主要说回调部分。
使用model.track
py
import cv2
from ultralytics import YOLO
from collections import defaultdict
import numpy as np
track_history = defaultdict(lambda: [])
# Open the video file
video_path = "1.mp4"
model = YOLO("yolov8s.pt")
# Open the video file
cap = cv2.VideoCapture(video_path)
# Retrieve video properties: width, height, and frames per second
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
# Initialize video writer to save the output video with the specified properties
out = cv2.VideoWriter("detection-object-tracking-bytetrack.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h))
# Loop through the video frames
while cap.isOpened():
# Read a frame from the video
success, frame = cap.read()
if success:
# Run YOLOv8 tracking on the frame, persisting tracks between frames
results = model.track(frame, persist=True,tracker="bytetrack.yaml")
# Get the boxes and track IDs
boxes = results[0].boxes.xywh.cpu()
track_ids = results[0].boxes.id.int().cpu().tolist()
# Visualize the results on the frame
annotated_frame = results[0].plot()
# Plot the tracks
for box, track_id in zip(boxes, track_ids):
x, y, w, h = box
track = track_history[track_id]
track.append((float(x), float(y))) # x, y center point
if len(track) > 30: # retain 90 tracks for 90 frames
track.pop(0)
# Draw the tracking lines
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(
annotated_frame,
[points],
isClosed=False,
color=(230, 230, 230),
thickness=10,
)
# Display the annotated frame
out.write(annotated_frame)
cv2.imshow("YOLOv8 Tracking", annotated_frame)
# Break the loop if 'q' is pressed
if cv2.waitKey(1) & 0xFF == ord("q"):
break
else:
# Break the loop if the end of the video is reached
break
# Release the video capture object and close the display window
out.release()
cap.release()
cv2.destroyAllWindows()
model.track
背后就是回调
我们分析下它的回调代码
model.predict()
方法会触发on_predict_start
和on_predict_postprocess_end
事件
分析回调
py
from functools import partial
from pathlib import Path
import torch
from ultralytics.utils import IterableSimpleNamespace, yaml_load
from ultralytics.utils.checks import check_yaml
from .bot_sort import BOTSORT
from .byte_tracker import BYTETracker
# A mapping of tracker types to corresponding tracker classes
TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}
def on_predict_start(predictor: object, persist: bool = False) -> None:
"""
Initialize trackers for object tracking during prediction.
Args:
predictor (object): The predictor object to initialize trackers for.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
Raises:
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
"""
if hasattr(predictor, "trackers") and persist:
return
tracker = check_yaml(predictor.args.tracker)
cfg = IterableSimpleNamespace(**yaml_load(tracker))
if cfg.tracker_type not in {"bytetrack", "botsort"}:
raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'")
trackers = []
for _ in range(predictor.dataset.bs):
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
trackers.append(tracker)
if predictor.dataset.mode != "stream": # only need one tracker for other modes.
break
predictor.trackers = trackers
predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video
def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:
"""
Postprocess detected boxes and update with object tracking.
Args:
predictor (object): The predictor object containing the predictions.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
"""
path, im0s = predictor.batch[:2]
is_obb = predictor.args.task == "obb"
is_stream = predictor.dataset.mode == "stream"
for i in range(len(im0s)):
tracker = predictor.trackers[i if is_stream else 0]
vid_path = predictor.save_dir / Path(path[i]).name
if not persist and predictor.vid_path[i if is_stream else 0] != vid_path:
tracker.reset()
predictor.vid_path[i if is_stream else 0] = vid_path
det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy()
if len(det) == 0:
continue
tracks = tracker.update(det, im0s[i])
if len(tracks) == 0:
continue
idx = tracks[:, -1].astype(int)
predictor.results[i] = predictor.results[i][idx]
update_args = {"obb" if is_obb else "boxes": torch.as_tensor(tracks[:, :-1])}
predictor.results[i].update(**update_args)
def register_tracker(model: object, persist: bool) -> None:
"""
Register tracking callbacks to the model for object tracking during prediction.
Args:
model (object): The model object to register tracking callbacks for.
persist (bool): Whether to persist the trackers if they already exist.
"""
model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))
简单仿写,可以独立运行
py
def on_predict_start(predictor: object, persist: bool = False) -> None:
# 回调函数代码
print("on_predict_start")
pass
def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:
# 回调函数代码
print("on_predict_postprocess_end")
pass
from functools import partial
def register_tracker(model: object, persist: bool) -> None:
model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))
from functools import partial
class Model:
def __init__(self):
self.callbacks = {"on_predict_start": [], "on_predict_postprocess_end": []}
def add_callback(self, event, callback):
if event in self.callbacks:
self.callbacks[event].append(callback)
def predict(self):
# 触发'on_predict_start'事件
for callback in self.callbacks["on_predict_start"]:
callback(self)
# 模拟预测过程
print("Predicting...")
# 触发'on_predict_postprocess_end'事件
for callback in self.callbacks["on_predict_postprocess_end"]:
callback(self)
# 使用例子 model.predict()方法会触发on_predict_start和on_predict_postprocess_end事件,调用已注册的回调函数。
model = Model()
register_tracker(model, persist=True)
model.predict()
输出
on_predict_start
Predicting...
on_predict_postprocess_end
partial应用在回调函数中
在回调函数的场景中,partial 特别有用,因为它允许预设某些参数,而不是在每次调用时都传入这些参数。
假设有一个回调函数,它需要两个参数,但是在注册回调函数时,只能传入一个参数:
py
def callback_function(event, persist):
print(f"Event: {event}, Persist: {persist}")
希望将这个函数作为回调函数,但是只希望在事件发生时传入 event 参数,而 persist 参数是预设好的。这时可以使用 partial:
py
from functools import partial
# 预设 persist 参数
partial_callback = partial(callback_function, persist=True)
# 当事件发生时,只需要传入 event 参数
partial_callback(event="on_predict_start") # 输出: Event: on_predict_start, Persist: True