需求
在做网球球员多目标跟踪项目时,我希望项目能支持不同的跟踪方案。
比如当前方案是:
- YOLO 做人体检测
- ByteTrack 做多目标跟踪
- 输出标注视频
annotated.mp4 - 输出轨迹文件
tracks.csv
但后续可能会继续尝试其他方案,例如:
- YOLO + DeepSORT
- RT-DETR + ByteTrack
- Ground Truth Replay
- 自定义规则跟踪器
如果每新增一种算法,就去修改主流程代码,会导致代码越来越乱。
因此我希望做到:
- 主流程不关心具体算法细节。
- 所有跟踪方案都遵守同一个接口。
- 命令行可以通过名字选择不同方案。
- 新增方案时,只需要新增一个类,并注册到统一位置。
这个场景就很适合使用 策略模式 + 注册表模式。
解决架构
整体结构如下:
text
solutions/
__init__.py
tracker_solution.py
registry.py
yolo_bytetrack.py
benchmark_runner.py
benchmark_cli.py
其中:
TrackerSolution:定义所有跟踪方案必须实现的统一接口。YoloByteTrackSolution:一个具体跟踪方案,实现 YOLO + ByteTrack。SOLUTION_REGISTRY:注册表,把字符串名称映射到具体方案类。BenchmarkRunner:主运行器,只依赖TrackerSolution接口。benchmark_cli.py:命令行入口,根据--solution从注册表中选择方案。
这里用到了两个设计思想。
策略模式
策略模式的核心是:把一类可替换的算法封装成不同的类,并让它们实现同一个接口。
在这个项目里,不同的跟踪算法就是不同的策略。
例如:
text
TrackerSolution
├── YoloByteTrackSolution
├── YoloDeepSortSolution
└── OtherSolution
主流程只调用:
python
solution.run(video_path, output_dir)
至于 solution 内部到底使用 YOLO + ByteTrack,还是别的算法,主流程并不关心。
注册表模式
注册表模式的核心是:维护一张"名字 -> 实现类"的映射表。
比如:
python
SOLUTION_REGISTRY = {
"yolo_bytetrack": YoloByteTrackSolution,
}
这样命令行只需要传入:
powershell
python benchmark_cli.py --solution yolo_bytetrack
程序就可以通过注册表找到对应的类,并实例化运行。
这种方式特别适合 CLI、配置文件、插件式架构等场景。
代码示例
首先定义统一接口。
python
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, TypeAlias
ProgressCallback: TypeAlias = Callable[[int, int], None]
@dataclass(frozen=True)
class TrackingResult:
input_video: Path
output_dir: Path
annotated_video: Path
tracks_file: Path
class TrackerSolution(ABC):
@property
def name(self) -> str:
return type(self).__name__
@abstractmethod
def run(
self,
video_path: Path,
output_dir: Path,
progress_callback: ProgressCallback | None = None,
) -> TrackingResult:
"""处理一个视频,并将结果写入 output_dir。"""
然后实现一个具体策略:YoloByteTrackSolution。
python
from pathlib import Path
from solutions.tracker_solution import ProgressCallback, TrackerSolution, TrackingResult
class YoloByteTrackSolution(TrackerSolution):
MODEL_PATH = Path(__file__).resolve().parents[1] / "weights" / "yolo11n.pt"
def __init__(
self,
confidence: float = 0.25,
device: str | int | None = None,
) -> None:
from ultralytics import YOLO
self.model = YOLO(self.MODEL_PATH)
self.confidence = confidence
self.device = device
def run(
self,
video_path: Path,
output_dir: Path,
progress_callback: ProgressCallback | None = None,
) -> TrackingResult:
output_dir.mkdir(parents=True, exist_ok=True)
annotated_video = output_dir / "annotated.mp4"
tracks_file = output_dir / "tracks.csv"
# 这里省略具体 YOLO + ByteTrack 推理逻辑
# 实际项目中会逐帧写入标注视频和 tracks.csv
return TrackingResult(
input_video=video_path,
output_dir=output_dir,
annotated_video=annotated_video,
tracks_file=tracks_file,
)
接着定义注册表。
python
from solutions.tracker_solution import TrackerSolution
from solutions.yolo_bytetrack import YoloByteTrackSolution
SOLUTION_REGISTRY: dict[str, type[TrackerSolution]] = {
"yolo_bytetrack": YoloByteTrackSolution,
}
主运行器只依赖抽象接口,不依赖具体算法。
python
from pathlib import Path
from typing import Iterable
from solutions import TrackerSolution, TrackingResult
class BenchmarkRunner:
def __init__(self, solution: TrackerSolution, runs_dir: Path = Path("runs")) -> None:
self.solution = solution
self.solution_root = Path(runs_dir) / solution.name
def run(self, videos: Iterable[Path]) -> list[TrackingResult]:
results: list[TrackingResult] = []
for video in videos:
result = self.solution.run(
video,
self.solution_root / video.stem,
)
results.append(result)
return results
最后在 CLI 中通过注册表选择具体方案。
python
from argparse import ArgumentParser
from pathlib import Path
from benchmark_runner import BenchmarkRunner
from solutions.registry import SOLUTION_REGISTRY
def main() -> None:
parser = ArgumentParser()
parser.add_argument(
"--solution",
choices=sorted(SOLUTION_REGISTRY),
default="yolo_bytetrack",
)
parser.add_argument("--input-dir", type=Path, default=Path("data/benchmark_double"))
parser.add_argument("--pattern", default="*.mp4")
parser.add_argument("--runs-dir", type=Path, default=Path("runs"))
args = parser.parse_args()
solution_class = SOLUTION_REGISTRY[args.solution]
solution = solution_class()
videos = sorted(args.input_dir.glob(args.pattern))
runner = BenchmarkRunner(solution, args.runs_dir)
results = runner.run(videos)
for result in results:
print(result.annotated_video)
print(result.tracks_file)
if __name__ == "__main__":
main()
这样新增一个方案时,只需要:
- 新建一个类,继承
TrackerSolution。 - 实现
run方法。 - 加入
SOLUTION_REGISTRY。
例如:
python
SOLUTION_REGISTRY = {
"yolo_bytetrack": YoloByteTrackSolution,
"yolo_deepsort": YoloDeepSortSolution,
}
主流程 BenchmarkRunner 不需要改。
总结一下:
- 策略模式解决的是:不同算法如何统一调用。
- 注册表模式解决的是:如何通过名字找到对应算法。
- 两者结合后,可以得到一个轻量级、可扩展的方案选择架构。
在机器学习实验、算法 Benchmark、数据处理 Pipeline 里,这种结构非常实用。