【从视频到数据集:焦糖玛奇朵的魔法工具Video To YOLO Dataset】

Video To YOLO Dataset

这个工具用于把视频文件夹转换成 YOLO 检测数据集,可作为 dataset_clean_tool.py 的前置工具。

输出结构

选择一个目录后,工具会在该目录下创建或使用:

text 复制代码
selected_dir/
  videos/
  images/
  labels/
  dataset.yaml
  • videos:可选地把所选目录根部的视频复制整理到这里。
  • images:按指定帧间隔抽帧得到的图片。
  • labels:ONNX 自动标注生成的 YOLO txt。
  • empty_images:模型没有输出任何目标<TOOLS_USER_GUIDE.md>的图片,会从 images 移到这里,不生成 txt。
  • dataset.yaml:YOLO 训练配置文件。

运行

powershell 复制代码
pip install -r requirements.txt
python .\video_to_yolo_tool.py

使用流程

  1. 点击 选择视频/数据集目录
  2. 设置 图片名前缀每多少帧抽一帧,点击 开始抽帧生成 images
  3. 选择 YOLO ONNX 模型。
  4. 如 ONNX 没有类别 metadata,手动填写类别名称,每行一个或逗号分隔。
  5. 点击 开始自动标注 images,工具会生成对应 txt 和 dataset.yaml

图片命名

抽帧图片统一使用 ASCII 文件名:

text 复制代码
图片名前缀_视频序号_帧序号.jpg

例如:

text 复制代码
efall_0520_3_24.jpg

这样可以避免中文路径或特殊字符在训练、跨平台拷贝、dataloader 中引发问题。

说明

自动标注逻辑兼容常见 YOLOv8 ONNX 输出,也尽量兼容带 objectness 的 YOLO 输出。若模型导出时带 NMS 且输出为 [x1,y1,x2,y2,score,class],也会按该格式解析。

启动

保存为run_video_to_yolo_yolov8.bat

bash 复制代码
@echo off
echo [VideoToYolo] Activating yolov8 environment...
cd /d "%~dp0"
call "C:\ProgramData\Anaconda3\Scripts\activate.bat" "C:\Users\zhang\.conda\envs\yolov8"
set "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1"
set "PATH=C:\Users\zhang\.conda\envs\yolov8;C:\Users\zhang\.conda\envs\yolov8\DLLs;C:\Users\zhang\.conda\envs\yolov8\Library\bin;C:\Users\zhang\.conda\envs\yolov8\Lib\site-packages\onnxruntime\capi;C:\Users\zhang\.conda\envs\yolov8\Lib\site-packages\torch\lib;C:\Users\zhang\.conda\envs\yolov8\Lib\site-packages\torchvision;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\bin;%PATH%"
echo [VideoToYolo] Current Python:
where python
echo [VideoToYolo] Checking onnxruntime...
python -c "import sys, os; import onnxruntime as ort; print('python=', sys.executable); print('providers=', ort.get_available_providers()); print('CONDA_PREFIX=', os.environ.get('CONDA_PREFIX'))" > "%~dp0video_to_yolo_startup.log" 2>&1
if errorlevel 1 (
  echo onnxruntime import failed. See "%~dp0video_to_yolo_startup.log"
  type "%~dp0video_to_yolo_startup.log"
  pause
  exit /b 1
)
type "%~dp0video_to_yolo_startup.log"
echo [VideoToYolo] Starting GUI...
python "%~dp0video_to_yolo_tool.py"

源码

python 复制代码
from __future__ import annotations

import ast
import argparse
import math
import os
import queue
import random
import shutil
import subprocess
import sys
import threading
import traceback
from dataclasses import dataclass
from pathlib import Path
from tkinter import (
    BOTH,
    BOTTOM,
    DISABLED,
    END,
    LEFT,
    NORMAL,
    RIGHT,
    TOP,
    X,
    Y,
    BooleanVar,
    Button,
    Checkbutton,
    DoubleVar,
    Entry,
    Frame,
    IntVar,
    Label,
    LabelFrame,
    StringVar,
    Text,
    Tk,
    filedialog,
    messagebox,
)
from tkinter import ttk


def configure_conda_dll_paths() -> None:
    if os.name != "nt" or not hasattr(os, "add_dll_directory"):
        return
    prefixes = [Path(sys.prefix)]
    conda_prefix = os.environ.get("CONDA_PREFIX")
    if conda_prefix:
        prefixes.insert(0, Path(conda_prefix))
    cuda_home = os.environ.get("CUDA_PATH") or os.environ.get("CUDA_HOME")
    seen: set[Path] = set()
    for prefix in prefixes:
        for path in (
            prefix,
            prefix / "DLLs",
            prefix / "Library" / "bin",
            prefix / "Lib" / "site-packages" / "onnxruntime" / "capi",
            prefix / "Lib" / "site-packages" / "torch" / "lib",
            prefix / "Lib" / "site-packages" / "torchvision",
        ):
            if path in seen or not path.exists():
                continue
            seen.add(path)
            try:
                os.add_dll_directory(str(path))
            except OSError:
                pass
            os.environ["PATH"] = str(path) + os.pathsep + os.environ.get("PATH", "")
    if cuda_home:
        cuda_bin = Path(cuda_home) / "bin"
        if cuda_bin.exists():
            try:
                os.add_dll_directory(str(cuda_bin))
            except OSError:
                pass
            os.environ["PATH"] = str(cuda_bin) + os.pathsep + os.environ.get("PATH", "")


def build_runtime_env() -> dict[str, str]:
    env = os.environ.copy()
    prefix = Path(env.get("CONDA_PREFIX") or sys.prefix)
    cuda_path = env.get("CUDA_PATH") or r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1"
    path_parts = [
        prefix,
        prefix / "DLLs",
        prefix / "Library" / "bin",
        prefix / "Lib" / "site-packages" / "onnxruntime" / "capi",
        prefix / "Lib" / "site-packages" / "torch" / "lib",
        prefix / "Lib" / "site-packages" / "torchvision",
        Path(cuda_path) / "bin",
    ]
    existing = env.get("PATH", "")
    env["PATH"] = os.pathsep.join(str(path) for path in path_parts if path.exists()) + os.pathsep + existing
    env["CUDA_PATH"] = cuda_path
    env.setdefault("CONDA_PREFIX", str(prefix))
    return env


configure_conda_dll_paths()

import cv2
import numpy as np
import yaml


VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".mpeg", ".mpg"}
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}


@dataclass
class Detection:
    class_id: int
    score: float
    xyxy: tuple[float, float, float, float]


class VideoToYoloApp:
    def __init__(self, root: Tk) -> None:
        self.root = root
        self.root.title("焦糖玛奇朵的魔法工具:Video To YOLO Dataset")
        self.root.geometry("980x700")
        self.root.minsize(820, 560)

        self.dataset_dir: Path | None = None
        self.model_path: Path | None = None
        self.worker: threading.Thread | None = None
        self.log_queue: queue.Queue[str] = queue.Queue()

        self.dataset_var = StringVar(value="未选择目录")
        self.model_var = StringVar(value="未选择 ONNX 模型")
        self.prefix_var = StringVar(value="")
        self.interval_var = IntVar(value=25)
        self.image_ext_var = StringVar(value="jpg")
        self.copy_videos_var = BooleanVar(value=True)
        self.conf_var = DoubleVar(value=0.35)
        self.iou_var = DoubleVar(value=0.50)
        self.imgsz_var = IntVar(value=640)
        self.balance_class_var = IntVar(value=0)
        self.balance_percent_var = DoubleVar(value=90.0)
        self.balance_seed_var = IntVar(value=20260521)
        self.names_var = StringVar(value="")
        self.status_var = StringVar(value="选择视频目录后,先抽帧,再选择 ONNX 自动标注")

        self._build_ui()
        self._poll_log_queue()

    def _build_ui(self) -> None:
        top_bar = Frame(self.root, padx=10, pady=10)
        top_bar.pack(side=TOP, fill=X)

        Button(top_bar, text="选择视频/数据集目录", command=self.choose_dataset_dir).pack(side=LEFT, padx=(0, 8))
        Label(top_bar, textvariable=self.dataset_var, anchor="w").pack(side=LEFT, fill=X, expand=True)

        body = Frame(self.root, padx=10, pady=0)
        body.pack(side=TOP, fill=BOTH, expand=True, pady=(0, 10))

        left_panel = Frame(body, width=390)
        left_panel.pack(side=LEFT, fill=Y, padx=(0, 10))
        left_panel.pack_propagate(False)

        extract_frame = LabelFrame(left_panel, text="1. 视频抽帧")
        extract_frame.pack(side=TOP, fill=X)

        row = Frame(extract_frame)
        row.pack(fill=X, padx=10, pady=(10, 4))
        Label(row, text="图片名前缀").pack(side=LEFT)
        Entry(row, textvariable=self.prefix_var, width=18).pack(side=RIGHT)

        row = Frame(extract_frame)
        row.pack(fill=X, padx=10, pady=4)
        Label(row, text="每多少帧抽一帧").pack(side=LEFT)
        ttk.Spinbox(row, from_=1, to=100000, textvariable=self.interval_var, width=10).pack(side=RIGHT)

        row = Frame(extract_frame)
        row.pack(fill=X, padx=10, pady=4)
        Label(row, text="图片格式").pack(side=LEFT)
        ttk.Combobox(row, textvariable=self.image_ext_var, values=("jpg", "png"), width=8, state="readonly").pack(
            side=RIGHT
        )

        Checkbutton(extract_frame, text="将根目录视频复制整理到 videos 文件夹", variable=self.copy_videos_var).pack(
            anchor="w", padx=8, pady=4
        )
        Button(extract_frame, text="开始抽帧生成 images", command=self.start_extract_frames).pack(
            fill=X, padx=10, pady=(8, 10)
        )

        label_frame = LabelFrame(left_panel, text="2. ONNX 自动标注")
        label_frame.pack(side=TOP, fill=X, pady=(10, 0))

        Button(label_frame, text="选择 ONNX 模型", command=self.choose_model).pack(fill=X, padx=10, pady=(10, 4))
        Label(label_frame, textvariable=self.model_var, anchor="w", wraplength=360).pack(fill=X, padx=10, pady=(0, 8))

        grid = Frame(label_frame)
        grid.pack(fill=X, padx=10, pady=4)
        Label(grid, text="置信度").grid(row=0, column=0, sticky="w", pady=3)
        Entry(grid, textvariable=self.conf_var, width=8).grid(row=0, column=1, sticky="e", pady=3)
        Label(grid, text="NMS IoU").grid(row=1, column=0, sticky="w", pady=3)
        Entry(grid, textvariable=self.iou_var, width=8).grid(row=1, column=1, sticky="e", pady=3)
        Label(grid, text="输入尺寸").grid(row=2, column=0, sticky="w", pady=3)
        Entry(grid, textvariable=self.imgsz_var, width=8).grid(row=2, column=1, sticky="e", pady=3)
        grid.columnconfigure(0, weight=1)

        Label(label_frame, text="类别名称,逗号或换行分隔;留空时优先读取 ONNX metadata").pack(
            anchor="w", padx=10, pady=(8, 2)
        )
        self.names_text = Text(label_frame, height=5, wrap="word")
        self.names_text.pack(fill=X, padx=10)

        Button(label_frame, text="开始自动标注 images", command=self.start_auto_label).pack(fill=X, padx=10, pady=(10, 4))
        Button(label_frame, text="仅生成/刷新 dataset.yaml", command=self.write_yaml_from_ui).pack(
            fill=X, padx=10, pady=(0, 10)
        )

        balance_frame = LabelFrame(left_panel, text="3. 类别统计/随机移出")
        balance_frame.pack(side=TOP, fill=X, pady=(10, 0))
        Button(balance_frame, text="统计 labels 类别数量", command=self.show_label_stats).pack(
            fill=X, padx=10, pady=(10, 6)
        )
        balance_grid = Frame(balance_frame)
        balance_grid.pack(fill=X, padx=10, pady=4)
        Label(balance_grid, text="类别ID").grid(row=0, column=0, sticky="w", pady=3)
        Entry(balance_grid, textvariable=self.balance_class_var, width=8).grid(row=0, column=1, sticky="e", pady=3)
        Label(balance_grid, text="移出百分比").grid(row=1, column=0, sticky="w", pady=3)
        Entry(balance_grid, textvariable=self.balance_percent_var, width=8).grid(row=1, column=1, sticky="e", pady=3)
        Label(balance_grid, text="随机种子").grid(row=2, column=0, sticky="w", pady=3)
        Entry(balance_grid, textvariable=self.balance_seed_var, width=8).grid(row=2, column=1, sticky="e", pady=3)
        balance_grid.columnconfigure(0, weight=1)
        Button(balance_frame, text="随机移出仅包含该类别的图片", command=self.start_prune_single_class).pack(
            fill=X, padx=10, pady=(6, 10)
        )

        right_panel = Frame(body)
        right_panel.pack(side=RIGHT, fill=BOTH, expand=True)

        progress_frame = Frame(right_panel)
        progress_frame.pack(side=TOP, fill=X)
        self.progress = ttk.Progressbar(progress_frame, orient="horizontal", mode="determinate")
        self.progress.pack(side=LEFT, fill=X, expand=True)
        self.busy = ttk.Progressbar(progress_frame, orient="horizontal", mode="indeterminate", length=80)
        self.busy.pack(side=RIGHT, padx=(8, 0))

        log_frame = LabelFrame(right_panel, text="日志")
        log_frame.pack(side=TOP, fill=BOTH, expand=True, pady=(8, 0))
        self.log_text = Text(log_frame, wrap="word", state=DISABLED)
        self.log_text.pack(side=LEFT, fill=BOTH, expand=True)
        scrollbar = ttk.Scrollbar(log_frame, command=self.log_text.yview)
        scrollbar.pack(side=RIGHT, fill=Y)
        self.log_text.configure(yscrollcommand=scrollbar.set)

        ttk.Label(self.root, textvariable=self.status_var, anchor="w", relief="sunken").pack(side=BOTTOM, fill=X)

    def choose_dataset_dir(self) -> None:
        path = filedialog.askdirectory(title="选择包含视频的目录,或已存在 videos 的数据集目录")
        if not path:
            return
        self.dataset_dir = Path(path)
        self.dataset_var.set(str(self.dataset_dir))
        self.prefix_var.set(default_dataset_prefix(self.dataset_dir))
        self.log(f"选择目录: {self.dataset_dir}")
        self.log("工具会在该目录下创建/使用 images、labels、videos 三个文件夹。")

    def choose_model(self) -> None:
        path = filedialog.askopenfilename(
            title="选择 ONNX 模型",
            filetypes=(("ONNX model", "*.onnx"), ("All files", "*.*")),
        )
        if not path:
            return
        self.model_path = Path(path)
        self.model_var.set(str(self.model_path))
        self.log(f"选择模型: {self.model_path}")

        names = try_read_onnx_names(self.model_path)
        if names:
            self.names_text.delete("1.0", END)
            self.names_text.insert("1.0", "\n".join(names))
            self.log(f"已从 ONNX metadata 读取 {len(names)} 个类别名。")

    def start_extract_frames(self) -> None:
        if not self.dataset_dir:
            messagebox.showwarning("缺少目录", "请先选择视频/数据集目录")
            return
        interval = max(1, int(self.interval_var.get()))
        ext = self.image_ext_var.get().lstrip(".").lower()
        copy_videos = bool(self.copy_videos_var.get())
        prefix = sanitize_prefix(self.prefix_var.get()) or default_dataset_prefix(self.dataset_dir)

        def job() -> None:
            self.set_busy(True)
            try:
                total = extract_frames_from_dataset(
                    self.dataset_dir,
                    interval=interval,
                    image_ext=ext,
                    prefix=prefix,
                    copy_root_videos=copy_videos,
                    progress=self.set_progress,
                    log=self.log,
                )
                (self.dataset_dir / "labels").mkdir(parents=True, exist_ok=True)
                self.log(f"抽帧完成,共生成 {total} 张图片。")
                self.set_status(f"抽帧完成: {total} 张图片")
            except Exception as exc:
                error = str(exc)
                self.log(f"抽帧失败: {error}")
                self.log(traceback.format_exc())
                self.root.after(0, lambda error=error: messagebox.showerror("抽帧失败", error))
            finally:
                self.set_busy(False)

        self.run_worker(job)

    def start_auto_label(self) -> None:
        if not self.dataset_dir:
            messagebox.showwarning("缺少目录", "请先选择视频/数据集目录")
            return
        if not self.model_path:
            messagebox.showwarning("缺少模型", "请先选择 ONNX 模型")
            return
        names = self.get_names_from_ui()
        conf = float(self.conf_var.get())
        iou = float(self.iou_var.get())
        imgsz = int(self.imgsz_var.get())

        def job() -> None:
            self.set_busy(True)
            try:
                self.run_auto_label_subprocess(
                    dataset_dir=self.dataset_dir,
                    model_path=self.model_path,
                    names=names,
                    conf=conf,
                    iou=iou,
                    imgsz=imgsz,
                )
                self.set_status("自动标注完成")
            except Exception as exc:
                error = str(exc)
                self.log(f"自动标注失败: {error}")
                self.log(traceback.format_exc())
                self.root.after(0, lambda error=error: messagebox.showerror("自动标注失败", error))
            finally:
                self.set_busy(False)

        self.run_worker(job)

    def run_auto_label_subprocess(
        self,
        dataset_dir: Path,
        model_path: Path,
        names: list[str],
        conf: float,
        iou: float,
        imgsz: int,
    ) -> None:
        script_path = Path(__file__).resolve()
        command = [
            sys.executable,
            str(script_path),
            "--auto-label-cli",
            "--dataset",
            str(dataset_dir),
            "--model",
            str(model_path),
            "--conf",
            str(conf),
            "--iou",
            str(iou),
            "--imgsz",
            str(imgsz),
        ]
        if names:
            command.extend(["--names", "|".join(names)])
        self.log(f"启动自动标注子进程: {sys.executable}")
        process = subprocess.Popen(
            command,
            cwd=str(script_path.parent),
            env=build_runtime_env(),
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            encoding="utf-8",
            errors="replace",
        )
        assert process.stdout is not None
        for line in process.stdout:
            line = line.rstrip()
            if line.startswith("PROGRESS "):
                _tag, current, total = line.split()
                self.set_progress(int(current), int(total))
            else:
                self.log(line)
        return_code = process.wait()
        if return_code != 0:
            raise RuntimeError(f"自动标注子进程失败,退出码 {return_code}")

    def write_yaml_from_ui(self) -> None:
        if not self.dataset_dir:
            messagebox.showwarning("缺少目录", "请先选择视频/数据集目录")
            return
        names = self.get_names_from_ui()
        if not names:
            messagebox.showwarning("缺少类别", "请填写类别名称,或先选择带 metadata 的 ONNX 模型")
            return
        write_dataset_yaml(self.dataset_dir, names)
        self.log(f"已生成: {self.dataset_dir / 'dataset.yaml'}")

    def show_label_stats(self) -> None:
        if not self.dataset_dir:
            messagebox.showwarning("缺少目录", "请先选择视频/数据集目录")
            return
        names = self.get_names_for_stats()
        stats = compute_label_stats(self.dataset_dir, names)
        log_label_stats(stats, self.log)

    def start_prune_single_class(self) -> None:
        if not self.dataset_dir:
            messagebox.showwarning("缺少目录", "请先选择视频/数据集目录")
            return
        class_id = int(self.balance_class_var.get())
        percent = float(self.balance_percent_var.get())
        seed = int(self.balance_seed_var.get())
        if percent <= 0 or percent > 100:
            messagebox.showwarning("比例无效", "移出百分比需要在 0-100 之间")
            return
        names = self.get_names_for_stats()

        def job() -> None:
            self.set_busy(True)
            try:
                result = prune_single_class_images(
                    self.dataset_dir,
                    class_id=class_id,
                    percent=percent,
                    seed=seed,
                    names=names,
                    log=self.log,
                )
                self.log(
                    f"随机移出完成: 候选 {result['candidates']} 张,移出 {result['moved']} 张,保留 {result['kept']} 张"
                )
                log_label_stats(compute_label_stats(self.dataset_dir, names), self.log)
                self.set_status("随机移出完成")
            except Exception as exc:
                error = str(exc)
                self.log(f"随机移出失败: {error}")
                self.log(traceback.format_exc())
                self.root.after(0, lambda error=error: messagebox.showerror("随机移出失败", error))
            finally:
                self.set_busy(False)

        self.run_worker(job)

    def get_names_from_ui(self) -> list[str]:
        text = self.names_text.get("1.0", END).strip()
        if not text:
            return []
        names: list[str] = []
        for chunk in text.replace(",", "\n").splitlines():
            name = chunk.strip()
            if name:
                names.append(name)
        return names

    def get_names_for_stats(self) -> list[str]:
        if self.dataset_dir is None:
            return self.get_names_from_ui()
        return self.get_names_from_ui() or read_dataset_yaml_names(self.dataset_dir / "dataset.yaml")

    def run_worker(self, target) -> None:
        if self.worker and self.worker.is_alive():
            messagebox.showinfo("正在运行", "当前任务还没有结束")
            return
        self.progress.configure(value=0, maximum=100)
        self.worker = threading.Thread(target=target, daemon=True)
        self.worker.start()

    def set_progress(self, current: int, total: int) -> None:
        def update() -> None:
            self.progress.configure(maximum=max(total, 1), value=current)

        self.root.after(0, update)

    def set_busy(self, busy: bool) -> None:
        def update() -> None:
            if busy:
                self.busy.start(10)
            else:
                self.busy.stop()

        self.root.after(0, update)

    def set_status(self, message: str) -> None:
        self.root.after(0, lambda message=message: self.status_var.set(message))

    def log(self, message: str) -> None:
        self.log_queue.put(message)

    def _poll_log_queue(self) -> None:
        while True:
            try:
                message = self.log_queue.get_nowait()
            except queue.Empty:
                break
            self.log_text.configure(state=NORMAL)
            self.log_text.insert(END, message + "\n")
            self.log_text.see(END)
            self.log_text.configure(state=DISABLED)
        self.root.after(100, self._poll_log_queue)


def discover_videos(dataset_dir: Path) -> list[Path]:
    candidates: list[Path] = []
    for base in (dataset_dir, dataset_dir / "videos"):
        if not base.exists():
            continue
        for path in sorted(base.iterdir()):
            if path.is_file() and path.suffix.lower() in VIDEO_EXTS:
                candidates.append(path)
    seen: set[Path] = set()
    videos: list[Path] = []
    for path in candidates:
        resolved = path.resolve()
        if resolved not in seen:
            seen.add(resolved)
            videos.append(path)
    return videos


def extract_frames_from_dataset(
    dataset_dir: Path,
    interval: int,
    image_ext: str,
    prefix: str,
    copy_root_videos: bool,
    progress,
    log,
) -> int:
    videos_dir = dataset_dir / "videos"
    images_dir = dataset_dir / "images"
    labels_dir = dataset_dir / "labels"
    videos_dir.mkdir(parents=True, exist_ok=True)
    images_dir.mkdir(parents=True, exist_ok=True)
    labels_dir.mkdir(parents=True, exist_ok=True)

    videos = discover_videos(dataset_dir)
    if not videos:
        raise FileNotFoundError("所选目录或 videos 子目录下没有找到视频文件")

    prepared_videos: list[Path] = []
    for video in videos:
        if copy_root_videos and video.parent.resolve() == dataset_dir.resolve():
            dest = unique_path(videos_dir / video.name) if not (videos_dir / video.name).exists() else videos_dir / video.name
            if not dest.exists():
                shutil.copy2(video, dest)
                log(f"复制视频到 videos: {dest.name}")
            prepared_videos.append(dest)
        else:
            prepared_videos.append(video)

    total_saved = 0
    for video_idx, video in enumerate(sorted(prepared_videos), start=1):
        log(f"开始抽帧 [{video_idx}/{len(prepared_videos)}]: {video.name}")
        saved = extract_video_frames(video, images_dir, interval, image_ext, prefix, video_idx)
        total_saved += saved
        progress(video_idx, len(prepared_videos))
        log(f"{video.name}: 生成 {saved} 张图片")
    return total_saved


def extract_video_frames(
    video_path: Path,
    images_dir: Path,
    interval: int,
    image_ext: str,
    prefix: str,
    video_index: int,
) -> int:
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        raise RuntimeError(f"无法打开视频: {video_path}")

    frame_idx = 0
    saved = 0
    try:
        while True:
            ok, frame = cap.read()
            if not ok:
                break
            if frame_idx % interval == 0:
                image_name = f"{prefix}_{video_index}_{frame_idx}.{image_ext}"
                safe_imwrite(images_dir / image_name, frame)
                saved += 1
            frame_idx += 1
    finally:
        cap.release()
    return saved


class YoloOnnxDetector:
    def __init__(self, model_path: Path, imgsz: int, names: list[str], log) -> None:
        try:
            import onnxruntime as ort
        except Exception as exc:
            raise RuntimeError(
                "onnxruntime 导入失败。建议用 yolov8 conda 环境运行本工具,"
                "或重新安装 onnxruntime/onnxruntime-gpu。"
            ) from exc

        self.model_path = model_path
        self.ort = ort
        self.log = log
        self.session = self._create_session(prefer_cuda=True)
        self.input = self.session.get_inputs()[0]
        self.input_name = self.input.name
        self.input_h, self.input_w = resolve_input_size(self.input.shape, imgsz)
        self.names = names or try_read_onnx_names(model_path)
        log(f"模型输入: {self.input_name}, size={self.input_w}x{self.input_h}")

    def _create_session(self, prefer_cuda: bool):
        available = self.ort.get_available_providers()
        provider_attempts: list[list[str]] = []
        if prefer_cuda and "CUDAExecutionProvider" in available:
            provider_attempts.append(["CUDAExecutionProvider", "CPUExecutionProvider"])
        provider_attempts.append(["CPUExecutionProvider"])

        last_exc: Exception | None = None
        for providers in provider_attempts:
            providers = [provider for provider in providers if provider in available]
            if not providers:
                continue
            try:
                session = self.ort.InferenceSession(str(self.model_path), providers=providers)
                self.active_providers = session.get_providers()
                self.log(f"ONNX Runtime providers: {self.active_providers}")
                return session
            except Exception as exc:
                last_exc = exc
                self.log(f"创建 ONNX session 失败,providers={providers}: {exc}")
        if last_exc:
            raise last_exc
        raise RuntimeError(f"没有可用的 ONNX Runtime provider: {available}")

    def predict(self, image_bgr: np.ndarray, conf_threshold: float, iou_threshold: float) -> list[Detection]:
        input_tensor, ratio, pad_x, pad_y = preprocess_yolo(image_bgr, self.input_w, self.input_h)
        try:
            outputs = self.session.run(None, {self.input_name: input_tensor})
        except Exception as exc:
            if "CPUExecutionProvider" in getattr(self, "active_providers", []):
                raise
            self.log(f"CUDA 推理失败,自动切换 CPU 后重试: {exc}")
            self.session = self._create_session(prefer_cuda=False)
            self.input = self.session.get_inputs()[0]
            self.input_name = self.input.name
            outputs = self.session.run(None, {self.input_name: input_tensor})
        detections = postprocess_yolo_output(
            outputs[0],
            image_shape=image_bgr.shape[:2],
            ratio=ratio,
            pad_x=pad_x,
            pad_y=pad_y,
            conf_threshold=conf_threshold,
            iou_threshold=iou_threshold,
            num_classes=len(self.names) if self.names else None,
        )
        return detections


def auto_label_images(
    dataset_dir: Path,
    detector: YoloOnnxDetector,
    conf_threshold: float,
    iou_threshold: float,
    progress,
    log,
) -> int:
    images_dir = dataset_dir / "images"
    labels_dir = dataset_dir / "labels"
    empty_images_dir = dataset_dir / "empty_images"
    labels_dir.mkdir(parents=True, exist_ok=True)
    empty_images_dir.mkdir(parents=True, exist_ok=True)
    if not images_dir.is_dir():
        raise FileNotFoundError(f"未找到 images 目录: {images_dir}")

    images = [path for path in sorted(images_dir.rglob("*")) if path.is_file() and path.suffix.lower() in IMAGE_EXTS]
    if not images:
        raise FileNotFoundError("images 目录中没有图片")

    max_class = -1
    empty_count = 0
    for idx, image_path in enumerate(images, start=1):
        image = safe_imread(image_path)
        if image is None:
            log(f"跳过无法读取图片: {image_path.name}")
            progress(idx, len(images))
            continue
        detections = detector.predict(image, conf_threshold, iou_threshold)
        label_path = labels_dir / image_path.relative_to(images_dir).with_suffix(".txt")
        label_path.parent.mkdir(parents=True, exist_ok=True)
        if not detections:
            if label_path.exists():
                label_path.unlink()
            empty_dest = unique_path(empty_images_dir / image_path.relative_to(images_dir))
            empty_dest.parent.mkdir(parents=True, exist_ok=True)
            shutil.move(str(image_path), str(empty_dest))
            empty_count += 1
            progress(idx, len(images))
            continue
        write_yolo_txt(label_path, detections, image.shape[1], image.shape[0])
        max_class = max(max_class, max(det.class_id for det in detections))
        if idx % 20 == 0 or idx == len(images):
            log(f"已标注 {idx}/{len(images)}")
        progress(idx, len(images))
    if empty_count:
        log(f"空检出图片已移动到 empty_images: {empty_count} 张")
    return max_class


def preprocess_yolo(image_bgr: np.ndarray, input_w: int, input_h: int) -> tuple[np.ndarray, float, float, float]:
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    h, w = image_rgb.shape[:2]
    ratio = min(input_w / w, input_h / h)
    new_w = int(round(w * ratio))
    new_h = int(round(h * ratio))
    resized = cv2.resize(image_rgb, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
    canvas = np.full((input_h, input_w, 3), 114, dtype=np.uint8)
    pad_x = (input_w - new_w) / 2
    pad_y = (input_h - new_h) / 2
    left = int(round(pad_x - 0.1))
    top = int(round(pad_y - 0.1))
    canvas[top : top + new_h, left : left + new_w] = resized
    tensor = canvas.astype(np.float32) / 255.0
    tensor = np.transpose(tensor, (2, 0, 1))[None, ...]
    return tensor, ratio, left, top


def postprocess_yolo_output(
    output: np.ndarray,
    image_shape: tuple[int, int],
    ratio: float,
    pad_x: float,
    pad_y: float,
    conf_threshold: float,
    iou_threshold: float,
    num_classes: int | None,
) -> list[Detection]:
    pred = np.asarray(output)
    if pred.ndim == 3:
        pred = pred[0]
    if pred.ndim != 2:
        raise ValueError(f"暂不支持的 ONNX 输出维度: {pred.shape}")
    if pred.shape[0] < pred.shape[1] and pred.shape[0] <= 256:
        pred = pred.T

    raw: list[Detection] = []
    image_h, image_w = image_shape
    for row in pred:
        parsed = parse_prediction_row(row, num_classes)
        if parsed is None:
            continue
        cx, cy, bw, bh, score, class_id, coords_are_xyxy = parsed
        if score < conf_threshold:
            continue
        if coords_are_xyxy:
            x1, y1, x2, y2 = cx, cy, bw, bh
        else:
            x1 = cx - bw / 2
            y1 = cy - bh / 2
            x2 = cx + bw / 2
            y2 = cy + bh / 2
        x1 = (x1 - pad_x) / ratio
        y1 = (y1 - pad_y) / ratio
        x2 = (x2 - pad_x) / ratio
        y2 = (y2 - pad_y) / ratio
        x1 = float(np.clip(x1, 0, image_w))
        y1 = float(np.clip(y1, 0, image_h))
        x2 = float(np.clip(x2, 0, image_w))
        y2 = float(np.clip(y2, 0, image_h))
        if x2 <= x1 or y2 <= y1:
            continue
        raw.append(Detection(class_id=class_id, score=float(score), xyxy=(x1, y1, x2, y2)))

    keep = nms(raw, iou_threshold)
    return [raw[idx] for idx in keep]


def parse_prediction_row(row: np.ndarray, num_classes: int | None) -> tuple[float, float, float, float, float, int, bool] | None:
    row = row.astype(float)
    if row.shape[0] == 6:
        return row[0], row[1], row[2], row[3], row[4], int(row[5]), True
    if row.shape[0] < 6:
        return None

    if num_classes and row.shape[0] == num_classes + 5:
        class_scores = row[5 : 5 + num_classes]
        class_id = int(np.argmax(class_scores))
        score = float(row[4] * class_scores[class_id])
        return row[0], row[1], row[2], row[3], score, class_id, False

    if num_classes and row.shape[0] == num_classes + 4:
        class_scores = row[4 : 4 + num_classes]
        class_id = int(np.argmax(class_scores))
        score = float(class_scores[class_id])
        return row[0], row[1], row[2], row[3], score, class_id, False

    class_scores = row[4:]
    class_id = int(np.argmax(class_scores))
    score = float(class_scores[class_id])
    return row[0], row[1], row[2], row[3], score, class_id, False


def nms(detections: list[Detection], iou_threshold: float) -> list[int]:
    if not detections:
        return []
    boxes = np.array([det.xyxy for det in detections], dtype=np.float32)
    scores = np.array([det.score for det in detections], dtype=np.float32)
    classes = np.array([det.class_id for det in detections], dtype=np.int32)
    keep: list[int] = []
    for class_id in sorted(set(classes.tolist())):
        idxs = np.where(classes == class_id)[0]
        order = idxs[np.argsort(scores[idxs])[::-1]]
        while order.size > 0:
            current = int(order[0])
            keep.append(current)
            if order.size == 1:
                break
            ious = box_iou(boxes[current], boxes[order[1:]])
            order = order[1:][ious <= iou_threshold]
    keep.sort(key=lambda idx: detections[idx].score, reverse=True)
    return keep


def box_iou(box: np.ndarray, boxes: np.ndarray) -> np.ndarray:
    x1 = np.maximum(box[0], boxes[:, 0])
    y1 = np.maximum(box[1], boxes[:, 1])
    x2 = np.minimum(box[2], boxes[:, 2])
    y2 = np.minimum(box[3], boxes[:, 3])
    inter = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1)
    area1 = max(0.0, (box[2] - box[0]) * (box[3] - box[1]))
    area2 = np.maximum(0, boxes[:, 2] - boxes[:, 0]) * np.maximum(0, boxes[:, 3] - boxes[:, 1])
    return inter / np.maximum(area1 + area2 - inter, 1e-9)


def write_yolo_txt(label_path: Path, detections: list[Detection], image_w: int, image_h: int) -> None:
    lines: list[str] = []
    for det in detections:
        x1, y1, x2, y2 = det.xyxy
        cx = ((x1 + x2) / 2) / image_w
        cy = ((y1 + y2) / 2) / image_h
        bw = (x2 - x1) / image_w
        bh = (y2 - y1) / image_h
        lines.append(f"{det.class_id} {cx:.6f} {cy:.6f} {bw:.6f} {bh:.6f}")
    label_path.write_text("\n".join(lines) + ("\n" if lines else ""), encoding="utf-8")


def write_dataset_yaml(dataset_dir: Path, names: list[str]) -> Path:
    data = {
        "path": str(dataset_dir.resolve()),
        "train": "images",
        "val": "images",
        "names": {idx: name for idx, name in enumerate(names)},
    }
    yaml_path = dataset_dir / "dataset.yaml"
    yaml_path.write_text(yaml.safe_dump(data, allow_unicode=True, sort_keys=False), encoding="utf-8")
    return yaml_path


def read_dataset_yaml_names(yaml_path: Path) -> list[str]:
    if not yaml_path.exists():
        return []
    try:
        data = yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
        names = data.get("names")
        if isinstance(names, list):
            return [str(name) for name in names]
        if isinstance(names, dict):
            return [str(names[key]) for key in sorted(names, key=lambda item: int(item))]
    except Exception:
        return []
    return []


def read_label_class_ids(label_path: Path) -> list[int]:
    ids: list[int] = []
    try:
        with label_path.open("r", encoding="utf-8") as f:
            for line in f:
                parts = line.strip().split()
                if not parts:
                    continue
                ids.append(int(float(parts[0])))
    except Exception:
        return []
    return ids


def compute_label_stats(dataset_dir: Path, names: list[str]) -> dict:
    labels_dir = dataset_dir / "labels"
    if not labels_dir.is_dir():
        raise FileNotFoundError(f"未找到 labels 目录: {labels_dir}")
    object_counts: dict[int, int] = {}
    image_counts: dict[int, int] = {}
    only_class_image_counts: dict[int, int] = {}
    total_labels = 0
    empty_labels = 0
    for label_path in sorted(labels_dir.rglob("*.txt")):
        total_labels += 1
        ids = read_label_class_ids(label_path)
        if not ids:
            empty_labels += 1
            continue
        for class_id in ids:
            object_counts[class_id] = object_counts.get(class_id, 0) + 1
        unique_ids = set(ids)
        for class_id in unique_ids:
            image_counts[class_id] = image_counts.get(class_id, 0) + 1
        if len(unique_ids) == 1:
            class_id = next(iter(unique_ids))
            only_class_image_counts[class_id] = only_class_image_counts.get(class_id, 0) + 1
    class_ids = sorted(set(object_counts) | set(image_counts) | set(only_class_image_counts) | set(range(len(names))))
    return {
        "total_labels": total_labels,
        "empty_labels": empty_labels,
        "class_ids": class_ids,
        "names": names,
        "object_counts": object_counts,
        "image_counts": image_counts,
        "only_class_image_counts": only_class_image_counts,
    }


def log_label_stats(stats: dict, log) -> None:
    log("类别统计:")
    log(f"  label文件数: {stats['total_labels']},空txt: {stats['empty_labels']}")
    for class_id in stats["class_ids"]:
        names = stats["names"]
        name = names[class_id] if 0 <= class_id < len(names) else f"class_{class_id}"
        log(
            "  "
            f"{class_id}: {name} | "
            f"目标数 {stats['object_counts'].get(class_id, 0)} | "
            f"出现图片数 {stats['image_counts'].get(class_id, 0)} | "
            f"仅含该类图片数 {stats['only_class_image_counts'].get(class_id, 0)}"
        )


def prune_single_class_images(
    dataset_dir: Path,
    class_id: int,
    percent: float,
    seed: int,
    names: list[str],
    log,
) -> dict[str, int]:
    images_dir = dataset_dir / "images"
    labels_dir = dataset_dir / "labels"
    if not images_dir.is_dir():
        raise FileNotFoundError(f"未找到 images 目录: {images_dir}")
    if not labels_dir.is_dir():
        raise FileNotFoundError(f"未找到 labels 目录: {labels_dir}")

    candidates: list[Path] = []
    for label_path in sorted(labels_dir.rglob("*.txt")):
        ids = read_label_class_ids(label_path)
        if ids and set(ids) == {class_id}:
            candidates.append(label_path)

    move_count = int(len(candidates) * percent / 100.0)
    selected = sorted(random.Random(seed).sample(candidates, move_count)) if move_count else []
    class_name = names[class_id] if 0 <= class_id < len(names) else f"class_{class_id}"
    bucket = dataset_dir / "class_balance_removed" / f"class_{class_id}_{sanitize_prefix(class_name)}_{percent:g}pct"
    bucket.mkdir(parents=True, exist_ok=True)
    manifest_lines = ["old_image,new_image,old_label,new_label"]
    moved = 0
    for label_path in selected:
        rel_label = label_path.relative_to(labels_dir)
        image_path = find_image_for_label(images_dir, rel_label)
        dest_label = unique_path(bucket / "labels" / rel_label)
        dest_label.parent.mkdir(parents=True, exist_ok=True)
        shutil.move(str(label_path), str(dest_label))
        if image_path is not None:
            rel_image = image_path.relative_to(images_dir)
            dest_image = unique_path(bucket / "images" / rel_image)
            dest_image.parent.mkdir(parents=True, exist_ok=True)
            shutil.move(str(image_path), str(dest_image))
            manifest_lines.append(f"{rel_image.as_posix()},{dest_image.relative_to(bucket).as_posix()},{rel_label.as_posix()},{dest_label.relative_to(bucket).as_posix()}")
        else:
            manifest_lines.append(f",{rel_label.as_posix()},{rel_label.as_posix()},{dest_label.relative_to(bucket).as_posix()}")
        moved += 1
    (bucket / "manifest.csv").write_text("\n".join(manifest_lines) + "\n", encoding="utf-8")
    log(f"已移出到: {bucket}")
    return {"candidates": len(candidates), "moved": moved, "kept": len(candidates) - moved}


def find_image_for_label(images_dir: Path, rel_label: Path) -> Path | None:
    rel_stem = rel_label.with_suffix("")
    for ext in IMAGE_EXTS:
        candidate = images_dir / rel_stem.with_suffix(ext)
        if candidate.exists():
            return candidate
    return None


def try_read_onnx_names(model_path: Path) -> list[str]:
    try:
        import onnxruntime as ort

        providers = ["CPUExecutionProvider"]
        available = ort.get_available_providers()
        providers = [provider for provider in providers if provider in available] or available
        session = ort.InferenceSession(str(model_path), providers=providers)
        metadata = session.get_modelmeta().custom_metadata_map
    except Exception:
        return []

    for key in ("names", "classes", "labels"):
        raw = metadata.get(key)
        if not raw:
            continue
        parsed = parse_names_value(raw)
        if parsed:
            return parsed
    return []


def parse_names_value(value: str) -> list[str]:
    try:
        parsed = ast.literal_eval(value)
    except Exception:
        parsed = None
    if isinstance(parsed, dict):
        return [str(parsed[key]) for key in sorted(parsed, key=lambda item: int(item))]
    if isinstance(parsed, list):
        return [str(item) for item in parsed]
    if isinstance(value, str):
        names = [part.strip() for part in value.replace(",", "\n").splitlines() if part.strip()]
        return names
    return []


def resolve_input_size(shape: list, fallback: int) -> tuple[int, int]:
    if len(shape) >= 4:
        h = shape[2]
        w = shape[3]
        if isinstance(h, int) and isinstance(w, int) and h > 0 and w > 0:
            return h, w
    return fallback, fallback


def safe_imread(path: Path) -> np.ndarray | None:
    data = np.fromfile(str(path), dtype=np.uint8)
    if data.size == 0:
        return None
    return cv2.imdecode(data, cv2.IMREAD_COLOR)


def safe_imwrite(path: Path, image: np.ndarray) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    ok, encoded = cv2.imencode(path.suffix, image)
    if not ok:
        raise RuntimeError(f"图片编码失败: {path}")
    encoded.tofile(str(path))


def unique_path(path: Path) -> Path:
    if not path.exists():
        return path
    parent = path.parent
    stem = path.stem
    suffix = path.suffix
    for idx in range(1, 100000):
        candidate = parent / f"{stem}_{idx}{suffix}"
        if not candidate.exists():
            return candidate
    raise RuntimeError(f"无法生成不重复路径: {path}")


def safe_stem(path: Path) -> str:
    stem = path.stem.strip() or "video"
    cleaned = "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in stem)
    return cleaned[:80] or "video"


def sanitize_prefix(value: str) -> str:
    cleaned = "".join(ch if ch.isascii() and (ch.isalnum() or ch in ("-", "_")) else "_" for ch in value.strip())
    while "__" in cleaned:
        cleaned = cleaned.replace("__", "_")
    return cleaned.strip("_")


def default_dataset_prefix(dataset_dir: Path) -> str:
    name = dataset_dir.name.strip()
    lowered = name.lower()
    if lowered.endswith("efall") and lowered[:-5].isdigit():
        return f"efall_{lowered[:-5]}"
    cleaned = sanitize_prefix(name)
    return cleaned or "dataset"


def main() -> int:
    if len(sys.argv) > 1 and sys.argv[1] == "--auto-label-cli":
        return auto_label_cli(sys.argv[2:])
    root = Tk()
    VideoToYoloApp(root)
    root.mainloop()
    return 0


def auto_label_cli(argv: list[str]) -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", required=True)
    parser.add_argument("--model", required=True)
    parser.add_argument("--conf", type=float, default=0.35)
    parser.add_argument("--iou", type=float, default=0.50)
    parser.add_argument("--imgsz", type=int, default=640)
    parser.add_argument("--names", default="")
    args = parser.parse_args(argv)

    dataset_dir = Path(args.dataset)
    model_path = Path(args.model)
    names = [name for name in args.names.split("|") if name] if args.names else []

    def cli_log(message: str) -> None:
        print(message, flush=True)

    def cli_progress(current: int, total: int) -> None:
        print(f"PROGRESS {current} {total}", flush=True)

    try:
        cli_log(f"子进程 Python: {sys.executable}")
        detector = YoloOnnxDetector(model_path, imgsz=args.imgsz, names=names, log=cli_log)
        final_names = names or detector.names
        max_class = auto_label_images(
            dataset_dir,
            detector=detector,
            conf_threshold=args.conf,
            iou_threshold=args.iou,
            progress=cli_progress,
            log=cli_log,
        )
        if not final_names:
            final_names = [f"class_{idx}" for idx in range(max_class + 1)] if max_class >= 0 else ["class_0"]
        write_dataset_yaml(dataset_dir, final_names)
        cli_log(f"自动标注完成,dataset.yaml 已生成,类别数 {len(final_names)}。")
        log_label_stats(compute_label_stats(dataset_dir, final_names), cli_log)
        return 0
    except Exception:
        traceback.print_exc()
        return 1


if __name__ == "__main__":
    raise SystemExit(main())
相关推荐
计算机魔术师2 小时前
【AI面试八股文 Vol.3.4:训练微调部署选型】从预训练到量化部署:LLM 工程落地如何做模型选择
人工智能·后端·面试·架构·moe·vol.3.3·vol.3.4
叶子Talk2 小时前
OpenAI破解80年数学猜想,AI首次做出原创证明
人工智能·数学·算法·机器学习·ai·openai·ai推理
刘一说2 小时前
AI科技热点日报 | 2026年5月22日
数据库·人工智能·科技
Binary_ey2 小时前
红外系统杂散光难管控?OAS精准助力高质量成像
人工智能
2501_940041742 小时前
探索非主流游戏机制的AI生成指南
人工智能·游戏
Agent手记2 小时前
环保排放数据自动上报全流程自动化—— 2026企业级智能体(Agent)落地全指南
运维·人工智能·ai·自动化
m0_641889292 小时前
GEO分析系统建设:解构大模型搜索的“黑盒”——AI推荐的底层检索与交叉验证机制
人工智能·geo·智能营销·geo优化·geo平台
Dust-Chasing2 小时前
Claude Code源码剖析 - Phase3
开发语言·人工智能·学习
yzx9910133 小时前
人工智能写作开发:从自动化内容到真正的创造力
人工智能·自动化·ai写作