设计模式实战笔记-策略模式+注册表模式

需求

在做网球球员多目标跟踪项目时,我希望项目能支持不同的跟踪方案。

比如当前方案是:

  • YOLO 做人体检测
  • ByteTrack 做多目标跟踪
  • 输出标注视频 annotated.mp4
  • 输出轨迹文件 tracks.csv

但后续可能会继续尝试其他方案,例如:

  • YOLO + DeepSORT
  • RT-DETR + ByteTrack
  • Ground Truth Replay
  • 自定义规则跟踪器

如果每新增一种算法,就去修改主流程代码,会导致代码越来越乱。

因此我希望做到:

  1. 主流程不关心具体算法细节。
  2. 所有跟踪方案都遵守同一个接口。
  3. 命令行可以通过名字选择不同方案。
  4. 新增方案时,只需要新增一个类,并注册到统一位置。

这个场景就很适合使用 策略模式 + 注册表模式

解决架构

整体结构如下:

text 复制代码
solutions/
  __init__.py
  tracker_solution.py
  registry.py
  yolo_bytetrack.py

benchmark_runner.py
benchmark_cli.py

其中:

  • TrackerSolution:定义所有跟踪方案必须实现的统一接口。
  • YoloByteTrackSolution:一个具体跟踪方案,实现 YOLO + ByteTrack。
  • SOLUTION_REGISTRY:注册表,把字符串名称映射到具体方案类。
  • BenchmarkRunner:主运行器,只依赖 TrackerSolution 接口。
  • benchmark_cli.py:命令行入口,根据 --solution 从注册表中选择方案。

这里用到了两个设计思想。

策略模式

策略模式的核心是:把一类可替换的算法封装成不同的类,并让它们实现同一个接口。

在这个项目里,不同的跟踪算法就是不同的策略。

例如:

text 复制代码
TrackerSolution
  ├── YoloByteTrackSolution
  ├── YoloDeepSortSolution
  └── OtherSolution

主流程只调用:

python 复制代码
solution.run(video_path, output_dir)

至于 solution 内部到底使用 YOLO + ByteTrack,还是别的算法,主流程并不关心。

注册表模式

注册表模式的核心是:维护一张"名字 -> 实现类"的映射表。

比如:

python 复制代码
SOLUTION_REGISTRY = {
    "yolo_bytetrack": YoloByteTrackSolution,
}

这样命令行只需要传入:

powershell 复制代码
python benchmark_cli.py --solution yolo_bytetrack

程序就可以通过注册表找到对应的类,并实例化运行。

这种方式特别适合 CLI、配置文件、插件式架构等场景。

代码示例

首先定义统一接口。

python 复制代码
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, TypeAlias


ProgressCallback: TypeAlias = Callable[[int, int], None]


@dataclass(frozen=True)
class TrackingResult:
    input_video: Path
    output_dir: Path
    annotated_video: Path
    tracks_file: Path


class TrackerSolution(ABC):
    @property
    def name(self) -> str:
        return type(self).__name__

    @abstractmethod
    def run(
        self,
        video_path: Path,
        output_dir: Path,
        progress_callback: ProgressCallback | None = None,
    ) -> TrackingResult:
        """处理一个视频,并将结果写入 output_dir。"""

然后实现一个具体策略:YoloByteTrackSolution

python 复制代码
from pathlib import Path

from solutions.tracker_solution import ProgressCallback, TrackerSolution, TrackingResult


class YoloByteTrackSolution(TrackerSolution):
    MODEL_PATH = Path(__file__).resolve().parents[1] / "weights" / "yolo11n.pt"

    def __init__(
        self,
        confidence: float = 0.25,
        device: str | int | None = None,
    ) -> None:
        from ultralytics import YOLO

        self.model = YOLO(self.MODEL_PATH)
        self.confidence = confidence
        self.device = device

    def run(
        self,
        video_path: Path,
        output_dir: Path,
        progress_callback: ProgressCallback | None = None,
    ) -> TrackingResult:
        output_dir.mkdir(parents=True, exist_ok=True)

        annotated_video = output_dir / "annotated.mp4"
        tracks_file = output_dir / "tracks.csv"

        # 这里省略具体 YOLO + ByteTrack 推理逻辑
        # 实际项目中会逐帧写入标注视频和 tracks.csv

        return TrackingResult(
            input_video=video_path,
            output_dir=output_dir,
            annotated_video=annotated_video,
            tracks_file=tracks_file,
        )

接着定义注册表。

python 复制代码
from solutions.tracker_solution import TrackerSolution
from solutions.yolo_bytetrack import YoloByteTrackSolution


SOLUTION_REGISTRY: dict[str, type[TrackerSolution]] = {
    "yolo_bytetrack": YoloByteTrackSolution,
}

主运行器只依赖抽象接口,不依赖具体算法。

python 复制代码
from pathlib import Path
from typing import Iterable

from solutions import TrackerSolution, TrackingResult


class BenchmarkRunner:
    def __init__(self, solution: TrackerSolution, runs_dir: Path = Path("runs")) -> None:
        self.solution = solution
        self.solution_root = Path(runs_dir) / solution.name

    def run(self, videos: Iterable[Path]) -> list[TrackingResult]:
        results: list[TrackingResult] = []

        for video in videos:
            result = self.solution.run(
                video,
                self.solution_root / video.stem,
            )
            results.append(result)

        return results

最后在 CLI 中通过注册表选择具体方案。

python 复制代码
from argparse import ArgumentParser
from pathlib import Path

from benchmark_runner import BenchmarkRunner
from solutions.registry import SOLUTION_REGISTRY


def main() -> None:
    parser = ArgumentParser()
    parser.add_argument(
        "--solution",
        choices=sorted(SOLUTION_REGISTRY),
        default="yolo_bytetrack",
    )
    parser.add_argument("--input-dir", type=Path, default=Path("data/benchmark_double"))
    parser.add_argument("--pattern", default="*.mp4")
    parser.add_argument("--runs-dir", type=Path, default=Path("runs"))

    args = parser.parse_args()

    solution_class = SOLUTION_REGISTRY[args.solution]
    solution = solution_class()

    videos = sorted(args.input_dir.glob(args.pattern))

    runner = BenchmarkRunner(solution, args.runs_dir)
    results = runner.run(videos)

    for result in results:
        print(result.annotated_video)
        print(result.tracks_file)


if __name__ == "__main__":
    main()

这样新增一个方案时,只需要:

  1. 新建一个类,继承 TrackerSolution
  2. 实现 run 方法。
  3. 加入 SOLUTION_REGISTRY

例如:

python 复制代码
SOLUTION_REGISTRY = {
    "yolo_bytetrack": YoloByteTrackSolution,
    "yolo_deepsort": YoloDeepSortSolution,
}

主流程 BenchmarkRunner 不需要改。

总结一下:

  • 策略模式解决的是:不同算法如何统一调用。
  • 注册表模式解决的是:如何通过名字找到对应算法。
  • 两者结合后,可以得到一个轻量级、可扩展的方案选择架构。

在机器学习实验、算法 Benchmark、数据处理 Pipeline 里,这种结构非常实用。