基于 PyQt5 的多算法视频关键帧提取工具

打造你的视频关键帧提取神器

在视频处理、内容分析、剪辑预览等场景中,关键帧提取是一项基础而强大的技术。本文将带你从零构建一个支持三种主流算法、具备完整图形界面的视频关键帧提取工具,并深入解析其核心原理与实现细节。


🎯 为什么需要关键帧提取?

一段视频由成千上万帧图像组成,但其中大量帧是高度相似甚至重复的。关键帧(Keyframe) 指的是那些能够代表视频内容显著变化的"代表性"帧。提取关键帧可以:

  • 快速预览视频主要内容
  • 减少视频分析的计算量
  • 用于视频摘要、缩略图生成、场景分割
  • 辅助视频检索与内容理解

然而,市面上大多数工具要么功能单一,要么命令行操作门槛高。于是,我决定开发一个开箱即用、算法可选、结果可视的桌面应用。


🛠️ 技术栈与依赖

本项目基于 Python 构建,核心依赖如下:

  • PyQt5:构建跨平台桌面 GUI
  • OpenCV (cv2):视频读取与图像处理
  • NumPy / SciPy:数值计算与信号处理
  • JSON:元数据存储

安装命令:

bash 复制代码
pip install opencv-python numpy scipy pyqt5

注意:matplotlib 在代码中被注释提及但未实际使用,可不安装。


🧠 三大关键帧提取算法详解

本工具内置三种经典算法,各有优劣,适用于不同场景。

1️⃣ 局部最大值算法(Local Maxima)

原理

计算相邻帧之间的亮度差异(LUV 色彩空间) ,得到帧差序列。对序列进行平滑后,寻找局部极大值点------这些点通常对应场景切换或剧烈运动。

优点

  • 对场景切换敏感
  • 算法简单高效

缺点

  • 对缓慢变化不敏感
  • 可能漏掉静态场景中的重要内容

代码亮点

python 复制代码
diff = cv2.absdiff(curr_frame, prev_frame)
diff_sum_mean = np.sum(diff) / (diff.shape[0] * diff.shape[1])
frame_diffs.append(diff_sum_mean)

# 平滑 + 寻找局部极大值
sm_diff_array = self.smooth(diff_array, len_window)
frame_indexes = argrelextrema(sm_diff_array, np.greater)[0]

2️⃣ 帧差统计算法(Frame Difference Statistics)

原理

同样基于帧差,但引入统计学阈值 。计算所有帧差的均值(mean)和标准差(std),设定阈值为 mean + 2.05 * std,高于此值的帧视为关键帧。

优点

  • 自适应阈值,鲁棒性强
  • 可保存元数据(-meta.json)供后续分析

缺点

  • 对光照变化敏感
  • 需要遍历两次视频(计算统计量 + 提取帧)

关键逻辑

python 复制代码
diff_threshold = data["stats"]["sd"] * 2.05 + data["stats"]["mean"]
if fi["diff_count"] > diff_threshold:
    # 保存为关键帧

3️⃣ 直方图聚类算法(Histogram Clustering)

原理

将每一帧的 RGB 直方图视为特征向量。通过增量聚类 ,将相似帧归为一类。每类中选择与聚类中心最相似的帧作为关键帧。

优点

  • 能捕捉内容相似性(即使位置变化)
  • 适合提取"代表性"画面(如PPT切换、产品展示)

缺点

  • 计算复杂度高(O(n²))
  • 内存占用大(存储所有直方图)

核心思想

python 复制代码
# 计算颜色直方图相似度(加权交集)
d = 0.30 * d_r + 0.59 * d_g + 0.11 * d_b
if d > threshold:
    # 加入现有聚类
else:
    # 创建新聚类

💻 GUI 设计亮点

✅ 多任务并行处理

  • 每个任务独立线程运行,互不阻塞
  • 支持批量添加、批量启动、单任务重试

✅ 算法灵活切换

  • 下拉菜单选择算法
  • "应用到所有任务" 按钮一键更新全部任务配置

✅ 友好交互体验

  • 进度条实时反馈(0-50%:分析阶段,50-100%:保存阶段)
  • 状态标签颜色区分(等待/进行中/完成/失败)
  • "📂 打开"按钮直达输出文件夹(跨平台支持 Windows/macOS/Linux)

✅ 安全退出机制

  • 关闭窗口时自动停止所有后台线程
  • 防止资源泄漏

📂 项目结构与输出

  • 默认输出目录~/KeyframesOutput
  • 每个视频生成独立子文件夹 ,命名格式:
    • 视频名_local_maxima/
    • 视频名_frame_diff/
    • 视频名_histogram/
  • 关键帧命名keyframe_000123.jpg
  • 帧差算法额外生成视频名-meta.json(含帧差统计信息)

🚀 使用指南

  1. 启动应用:运行脚本,打开主界面
  2. 添加视频:点击"📁 添加视频",支持多选
  3. 选择算法:从下拉菜单切换(默认"局部最大值")
  4. 开始提取
    • 点击"▶️ 开始提取"运行所有等待任务
    • 或点击单个任务的"▶️ 执行"单独运行
  5. 查看结果
    • 完成后点击"📂 打开"查看关键帧
    • 失败任务可点击"🔄 重试"重新执行

单任务

多任务并行

输出结果


📌 结语

这个工具不仅是一个实用程序,更是一个算法实验平台。通过对比三种算法在不同视频上的表现,你可以直观理解计算机视觉中"变化检测"与"内容表征"的核心思想。

源码如下:

python 复制代码
import sys
import os
import subprocess
import platform
import cv2
import numpy as np
import json
from scipy.signal import argrelextrema
from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
    QPushButton, QListWidget, QListWidgetItem, QLabel, QFileDialog,
    QMessageBox, QProgressBar, QComboBox, QGroupBox
)
from PyQt5.QtCore import Qt, pyqtSignal, QObject, QThread, QTimer
from PyQt5.QtGui import QFont, QIcon


def open_file_location(file_path):
    try:
        system = platform.system()
        if system == "Windows":
            subprocess.run(['explorer', '/select,', os.path.normpath(file_path)], shell=True)
        elif system == "Darwin":
            subprocess.run(['open', '-R', file_path])
        else:
            folder = os.path.dirname(file_path)
            subprocess.run(['xdg-open', folder])
    except Exception as e:
        QMessageBox.warning(None, "提示", f"无法打开文件位置:\n{str(e)}")


class KeyframeAlgorithm(QObject):
    """关键帧提取算法基类"""
    progress_updated = pyqtSignal(int, int)  # task_id, progress
    finished = pyqtSignal(int, bool, str, str)  # task_id, success, message, output_path
    
    def __init__(self, task_id, video_path, output_dir):
        super().__init__()
        self.task_id = task_id
        self.video_path = video_path
        self.output_dir = output_dir
        self._is_running = True
    
    def stop(self):
        self._is_running = False
    
    @staticmethod
    def smooth(x, window_len=13, window='hanning'):
        s = np.r_[2 * x[0] - x[window_len:1:-1],
                  x, 2 * x[-1] - x[-1:-window_len:-1]]

        if window == 'flat':  # moving average
            w = np.ones(window_len, 'd')
        else:
            w = getattr(np, window)(window_len)
        y = np.convolve(w / w.sum(), s, mode='same')
        return y[window_len - 1:-window_len + 1]


class LocalMaximaAlgorithm(KeyframeAlgorithm):
    """基于局部最大值的关键帧提取算法"""
    
    name = "局部最大值算法"
    
    def extract(self):
        try:
            name = os.path.splitext(os.path.basename(self.video_path))[0]
            dir_path = os.path.join(self.output_dir, f"{name}_local_maxima")
            os.makedirs(dir_path, exist_ok=True)
            
            USE_LOCAL_MAXIMA = True
            len_window = 50

            cap = cv2.VideoCapture(self.video_path)
            if not cap.isOpened():
                self.finished.emit(self.task_id, False, "无法打开视频文件", "")
                return
            
            curr_frame = None
            prev_frame = None
            frame_diffs = []
            frames = []
            success, frame = cap.read()
            i = 0
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            
            while success and self._is_running:
                # 更新进度
                progress = int((i / total_frames) * 50) if total_frames > 0 else 0
                self.progress_updated.emit(self.task_id, progress)
                    
                luv = cv2.cvtColor(frame, cv2.COLOR_BGR2LUV)
                curr_frame = luv
                if curr_frame is not None and prev_frame is not None:
                    diff = cv2.absdiff(curr_frame, prev_frame)
                    diff_sum = np.sum(diff)
                    diff_sum_mean = diff_sum / (diff.shape[0] * diff.shape[1])
                    frame_diffs.append(diff_sum_mean)
                    frames.append(i)
                prev_frame = curr_frame
                i = i + 1
                success, frame = cap.read()
            cap.release()

            if not self._is_running:
                return

            keyframe_id_set = set()
            if USE_LOCAL_MAXIMA and frame_diffs:
                diff_array = np.array(frame_diffs)
                sm_diff_array = self.smooth(diff_array, len_window)
                frame_indexes = np.asarray(argrelextrema(sm_diff_array, np.greater))[0]
                for idx in frame_indexes:
                    if idx < len(frames):
                        keyframe_id_set.add(frames[idx])

            # 保存关键帧
            cap = cv2.VideoCapture(self.video_path)
            success, frame = cap.read()
            idx = 0
            saved_count = 0
            total_keyframes = len(keyframe_id_set)
            
            while success and keyframe_id_set and self._is_running:
                # 更新进度
                progress = 50 + int((saved_count / total_keyframes) * 50) if total_keyframes > 0 else 100
                self.progress_updated.emit(self.task_id, progress)
                    
                if idx in keyframe_id_set:
                    name = f"keyframe_{idx:06d}.jpg"
                    cv2.imwrite(os.path.join(dir_path, name), frame)
                    saved_count += 1
                    keyframe_id_set.remove(idx)
                idx = idx + 1
                success, frame = cap.read()
            cap.release()
            
            if self._is_running:
                self.finished.emit(self.task_id, True, f"成功提取 {saved_count} 个关键帧", dir_path)
            
        except Exception as e:
            self.finished.emit(self.task_id, False, f"算法执行失败: {str(e)}", "")


class FrameDiffAlgorithm(KeyframeAlgorithm):
    """基于帧差的关键帧提取算法"""
    
    name = "帧差统计算法"
    
    @staticmethod
    def get_video_info(source_path):
        cap = cv2.VideoCapture(source_path)
        info = {
            "framecount": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
            "fps": cap.get(cv2.CAP_PROP_FPS),
            "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
            "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
            "codec": int(cap.get(cv2.CAP_PROP_FOURCC))
        }
        cap.release()
        return info

    @staticmethod
    def scale(img, xScale, yScale):
        return cv2.resize(img, None, fx=xScale, fy=yScale, interpolation=cv2.INTER_AREA)

    def calculate_frame_stats(self):
        cap = cv2.VideoCapture(self.video_path)
        data = {"frame_info": []}
        last_frame = None
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_number = 0
        
        while cap.isOpened() and self._is_running:
            ret, frame = cap.read()
            if frame is None:
                break

            frame_number = int(cap.get(cv2.CAP_PROP_POS_FRAMES) - 1)
            
            # 更新进度
            progress = int((frame_number / total_frames) * 50) if total_frames > 0 else 0
            self.progress_updated.emit(self.task_id, progress)
                
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            gray = self.scale(gray, 0.25, 0.25)
            gray = cv2.GaussianBlur(gray, (9, 9), 0.0)

            if frame_number >= 0 and last_frame is not None:
                diff = cv2.absdiff(gray, last_frame)
                diff_mag = int(cv2.countNonZero(diff))
                data["frame_info"].append({"frame_number": frame_number, "diff_count": diff_mag})

            last_frame = gray

        cap.release()

        diff_counts = [fi["diff_count"] for fi in data["frame_info"]]
        if diff_counts:
            data["stats"] = {
                "num": int(len(diff_counts)),
                "min": int(np.min(diff_counts)),
                "max": int(np.max(diff_counts)),
                "mean": float(np.mean(diff_counts)),
                "median": float(np.median(diff_counts)),
                "sd": float(np.std(diff_counts))
            }
        return data

    def extract(self):
        try:
            name = os.path.splitext(os.path.basename(self.video_path))[0]
            dir_path = os.path.join(self.output_dir, f"{name}_frame_diff")
            os.makedirs(dir_path, exist_ok=True)

            # 计算帧差统计数据
            data = self.calculate_frame_stats()
            
            if not self._is_running:
                return
            
            if "stats" not in data:
                self.finished.emit(self.task_id, False, "无法计算帧差统计数据", "")
                return

            diff_threshold = data["stats"]["sd"] * 2.05 + data["stats"]["mean"]

            cap = cv2.VideoCapture(self.video_path)
            saved_count = 0
            total_frames = len(data["frame_info"])

            for index, fi in enumerate(data["frame_info"]):
                if not self._is_running:
                    cap.release()
                    return
                    
                # 更新进度
                progress = 50 + int((index / total_frames) * 50) if total_frames > 0 else 100
                self.progress_updated.emit(self.task_id, progress)
                    
                if fi["diff_count"] < diff_threshold:
                    continue

                # 将视频定位到关键帧并读取该帧
                cap.set(cv2.CAP_PROP_POS_FRAMES, fi["frame_number"])
                ret, frame = cap.read()
                if not ret:
                    continue

                # 保存关键帧图像到目标文件夹
                frame_filename = os.path.join(dir_path, f"keyframe_{fi['frame_number']:06d}.jpg")
                cv2.imwrite(frame_filename, frame)
                saved_count += 1

            cap.release()
            
            # 保存元数据
            data_fp = os.path.join(dir_path, f"{name}-meta.json")
            with open(data_fp, 'w') as f:
                json.dump(data, f, indent=4)

            if self._is_running:
                self.finished.emit(self.task_id, True, f"成功提取 {saved_count} 个关键帧", dir_path)
            
        except Exception as e:
            self.finished.emit(self.task_id, False, f"算法执行失败: {str(e)}", "")


class HistogramAlgorithm(KeyframeAlgorithm):
    """基于直方图聚类关键帧提取算法"""
    
    name = "直方图聚类算法"
    
    def extract(self):
        try:
            name = os.path.splitext(os.path.basename(self.video_path))[0]
            output_folder = os.path.join(self.output_dir, f"{name}_histogram")
            os.makedirs(output_folder, exist_ok=True)

            cap = cv2.VideoCapture(self.video_path)
            if not cap.isOpened():
                self.finished.emit(self.task_id, False, "无法打开视频文件!", "")
                return

            # 获取视频帧数
            num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            cluster = np.zeros(num_frames)
            cluster_count = np.zeros(num_frames)
            cluster_num = 0

            threshold = 0.91
            centroid_r = np.zeros((num_frames, 256))
            centroid_g = np.zeros((num_frames, 256))
            centroid_b = np.zeros((num_frames, 256))

            # 读取首帧,形成第一个聚类
            ret, frame = cap.read()
            if not ret:
                self.finished.emit(self.task_id, False, "无法读取第一帧!", "")
                return

            cluster_num += 1
            prev_count_r = cv2.calcHist([frame], [0], None, [256], [0, 256]).flatten()
            prev_count_g = cv2.calcHist([frame], [1], None, [256], [0, 256]).flatten()
            prev_count_b = cv2.calcHist([frame], [2], None, [256], [0, 256]).flatten()

            cluster[0] = 1
            cluster_count[0] += 1
            centroid_r[0] = prev_count_r
            centroid_g[0] = prev_count_g
            centroid_b[0] = prev_count_b

            visit = 1

            # 遍历视频的其他帧
            for k in range(1, num_frames):
                if not self._is_running:
                    cap.release()
                    return
                    
                ret, frame = cap.read()
                if not ret:
                    break
                    
                # 更新进度
                progress = int((k / num_frames) * 50)
                self.progress_updated.emit(self.task_id, progress)

                tmp_count_r = cv2.calcHist([frame], [0], None, [256], [0, 256]).flatten()
                tmp_count_g = cv2.calcHist([frame], [1], None, [256], [0, 256]).flatten()
                tmp_count_b = cv2.calcHist([frame], [2], None, [256], [0, 256]).flatten()

                cluster_group_id = 1
                max_similarity = 0

                # 计算相似度
                for cluster_idx in range(visit, cluster_num + 1):
                    s_r = np.sum(np.minimum(centroid_r[cluster_idx - 1], tmp_count_r))
                    s_g = np.sum(np.minimum(centroid_g[cluster_idx - 1], tmp_count_g))
                    s_b = np.sum(np.minimum(centroid_b[cluster_idx - 1], tmp_count_b))

                    d_r = s_r / np.sum(tmp_count_r)
                    d_g = s_g / np.sum(tmp_count_g)
                    d_b = s_b / np.sum(tmp_count_b)
                    d = 0.30 * d_r + 0.59 * d_g + 0.11 * d_b

                    if d > max_similarity:
                        cluster_group_id = cluster_idx
                        max_similarity = d

                # 判断是否加入现有聚类或形成新聚类
                if max_similarity > threshold:
                    centroid_r[cluster_group_id - 1] = (centroid_r[cluster_group_id - 1] * cluster_count[cluster_group_id - 1] + tmp_count_r) / (cluster_count[cluster_group_id - 1] + 1)
                    centroid_g[cluster_group_id - 1] = (centroid_g[cluster_group_id - 1] * cluster_count[cluster_group_id - 1] + tmp_count_g) / (cluster_count[cluster_group_id - 1] + 1)
                    centroid_b[cluster_group_id - 1] = (centroid_b[cluster_group_id - 1] * cluster_count[cluster_group_id - 1] + tmp_count_b) / (cluster_count[cluster_group_id - 1] + 1)
                    cluster_count[cluster_group_id - 1] += 1
                    cluster[k] = cluster_group_id
                else:
                    cluster_num += 1
                    visit += 1
                    cluster_count[cluster_num - 1] += 1
                    centroid_r[cluster_num - 1] = tmp_count_r
                    centroid_g[cluster_num - 1] = tmp_count_g
                    centroid_b[cluster_num - 1] = tmp_count_b
                    cluster[k] = cluster_num

            cap.release()

            if not self._is_running:
                return

            # 提取每个聚类的关键帧
            max_similarity = np.zeros(cluster_num)
            frame_indices = np.zeros(cluster_num, dtype=int)

            cap = cv2.VideoCapture(self.video_path)
            frame_number = 0

            while frame_number < num_frames and self._is_running:
                ret, frame = cap.read()
                if not ret:
                    break

                tmp_count_r = cv2.calcHist([frame], [0], None, [256], [0, 256]).flatten()
                tmp_count_g = cv2.calcHist([frame], [1], None, [256], [0, 256]).flatten()
                tmp_count_b = cv2.calcHist([frame], [2], None, [256], [0, 256]).flatten()

                cluster_id = int(cluster[frame_number])
                if cluster_id > 0 and cluster_id <= cluster_num:  # 确保聚类ID有效
                    s_r = np.sum(np.minimum(centroid_r[cluster_id - 1], tmp_count_r))
                    s_g = np.sum(np.minimum(centroid_g[cluster_id - 1], tmp_count_g))
                    s_b = np.sum(np.minimum(centroid_b[cluster_id - 1], tmp_count_b))

                    d_r = s_r / np.sum(tmp_count_r)
                    d_g = s_g / np.sum(tmp_count_g)
                    d_b = s_b / np.sum(tmp_count_b)
                    d = 0.30 * d_r + 0.59 * d_g + 0.11 * d_b

                    if d > max_similarity[cluster_id - 1]:
                        max_similarity[cluster_id - 1] = d
                        frame_indices[cluster_id - 1] = frame_number

                frame_number += 1

                # 更新进度
                progress = 50 + int((frame_number / num_frames) * 25) if num_frames > 0 else 75
                self.progress_updated.emit(self.task_id, progress)

            cap.release()

            if not self._is_running:
                return

            # 保存关键帧到文件夹
            cap = cv2.VideoCapture(self.video_path)
            saved_count = 0
            total_clusters = len(frame_indices)
            
            for i, idx in enumerate(frame_indices):
                if not self._is_running:
                    cap.release()
                    return
                    
                cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
                ret, img = cap.read()
                if ret:
                    frame_filename = os.path.join(output_folder, f'keyframe_{int(idx):06d}.jpg')
                    cv2.imwrite(frame_filename, img)
                    saved_count += 1
                
                # 更新进度
                progress = 75 + int(((i + 1) / total_clusters) * 25) if total_clusters > 0 else 100
                self.progress_updated.emit(self.task_id, progress)
                    
            cap.release()

            if self._is_running:
                self.finished.emit(self.task_id, True, f"成功提取 {saved_count} 个关键帧", output_folder)
            
        except Exception as e:
            self.finished.emit(self.task_id, False, f"算法执行失败: {str(e)}", "")


# 算法映射
ALGORITHMS = {
    "局部最大值算法": LocalMaximaAlgorithm,
    "帧差统计算法": FrameDiffAlgorithm,
    "直方图聚类算法": HistogramAlgorithm
}


class KeyframeExtractor(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("🎬 视频关键帧提取器")
        self.resize(1200, 700)
        
        self.tasks = []
        self.current_algorithm = "局部最大值算法"
        self.output_dir = os.path.expanduser("~/KeyframesOutput")
        self.threads = {}  # 存储线程引用
        
        # 创建输出目录
        os.makedirs(self.output_dir, exist_ok=True)
        
        self.init_ui()
        self.check_dependencies()

    def init_ui(self):
        central = QWidget()
        self.setCentralWidget(central)
        layout = QVBoxLayout(central)

        # 顶部控制区域
        top_group = QGroupBox("控制面板")
        top_layout = QHBoxLayout(top_group)
        
        self.add_btn = QPushButton("📁 添加视频")
        self.add_btn.clicked.connect(self.add_videos)
        self.start_btn = QPushButton("▶️ 开始提取")
        self.start_btn.clicked.connect(self.start_extraction)
        self.clear_btn = QPushButton("🗑️ 清空列表")
        self.clear_btn.clicked.connect(self.clear_all_tasks)
        self.output_btn = QPushButton("📂 输出目录")
        self.output_btn.clicked.connect(self.select_output_dir)
        
        self.algorithm_combo = QComboBox()
        self.algorithm_combo.addItems(ALGORITHMS.keys())
        self.algorithm_combo.setCurrentText(self.current_algorithm)
        self.algorithm_combo.currentTextChanged.connect(self.on_algorithm_changed)
        
        # 添加"应用到所有任务"按钮
        self.apply_to_all_btn = QPushButton("🔄 应用到所有任务")
        self.apply_to_all_btn.clicked.connect(self.apply_algorithm_to_all_tasks)
        
        top_layout.addWidget(self.add_btn)
        top_layout.addWidget(self.start_btn)
        top_layout.addWidget(self.clear_btn)
        top_layout.addWidget(self.output_btn)
        top_layout.addWidget(QLabel("算法选择:"))
        top_layout.addWidget(self.algorithm_combo)
        top_layout.addWidget(self.apply_to_all_btn)
        top_layout.addStretch()
        
        layout.addWidget(top_group)

        # 任务列表
        list_group = QGroupBox("任务列表")
        list_layout = QVBoxLayout(list_group)
        self.list_widget = QListWidget()
        list_layout.addWidget(self.list_widget)
        layout.addWidget(list_group)

        # 状态栏
        self.status_label = QLabel("就绪")
        self.status_label.setAlignment(Qt.AlignCenter)
        self.status_label.setStyleSheet("color: #666; padding: 5px; font-size: 13px;")
        layout.addWidget(self.status_label)

        self.apply_stylesheet()

    def apply_stylesheet(self):
        self.setStyleSheet("""
            QMainWindow {
                background-color: #f5f7fa;
            }
            QGroupBox {
                font-weight: bold;
                font-size: 12px;
                border: 2px solid #dcdfe6;
                border-radius: 8px;
                margin-top: 1ex;
                padding-top: 10px;
            }
            QGroupBox::title {
                subcontrol-origin: margin;
                left: 10px;
                padding: 0 5px 0 5px;
            }
            QPushButton {
                background-color: #409eff;
                color: white;
                border: none;
                padding: 8px 16px;
                border-radius: 6px;
                font-weight: bold;
                min-width: 80px;
            }
            QPushButton:hover {
                background-color: #66b1ff;
            }
            QPushButton:disabled {
                background-color: #c0c4cc;
                color: #909399;
            }
            QListWidget {
                background-color: white;
                border: 1px solid #e4e7ed;
                border-radius: 6px;
                padding: 5px;
            }
            QListWidget::item {
                padding: 0px;
                border-bottom: 1px solid #ebeef5;
            }
            QComboBox {
                padding: 6px;
                border: 1px solid #dcdfe6;
                border-radius: 4px;
                min-width: 120px;
            }
            QLabel {
                font-size: 13px;
            }
            QProgressBar {
                border: 1px solid #dcdfe6;
                border-radius: 4px;
                text-align: center;
                background-color: #f5f7fa;
            }
            QProgressBar::chunk {
                background-color: #409eff;
                border-radius: 3px;
            }
        """)

    def on_algorithm_changed(self, text):
        self.current_algorithm = text

    def apply_algorithm_to_all_tasks(self):
        """将当前选择的算法应用到所有任务,并重置任务状态"""
        if not self.tasks:
            QMessageBox.information(self, "提示", "没有任务可以应用算法")
            return
        
        # 检查是否有正在运行的任务
        running_tasks = [i for i, t in enumerate(self.tasks) if t['status'] == '进行中']
        if running_tasks:
            reply = QMessageBox.question(
                self, "确认操作",
                "有任务正在运行,是否停止这些任务并应用新算法?",
                QMessageBox.Yes | QMessageBox.No,
                QMessageBox.No
            )
            if reply == QMessageBox.No:
                return
            
            # 停止正在运行的任务
            for task_id in running_tasks:
                task = self.tasks[task_id]
                if task.get('worker'):
                    task['worker'].stop()
                if task.get('thread') and task['thread'].isRunning():
                    task['thread'].quit()
                    task['thread'].wait(1000)
                if task_id in self.threads:
                    del self.threads[task_id]
        
        # 应用到所有任务
        updated_count = 0
        for task in self.tasks:
            # 更新算法
            task['algorithm'] = self.current_algorithm
            
            # 重置任务状态(除了正在运行的任务)
            if task['status'] != '进行中':
                task['status'] = '等待'
                task['progress'] = 0
                task['output_path'] = ''
                
                # 更新UI
                if task['status_label']:
                    task['status_label'].setText("等待")
                    task['status_label'].setStyleSheet("")  # 清除样式
                
                if task['progress_bar']:
                    task['progress_bar'].setValue(0)
                
                if task['open_btn']:
                    task['open_btn'].setEnabled(False)
                
                if task['exec_btn']:
                    task['exec_btn'].setText("▶️ 执行")
                    task['exec_btn'].setEnabled(True)
                
                if task['algo_info']:
                    task['algo_info'].setText(f"算法: {self.current_algorithm}")
            
            updated_count += 1
        
        if updated_count > 0:
            self.status_label.setText(f"已将算法应用到 {updated_count} 个任务")
            
            # 如果之前有运行中的任务,重新启用按钮
            if running_tasks:
                self.start_btn.setEnabled(True)
                self.add_btn.setEnabled(True)
                self.clear_btn.setEnabled(True)
        else:
            QMessageBox.information(self, "提示", "没有任务可以更新算法")

    def select_output_dir(self):
        dir_path = QFileDialog.getExistingDirectory(
            self, "选择输出目录", self.output_dir
        )
        if dir_path:
            self.output_dir = dir_path
            self.status_label.setText(f"输出目录已设置为: {dir_path}")

    def add_videos(self):
        files, _ = QFileDialog.getOpenFileNames(
            self, "选择视频文件", "",
            "视频文件 (*.mp4 *.mkv *.avi *.mov *.flv *.wmv *.webm *.m4v *.ts *.mts)"
        )
        for file in files:
            task = {
                'video_path': file,
                'algorithm': self.current_algorithm,
                'status': '等待',
                'progress': 0,
                'output_path': '',
                'widget': None,
                'progress_bar': None,
                'status_label': None,
                'open_btn': None,
                'remove_btn': None,
                'exec_btn': None,
                'thread': None,
                'worker': None,
                'algo_info': None  # 添加算法信息标签引用
            }
            self.tasks.append(task)
            self.add_task_to_list(len(self.tasks) - 1)

    def add_task_to_list(self, index):
        item = QListWidgetItem()
        widget = QWidget()
        layout = QHBoxLayout(widget)

        # 视频信息
        video_info = QLabel(f"{os.path.basename(self.tasks[index]['video_path'])}")
        video_info.setWordWrap(True)
        video_info.setFont(QFont("Arial", 9))
        
        # 算法信息
        algo_info = QLabel(f"算法: {self.tasks[index]['algorithm']}")
        algo_info.setStyleSheet("color: #909399; font-size: 11px;")
        algo_info.setFixedWidth(120)

        # 状态标签
        status_label = QLabel("等待")
        status_label.setMinimumWidth(80)
        status_label.setAlignment(Qt.AlignCenter)

        # 进度条
        progress_bar = QProgressBar()
        progress_bar.setRange(0, 100)
        progress_bar.setValue(0)
        progress_bar.setFixedHeight(20)
        progress_bar.setFixedWidth(150)

        # 执行按钮
        exec_btn = QPushButton("▶️ 执行")
        exec_btn.setFixedWidth(70)
        exec_btn.setFixedHeight(30)
        exec_btn.clicked.connect(lambda _, idx=index: self.execute_single_task(idx))

        # 打开按钮
        open_btn = QPushButton("📂 打开")
        open_btn.setFixedWidth(70)
        open_btn.setFixedHeight(30)
        open_btn.setEnabled(False)
        open_btn.clicked.connect(lambda _, idx=index: self.open_output_folder(idx))

        # 移除按钮
        remove_btn = QPushButton("🗑️")
        remove_btn.setFixedWidth(40)
        remove_btn.setFixedHeight(30)
        remove_btn.setStyleSheet("background-color: #f56c6c; color: white;")
        remove_btn.clicked.connect(lambda _, idx=index: self.remove_task(idx))

        layout.addWidget(video_info)
        layout.addWidget(algo_info)
        layout.addWidget(status_label)
        layout.addWidget(progress_bar)
        layout.addWidget(exec_btn)
        layout.addWidget(open_btn)
        layout.addWidget(remove_btn)
        layout.setStretch(0, 1)
        layout.setAlignment(Qt.AlignVCenter)

        item.setSizeHint(widget.sizeHint())
        self.list_widget.addItem(item)
        self.list_widget.setItemWidget(item, widget)

        self.tasks[index]['widget'] = widget
        self.tasks[index]['status_label'] = status_label
        self.tasks[index]['progress_bar'] = progress_bar
        self.tasks[index]['open_btn'] = open_btn
        self.tasks[index]['remove_btn'] = remove_btn
        self.tasks[index]['exec_btn'] = exec_btn
        self.tasks[index]['algo_info'] = algo_info  # 保存算法标签引用

    def open_output_folder(self, index):
        output_path = self.tasks[index]['output_path']
        if output_path and os.path.exists(output_path):
            open_file_location(output_path)
        else:
            QMessageBox.warning(self, "文件夹不存在", "输出文件夹尚未生成或已被删除。")

    def remove_task(self, index):
        if index >= len(self.tasks):
            return
            
        task = self.tasks[index]
        if task['status'] == '进行中':
            reply = QMessageBox.question(
                self, "确认停止",
                "任务正在运行,确定要停止并移除吗?",
                QMessageBox.Yes | QMessageBox.No,
                QMessageBox.No
            )
            if reply == QMessageBox.Yes:
                # 停止工作线程
                if task.get('worker'):
                    task['worker'].stop()
                if task.get('thread') and task['thread'].isRunning():
                    task['thread'].quit()
                    task['thread'].wait(1000)  # 等待1秒
                self.finish_task_removal(index)
            return

        self.finish_task_removal(index)

    def finish_task_removal(self, index):
        if index < len(self.tasks):
            self.list_widget.takeItem(index)
            if index in self.threads:
                del self.threads[index]
            del self.tasks[index]
            self.status_label.setText("已移除一个任务")

    def clear_all_tasks(self):
        if any(t['status'] == '进行中' for t in self.tasks):
            QMessageBox.warning(self, "无法清空", "有任务正在运行,请等待完成后再清空!")
            return
            
        reply = QMessageBox.question(
            self, "确认清空",
            "确定要清空所有任务吗?",
            QMessageBox.Yes | QMessageBox.No,
            QMessageBox.No
        )
        if reply == QMessageBox.Yes:
            # 停止所有工作线程
            for task in self.tasks:
                if task.get('worker'):
                    task['worker'].stop()
                if task.get('thread') and task['thread'].isRunning():
                    task['thread'].quit()
                    task['thread'].wait(500)
            self.threads.clear()
            self.tasks.clear()
            self.list_widget.clear()
            self.status_label.setText("任务列表已清空")

    def start_extraction(self):
        if not self.tasks:
            QMessageBox.warning(self, "提示", "请先添加视频文件!")
            return

        waiting_tasks = [i for i, t in enumerate(self.tasks) if t['status'] == '等待']
        if not waiting_tasks:
            QMessageBox.information(self, "提示", "没有待处理的任务!")
            return

        self.start_btn.setEnabled(False)
        self.add_btn.setEnabled(False)
        self.clear_btn.setEnabled(False)

        # 使用定时器延迟启动任务,避免UI阻塞
        QTimer.singleShot(100, lambda: self.start_tasks_delayed(waiting_tasks))

    def start_tasks_delayed(self, task_indices):
        for i in task_indices:
            self.run_task(i)

    def run_task(self, index):
        if index >= len(self.tasks):
            return
            
        task = self.tasks[index]
        task['status'] = '进行中'
        task['status_label'].setText("进行中")
        task['status_label'].setStyleSheet("color: #e6a23c; font-weight: bold;")
        task['remove_btn'].setEnabled(False)
        task['exec_btn'].setEnabled(False)
        task['progress_bar'].setValue(0)

        # 创建算法实例
        algorithm_class = ALGORITHMS.get(task['algorithm'])
        if not algorithm_class:
            self.status_label.setText(f"未知算法: {task['algorithm']}")
            return
            
        worker = algorithm_class(index, task['video_path'], self.output_dir)
        
        # 创建线程
        thread = QThread()
        worker.moveToThread(thread)
        
        # 连接信号
        worker.progress_updated.connect(self.on_progress_updated)
        worker.finished.connect(self.on_task_finished)
        thread.started.connect(worker.extract)
        
        # 存储引用
        task['worker'] = worker
        task['thread'] = thread
        self.threads[index] = thread
        
        # 启动线程
        thread.start()

    def on_progress_updated(self, task_id, progress):
        if task_id < len(self.tasks):
            task = self.tasks[task_id]
            task['progress'] = progress
            task['progress_bar'].setValue(progress)

    def on_task_finished(self, task_id, success, message, output_path):
        # 使用定时器延迟UI更新,避免递归重绘
        QTimer.singleShot(0, lambda: self.process_task_result(task_id, success, message, output_path))

    def process_task_result(self, task_id, success, message, output_path):
        if task_id >= len(self.tasks):
            return
            
        task = self.tasks[task_id]
        task['status'] = '完成' if success else '失败'
        color = "#67c23a" if success else "#f56c6c"
        task['status_label'].setText("完成" if success else "失败")
        task['status_label'].setStyleSheet(f"color: {color}; font-weight: bold;")
        task['progress_bar'].setValue(100)
        task['remove_btn'].setEnabled(True)
        task['exec_btn'].setEnabled(True)  # 启用执行按钮,允许重新执行
        task['exec_btn'].setText("🔄 重试")  # 修改按钮文本为"重试"

        # 清理线程
        if task.get('thread'):
            task['thread'].quit()
            task['thread'].wait(500)
            if task_id in self.threads:
                del self.threads[task_id]

        if success:
            task['output_path'] = output_path
            self.status_label.setText(f"✅ {os.path.basename(task['video_path'])} 提取成功!")
            task['open_btn'].setEnabled(True)
        else:
            self.status_label.setText(f"❌ {message}")
            task['open_btn'].setEnabled(False)

        # 检查是否所有批量任务完成
        if not any(t['status'] == '进行中' for t in self.tasks):
            self.start_btn.setEnabled(True)
            self.add_btn.setEnabled(True)
            self.clear_btn.setEnabled(True)

    def execute_single_task(self, index):
        task = self.tasks[index]
        
        # 如果任务已完成或失败,重置为等待状态
        if task['status'] in ['完成', '失败']:
            task['status'] = '等待'
            task['progress'] = 0
            task['output_path'] = ''
            task['status_label'].setText("等待")
            task['status_label'].setStyleSheet("")  # 清除样式
            task['progress_bar'].setValue(0)
            task['open_btn'].setEnabled(False)
            task['exec_btn'].setText("▶️ 执行")  # 恢复执行按钮文本
            
        # 如果任务正在运行,不允许重复执行
        if task['status'] == '进行中':
            QMessageBox.warning(self, "提示", "任务正在运行,请等待完成!")
            return

        self.run_task(index)

    def check_dependencies(self):
        try:
            import cv2
            import numpy as np
            from scipy.signal import argrelextrema
        except ImportError as e:
            missing = str(e).split(" ")[-1]
            msg = (
                f"缺少必要的依赖库: {missing}\n\n"
                "请安装以下依赖:\n"
                "pip install opencv-python numpy scipy matplotlib"
            )
            QMessageBox.critical(self, "依赖缺失", msg)
            sys.exit(1)

    def closeEvent(self, event):
        """窗口关闭时停止所有工作线程"""
        for task in self.tasks:
            if task.get('worker'):
                task['worker'].stop()
            if task.get('thread') and task['thread'].isRunning():
                task['thread'].quit()
                task['thread'].wait(1000)
        event.accept()


if __name__ == "__main__":
    app = QApplication(sys.argv)
    
    # 设置应用程序属性以减少重绘问题
    app.setAttribute(Qt.AA_UseHighDpiPixmaps, True)
    app.setAttribute(Qt.AA_EnableHighDpiScaling, True)
    
    window = KeyframeExtractor()
    window.show()
    
    sys.exit(app.exec_())

让每一帧都说话,让视频内容一目了然。


💡 小提示:对于长视频,建议先用"帧差统计算法"快速筛选,再对重点片段用"直方图聚类"精细提取,效率与质量兼得!

相关推荐
郝学胜-神的一滴17 小时前
Linux系统函数stat和lstat详解
linux·运维·服务器·开发语言·c++·程序人生·软件工程
编程岁月17 小时前
java面试-0141-java反射?优缺点?场景?原理?Class.forName和ClassLoader区别?
java·开发语言·面试
、花无将17 小时前
PHP:配置问题从而导致代码运行出现错误
开发语言·php
小小测试开发17 小时前
pytest 库用法示例:Python 测试框架的高效实践
开发语言·python·pytest
BUG弄潮儿17 小时前
go-swagger标准接口暴露
开发语言·后端·golang
数字化顾问17 小时前
Flink ProcessFunction 与低层级 Join 实战手册:实时画像秒级更新系统
java·开发语言
qq_3391911418 小时前
go win安装grpc-gen-go插件
开发语言·后端·golang
疯狂吧小飞牛18 小时前
Lua中,表、元表、对象、类的解析
开发语言·junit·lua
owCode18 小时前
3-C++中类大小影响因素
开发语言·c++