动作识别8——自建数据集进行TSN训练

目录

一、前言

二、找视频和剪视频

[2.1 视频转图片文件夹](#2.1 视频转图片文件夹)

[2.2 图像二分类标注工具](#2.2 图像二分类标注工具)

[2.3 改为方向键自动播放](#2.3 改为方向键自动播放)

[2.4 新增转视频功能](#2.4 新增转视频功能)

[2.5 新增自动补标负样本功能](#2.5 新增自动补标负样本功能)

三、数据量的问题

[3.1 mmaction2训练TSN模型脚踢/非脚踢 两个分类的动作识别,需要多少训练和验证样本、需要多少正负样本?各需要多少视频才能效果好](#3.1 mmaction2训练TSN模型脚踢/非脚踢 两个分类的动作识别,需要多少训练和验证样本、需要多少正负样本?各需要多少视频才能效果好)

[3.2 每段视频有时间的要求吗?](#3.2 每段视频有时间的要求吗?)

四、转数据


一、前言

《动作识别5------mmaction2的训练和测试》里面说了官方文档介绍的使用kinetics400_tiny数据进行训练的demo。那我们就模仿kinetics400_tiny搞个数据集,然后训练TSN试试。

这一篇的内容主要是写如何构建一个标注的工具。

二、找视频和剪视频

找视频然后下载视频就不说了,去找那种在线下载B站视频的网址。说一下怎么剪视频。虽然我们在《动作识别1------2D姿态估计+ 几何分析 + 有限状态机》第三节的工具中有视频剪辑和视频转图片文件夹的代码,但是我觉得那个剪视频的代码不是很好用啊,因为还得知道要剪视频的帧号范围。所以我在2.2中搞了个图像二分类标注工具代码,并进行了一些修改。

2.1 视频转图片文件夹

首先,我搞了个视频转图片文件夹的代码。然后我把下载的视频都放在D:\zero_track\mmaction2\input_videos,然后在mmaction2新建了一个my_tools文件夹,里面放一个mp4tojpg.py,下面脚本就会在视频的所在路径下生成一个跟视频名称一样的图片文件夹,比方说test1.mp4就会生成一个test1文件夹。

python 复制代码
#!/usr/bin/env python3
import cv2
import os
import argparse

# ---------- 可自定义 ----------
VIDEO_EXTS = ('.mp4', '.avi', '.mov', '.mkv')   # 支持的视频扩展名
# ----------------------------

def extract_frames(video_path):
    """把单个视频拆帧成图片文件夹"""
    base_dir = os.path.dirname(video_path)
    video_name = os.path.splitext(os.path.basename(video_path))[0]
    output_dir = os.path.join(base_dir, video_name + '_frames')
    os.makedirs(output_dir, exist_ok=True)

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f'[WARN] 无法打开视频: {video_path}')
        return 0

    frame_idx = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame_path = os.path.join(output_dir, f'{frame_idx:06d}.jpg')
        cv2.imwrite(frame_path, frame)
        frame_idx += 1

    cap.release()
    print(f'[INFO] 提取完成: {frame_idx} 张 -> {output_dir}')
    return frame_idx

def walk_and_extract(root_dir):
    """递归遍历目录,对所有视频拆帧"""
    count = 0
    for dirpath, _, filenames in os.walk(root_dir):
        for file in filenames:
            if file.lower().endswith(VIDEO_EXTS):
                video_path = os.path.join(dirpath, file)
                extract_frames(video_path)
                count += 1
    print(f'[INFO] 批量提取完成,共处理 {count} 个视频')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='视频逐帧提取为图片(支持单文件或整目录)')
    parser.add_argument('--input', help='单个视频路径')
    parser.add_argument('--dir',  help='批量提取目录')
    args = parser.parse_args()

    # 缺省行为:如果两个都没给,就默认用指定目录
    if args.input is None and args.dir is None:
        # 修改这里为你的视频文件夹目录
        args.dir = r'D:\zero_track\mmaction2\input_videos'

    # 互斥检查
    if args.input is not None and args.dir is not None:
        parser.error('不能同时指定 --input 和 --dir,请二选一')

    # 执行
    if args.input:
        extract_frames(args.input)
    else:
        walk_and_extract(args.dir)

2.2 图像二分类标注工具

然后,我搞了一个图像二分类的标注工具

python 复制代码
import os
import pygame
import sys
import shutil
import time
import json
from pygame.locals import *


# 初始化pygame
pygame.init()

# 配置参数
SCREEN_WIDTH, SCREEN_HEIGHT = pygame.display.Info().current_w, pygame.display.Info().current_h   

WINDOW_WIDTH, WINDOW_HEIGHT = SCREEN_WIDTH - 100, SCREEN_HEIGHT - 100
BG_COLOR = (40, 44, 52)
TEXT_COLOR = (220, 220, 220)
HIGHLIGHT_COLOR = (97, 175, 239)
BUTTON_COLOR = (56, 58, 66)
BUTTON_HOVER_COLOR = (72, 74, 82)
WARNING_COLOR = (255, 152, 0)
CONFIRM_COLOR = (76, 175, 80)

# 创建窗口
screen = pygame.display.set_mode((WINDOW_WIDTH, WINDOW_HEIGHT))
pygame.display.set_caption("图像分类标注工具")

# 字体
font = pygame.font.SysFont("SimHei", 24)
small_font = pygame.font.SysFont("SimHei", 18)

class ImageLabelingTool:
    def __init__(self, root_path):
        self.root_path = root_path
        self.folders = []               # 所有含图片的文件夹绝对路径
        self.current_folder_index = 0   # 当前文件夹索引
        self.images = []                # 当前文件夹内所有图片绝对路径
        self.current_image_index = 0    # 当前图片索引
        self.labels = {}                # 路径 -> 'positive' / 'negative'

        # 标记状态
        self.continuous_mode = False    # 是否处于连续标记模式
        self.continuous_label = None    # 连续标记时统一的标签
        self.continuous_start_index = None  # 连续标记起始索引

        # 键盘长按状态
        self.key_pressed = {"left": False, "right": False}
        self.last_key_time = 0          # 长按重复计时
        self.key_repeat_delay = 0.8  # 初始延迟增加到0.8秒
        self.key_repeat_interval = 0.15  # 重复间隔增加到0.15秒  

        # 操作历史(用于撤销)
        self.undo_stack = []
        self.max_undo_steps = 50
  
        # 确认对话框状态
        self.show_confirm_dialog = False
        self.confirm_message = ""
        self.confirm_action = ""        # 标记确认对话框触发动作

        # 获取所有包含图片的文件夹
        self.find_image_folders()

        # 加载当前文件夹的图片
        if self.folders:
            self.load_current_folder_images()
        
        # 加载保存的标记状态
        self.load_labels()          # 尝试加载历史标签

        
    def find_image_folders(self):
        """查找所有包含图片的文件夹"""
        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
        for root, dirs, files in os.walk(self.root_path):
            has_images = any(file.lower().endswith(image_extensions) for file in files)
            if has_images:
                self.folders.append(root)

    def load_current_folder_images(self):
        """加载当前文件夹中的所有图片"""
        folder_path = self.folders[self.current_folder_index]
        self.images = []

        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')

        for file in os.listdir(folder_path):
            if file.lower().endswith(image_extensions):
                self.images.append(os.path.join(folder_path, file))

        # 按文件名排序
        self.images.sort()
        self.current_image_index = 0

    def get_current_image(self):
        """获取当前图片"""
        if not self.images:
            return None
        return self.images[self.current_image_index]

    def next_image(self):
        """切换到下一张图片"""
        if self.current_image_index < len(self.images) - 1:
            self.save_state()  # 保存状态以便撤销
            self.current_image_index += 1
            return True
        return False
        
    def prev_image(self):
        """切换到上一张图片"""
        if self.current_image_index > 0:
            self.current_image_index -= 1
            return True
        return False

    def label_current_image(self, label):
        """标记当前图片"""
        current_image= self.get_current_image()
        if current_image:
            self.save_state()        # 保存状态以便撤销         
            self.labels[current_image] = label
            # 自动保存标记状态
            self.save_labels()

    def start_continuous_labeling(self):
        """开始连续标记"""
        current_image = self.get_current_image()
        if current_image:
            self.save_state()  # 保存状态以便撤销
            # 如果当前图片已经有标签,使用该标签
            if current_image in self.labels:
                self.continuous_label = self.labels[current_image]
            else:
                # 如果没有标签,默认为正样本
                self.continuous_label = "positive"
                self.labels[current_image] = self.continuous_label

            self.continuous_mode = True
            self.continuous_start_index = self.current_image_index
            # 自动保存标记状态
            self.save_labels()
            return True
        return False

    def end_continuous_labeling(self):
        """结束连续标记"""
        if self.continuous_mode and self.continuous_start_index is not None:
            self.save_state()  # 保存状态以便撤销
            start = min(self.continuous_start_index, self.current_image_index)
            end = max(self.continuous_start_index, self.current_image_index)

            for i in range(start, end + 1):
                self.labels[self.images[i]] = self.continuous_label

            self.continuous_mode = False
            self.continuous_start_index = None
            # 自动保存标记状态
            self.save_labels()
            return True
        return False

    def move_labeled_files(self, positive_dir, negative_dir):
        """移动已标记的文件到正负样本文件夹"""
        if not os.path.exists(positive_dir):
            os.makedirs(positive_dir)
        if not os.path.exists(negative_dir):
            os.makedirs(negative_dir)

        moved_count = 0
        files_to_remove = []

        for img_path, label in self.labels.items():
            if label in ["positive", "negative"] and os.path.exists(img_path):
                filename = os.path.basename(img_path)
                dest_dir = positive_dir if label == "positive" else negative_dir

                # 处理文件名冲突
                counter = 1
                base_name, ext = os.path.splitext(filename)
                new_filename = filename 
                while os.path.exists(os.path.join(dest_dir, new_filename)):
                    new_filename = f"{base_name}_{counter}{ext}"
                    counter += 1

                try:
                    shutil.move(img_path, os.path.join(dest_dir, new_filename))
                    moved_count += 1
                    files_to_remove.append(img_path)
                except Exception as e:
                    print(f"移动文件失败: {e}") 

        # 从标签字典中移除已移动的文件
        for img_path in files_to_remove:
            if img_path in self.labels:
                del self.labels[img_path] 

        # 更新当前文件夹的图片列表
        self.load_current_folder_images()

        # 自动保存标记状态
        self.save_labels()

        return moved_count

    def next_folder(self):
        """切换到下一个文件夹"""
        if self.current_folder_index < len(self.folders) - 1:
            # 检查当前文件夹是否有未移动的标记文件
            current_folder = self.folders[self.current_folder_index]
            has_unmoved_labels = any(
                img_path.startswith(current_folder) and os.path.exists(img_path)
                for img_path in self.labels.keys()
            )

            if has_unmoved_labels:  
                # 显示确认对话框
                self.show_confirm_dialog = True
                self.confirm_action = "next_folder"
                self.confirm_message = "当前文件夹有未移动的标记文件,确定要切换到下一个文件夹吗?"
                return False
            else:
                # 直接切换文件夹
                self.current_folder_index += 1
                self.load_current_folder_images()
                return True
        return False

    def prev_folder(self):
        """切换到上一个文件夹"""
        if self.current_folder_index > 0:
            self.current_folder_index -= 1
            self.load_current_folder_images()
            return True
        return False

    def handle_key_repeats(self):
        """处理方向键长按"""   
        current_time = time.time() 

        # 检查是否需要触发按键重复
        if any(self.key_pressed.values()):
            # 如果是第一次按下,等待较长时间
            if self.last_key_time == 0:
                if current_time - self.key_pressed_time > self.key_repeat_delay:
                    if self.key_pressed["left"]:
                        self.prev_image()
                    elif self.key_pressed["right"]:
                        self.next_image()
                    self.last_key_time = current_time
            # 后续重复,使用较短的间隔  
            elif current_time - self.last_key_time > self.key_repeat_interval:
                if self.key_pressed["left"]:
                    self.prev_image()
                elif self.key_pressed["right"]:  
                    self.next_image()
                self.last_key_time = current_time

    def save_state(self):
        """保存当前状态以便撤销"""
        if len(self.undo_stack) >= self.max_undo_steps:
            self.undo_stack.pop(0)  # 移除最旧的状态

        state = {
            "current_image_index": self.current_image_index,
            "labels": self.labels.copy(),
            "continuous_mode": self.continuous_mode,
            "continuous_start_index": self.continuous_start_index,
            "continuous_label": self.continuous_label
        } 

        self.undo_stack.append(state)

    def undo(self):
        """撤销上一次操作"""
        if self.undo_stack:
            state = self.undo_stack.pop()
            self.current_image_index = state["current_image_index"]
            self.labels = state["labels"]
            self.continuous_mode = state["continuous_mode"]
            self.continuous_start_index = state["continuous_start_index"]
            self.continuous_label = state["continuous_label"]
            return True
        return False

    def save_labels(self):
        """保存标记状态到文件"""
        labels_file = os.path.join(self.root_path, "labels_backup.json")
        try:
            # 只保存仍然存在的文件的标记
            existing_labels = {k: v for k, v in self.labels.items() if os.path.exists(k)}
            with open(labels_file, 'w') as f:
                json.dump(existing_labels, f)
        except Exception as e:
            print(f"保存标记状态失败: {e}")

    def load_labels(self):
        """从文件加载标记状态"""
        labels_file = os.path.join(self.root_path, "labels_backup.json")
        if os.path.exists(labels_file):
            try:
                with open(labels_file, 'r') as f:
                    self.labels = json.load(f)
            except Exception as e:
                print(f"加载标记状态失败: {e}")

def draw_button(screen, text, rect, hover=False, color=None):
    """绘制按钮"""
    if color is None:
        color = BUTTON_HOVER_COLOR if hover else BUTTON_COLOR

    # 先画主体
    pygame.draw.rect(screen, color, rect, border_radius=5)
    # 再画边框
    pygame.draw.rect(screen, (100, 100, 100), rect, 2, border_radius=5)

    # 文字居中
    text_surface= small_font.render(text, True, TEXT_COLOR)
    txt_rect = text_surface.get_rect(center=rect.center)
    screen.blit(text_surface, txt_rect)

def draw_confirm_dialog(screen, message, width=400, height=200):
    """绘制确认对话框"""
    dialog_rect = pygame.Rect(
        (WINDOW_WIDTH - width) // 2,
        (WINDOW_HEIGHT - height) // 2, 
        width, height
    )

    # 绘制对话框背景
    pygame.draw.rect(screen, BG_COLOR, dialog_rect, border_radius=10)
    pygame.draw.rect(screen, TEXT_COLOR, dialog_rect, 2, border_radius=10)

    # 绘制消息
    lines = []
    words = message.split()
    current_line = ""

    for word in words:
        test_line = current_line + word + " "
        if small_font.size(test_line)[0] < width - 40:
            current_line = test_line
        else:
            lines.append(current_line)
            current_line = word + " "

    if current_line:
        lines.append(current_line)

    for i, line in enumerate(lines):
        text_surface = small_font.render(line, True, TEXT_COLOR)
    screen.blit(text_surface, (dialog_rect.x + 20, dialog_rect.y + 30 + i * 25))

    # 绘制按钮
    yes_button= pygame.Rect(dialog_rect.x + width // 2 - 100, dialog_rect.y + height - 50, 80, 30)
    no_button  = pygame.Rect(dialog_rect.x + width // 2 + 20, dialog_rect.y + height - 50, 80, 30)

    draw_button(screen, "是", yes_button, color=CONFIRM_COLOR)
    draw_button(screen, "否", no_button, color=WARNING_COLOR)

    return dialog_rect, yes_button, no_button

def main():
    # 假设的根路径,实际使用时需要修改
    root_path = r"D:\zero_track\mmaction2\input_videos\test1"

    # 创建标注工具实例
    tool = ImageLabelingTool(root_path)

    # 创建正负样本输出目录
    # positive_dir = os.path.join(root_path, "positive_samples")
    # negative_dir = os.path.join(root_path, "negative_samples")
    positive_dir = os.path.join(root_path, "1")
    negative_dir = os.path.join(root_path, "0")

    # 主循环
    running = True
    clock = pygame.time.Clock()

    # 按钮区域 - 分为两行
    button_height = 40
    button_width = 140
    button_margin =15
    button_row1_y = WINDOW_HEIGHT - button_height - button_margin
    button_row2_y = WINDOW_HEIGHT - 2 * button_height - 2 * button_margin
    
    # 第一行按钮(导航按钮)
    nav_buttons = {
        "prev": pygame.Rect(button_margin, button_row2_y, button_width, button_height),
        "next": pygame.Rect(button_margin * 2 + button_width, button_row2_y, button_width, button_height),
        "prev_folder": pygame.Rect(button_margin * 3 + button_width * 2, button_row2_y, button_width, button_height),
        "next_folder": pygame.Rect(button_margin * 4 + button_width * 3, button_row2_y, button_width, button_height),
        "undo": pygame.Rect(button_margin * 5 + button_width * 4, button_row2_y, button_width, button_height),
    }

    # 第二行按钮(标注按钮)
    label_buttons = {
        "positive": pygame.Rect(button_margin, button_row1_y, button_width, button_height),
        "negative": pygame.Rect(button_margin * 2 + button_width, button_row1_y, button_width, button_height),
        "continuous_start": pygame.Rect(button_margin * 3 + button_width * 2, button_row1_y, button_width, button_height),
        "continuous_end": pygame.Rect(button_margin * 4 + button_width * 3, button_row1_y, button_width, button_height),
        "move_files": pygame.Rect(button_margin * 5 + button_width * 4, button_row1_y, button_width, button_height),
    }

    # 图片显示区域
    image_area = pygame.Rect(50, 80, WINDOW_WIDTH - 100, WINDOW_HEIGHT - 220)


    # 添加按键按下时间记录
    tool.key_pressed_time = 0

    while running:
        mouse_pos = pygame.mouse.get_pos()

        # 处理按键重复
        tool.handle_key_repeats()

        for event in pygame.event.get():
            if event.type == QUIT:
                running = False
            elif event.type == KEYDOWN:
                if event.key == K_RIGHT:
                    tool.key_pressed["right"] = True
                    tool.key_pressed["left"] = False
                    tool.key_pressed_time = time.time()  # 记录按下时间
                    tool.next_image()  # 立即响应一次
                elif event.key == K_LEFT:
                    tool.key_pressed["left"] = True
                    tool.key_pressed["right"] = False
                    tool.key_pressed_time = time.time()  # 记录按下时间
                    tool.prev_image()  # 立即响应一次
                elif event.key == K_w:  # 标记为正样本
                    tool.label_current_image("positive")
                elif event.key == K_s:  # 标记为负样本
                    tool.label_current_image("negative")
                elif event.key == K_UP: # 开始连续标记
                    if not tool.start_continuous_labeling():
                        print("无法开始连续标记")
                elif event.key == K_DOWN: # 结束连续标记
                    if not tool.end_continuous_labeling():
                        print("没有激活的连续标记")
                elif event.key == K_x: # 移动文件
                    moved = tool.move_labeled_files(positive_dir, negative_dir)
                    print(f"已移动 {moved} 个文件")
                elif event.key == K_c: # 下一个文件夹
                    tool.next_folder()  
                elif event.key == K_z:  # 上一个文件夹
                    tool.prev_folder()
                elif event.key == K_z and (pygame.key.get_mods() & KMOD_CTRL): # Ctrl+Z 撤销
                    if tool.undo():
                        print("已撤销上一次操作")
                    else:
                        print("没有可撤销的操作")
                elif event.key == K_ESCAPE:  # ESC 键取消确认对话框
                    if tool.show_confirm_dialog:
                        tool.show_confirm_dialog = False

            elif event.type == KEYUP:
                if event.key == K_RIGHT:
                    tool.key_pressed["right"] = False
                    tool.last_key_time = 0 # 重置重复计时
                elif event.key == K_LEFT:
                    tool.key_pressed["left"] = False
                    tool.last_key_time =0  # 重置重复计时

            elif event.type == MOUSEBUTTONDOWN:
                if event.button == 1: # 左键点击
                    # 检查是否点击了确认对话框
                    if tool.show_confirm_dialog:
                        dialog_rect, yes_button, no_button = draw_confirm_dialog(screen, tool.confirm_message)
                        if yes_button.collidepoint(mouse_pos):
                            tool.show_confirm_dialog = False
                            if tool.confirm_action == "next_folder":
                                tool.current_folder_index += 1
                                tool.load_current_folder_images()
                        elif no_button.collidepoint(mouse_pos):
                            tool.show_confirm_dialog = Fasle
                    else:
                        # 导航按钮
                        if nav_buttons["prev"].collidepoint(mouse_pos):
                            tool.prev_image()
                        elif nav_buttons["next"].collidepoint(mouse_pos):
                            tool.next_image()
                        elif nav_buttons["prev_folder"].collidepoint(mouse_pos):
                            tool.prev_folder()
                        elif nav_buttons["next_folder"].collidepoint(mouse_pos):
                            tool.next_folder()
                        elif nav_buttons["undo"].collidepoint(mouse_pos):
                            if tool.undo():
                                print("已撤销上一次操作")
                            else:
                                print("没有可撤销的操作")

                        # 标注按钮
                        elif label_buttons["positive"].collidepoint(mouse_pos):
                            tool.label_current_image("positive")
                        elif label_buttons["negative"].collidepoint(mouse_pos):
                            tool.label_current_image("negative")
                        elif label_buttons["continuous_start"].collidepoint(mouse_pos):
                            if not tool.start_continuous_labeling():
                                print("无法开始连续标记")
                        elif label_buttons["continuous_end"].collidepoint(mouse_pos):
                            if not tool.end_continuous_labeling():
                                print("没有激活的连续标记")  
                        elif label_buttons["move_files"].collidepoint(mouse_pos):
                            moved = tool.move_labeled_files(positive_dir, negative_dir)
                            print("已移动 {moved} 个文件") 

        # 清屏
        screen.fill(BG_COLOR)
        
        # 显示文件信息
        if tool.folders:
            folder_text = f"当前文件夹: {os.path.basename(tool.folders[tool.current_folder_index])} ({tool.current_folder_index + 1}/{len(tool.folders)})"
            text_surface = small_font.render(folder_text, True, TEXT_COLOR)
            screen.blit(text_surface, (20, 20))

        # 显示当前图片
        current_image_path = tool.get_current_image()
        if current_image_path and os.path.exists(current_image_path):
            try:
                img = pygame.image.load(current_image_path)
                img_rect = img.get_rect()
        
                # 缩放图片以适应显示区域
                scale = min(image_area.width / img_rect.width, image_area.height / img_rect.height)
                new_size = (int(img_rect.width * scale), int(img_rect.height * scale))
                img = pygame.transform.smoothscale(img, new_size)
                img_rect = img.get_rect(center=image_area.center)

                screen.blit(img, img_rect)
    
                # 显示图片信息(在图片上方)  
                info_text = f"{os.path.basename(current_image_path)} ({tool.current_image_index + 1}/{len(tool.images)})"
                if current_image_path in tool.labels:
                    label = tool.labels[current_image_path]
                    info_text += f" - 已标记: {'正样本' if label == 'positive' else '负样本'}"
                text_surface = font.render(info_text, True, TEXT_COLOR)
                text_rect = text_surface.get_rect(center=(WINDOW_WIDTH // 2, image_area.y - 20))
                screen.blit(text_surface, text_rect)
            
                # 在连续标记模式下显示标记范围
                if tool.continuous_mode and tool.continuous_start_index is not None:
                    start_idx = min(tool.continuous_start_index, tool.current_image_index)
                    end_idx = max(tool.continuous_start_index, tool.current_image_index)
                
                    range_text = f"标记范围: {start_idx + 1} - {end_idx + 1}"
                    range_surface = small_font.render(range_text, True, HIGHLIGHT_COLOR)
                    screen.blit(range_surface, (20, 50))

                    # 绘制标记范围的指示器
                    marker_width = image_area.width / len(tool.images)
                    start_x = image_area.x + start_idx * marker_width
                    end_x = image_area.x + (end_idx + 1) * marker_width
    
                    pygame.draw.rect(screen, HIGHLIGHT_COLOR,
                                        (start_x, image_area.y + image_area.height + 5,
                                         end_x -start_x, 5))

            except Exception as e:
                error_text = f"无法加载图片: {e}"
                text_surface = font.render(error_text, True, (255, 0, 0))
                screen.blit(text_surface, (image_area.centerx - text_surface.get_width() // 2, image_area.centery - text_surface.get_height() // 2))

        else:
            no_image_text = "没有图片可显示"
            text_surface = font.render(no_image_text, True, TEXT_COLOR)
            screen.blit(text_surface, (image_area.centerx - text_surface.get_width() // 2, image_area.centery - text_surface.get_height() // 2))

        # 显示连续标记状态
        if tool.continuous_mode:
            mode_text = f"连续标记模式已启动 - 标记类型: {'正样本' if tool.continuous_label == 'positive' else '负样本'}"
            text_surface = small_font.render(mode_text, True, HIGHLIGHT_COLOR)
            screen.blit(text_surface, (WINDOW_WIDTH - text_surface.get_width() - 20, 50))

        # 绘制导航按钮
        draw_button(screen, "上一张 (←)", nav_buttons["prev"], nav_buttons["prev"].collidepoint(mouse_pos))
        draw_button(screen, "下一张 (→)", nav_buttons["next"], nav_buttons["next"].collidepoint(mouse_pos))
        draw_button(screen, "上个文件夹 (z)", nav_buttons["prev_folder"], nav_buttons["prev_folder"].collidepoint(mouse_pos))
        draw_button(screen, "下个文件夹 (c)", nav_buttons["next_folder"], nav_buttons["next_folder"].collidepoint(mouse_pos))
        draw_button(screen, "撤销 (Ctrl+Z)", nav_buttons["undo"], nav_buttons["undo"].collidepoint(mouse_pos))

        # 绘制标注按钮
        draw_button(screen, "正样本 (w)", label_buttons["positive"], label_buttons["positive"].collidepoint(mouse_pos))
        draw_button(screen, "负样本 (s)", label_buttons["negative"], label_buttons["negative"].collidepoint(mouse_pos))
        draw_button(screen, "开始连续标(↑)", label_buttons["continuous_start"], label_buttons["continuous_start"].collidepoint(mouse_pos))
        draw_button(screen, "结束连续标(↓)", label_buttons["continuous_end"], label_buttons["continuous_end"].collidepoint(mouse_pos))
        draw_button(screen, "移动文件 (x)", label_buttons["move_files"], label_buttons["move_files"].collidepoint(mouse_pos))

        # 显示确认对话框
        if tool.show_confirm_dialog:
            draw_confirm_dialog(screen, tool.confirm_message)


        # 更新屏幕
        pygame.display.flip()
        clock.tick(30)

    # 退出前保存标记状态
    tool.save_labels()
    pygame.quit()
    sys.exit()
        
if __name__ == "__main__":
    main()

上面这个工具就是方向键左右控制向前向后查看图片,然后w表示把图片标记为正样本,s表示把图标记为负样本,方向键上表示开始连续标记(如果当前是正样本就是连续标记为正样本,如果当前是负样本就是连续标记为负样本),方向键下表示结束连续标记。x表示移动正负样本到新建的文件夹1和0。注意,我发现我运行这个py程序的时候,搜狗输入法会自动变成中文输入法状态(真坑),要手动shift键切回英文再用鼠标点一下界面,才能确保字母键w、s、x能够正常使用,如果你强制设定电脑默认输入法是英文应该不会有这个问题。

还有就是运行一次之后它会在图片文件夹下面生成labels_backup.json,如果下次运行你不手动删除它,是会加载上一次的结果的。

2.3 改为方向键自动播放

把原先的方向键左和右的单独退后和前进改成a和d,然后方向键的左改成向前播放,右改成向后播放,空格表示暂停播放。

这样我们就可以用方向键去自动播放图片了,不用一直长按,长按太累了,同时我们把原先的方向键的功能改到a和d,也就是说要是觉得自动播放太快了,也可以使用a和d去慢慢看、或者长按播放。

  1. 把"单张前进/后退"的热键从方向键改成 A / D

    位置:main() 函数里的 KEYDOWN 分支

    原代码:

    复制代码
             elif event.key == K_RIGHT:
                 tool.key_pressed["right"] = True
                 tool.key_pressed["left"]  = False
                 tool.key_pressed_time     = time.time()
                 tool.next_image()               # 立即响应一次
             elif event.key == K_LEFT:
                 tool.key_pressed["left"]  = True
                 tool.key_pressed["right"] = False
                 tool.key_pressed_time     = time.time()
                 tool.prev_image()               # 立即响应一次

改成:

复制代码
            elif event.key == K_d:              # 单张前进
                tool.key_pressed["right"] = True
                tool.key_pressed["left"]  = False
                tool.key_pressed_time     = time.time()
                tool.next_image()
            elif event.key == K_a:              # 单张后退
                tool.key_pressed["left"]  = True
                tool.key_pressed["right"] = False
                tool.key_pressed_time     = time.time()
                tool.prev_image()

对应的 KEYUP 分支也同步改掉:

复制代码
            elif event.key == K_d:
                tool.key_pressed["right"] = False
                tool.last_key_time        = 0
            elif event.key == K_a:
                tool.key_pressed["left"]  = False
                tool.last_key_time        = 0

  1. 把方向键 ← → 变成"自动播放"

    仍在 KEYDOWN 分支,在刚刚改完 A/D 的下面新增:

    复制代码
             elif event.key == K_RIGHT:          # 向后自动播放
                 tool.play_direction = 1
                 tool.playing        = True
                 tool.last_play_tick = pygame.time.get_ticks()
             elif event.key == K_LEFT:           # 向前自动播放
                 tool.play_direction = -1
                 tool.playing        = True
                 tool.last_play_tick = pygame.time.get_ticks()
             elif event.key == K_SPACE:          # 暂停/继续
                 tool.playing = not tool.playing
                 if tool.playing:
                     tool.last_play_tick = pygame.time.get_ticks()

  1. ImageLabelingTool.__init__ 末尾加 3 个状态量

    任意位置(例如 self.max_undo_steps = 50 之后)追加:

    复制代码
         # 自动播放相关
         self.playing        = False      # 是否处于自动播放
         self.play_direction = 1          # 1 下一张,-1 上一张
         self.last_play_tick = 0          # 上一次翻片的时间
         self.play_interval  = 400        # 毫秒,每 0.4 s 翻一张

  1. 主循环里真正执行"定时翻片"

    while running: 的最前面(靠近 handle_key_repeats() 即可)插入:

    复制代码
         # 自动播放逻辑
         if tool.playing:
             now = pygame.time.get_ticks()
             if now - tool.last_play_tick > tool.play_interval:
                 if tool.play_direction == 1:
                     tool.next_image()
                 else:
                     tool.prev_image()
                 tool.last_play_tick = now

  1. 按钮提示文字同步改一下(可选)

    把下面两行:

    复制代码
         draw_button(screen, "上一张 (←)", nav_buttons["prev"], nav_buttons["prev"].collidepoint(mouse_pos))
         draw_button(screen, "下一张 (→)", nav_buttons["next"], nav_buttons["next"].collidepoint(mouse_pos))

改成:

复制代码
        draw_button(screen, "上一张 (A)", nav_buttons["prev"], nav_buttons["prev"].collidepoint(mouse_pos))
        draw_button(screen, "下一张 (D)", nav_buttons["next"], nav_buttons["next"].collidepoint(mouse_pos))

下面是修改之后的完整代码

python 复制代码
import os
import pygame
import sys
import shutil
import time
import json
from pygame.locals import *


# 初始化pygame
pygame.init()

# 配置参数
SCREEN_WIDTH, SCREEN_HEIGHT = pygame.display.Info().current_w, pygame.display.Info().current_h   

WINDOW_WIDTH, WINDOW_HEIGHT = SCREEN_WIDTH - 100, SCREEN_HEIGHT - 100
BG_COLOR = (40, 44, 52)
TEXT_COLOR = (220, 220, 220)
HIGHLIGHT_COLOR = (97, 175, 239)
BUTTON_COLOR = (56, 58, 66)
BUTTON_HOVER_COLOR = (72, 74, 82)
WARNING_COLOR = (255, 152, 0)
CONFIRM_COLOR = (76, 175, 80)

# 创建窗口
screen = pygame.display.set_mode((WINDOW_WIDTH, WINDOW_HEIGHT))
pygame.display.set_caption("图像分类标注工具")

# 字体
font = pygame.font.SysFont("SimHei", 24)
small_font = pygame.font.SysFont("SimHei", 18)

class ImageLabelingTool:
    def __init__(self, root_path):
        self.root_path = root_path
        self.folders = []               # 所有含图片的文件夹绝对路径
        self.current_folder_index = 0   # 当前文件夹索引
        self.images = []                # 当前文件夹内所有图片绝对路径
        self.current_image_index = 0    # 当前图片索引
        self.labels = {}                # 路径 -> 'positive' / 'negative'

        # 自动播放相关
        self.playing        = False      # 是否处于自动播放
        self.play_direction = 1          # 1 下一张,-1 上一张
        self.last_play_tick = 0          # 上一次翻片的时间
        self.play_interval  = 30        # 毫秒,每 0.03 s 翻一张

        # 标记状态
        self.continuous_mode = False    # 是否处于连续标记模式
        self.continuous_label = None    # 连续标记时统一的标签
        self.continuous_start_index = None  # 连续标记起始索引

        # 键盘长按状态
        self.key_pressed = {"left": False, "right": False}
        self.last_key_time = 0          # 长按重复计时
        self.key_repeat_delay = 0.8  # 初始延迟增加到0.8秒
        self.key_repeat_interval = 0.15  # 重复间隔增加到0.15秒  

        # 操作历史(用于撤销)
        self.undo_stack = []
        self.max_undo_steps = 50
  
        # 确认对话框状态
        self.show_confirm_dialog = False
        self.confirm_message = ""
        self.confirm_action = ""        # 标记确认对话框触发动作

        # 获取所有包含图片的文件夹
        self.find_image_folders()

        # 加载当前文件夹的图片
        if self.folders:
            self.load_current_folder_images()
        
        # 加载保存的标记状态
        self.load_labels()          # 尝试加载历史标签

        
    def find_image_folders(self):
        """查找所有包含图片的文件夹"""
        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
        for root, dirs, files in os.walk(self.root_path):
            has_images = any(file.lower().endswith(image_extensions) for file in files)
            if has_images:
                self.folders.append(root)

    def load_current_folder_images(self):
        """加载当前文件夹中的所有图片"""
        folder_path = self.folders[self.current_folder_index]
        self.images = []

        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')

        for file in os.listdir(folder_path):
            if file.lower().endswith(image_extensions):
                self.images.append(os.path.join(folder_path, file))

        # 按文件名排序
        self.images.sort()
        self.current_image_index = 0

    def get_current_image(self):
        """获取当前图片"""
        if not self.images:
            return None
        return self.images[self.current_image_index]

    def next_image(self):
        """切换到下一张图片"""
        if self.current_image_index < len(self.images) - 1:
            self.save_state()  # 保存状态以便撤销
            self.current_image_index += 1
            return True
        return False
        
    def prev_image(self):
        """切换到上一张图片"""
        if self.current_image_index > 0:
            self.current_image_index -= 1
            return True
        return False

    def label_current_image(self, label):
        """标记当前图片"""
        current_image= self.get_current_image()
        if current_image:
            self.save_state()        # 保存状态以便撤销         
            self.labels[current_image] = label
            # 自动保存标记状态
            self.save_labels()

    def start_continuous_labeling(self):
        """开始连续标记"""
        current_image = self.get_current_image()
        if current_image:
            self.save_state()  # 保存状态以便撤销
            # 如果当前图片已经有标签,使用该标签
            if current_image in self.labels:
                self.continuous_label = self.labels[current_image]
            else:
                # 如果没有标签,默认为正样本
                self.continuous_label = "positive"
                self.labels[current_image] = self.continuous_label

            self.continuous_mode = True
            self.continuous_start_index = self.current_image_index
            # 自动保存标记状态
            self.save_labels()
            return True
        return False

    def end_continuous_labeling(self):
        """结束连续标记"""
        if self.continuous_mode and self.continuous_start_index is not None:
            self.save_state()  # 保存状态以便撤销
            start = min(self.continuous_start_index, self.current_image_index)
            end = max(self.continuous_start_index, self.current_image_index)

            for i in range(start, end + 1):
                self.labels[self.images[i]] = self.continuous_label

            self.continuous_mode = False
            self.continuous_start_index = None
            # 自动保存标记状态
            self.save_labels()
            return True
        return False

    def move_labeled_files(self, positive_dir, negative_dir):
        """移动已标记的文件到正负样本文件夹"""
        if not os.path.exists(positive_dir):
            os.makedirs(positive_dir)
        if not os.path.exists(negative_dir):
            os.makedirs(negative_dir)

        moved_count = 0
        files_to_remove = []

        for img_path, label in self.labels.items():
            if label in ["positive", "negative"] and os.path.exists(img_path):
                filename = os.path.basename(img_path)
                dest_dir = positive_dir if label == "positive" else negative_dir

                # 处理文件名冲突
                counter = 1
                base_name, ext = os.path.splitext(filename)
                new_filename = filename 
                while os.path.exists(os.path.join(dest_dir, new_filename)):
                    new_filename = f"{base_name}_{counter}{ext}"
                    counter += 1

                try:
                    shutil.move(img_path, os.path.join(dest_dir, new_filename))
                    moved_count += 1
                    files_to_remove.append(img_path)
                except Exception as e:
                    print(f"移动文件失败: {e}") 

        # 从标签字典中移除已移动的文件
        for img_path in files_to_remove:
            if img_path in self.labels:
                del self.labels[img_path] 

        # 更新当前文件夹的图片列表
        self.load_current_folder_images()

        # 自动保存标记状态
        self.save_labels()

        return moved_count

    def next_folder(self):
        """切换到下一个文件夹"""
        if self.current_folder_index < len(self.folders) - 1:
            # 检查当前文件夹是否有未移动的标记文件
            current_folder = self.folders[self.current_folder_index]
            has_unmoved_labels = any(
                img_path.startswith(current_folder) and os.path.exists(img_path)
                for img_path in self.labels.keys()
            )

            if has_unmoved_labels:  
                # 显示确认对话框
                self.show_confirm_dialog = True
                self.confirm_action = "next_folder"
                self.confirm_message = "当前文件夹有未移动的标记文件,确定要切换到下一个文件夹吗?"
                return False
            else:
                # 直接切换文件夹
                self.current_folder_index += 1
                self.load_current_folder_images()
                return True
        return False

    def prev_folder(self):
        """切换到上一个文件夹"""
        if self.current_folder_index > 0:
            self.current_folder_index -= 1
            self.load_current_folder_images()
            return True
        return False

    def handle_key_repeats(self):
        """处理方向键长按"""   
        current_time = time.time() 

        # 检查是否需要触发按键重复
        if any(self.key_pressed.values()):
            # 如果是第一次按下,等待较长时间
            if self.last_key_time == 0:
                if current_time - self.key_pressed_time > self.key_repeat_delay:
                    if self.key_pressed["left"]:
                        self.prev_image()
                    elif self.key_pressed["right"]:
                        self.next_image()
                    self.last_key_time = current_time
            # 后续重复,使用较短的间隔  
            elif current_time - self.last_key_time > self.key_repeat_interval:
                if self.key_pressed["left"]:
                    self.prev_image()
                elif self.key_pressed["right"]:  
                    self.next_image()
                self.last_key_time = current_time

    def save_state(self):
        """保存当前状态以便撤销"""
        if len(self.undo_stack) >= self.max_undo_steps:
            self.undo_stack.pop(0)  # 移除最旧的状态

        state = {
            "current_image_index": self.current_image_index,
            "labels": self.labels.copy(),
            "continuous_mode": self.continuous_mode,
            "continuous_start_index": self.continuous_start_index,
            "continuous_label": self.continuous_label
        } 

        self.undo_stack.append(state)

    def undo(self):
        """撤销上一次操作"""
        if self.undo_stack:
            state = self.undo_stack.pop()
            self.current_image_index = state["current_image_index"]
            self.labels = state["labels"]
            self.continuous_mode = state["continuous_mode"]
            self.continuous_start_index = state["continuous_start_index"]
            self.continuous_label = state["continuous_label"]
            return True
        return False

    def save_labels(self):
        """保存标记状态到文件"""
        labels_file = os.path.join(self.root_path, "labels_backup.json")
        try:
            # 只保存仍然存在的文件的标记
            existing_labels = {k: v for k, v in self.labels.items() if os.path.exists(k)}
            with open(labels_file, 'w') as f:
                json.dump(existing_labels, f)
        except Exception as e:
            print(f"保存标记状态失败: {e}")

    def load_labels(self):
        """从文件加载标记状态"""
        labels_file = os.path.join(self.root_path, "labels_backup.json")
        if os.path.exists(labels_file):
            try:
                with open(labels_file, 'r') as f:
                    self.labels = json.load(f)
            except Exception as e:
                print(f"加载标记状态失败: {e}")

def draw_button(screen, text, rect, hover=False, color=None):
    """绘制按钮"""
    if color is None:
        color = BUTTON_HOVER_COLOR if hover else BUTTON_COLOR

    # 先画主体
    pygame.draw.rect(screen, color, rect, border_radius=5)
    # 再画边框
    pygame.draw.rect(screen, (100, 100, 100), rect, 2, border_radius=5)

    # 文字居中
    text_surface= small_font.render(text, True, TEXT_COLOR)
    txt_rect = text_surface.get_rect(center=rect.center)
    screen.blit(text_surface, txt_rect)

def draw_confirm_dialog(screen, message, width=400, height=200):
    """绘制确认对话框"""
    dialog_rect = pygame.Rect(
        (WINDOW_WIDTH - width) // 2,
        (WINDOW_HEIGHT - height) // 2, 
        width, height
    )

    # 绘制对话框背景
    pygame.draw.rect(screen, BG_COLOR, dialog_rect, border_radius=10)
    pygame.draw.rect(screen, TEXT_COLOR, dialog_rect, 2, border_radius=10)

    # 绘制消息
    lines = []
    words = message.split()
    current_line = ""

    for word in words:
        test_line = current_line + word + " "
        if small_font.size(test_line)[0] < width - 40:
            current_line = test_line
        else:
            lines.append(current_line)
            current_line = word + " "

    if current_line:
        lines.append(current_line)

    for i, line in enumerate(lines):
        text_surface = small_font.render(line, True, TEXT_COLOR)
    screen.blit(text_surface, (dialog_rect.x + 20, dialog_rect.y + 30 + i * 25))

    # 绘制按钮
    yes_button= pygame.Rect(dialog_rect.x + width // 2 - 100, dialog_rect.y + height - 50, 80, 30)
    no_button  = pygame.Rect(dialog_rect.x + width // 2 + 20, dialog_rect.y + height - 50, 80, 30)

    draw_button(screen, "是", yes_button, color=CONFIRM_COLOR)
    draw_button(screen, "否", no_button, color=WARNING_COLOR)

    return dialog_rect, yes_button, no_button

def main():
    # 假设的根路径,实际使用时需要修改
    root_path = r"D:\zero_track\mmaction2\input_videos\test1"

    # 创建标注工具实例
    tool = ImageLabelingTool(root_path)

    # 创建正负样本输出目录
    # positive_dir = os.path.join(root_path, "positive_samples")
    # negative_dir = os.path.join(root_path, "negative_samples")
    positive_dir = os.path.join(root_path, "1")
    negative_dir = os.path.join(root_path, "0")

    # 主循环
    running = True
    clock = pygame.time.Clock()

    # 按钮区域 - 分为两行
    button_height = 40
    button_width = 140
    button_margin =15
    button_row1_y = WINDOW_HEIGHT - button_height - button_margin
    button_row2_y = WINDOW_HEIGHT - 2 * button_height - 2 * button_margin
    
    # 第一行按钮(导航按钮)
    nav_buttons = {
        "prev": pygame.Rect(button_margin, button_row2_y, button_width, button_height),
        "next": pygame.Rect(button_margin * 2 + button_width, button_row2_y, button_width, button_height),
        "prev_folder": pygame.Rect(button_margin * 3 + button_width * 2, button_row2_y, button_width, button_height),
        "next_folder": pygame.Rect(button_margin * 4 + button_width * 3, button_row2_y, button_width, button_height),
        "undo": pygame.Rect(button_margin * 5 + button_width * 4, button_row2_y, button_width, button_height),
    }

    # 第二行按钮(标注按钮)
    label_buttons = {
        "positive": pygame.Rect(button_margin, button_row1_y, button_width, button_height),
        "negative": pygame.Rect(button_margin * 2 + button_width, button_row1_y, button_width, button_height),
        "continuous_start": pygame.Rect(button_margin * 3 + button_width * 2, button_row1_y, button_width, button_height),
        "continuous_end": pygame.Rect(button_margin * 4 + button_width * 3, button_row1_y, button_width, button_height),
        "move_files": pygame.Rect(button_margin * 5 + button_width * 4, button_row1_y, button_width, button_height),
    }

    # 图片显示区域
    image_area = pygame.Rect(50, 80, WINDOW_WIDTH - 100, WINDOW_HEIGHT - 220)


    # 添加按键按下时间记录
    tool.key_pressed_time = 0

    while running:
        mouse_pos = pygame.mouse.get_pos()

        # 处理按键重复
        tool.handle_key_repeats()

        # 自动播放逻辑
        if tool.playing:
            now = pygame.time.get_ticks()
            if now - tool.last_play_tick > tool.play_interval:
                if tool.play_direction == 1:
                    tool.next_image()
                else:
                    tool.prev_image()
                tool.last_play_tick = now

        for event in pygame.event.get():
            if event.type == QUIT:
                running = False
            elif event.type == KEYDOWN:
                if event.key == K_d:
                    tool.key_pressed["right"] = True
                    tool.key_pressed["left"] = False
                    tool.key_pressed_time = time.time()  # 记录按下时间
                    tool.next_image()  # 立即响应一次
                elif event.key == K_a:
                    tool.key_pressed["left"] = True
                    tool.key_pressed["right"] = False
                    tool.key_pressed_time = time.time()  # 记录按下时间
                    tool.prev_image()  # 立即响应一次

                elif event.key == K_RIGHT:          # 向后自动播放
                    tool.play_direction = 1
                    tool.playing        = True
                    tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_LEFT:           # 向前自动播放
                    tool.play_direction = -1
                    tool.playing        = True
                    tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_SPACE:          # 暂停/继续
                    tool.playing = not tool.playing
                    if tool.playing:
                        tool.last_play_tick = pygame.time.get_ticks()

                elif event.key == K_w:  # 标记为正样本
                    tool.label_current_image("positive")
                elif event.key == K_s:  # 标记为负样本
                    tool.label_current_image("negative")
                elif event.key == K_UP: # 开始连续标记
                    if not tool.start_continuous_labeling():
                        print("无法开始连续标记")
                elif event.key == K_DOWN: # 结束连续标记
                    if not tool.end_continuous_labeling():
                        print("没有激活的连续标记")
                elif event.key == K_x: # 移动文件
                    moved = tool.move_labeled_files(positive_dir, negative_dir)
                    print(f"已移动 {moved} 个文件")
                elif event.key == K_c: # 下一个文件夹
                    tool.next_folder()  
                elif event.key == K_z:  # 上一个文件夹
                    tool.prev_folder()
                elif event.key == K_z and (pygame.key.get_mods() & KMOD_CTRL): # Ctrl+Z 撤销
                    if tool.undo():
                        print("已撤销上一次操作")
                    else:
                        print("没有可撤销的操作")
                elif event.key == K_ESCAPE:  # ESC 键取消确认对话框
                    if tool.show_confirm_dialog:
                        tool.show_confirm_dialog = False

            elif event.type == KEYUP:
                if event.key == K_d:
                    tool.key_pressed["right"] = False
                    tool.last_key_time = 0 # 重置重复计时
                elif event.key == K_a:
                    tool.key_pressed["left"] = False
                    tool.last_key_time =0  # 重置重复计时

            elif event.type == MOUSEBUTTONDOWN:
                if event.button == 1: # 左键点击
                    # 检查是否点击了确认对话框
                    if tool.show_confirm_dialog:
                        dialog_rect, yes_button, no_button = draw_confirm_dialog(screen, tool.confirm_message)
                        if yes_button.collidepoint(mouse_pos):
                            tool.show_confirm_dialog = False
                            if tool.confirm_action == "next_folder":
                                tool.current_folder_index += 1
                                tool.load_current_folder_images()
                        elif no_button.collidepoint(mouse_pos):
                            tool.show_confirm_dialog = Fasle
                    else:
                        # 导航按钮
                        if nav_buttons["prev"].collidepoint(mouse_pos):
                            tool.prev_image()
                        elif nav_buttons["next"].collidepoint(mouse_pos):
                            tool.next_image()
                        elif nav_buttons["prev_folder"].collidepoint(mouse_pos):
                            tool.prev_folder()
                        elif nav_buttons["next_folder"].collidepoint(mouse_pos):
                            tool.next_folder()
                        elif nav_buttons["undo"].collidepoint(mouse_pos):
                            if tool.undo():
                                print("已撤销上一次操作")
                            else:
                                print("没有可撤销的操作")

                        # 标注按钮
                        elif label_buttons["positive"].collidepoint(mouse_pos):
                            tool.label_current_image("positive")
                        elif label_buttons["negative"].collidepoint(mouse_pos):
                            tool.label_current_image("negative")
                        elif label_buttons["continuous_start"].collidepoint(mouse_pos):
                            if not tool.start_continuous_labeling():
                                print("无法开始连续标记")
                        elif label_buttons["continuous_end"].collidepoint(mouse_pos):
                            if not tool.end_continuous_labeling():
                                print("没有激活的连续标记")  
                        elif label_buttons["move_files"].collidepoint(mouse_pos):
                            moved = tool.move_labeled_files(positive_dir, negative_dir)
                            print("已移动 {moved} 个文件") 

        # 清屏
        screen.fill(BG_COLOR)
        
        # 显示文件信息
        if tool.folders:
            folder_text = f"当前文件夹: {os.path.basename(tool.folders[tool.current_folder_index])} ({tool.current_folder_index + 1}/{len(tool.folders)})"
            text_surface = small_font.render(folder_text, True, TEXT_COLOR)
            screen.blit(text_surface, (20, 20))

        # 显示当前图片
        current_image_path = tool.get_current_image()
        if current_image_path and os.path.exists(current_image_path):
            try:
                img = pygame.image.load(current_image_path)
                img_rect = img.get_rect()
        
                # 缩放图片以适应显示区域
                scale = min(image_area.width / img_rect.width, image_area.height / img_rect.height)
                new_size = (int(img_rect.width * scale), int(img_rect.height * scale))
                img = pygame.transform.smoothscale(img, new_size)
                img_rect = img.get_rect(center=image_area.center)

                screen.blit(img, img_rect)
    
                # 显示图片信息(在图片上方)  
                info_text = f"{os.path.basename(current_image_path)} ({tool.current_image_index + 1}/{len(tool.images)})"
                if current_image_path in tool.labels:
                    label = tool.labels[current_image_path]
                    info_text += f" - 已标记: {'正样本' if label == 'positive' else '负样本'}"
                text_surface = font.render(info_text, True, TEXT_COLOR)
                text_rect = text_surface.get_rect(center=(WINDOW_WIDTH // 2, image_area.y - 20))
                screen.blit(text_surface, text_rect)
            
                # 在连续标记模式下显示标记范围
                if tool.continuous_mode and tool.continuous_start_index is not None:
                    start_idx = min(tool.continuous_start_index, tool.current_image_index)
                    end_idx = max(tool.continuous_start_index, tool.current_image_index)
                
                    range_text = f"标记范围: {start_idx + 1} - {end_idx + 1}"
                    range_surface = small_font.render(range_text, True, HIGHLIGHT_COLOR)
                    screen.blit(range_surface, (20, 50))

                    # 绘制标记范围的指示器
                    marker_width = image_area.width / len(tool.images)
                    start_x = image_area.x + start_idx * marker_width
                    end_x = image_area.x + (end_idx + 1) * marker_width
    
                    pygame.draw.rect(screen, HIGHLIGHT_COLOR,
                                        (start_x, image_area.y + image_area.height + 5,
                                         end_x -start_x, 5))

            except Exception as e:
                error_text = f"无法加载图片: {e}"
                text_surface = font.render(error_text, True, (255, 0, 0))
                screen.blit(text_surface, (image_area.centerx - text_surface.get_width() // 2, image_area.centery - text_surface.get_height() // 2))

        else:
            no_image_text = "没有图片可显示"
            text_surface = font.render(no_image_text, True, TEXT_COLOR)
            screen.blit(text_surface, (image_area.centerx - text_surface.get_width() // 2, image_area.centery - text_surface.get_height() // 2))

        # 显示连续标记状态
        if tool.continuous_mode:
            mode_text = f"连续标记模式已启动 - 标记类型: {'正样本' if tool.continuous_label == 'positive' else '负样本'}"
            text_surface = small_font.render(mode_text, True, HIGHLIGHT_COLOR)
            screen.blit(text_surface, (WINDOW_WIDTH - text_surface.get_width() - 20, 50))

        # 绘制导航按钮
        draw_button(screen, "上一张 (a)", nav_buttons["prev"], nav_buttons["prev"].collidepoint(mouse_pos))
        draw_button(screen, "下一张 (d)", nav_buttons["next"], nav_buttons["next"].collidepoint(mouse_pos))
        draw_button(screen, "上个文件夹 (z)", nav_buttons["prev_folder"], nav_buttons["prev_folder"].collidepoint(mouse_pos))
        draw_button(screen, "下个文件夹 (c)", nav_buttons["next_folder"], nav_buttons["next_folder"].collidepoint(mouse_pos))
        draw_button(screen, "撤销 (Ctrl+Z)", nav_buttons["undo"], nav_buttons["undo"].collidepoint(mouse_pos))

        # 绘制标注按钮
        draw_button(screen, "正样本 (w)", label_buttons["positive"], label_buttons["positive"].collidepoint(mouse_pos))
        draw_button(screen, "负样本 (s)", label_buttons["negative"], label_buttons["negative"].collidepoint(mouse_pos))
        draw_button(screen, "开始连续标(↑)", label_buttons["continuous_start"], label_buttons["continuous_start"].collidepoint(mouse_pos))
        draw_button(screen, "结束连续标(↓)", label_buttons["continuous_end"], label_buttons["continuous_end"].collidepoint(mouse_pos))
        draw_button(screen, "移动文件 (x)", label_buttons["move_files"], label_buttons["move_files"].collidepoint(mouse_pos))

        # 显示确认对话框
        if tool.show_confirm_dialog:
            draw_confirm_dialog(screen, tool.confirm_message)


        # 更新屏幕
        pygame.display.flip()
        clock.tick(30)

    # 退出前保存标记状态
    tool.save_labels()
    pygame.quit()
    sys.exit()
        
if __name__ == "__main__":
    main()

2.4 新增转视频功能

上面的逻辑是按x就把正样本和负样本移动到文件夹1和0中,但是我想新增一个变量控制是否在转移的时候同时把转移的这些图片转成视频(图片不要了),比如说现在第10到第20帧是正样本,第21帧到30帧是负样本,用户按了x之后,原先的逻辑是第10到第20帧移动到文件夹1中,第21帧到30帧移动到文件夹0中,现在的逻辑变成第10到第20帧转成一个视频片段放到文件夹1中,第21帧到30帧也转成一个视频片段放到文件夹0中。允许没有负样本或者没有正样本

python 复制代码
import os
import pygame
import sys
import shutil
import time
import json
from pygame.locals import *
import cv2  # pip install opencv-python 

# 初始化pygame
pygame.init()

# 配置参数
SCREEN_WIDTH, SCREEN_HEIGHT = pygame.display.Info().current_w, pygame.display.Info().current_h   

WINDOW_WIDTH, WINDOW_HEIGHT = SCREEN_WIDTH - 100, SCREEN_HEIGHT - 100
BG_COLOR = (40, 44, 52)
TEXT_COLOR = (220, 220, 220)
HIGHLIGHT_COLOR = (97, 175, 239)
BUTTON_COLOR = (56, 58, 66)
BUTTON_HOVER_COLOR = (72, 74, 82)
WARNING_COLOR = (255, 152, 0)
CONFIRM_COLOR = (76, 175, 80)

# 创建窗口
screen = pygame.display.set_mode((WINDOW_WIDTH, WINDOW_HEIGHT))
pygame.display.set_caption("图像分类标注工具")

# 字体
font = pygame.font.SysFont("SimHei", 24)
small_font = pygame.font.SysFont("SimHei", 18)

class ImageLabelingTool:
    def __init__(self, root_path):
        self.root_path = root_path
        self.folders = []               # 所有含图片的文件夹绝对路径
        self.current_folder_index = 0   # 当前文件夹索引
        self.images = []                # 当前文件夹内所有图片绝对路径
        self.current_image_index = 0    # 当前图片索引
        self.labels = {}                # 路径 -> 'positive' / 'negative'

        self.convert_to_video = True  # 是否启用转视频模式
        self.video_fps = 10  # 视频帧率
        
        # 自动播放相关
        self.playing        = False      # 是否处于自动播放
        self.play_direction = 1          # 1 下一张,-1 上一张
        self.last_play_tick = 0          # 上一次翻片的时间
        self.play_interval  = 100        # 每 多少 毫秒 翻一张

        # 标记状态
        self.continuous_mode = False    # 是否处于连续标记模式
        self.continuous_label = None    # 连续标记时统一的标签
        self.continuous_start_index = None  # 连续标记起始索引

        # 键盘长按状态
        self.key_pressed = {"left": False, "right": False}
        self.last_key_time = 0          # 长按重复计时
        self.key_repeat_delay = 0.8  # 初始延迟增加到0.8秒
        self.key_repeat_interval = 0.15  # 重复间隔增加到0.15秒  

        # 操作历史(用于撤销)
        self.undo_stack = []
        self.max_undo_steps = 50
  
        # 确认对话框状态
        self.show_confirm_dialog = False
        self.confirm_message = ""
        self.confirm_action = ""        # 标记确认对话框触发动作

        # 获取所有包含图片的文件夹
        self.find_image_folders()

        # 加载当前文件夹的图片
        if self.folders:
            self.load_current_folder_images()
        
        # 加载保存的标记状态
        self.load_labels()          # 尝试加载历史标签

    def images_to_video(self, image_paths, output_path, fps=10):
        """将图片序列转为视频"""
        if not image_paths:
            return

        # 读取第一张图获取尺寸
        frame = cv2.imread(image_paths[0])
        h, w, _ = frame.shape

        # 初始化视频写入器
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

        for img_path in image_paths:
            frame = cv2.imread(img_path)
            out.write(frame)

        out.release()

        # 删除原图
        for img_path in image_paths:
            if os.path.exists(img_path):
                os.remove(img_path)
        
    def find_image_folders(self):
        """查找所有包含图片的文件夹"""
        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
        for root, dirs, files in os.walk(self.root_path):
            has_images = any(file.lower().endswith(image_extensions) for file in files)
            if has_images:
                self.folders.append(root)

    def load_current_folder_images(self):
        """加载当前文件夹中的所有图片"""
        folder_path = self.folders[self.current_folder_index]
        self.images = []

        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')

        for file in os.listdir(folder_path):
            if file.lower().endswith(image_extensions):
                self.images.append(os.path.join(folder_path, file))

        # 按文件名排序
        self.images.sort()
        self.current_image_index = 0

    def get_current_image(self):
        """获取当前图片"""
        if not self.images:
            return None
        return self.images[self.current_image_index]

    def next_image(self):
        """切换到下一张图片"""
        if self.current_image_index < len(self.images) - 1:
            self.save_state()  # 保存状态以便撤销
            self.current_image_index += 1
            return True
        return False
        
    def prev_image(self):
        """切换到上一张图片"""
        if self.current_image_index > 0:
            self.current_image_index -= 1
            return True
        return False

    def label_current_image(self, label):
        """标记当前图片"""
        current_image= self.get_current_image()
        if current_image:
            self.save_state()        # 保存状态以便撤销         
            self.labels[current_image] = label
            # 自动保存标记状态
            self.save_labels()

    def start_continuous_labeling(self):
        """开始连续标记"""
        current_image = self.get_current_image()
        if current_image:
            self.save_state()  # 保存状态以便撤销
            # 如果当前图片已经有标签,使用该标签
            if current_image in self.labels:
                self.continuous_label = self.labels[current_image]
            else:
                # 如果没有标签,默认为正样本
                self.continuous_label = "positive"
                self.labels[current_image] = self.continuous_label

            self.continuous_mode = True
            self.continuous_start_index = self.current_image_index
            # 自动保存标记状态
            self.save_labels()
            return True
        return False

    def end_continuous_labeling(self):
        """结束连续标记"""
        if self.continuous_mode and self.continuous_start_index is not None:
            self.save_state()  # 保存状态以便撤销
            start = min(self.continuous_start_index, self.current_image_index)
            end = max(self.continuous_start_index, self.current_image_index)

            for i in range(start, end + 1):
                self.labels[self.images[i]] = self.continuous_label

            self.continuous_mode = False
            self.continuous_start_index = None
            # 自动保存标记状态
            self.save_labels()
            return True
        return False

    def move_labeled_files(self, positive_dir, negative_dir):
        """移动已标记的文件到正负样本文件夹"""
        if not os.path.exists(positive_dir):
            os.makedirs(positive_dir)
        if not os.path.exists(negative_dir):
            os.makedirs(negative_dir)

        # 按标签分组
        from collections import defaultdict
        groups = defaultdict(list)

        # 按标签分组
        from collections import defaultdict
        groups = defaultdict(list)

        for img_path, label in self.labels.items():
            if label in ["positive", "negative"] and os.path.exists(img_path):
                groups[label].append(img_path)

        # 排序每组图片(按文件名)
        for label in groups:
            groups[label].sort()

        # 处理每组
        for label, image_paths in groups.items():
            if not image_paths:
                continue

            dest_dir = positive_dir if label == "positive" else negative_dir

            if self.convert_to_video:
                # 合成视频
                # 提取帧号范围
                def extract_index(path):
                    # 假设文件名格式为 frame_0010.jpg 或 0010.jpg
                    name = os.path.splitext(os.path.basename(path))[0]
                    try:
                        return int(name.split('_')[-1])
                    except ValueError:
                        return None

                start_idx = extract_index(image_paths[0])
                end_idx = extract_index(image_paths[-1])

                if start_idx is not None and end_idx is not None:
                    range_str = f"{start_idx}to{end_idx}"
                else:
                    range_str = "unknown_range"

                folder_name = os.path.basename(os.path.dirname(image_paths[0]))
                video_name = f"{folder_name}_{label}_{range_str}.mp4"

                video_path = os.path.join(dest_dir, video_name)
                self.images_to_video(image_paths, video_path, fps=self.video_fps)
            else:
                # 原逻辑:移动文件
                for img_path in image_paths:
                    filename = os.path.basename(img_path)
                    dest_path = os.path.join(dest_dir, filename)
                    shutil.move(img_path, dest_path)

            # 从标签中移除
            for img_path in image_paths:
                self.labels.pop(img_path, None)
        
        # 重新加载当前文件夹
        self.load_current_folder_images()
        self.save_labels()

    def next_folder(self):
        """切换到下一个文件夹"""
        if self.current_folder_index < len(self.folders) - 1:
            # 检查当前文件夹是否有未移动的标记文件
            current_folder = self.folders[self.current_folder_index]
            has_unmoved_labels = any(
                img_path.startswith(current_folder) and os.path.exists(img_path)
                for img_path in self.labels.keys()
            )

            if has_unmoved_labels:  
                # 显示确认对话框
                self.show_confirm_dialog = True
                self.confirm_action = "next_folder"
                self.confirm_message = "当前文件夹有未移动的标记文件,确定要切换到下一个文件夹吗?"
                return False
            else:
                # 直接切换文件夹
                self.current_folder_index += 1
                self.load_current_folder_images()
                return True
        return False

    def prev_folder(self):
        """切换到上一个文件夹"""
        if self.current_folder_index > 0:
            self.current_folder_index -= 1
            self.load_current_folder_images()
            return True
        return False

    def handle_key_repeats(self):
        """处理方向键长按"""   
        current_time = time.time() 

        # 检查是否需要触发按键重复
        if any(self.key_pressed.values()):
            # 如果是第一次按下,等待较长时间
            if self.last_key_time == 0:
                if current_time - self.key_pressed_time > self.key_repeat_delay:
                    if self.key_pressed["left"]:
                        self.prev_image()
                    elif self.key_pressed["right"]:
                        self.next_image()
                    self.last_key_time = current_time
            # 后续重复,使用较短的间隔  
            elif current_time - self.last_key_time > self.key_repeat_interval:
                if self.key_pressed["left"]:
                    self.prev_image()
                elif self.key_pressed["right"]:  
                    self.next_image()
                self.last_key_time = current_time

    def save_state(self):
        """保存当前状态以便撤销"""
        if len(self.undo_stack) >= self.max_undo_steps:
            self.undo_stack.pop(0)  # 移除最旧的状态

        state = {
            "current_image_index": self.current_image_index,
            "labels": self.labels.copy(),
            "continuous_mode": self.continuous_mode,
            "continuous_start_index": self.continuous_start_index,
            "continuous_label": self.continuous_label
        } 

        self.undo_stack.append(state)

    def undo(self):
        """撤销上一次操作"""
        if self.undo_stack:
            state = self.undo_stack.pop()
            self.current_image_index = state["current_image_index"]
            self.labels = state["labels"]
            self.continuous_mode = state["continuous_mode"]
            self.continuous_start_index = state["continuous_start_index"]
            self.continuous_label = state["continuous_label"]
            return True
        return False

    def save_labels(self):
        """保存标记状态到文件"""
        labels_file = os.path.join(self.root_path, "labels_backup.json")
        try:
            # 只保存仍然存在的文件的标记
            existing_labels = {k: v for k, v in self.labels.items() if os.path.exists(k)}
            with open(labels_file, 'w') as f:
                json.dump(existing_labels, f)
        except Exception as e:
            print(f"保存标记状态失败: {e}")

    def load_labels(self):
        """从文件加载标记状态"""
        labels_file = os.path.join(self.root_path, "labels_backup.json")
        if os.path.exists(labels_file):
            try:
                with open(labels_file, 'r') as f:
                    self.labels = json.load(f)
            except Exception as e:
                print(f"加载标记状态失败: {e}")

def draw_button(screen, text, rect, hover=False, color=None):
    """绘制按钮"""
    if color is None:
        color = BUTTON_HOVER_COLOR if hover else BUTTON_COLOR

    # 先画主体
    pygame.draw.rect(screen, color, rect, border_radius=5)
    # 再画边框
    pygame.draw.rect(screen, (100, 100, 100), rect, 2, border_radius=5)

    # 文字居中
    text_surface= small_font.render(text, True, TEXT_COLOR)
    txt_rect = text_surface.get_rect(center=rect.center)
    screen.blit(text_surface, txt_rect)

def draw_confirm_dialog(screen, message, width=400, height=200):
    """绘制确认对话框"""
    dialog_rect = pygame.Rect(
        (WINDOW_WIDTH - width) // 2,
        (WINDOW_HEIGHT - height) // 2, 
        width, height
    )

    # 绘制对话框背景
    pygame.draw.rect(screen, BG_COLOR, dialog_rect, border_radius=10)
    pygame.draw.rect(screen, TEXT_COLOR, dialog_rect, 2, border_radius=10)

    # 绘制消息
    lines = []
    words = message.split()
    current_line = ""

    for word in words:
        test_line = current_line + word + " "
        if small_font.size(test_line)[0] < width - 40:
            current_line = test_line
        else:
            lines.append(current_line)
            current_line = word + " "

    if current_line:
        lines.append(current_line)

    for i, line in enumerate(lines):
        text_surface = small_font.render(line, True, TEXT_COLOR)
    screen.blit(text_surface, (dialog_rect.x + 20, dialog_rect.y + 30 + i * 25))

    # 绘制按钮
    yes_button= pygame.Rect(dialog_rect.x + width // 2 - 100, dialog_rect.y + height - 50, 80, 30)
    no_button  = pygame.Rect(dialog_rect.x + width // 2 + 20, dialog_rect.y + height - 50, 80, 30)

    draw_button(screen, "是", yes_button, color=CONFIRM_COLOR)
    draw_button(screen, "否", no_button, color=WARNING_COLOR)

    return dialog_rect, yes_button, no_button

def main():
    # 假设的根路径,实际使用时需要修改
    root_path = r"D:\zero_track\mmaction2\input_videos\test1"

    # 创建标注工具实例
    tool = ImageLabelingTool(root_path)

    # 创建正负样本输出目录
    # positive_dir = os.path.join(root_path, "positive_samples")
    # negative_dir = os.path.join(root_path, "negative_samples")
    positive_dir = os.path.join(root_path, "1")
    negative_dir = os.path.join(root_path, "0")

    # 主循环
    running = True
    clock = pygame.time.Clock()

    # 按钮区域 - 分为两行
    button_height = 40
    button_width = 140
    button_margin =15
    button_row1_y = WINDOW_HEIGHT - button_height - button_margin
    button_row2_y = WINDOW_HEIGHT - 2 * button_height - 2 * button_margin
    
    # 第一行按钮(导航按钮)
    nav_buttons = {
        "prev": pygame.Rect(button_margin, button_row2_y, button_width, button_height),
        "next": pygame.Rect(button_margin * 2 + button_width, button_row2_y, button_width, button_height),
        "prev_folder": pygame.Rect(button_margin * 3 + button_width * 2, button_row2_y, button_width, button_height),
        "next_folder": pygame.Rect(button_margin * 4 + button_width * 3, button_row2_y, button_width, button_height),
        "undo": pygame.Rect(button_margin * 5 + button_width * 4, button_row2_y, button_width, button_height),
    }

    # 第二行按钮(标注按钮)
    label_buttons = {
        "positive": pygame.Rect(button_margin, button_row1_y, button_width, button_height),
        "negative": pygame.Rect(button_margin * 2 + button_width, button_row1_y, button_width, button_height),
        "continuous_start": pygame.Rect(button_margin * 3 + button_width * 2, button_row1_y, button_width, button_height),
        "continuous_end": pygame.Rect(button_margin * 4 + button_width * 3, button_row1_y, button_width, button_height),
        "move_files": pygame.Rect(button_margin * 5 + button_width * 4, button_row1_y, button_width, button_height),
    }

    # 图片显示区域
    image_area = pygame.Rect(50, 80, WINDOW_WIDTH - 100, WINDOW_HEIGHT - 220)


    # 添加按键按下时间记录
    tool.key_pressed_time = 0

    while running:
        mouse_pos = pygame.mouse.get_pos()

        # 处理按键重复
        tool.handle_key_repeats()

        # 自动播放逻辑
        if tool.playing:
            now = pygame.time.get_ticks()
            if now - tool.last_play_tick > tool.play_interval:
                if tool.play_direction == 1:
                    tool.next_image()
                else:
                    tool.prev_image()
                tool.last_play_tick = now

        for event in pygame.event.get():
            if event.type == QUIT:
                running = False
            elif event.type == KEYDOWN:
                if event.key == K_d:
                    tool.key_pressed["right"] = True
                    tool.key_pressed["left"] = False
                    tool.key_pressed_time = time.time()  # 记录按下时间
                    tool.next_image()  # 立即响应一次
                elif event.key == K_a:
                    tool.key_pressed["left"] = True
                    tool.key_pressed["right"] = False
                    tool.key_pressed_time = time.time()  # 记录按下时间
                    tool.prev_image()  # 立即响应一次

                elif event.key == K_RIGHT:          # 向后自动播放
                    tool.play_direction = 1
                    tool.playing        = True
                    tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_LEFT:           # 向前自动播放
                    tool.play_direction = -1
                    tool.playing        = True
                    tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_SPACE:          # 暂停/继续
                    tool.playing = not tool.playing
                    if tool.playing:
                        tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_v:
                    tool.convert_to_video = not tool.convert_to_video
                    print("转视频模式:" + ("开启" if tool.convert_to_video else "关闭"))
                elif event.key == K_w:  # 标记为正样本
                    tool.label_current_image("positive")
                elif event.key == K_s:  # 标记为负样本
                    tool.label_current_image("negative")
                elif event.key == K_UP: # 开始连续标记
                    if not tool.start_continuous_labeling():
                        print("无法开始连续标记")
                elif event.key == K_DOWN: # 结束连续标记
                    if not tool.end_continuous_labeling():
                        print("没有激活的连续标记")
                elif event.key == K_x: # 移动文件
                    moved = tool.move_labeled_files(positive_dir, negative_dir)
                    print(f"已移动 {moved} 个文件")
                elif event.key == K_c: # 下一个文件夹
                    tool.next_folder()  
                elif event.key == K_z:  # 上一个文件夹
                    tool.prev_folder()
                elif event.key == K_z and (pygame.key.get_mods() & KMOD_CTRL): # Ctrl+Z 撤销
                    if tool.undo():
                        print("已撤销上一次操作")
                    else:
                        print("没有可撤销的操作")
                elif event.key == K_ESCAPE:  # ESC 键取消确认对话框
                    if tool.show_confirm_dialog:
                        tool.show_confirm_dialog = False

            elif event.type == KEYUP:
                if event.key == K_d:
                    tool.key_pressed["right"] = False
                    tool.last_key_time = 0 # 重置重复计时
                elif event.key == K_a:
                    tool.key_pressed["left"] = False
                    tool.last_key_time =0  # 重置重复计时

            elif event.type == MOUSEBUTTONDOWN:
                if event.button == 1: # 左键点击
                    # 检查是否点击了确认对话框
                    if tool.show_confirm_dialog:
                        dialog_rect, yes_button, no_button = draw_confirm_dialog(screen, tool.confirm_message)
                        if yes_button.collidepoint(mouse_pos):
                            tool.show_confirm_dialog = False
                            if tool.confirm_action == "next_folder":
                                tool.current_folder_index += 1
                                tool.load_current_folder_images()
                        elif no_button.collidepoint(mouse_pos):
                            tool.show_confirm_dialog = Fasle
                    else:
                        # 导航按钮
                        if nav_buttons["prev"].collidepoint(mouse_pos):
                            tool.prev_image()
                        elif nav_buttons["next"].collidepoint(mouse_pos):
                            tool.next_image()
                        elif nav_buttons["prev_folder"].collidepoint(mouse_pos):
                            tool.prev_folder()
                        elif nav_buttons["next_folder"].collidepoint(mouse_pos):
                            tool.next_folder()
                        elif nav_buttons["undo"].collidepoint(mouse_pos):
                            if tool.undo():
                                print("已撤销上一次操作")
                            else:
                                print("没有可撤销的操作")

                        # 标注按钮
                        elif label_buttons["positive"].collidepoint(mouse_pos):
                            tool.label_current_image("positive")
                        elif label_buttons["negative"].collidepoint(mouse_pos):
                            tool.label_current_image("negative")
                        elif label_buttons["continuous_start"].collidepoint(mouse_pos):
                            if not tool.start_continuous_labeling():
                                print("无法开始连续标记")
                        elif label_buttons["continuous_end"].collidepoint(mouse_pos):
                            if not tool.end_continuous_labeling():
                                print("没有激活的连续标记")  
                        elif label_buttons["move_files"].collidepoint(mouse_pos):
                            moved = tool.move_labeled_files(positive_dir, negative_dir)
                            print("已移动 {moved} 个文件") 

        # 清屏
        screen.fill(BG_COLOR)
        
        # 显示文件信息
        if tool.folders:
            folder_text = f"当前文件夹: {os.path.basename(tool.folders[tool.current_folder_index])} ({tool.current_folder_index + 1}/{len(tool.folders)})"
            text_surface = small_font.render(folder_text, True, TEXT_COLOR)
            screen.blit(text_surface, (20, 20))

        # 显示当前图片
        current_image_path = tool.get_current_image()
        if current_image_path and os.path.exists(current_image_path):
            try:
                img = pygame.image.load(current_image_path)
                img_rect = img.get_rect()
        
                # 缩放图片以适应显示区域
                scale = min(image_area.width / img_rect.width, image_area.height / img_rect.height)
                new_size = (int(img_rect.width * scale), int(img_rect.height * scale))
                img = pygame.transform.smoothscale(img, new_size)
                img_rect = img.get_rect(center=image_area.center)

                screen.blit(img, img_rect)
    
                # 显示图片信息(在图片上方)  
                info_text = f"{os.path.basename(current_image_path)} ({tool.current_image_index + 1}/{len(tool.images)})"
                if current_image_path in tool.labels:
                    label = tool.labels[current_image_path]
                    info_text += f" - 已标记: {'正样本' if label == 'positive' else '负样本'}"
                text_surface = font.render(info_text, True, TEXT_COLOR)
                text_rect = text_surface.get_rect(center=(WINDOW_WIDTH // 2, image_area.y - 20))
                screen.blit(text_surface, text_rect)
            
                # 在连续标记模式下显示标记范围
                if tool.continuous_mode and tool.continuous_start_index is not None:
                    start_idx = min(tool.continuous_start_index, tool.current_image_index)
                    end_idx = max(tool.continuous_start_index, tool.current_image_index)
                
                    range_text = f"标记范围: {start_idx + 1} - {end_idx + 1}"
                    range_surface = small_font.render(range_text, True, HIGHLIGHT_COLOR)
                    screen.blit(range_surface, (20, 50))

                    # 绘制标记范围的指示器
                    marker_width = image_area.width / len(tool.images)
                    start_x = image_area.x + start_idx * marker_width
                    end_x = image_area.x + (end_idx + 1) * marker_width
    
                    pygame.draw.rect(screen, HIGHLIGHT_COLOR,
                                        (start_x, image_area.y + image_area.height + 5,
                                         end_x -start_x, 5))

            except Exception as e:
                error_text = f"无法加载图片: {e}"
                text_surface = font.render(error_text, True, (255, 0, 0))
                screen.blit(text_surface, (image_area.centerx - text_surface.get_width() // 2, image_area.centery - text_surface.get_height() // 2))

        else:
            no_image_text = "没有图片可显示"
            text_surface = font.render(no_image_text, True, TEXT_COLOR)
            screen.blit(text_surface, (image_area.centerx - text_surface.get_width() // 2, image_area.centery - text_surface.get_height() // 2))

        # 显示连续标记状态
        if tool.continuous_mode:
            mode_text = f"连续标记模式已启动 - 标记类型: {'正样本' if tool.continuous_label == 'positive' else '负样本'}"
            text_surface = small_font.render(mode_text, True, HIGHLIGHT_COLOR)
            screen.blit(text_surface, (WINDOW_WIDTH - text_surface.get_width() - 20, 50))

        # 绘制导航按钮
        draw_button(screen, "上一张 (a)", nav_buttons["prev"], nav_buttons["prev"].collidepoint(mouse_pos))
        draw_button(screen, "下一张 (d)", nav_buttons["next"], nav_buttons["next"].collidepoint(mouse_pos))
        draw_button(screen, "上个文件夹 (z)", nav_buttons["prev_folder"], nav_buttons["prev_folder"].collidepoint(mouse_pos))
        draw_button(screen, "下个文件夹 (c)", nav_buttons["next_folder"], nav_buttons["next_folder"].collidepoint(mouse_pos))
        draw_button(screen, "撤销 (Ctrl+Z)", nav_buttons["undo"], nav_buttons["undo"].collidepoint(mouse_pos))

        # 绘制标注按钮
        draw_button(screen, "正样本 (w)", label_buttons["positive"], label_buttons["positive"].collidepoint(mouse_pos))
        draw_button(screen, "负样本 (s)", label_buttons["negative"], label_buttons["negative"].collidepoint(mouse_pos))
        draw_button(screen, "开始连续标(↑)", label_buttons["continuous_start"], label_buttons["continuous_start"].collidepoint(mouse_pos))
        draw_button(screen, "结束连续标(↓)", label_buttons["continuous_end"], label_buttons["continuous_end"].collidepoint(mouse_pos))
        draw_button(screen, "移动文件 (x)", label_buttons["move_files"], label_buttons["move_files"].collidepoint(mouse_pos))

        # 显示确认对话框
        if tool.show_confirm_dialog:
            draw_confirm_dialog(screen, tool.confirm_message)


        # 更新屏幕
        pygame.display.flip()
        clock.tick(30)

    # 退出前保存标记状态
    tool.save_labels()
    pygame.quit()
    sys.exit()
        
if __name__ == "__main__":
    main()

2.5 新增自动补标负样本功能

我想新增一个变量用于控制是否仅标注正样本,如果变量为True的话,即使用户只标注了正样本,也会把正样本之前的帧全都视为负样本,比如说用户标注了10到20帧是正样本,但是0到10帧没有标注,在按x移动的时候,不但会把10-20帧视为正样本转为一个视频并放到1的文件夹,也会将0-10视为负样本并且转为一个视频并放到0的文件夹。

我的意思是从第一帧开始找到第一个正样本,在此之前的自动视为负样本,并不是说之后的只要没标的都是负样本,这两种情况要区分一下。

新增一个枚举变量 auto_neg_mode,可切换三种行为:

  • auto_neg_mode = 0 关闭自动补标(原逻辑)

  • auto_neg_mode = 1 仅把"第一个正样本之前"的帧补成负样本(你要的新需求)

  • auto_neg_mode = 2 把所有未标帧都补成负样本

下面给出"仅把第一个正样本之前的帧 自动补成负样本"的完整实现。

新增一个枚举变量 auto_neg_mode,可切换三种行为:

  • auto_neg_mode = 0 关闭自动补标(原逻辑)

  • auto_neg_mode = 1 仅把"第一个正样本之前"的帧补成负样本(你要的新需求)

  • auto_neg_mode = 2 把所有未标帧都补成负样本(上一版实现)

---------------- 代码改动 ----------------

  1. __init__ 里加两个变量

    自动补标负样本模式

    0=关闭 1=仅第一个正样本之前 2=全部未标

    self.auto_neg_mode = 1

  2. move_labeled_files 里替换原来的"补标"逻辑

    找到你原来已经有的

    按标签分组

    from collections import defaultdict
    groups = defaultdict(list)

把上一回我给的补标代码整块删掉,换成下面这段:

复制代码
# ---------- 自动补标负样本 ----------
if self.auto_neg_mode:
    # 当前文件夹已标正样本的索引
    pos_idx = [i for i, p in enumerate(self.images)
               if p in self.labels and self.labels[p] == 'positive']
    if pos_idx:                       # 有正样本才需要补
        first_pos = min(pos_idx)      # 第一个正样本序号
        end_idx   = len(self.images) if self.auto_neg_mode == 2 else first_pos
        for i in range(0, end_idx):
            p = self.images[i]
            if p not in self.labels:   # 未标才补
                self.labels[p] = 'negative'
# -----------------------------------
  1. 如果想在界面上随时切换模式,可在 main() 里再加一组热键,例如

    elif event.key == K_1:
    tool.auto_neg_mode = (tool.auto_neg_mode + 1) % 3
    modes = ['关闭', '首正之前', '全部未标']
    print('自动补标模式:', modes[tool.auto_neg_mode])

---------------- 效果 ----------------

  • auto_neg_mode = 1 后,用户只标 10-20 帧为正样本:

    -- 0-9 帧 → 自动补标为负样本,生成 xxx_negative_0to9.mp4 放入 0 文件夹;

    -- 10-20 帧 → 正样本视频放入 1 文件夹;

    -- 21 帧以后若未标,保持未标状态,不会被处理。

  • 若想回到"全部未标都当负样本"或完全关闭,只需把 auto_neg_mode 切到 2 或 0 即可,互不干扰。

ruby 复制代码
import os
import pygame
import sys
import shutil
import time
import json
from pygame.locals import *
import cv2  # pip install opencv-python 

# 初始化pygame
pygame.init()

# 配置参数
SCREEN_WIDTH, SCREEN_HEIGHT = pygame.display.Info().current_w, pygame.display.Info().current_h   

WINDOW_WIDTH, WINDOW_HEIGHT = SCREEN_WIDTH - 100, SCREEN_HEIGHT - 100
BG_COLOR = (40, 44, 52)
TEXT_COLOR = (220, 220, 220)
HIGHLIGHT_COLOR = (97, 175, 239)
BUTTON_COLOR = (56, 58, 66)
BUTTON_HOVER_COLOR = (72, 74, 82)
WARNING_COLOR = (255, 152, 0)
CONFIRM_COLOR = (76, 175, 80)

# 创建窗口
screen = pygame.display.set_mode((WINDOW_WIDTH, WINDOW_HEIGHT))
pygame.display.set_caption("图像分类标注工具")

# 字体
font = pygame.font.SysFont("SimHei", 24)
small_font = pygame.font.SysFont("SimHei", 18)

class ImageLabelingTool:
    def __init__(self, root_path):
        self.root_path = root_path
        self.folders = []               # 所有含图片的文件夹绝对路径
        self.current_folder_index = 0   # 当前文件夹索引
        self.images = []                # 当前文件夹内所有图片绝对路径
        self.current_image_index = 0    # 当前图片索引
        self.labels = {}                # 路径 -> 'positive' / 'negative'

        # 自动补标负样本模式
        # 0=关闭 1=仅第一个正样本之前 2=全部未标
        self.auto_neg_mode = 1

        self.convert_to_video = True  # 是否启用转视频模式
        self.video_fps = 10  # 视频帧率
        
        # 自动播放相关
        self.playing        = False      # 是否处于自动播放
        self.play_direction = 1          # 1 下一张,-1 上一张
        self.last_play_tick = 0          # 上一次翻片的时间
        self.play_interval  = 100        # 每 多少 毫秒 翻一张

        # 标记状态
        self.continuous_mode = False    # 是否处于连续标记模式
        self.continuous_label = None    # 连续标记时统一的标签
        self.continuous_start_index = None  # 连续标记起始索引

        # 键盘长按状态
        self.key_pressed = {"left": False, "right": False}
        self.last_key_time = 0          # 长按重复计时
        self.key_repeat_delay = 0.8  # 初始延迟增加到0.8秒
        self.key_repeat_interval = 0.15  # 重复间隔增加到0.15秒  

        # 操作历史(用于撤销)
        self.undo_stack = []
        self.max_undo_steps = 50
  
        # 确认对话框状态
        self.show_confirm_dialog = False
        self.confirm_message = ""
        self.confirm_action = ""        # 标记确认对话框触发动作

        # 获取所有包含图片的文件夹
        self.find_image_folders()

        # 加载当前文件夹的图片
        if self.folders:
            self.load_current_folder_images()
        
        # 加载保存的标记状态
        self.load_labels()          # 尝试加载历史标签

    def images_to_video(self, image_paths, output_path, fps=10):
        """将图片序列转为视频"""
        if not image_paths:
            return

        # 读取第一张图获取尺寸
        frame = cv2.imread(image_paths[0])
        h, w, _ = frame.shape

        # 初始化视频写入器
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

        for img_path in image_paths:
            frame = cv2.imread(img_path)
            out.write(frame)

        out.release()

        # 删除原图
        for img_path in image_paths:
            if os.path.exists(img_path):
                os.remove(img_path)
        
    def find_image_folders(self):
        """查找所有包含图片的文件夹"""
        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
        for root, dirs, files in os.walk(self.root_path):
            has_images = any(file.lower().endswith(image_extensions) for file in files)
            if has_images:
                self.folders.append(root)

    def load_current_folder_images(self):
        """加载当前文件夹中的所有图片"""
        folder_path = self.folders[self.current_folder_index]
        self.images = []

        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')

        for file in os.listdir(folder_path):
            if file.lower().endswith(image_extensions):
                self.images.append(os.path.join(folder_path, file))

        # 按文件名排序
        self.images.sort()
        self.current_image_index = 0

    def get_current_image(self):
        """获取当前图片"""
        if not self.images:
            return None
        return self.images[self.current_image_index]

    def next_image(self):
        """切换到下一张图片"""
        if self.current_image_index < len(self.images) - 1:
            self.save_state()  # 保存状态以便撤销
            self.current_image_index += 1
            return True
        return False
        
    def prev_image(self):
        """切换到上一张图片"""
        if self.current_image_index > 0:
            self.current_image_index -= 1
            return True
        return False

    def label_current_image(self, label):
        """标记当前图片"""
        current_image= self.get_current_image()
        if current_image:
            self.save_state()        # 保存状态以便撤销         
            self.labels[current_image] = label
            # 自动保存标记状态
            self.save_labels()

    def start_continuous_labeling(self):
        """开始连续标记"""
        current_image = self.get_current_image()
        if current_image:
            self.save_state()  # 保存状态以便撤销
            # 如果当前图片已经有标签,使用该标签
            if current_image in self.labels:
                self.continuous_label = self.labels[current_image]
            else:
                # 如果没有标签,默认为正样本
                self.continuous_label = "positive"
                self.labels[current_image] = self.continuous_label

            self.continuous_mode = True
            self.continuous_start_index = self.current_image_index
            # 自动保存标记状态
            self.save_labels()
            return True
        return False

    def end_continuous_labeling(self):
        """结束连续标记"""
        if self.continuous_mode and self.continuous_start_index is not None:
            self.save_state()  # 保存状态以便撤销
            start = min(self.continuous_start_index, self.current_image_index)
            end = max(self.continuous_start_index, self.current_image_index)

            for i in range(start, end + 1):
                self.labels[self.images[i]] = self.continuous_label

            self.continuous_mode = False
            self.continuous_start_index = None
            # 自动保存标记状态
            self.save_labels()
            return True
        return False

    def move_labeled_files(self, positive_dir, negative_dir):
        """移动已标记的文件到正负样本文件夹"""
        if not os.path.exists(positive_dir):
            os.makedirs(positive_dir)
        if not os.path.exists(negative_dir):
            os.makedirs(negative_dir)

        # 按标签分组
        from collections import defaultdict
        groups = defaultdict(list)

        # ---------- 自动补标负样本 ----------
        if self.auto_neg_mode:
            # 当前文件夹已标正样本的索引
            pos_idx = [i for i, p in enumerate(self.images)
                       if p in self.labels and self.labels[p] == 'positive']
            if pos_idx:                       # 有正样本才需要补
                first_pos = min(pos_idx)      # 第一个正样本序号
                end_idx   = len(self.images) if self.auto_neg_mode == 2 else first_pos
                for i in range(0, end_idx):
                    p = self.images[i]
                    if p not in self.labels:   # 未标才补
                        self.labels[p] = 'negative'
        # -----------------------------------

        # 按标签分组
        from collections import defaultdict
        groups = defaultdict(list)

        for img_path, label in self.labels.items():
            if label in ["positive", "negative"] and os.path.exists(img_path):
                groups[label].append(img_path)

        # 排序每组图片(按文件名)
        for label in groups:
            groups[label].sort()

        # 处理每组
        for label, image_paths in groups.items():
            if not image_paths:
                continue

            dest_dir = positive_dir if label == "positive" else negative_dir

            if self.convert_to_video:
                # 合成视频
                # 提取帧号范围
                def extract_index(path):
                    # 假设文件名格式为 frame_0010.jpg 或 0010.jpg
                    name = os.path.splitext(os.path.basename(path))[0]
                    try:
                        return int(name.split('_')[-1])
                    except ValueError:
                        return None

                start_idx = extract_index(image_paths[0])
                end_idx = extract_index(image_paths[-1])

                if start_idx is not None and end_idx is not None:
                    range_str = f"{start_idx}to{end_idx}"
                else:
                    range_str = "unknown_range"

                folder_name = os.path.basename(os.path.dirname(image_paths[0]))
                video_name = f"{folder_name}_{label}_{range_str}.mp4"

                video_path = os.path.join(dest_dir, video_name)
                self.images_to_video(image_paths, video_path, fps=self.video_fps)
            else:
                # 原逻辑:移动文件
                for img_path in image_paths:
                    filename = os.path.basename(img_path)
                    dest_path = os.path.join(dest_dir, filename)
                    shutil.move(img_path, dest_path)

            # 从标签中移除
            for img_path in image_paths:
                self.labels.pop(img_path, None)
        
        # 重新加载当前文件夹
        self.load_current_folder_images()
        self.save_labels()

    def next_folder(self):
        """切换到下一个文件夹"""
        if self.current_folder_index < len(self.folders) - 1:
            # 检查当前文件夹是否有未移动的标记文件
            current_folder = self.folders[self.current_folder_index]
            has_unmoved_labels = any(
                img_path.startswith(current_folder) and os.path.exists(img_path)
                for img_path in self.labels.keys()
            )

            if has_unmoved_labels:  
                # 显示确认对话框
                self.show_confirm_dialog = True
                self.confirm_action = "next_folder"
                self.confirm_message = "当前文件夹有未移动的标记文件,确定要切换到下一个文件夹吗?"
                return False
            else:
                # 直接切换文件夹
                self.current_folder_index += 1
                self.load_current_folder_images()
                return True
        return False

    def prev_folder(self):
        """切换到上一个文件夹"""
        if self.current_folder_index > 0:
            self.current_folder_index -= 1
            self.load_current_folder_images()
            return True
        return False

    def handle_key_repeats(self):
        """处理方向键长按"""   
        current_time = time.time() 

        # 检查是否需要触发按键重复
        if any(self.key_pressed.values()):
            # 如果是第一次按下,等待较长时间
            if self.last_key_time == 0:
                if current_time - self.key_pressed_time > self.key_repeat_delay:
                    if self.key_pressed["left"]:
                        self.prev_image()
                    elif self.key_pressed["right"]:
                        self.next_image()
                    self.last_key_time = current_time
            # 后续重复,使用较短的间隔  
            elif current_time - self.last_key_time > self.key_repeat_interval:
                if self.key_pressed["left"]:
                    self.prev_image()
                elif self.key_pressed["right"]:  
                    self.next_image()
                self.last_key_time = current_time

    def save_state(self):
        """保存当前状态以便撤销"""
        if len(self.undo_stack) >= self.max_undo_steps:
            self.undo_stack.pop(0)  # 移除最旧的状态

        state = {
            "current_image_index": self.current_image_index,
            "labels": self.labels.copy(),
            "continuous_mode": self.continuous_mode,
            "continuous_start_index": self.continuous_start_index,
            "continuous_label": self.continuous_label
        } 

        self.undo_stack.append(state)

    def undo(self):
        """撤销上一次操作"""
        if self.undo_stack:
            state = self.undo_stack.pop()
            self.current_image_index = state["current_image_index"]
            self.labels = state["labels"]
            self.continuous_mode = state["continuous_mode"]
            self.continuous_start_index = state["continuous_start_index"]
            self.continuous_label = state["continuous_label"]
            return True
        return False

    def save_labels(self):
        """保存标记状态到文件"""
        labels_file = os.path.join(self.root_path, "labels_backup.json")
        try:
            # 只保存仍然存在的文件的标记
            existing_labels = {k: v for k, v in self.labels.items() if os.path.exists(k)}
            with open(labels_file, 'w') as f:
                json.dump(existing_labels, f)
        except Exception as e:
            print(f"保存标记状态失败: {e}")

    def load_labels(self):
        """从文件加载标记状态"""
        labels_file = os.path.join(self.root_path, "labels_backup.json")
        if os.path.exists(labels_file):
            try:
                with open(labels_file, 'r') as f:
                    self.labels = json.load(f)
            except Exception as e:
                print(f"加载标记状态失败: {e}")

def draw_button(screen, text, rect, hover=False, color=None):
    """绘制按钮"""
    if color is None:
        color = BUTTON_HOVER_COLOR if hover else BUTTON_COLOR

    # 先画主体
    pygame.draw.rect(screen, color, rect, border_radius=5)
    # 再画边框
    pygame.draw.rect(screen, (100, 100, 100), rect, 2, border_radius=5)

    # 文字居中
    text_surface= small_font.render(text, True, TEXT_COLOR)
    txt_rect = text_surface.get_rect(center=rect.center)
    screen.blit(text_surface, txt_rect)

def draw_confirm_dialog(screen, message, width=400, height=200):
    """绘制确认对话框"""
    dialog_rect = pygame.Rect(
        (WINDOW_WIDTH - width) // 2,
        (WINDOW_HEIGHT - height) // 2, 
        width, height
    )

    # 绘制对话框背景
    pygame.draw.rect(screen, BG_COLOR, dialog_rect, border_radius=10)
    pygame.draw.rect(screen, TEXT_COLOR, dialog_rect, 2, border_radius=10)

    # 绘制消息
    lines = []
    words = message.split()
    current_line = ""

    for word in words:
        test_line = current_line + word + " "
        if small_font.size(test_line)[0] < width - 40:
            current_line = test_line
        else:
            lines.append(current_line)
            current_line = word + " "

    if current_line:
        lines.append(current_line)

    for i, line in enumerate(lines):
        text_surface = small_font.render(line, True, TEXT_COLOR)
    screen.blit(text_surface, (dialog_rect.x + 20, dialog_rect.y + 30 + i * 25))

    # 绘制按钮
    yes_button= pygame.Rect(dialog_rect.x + width // 2 - 100, dialog_rect.y + height - 50, 80, 30)
    no_button  = pygame.Rect(dialog_rect.x + width // 2 + 20, dialog_rect.y + height - 50, 80, 30)

    draw_button(screen, "是", yes_button, color=CONFIRM_COLOR)
    draw_button(screen, "否", no_button, color=WARNING_COLOR)

    return dialog_rect, yes_button, no_button

def main():
    # 假设的根路径,实际使用时需要修改
    root_path = r"D:\zero_track\mmaction2\input_videos\test1"

    # 创建标注工具实例
    tool = ImageLabelingTool(root_path)

    # 创建正负样本输出目录
    # positive_dir = os.path.join(root_path, "positive_samples")
    # negative_dir = os.path.join(root_path, "negative_samples")
    positive_dir = os.path.join(root_path, "1")
    negative_dir = os.path.join(root_path, "0")
    os.makedirs(positive_dir, exist_ok=True)
    os.makedirs(negative_dir, exist_ok=True)

    # 主循环
    running = True
    clock = pygame.time.Clock()

    # 按钮区域 - 分为两行
    button_height = 40
    button_width = 140
    button_margin =15
    button_row1_y = WINDOW_HEIGHT - button_height - button_margin
    button_row2_y = WINDOW_HEIGHT - 2 * button_height - 2 * button_margin
    
    # 第一行按钮(导航按钮)
    nav_buttons = {
        "prev": pygame.Rect(button_margin, button_row2_y, button_width, button_height),
        "next": pygame.Rect(button_margin * 2 + button_width, button_row2_y, button_width, button_height),
        "prev_folder": pygame.Rect(button_margin * 3 + button_width * 2, button_row2_y, button_width, button_height),
        "next_folder": pygame.Rect(button_margin * 4 + button_width * 3, button_row2_y, button_width, button_height),
        "undo": pygame.Rect(button_margin * 5 + button_width * 4, button_row2_y, button_width, button_height),
    }

    # 第二行按钮(标注按钮)
    label_buttons = {
        "positive": pygame.Rect(button_margin, button_row1_y, button_width, button_height),
        "negative": pygame.Rect(button_margin * 2 + button_width, button_row1_y, button_width, button_height),
        "continuous_start": pygame.Rect(button_margin * 3 + button_width * 2, button_row1_y, button_width, button_height),
        "continuous_end": pygame.Rect(button_margin * 4 + button_width * 3, button_row1_y, button_width, button_height),
        "move_files": pygame.Rect(button_margin * 5 + button_width * 4, button_row1_y, button_width, button_height),
    }

    # 图片显示区域
    image_area = pygame.Rect(50, 80, WINDOW_WIDTH - 100, WINDOW_HEIGHT - 220)


    # 添加按键按下时间记录
    tool.key_pressed_time = 0

    while running:
        mouse_pos = pygame.mouse.get_pos()

        # 处理按键重复
        tool.handle_key_repeats()

        # 自动播放逻辑
        if tool.playing:
            now = pygame.time.get_ticks()
            if now - tool.last_play_tick > tool.play_interval:
                if tool.play_direction == 1:
                    tool.next_image()
                else:
                    tool.prev_image()
                tool.last_play_tick = now

        for event in pygame.event.get():
            if event.type == QUIT:
                running = False
            elif event.type == KEYDOWN:
                if event.key == K_d:
                    tool.key_pressed["right"] = True
                    tool.key_pressed["left"] = False
                    tool.key_pressed_time = time.time()  # 记录按下时间
                    tool.next_image()  # 立即响应一次
                elif event.key == K_a:
                    tool.key_pressed["left"] = True
                    tool.key_pressed["right"] = False
                    tool.key_pressed_time = time.time()  # 记录按下时间
                    tool.prev_image()  # 立即响应一次

                elif event.key == K_RIGHT:          # 向后自动播放
                    tool.play_direction = 1
                    tool.playing        = True
                    tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_LEFT:           # 向前自动播放
                    tool.play_direction = -1
                    tool.playing        = True
                    tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_SPACE:          # 暂停/继续
                    tool.playing = not tool.playing
                    if tool.playing:
                        tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_v:
                    tool.convert_to_video = not tool.convert_to_video
                    print("转视频模式:" + ("开启" if tool.convert_to_video else "关闭"))
                elif event.key == K_w:  # 标记为正样本
                    tool.label_current_image("positive")
                elif event.key == K_s:  # 标记为负样本
                    tool.label_current_image("negative")
                elif event.key == K_UP: # 开始连续标记
                    if not tool.start_continuous_labeling():
                        print("无法开始连续标记")
                elif event.key == K_DOWN: # 结束连续标记
                    if not tool.end_continuous_labeling():
                        print("没有激活的连续标记")
                elif event.key == K_x: # 移动文件
                    moved = tool.move_labeled_files(positive_dir, negative_dir)
                    print(f"已移动 {moved} 个文件")
                elif event.key == K_c: # 下一个文件夹
                    tool.next_folder()  
                elif event.key == K_z:  # 上一个文件夹
                    tool.prev_folder()
                elif event.key == K_z and (pygame.key.get_mods() & KMOD_CTRL): # Ctrl+Z 撤销
                    if tool.undo():
                        print("已撤销上一次操作")
                    else:
                        print("没有可撤销的操作")
                elif event.key == K_ESCAPE:  # ESC 键取消确认对话框
                    if tool.show_confirm_dialog:
                        tool.show_confirm_dialog = False

            elif event.type == KEYUP:
                if event.key == K_d:
                    tool.key_pressed["right"] = False
                    tool.last_key_time = 0 # 重置重复计时
                elif event.key == K_a:
                    tool.key_pressed["left"] = False
                    tool.last_key_time =0  # 重置重复计时

            elif event.type == MOUSEBUTTONDOWN:
                if event.button == 1: # 左键点击
                    # 检查是否点击了确认对话框
                    if tool.show_confirm_dialog:
                        dialog_rect, yes_button, no_button = draw_confirm_dialog(screen, tool.confirm_message)
                        if yes_button.collidepoint(mouse_pos):
                            tool.show_confirm_dialog = False
                            if tool.confirm_action == "next_folder":
                                tool.current_folder_index += 1
                                tool.load_current_folder_images()
                        elif no_button.collidepoint(mouse_pos):
                            tool.show_confirm_dialog = Fasle
                    else:
                        # 导航按钮
                        if nav_buttons["prev"].collidepoint(mouse_pos):
                            tool.prev_image()
                        elif nav_buttons["next"].collidepoint(mouse_pos):
                            tool.next_image()
                        elif nav_buttons["prev_folder"].collidepoint(mouse_pos):
                            tool.prev_folder()
                        elif nav_buttons["next_folder"].collidepoint(mouse_pos):
                            tool.next_folder()
                        elif nav_buttons["undo"].collidepoint(mouse_pos):
                            if tool.undo():
                                print("已撤销上一次操作")
                            else:
                                print("没有可撤销的操作")

                        # 标注按钮
                        elif label_buttons["positive"].collidepoint(mouse_pos):
                            tool.label_current_image("positive")
                        elif label_buttons["negative"].collidepoint(mouse_pos):
                            tool.label_current_image("negative")
                        elif label_buttons["continuous_start"].collidepoint(mouse_pos):
                            if not tool.start_continuous_labeling():
                                print("无法开始连续标记")
                        elif label_buttons["continuous_end"].collidepoint(mouse_pos):
                            if not tool.end_continuous_labeling():
                                print("没有激活的连续标记")  
                        elif label_buttons["move_files"].collidepoint(mouse_pos):
                            moved = tool.move_labeled_files(positive_dir, negative_dir)
                            print("已移动 {moved} 个文件") 

        # 清屏
        screen.fill(BG_COLOR)
        
        # 显示文件信息
        if tool.folders:
            folder_text = f"当前文件夹: {os.path.basename(tool.folders[tool.current_folder_index])} ({tool.current_folder_index + 1}/{len(tool.folders)})"
            text_surface = small_font.render(folder_text, True, TEXT_COLOR)
            screen.blit(text_surface, (20, 20))

        # 显示当前图片
        current_image_path = tool.get_current_image()
        if current_image_path and os.path.exists(current_image_path):
            try:
                img = pygame.image.load(current_image_path)
                img_rect = img.get_rect()
        
                # 缩放图片以适应显示区域
                scale = min(image_area.width / img_rect.width, image_area.height / img_rect.height)
                new_size = (int(img_rect.width * scale), int(img_rect.height * scale))
                img = pygame.transform.smoothscale(img, new_size)
                img_rect = img.get_rect(center=image_area.center)

                screen.blit(img, img_rect)
    
                # 显示图片信息(在图片上方)  
                info_text = f"{os.path.basename(current_image_path)} ({tool.current_image_index + 1}/{len(tool.images)})"
                if current_image_path in tool.labels:
                    label = tool.labels[current_image_path]
                    info_text += f" - 已标记: {'正样本' if label == 'positive' else '负样本'}"
                text_surface = font.render(info_text, True, TEXT_COLOR)
                text_rect = text_surface.get_rect(center=(WINDOW_WIDTH // 2, image_area.y - 20))
                screen.blit(text_surface, text_rect)
            
                # 在连续标记模式下显示标记范围
                if tool.continuous_mode and tool.continuous_start_index is not None:
                    start_idx = min(tool.continuous_start_index, tool.current_image_index)
                    end_idx = max(tool.continuous_start_index, tool.current_image_index)
                
                    range_text = f"标记范围: {start_idx + 1} - {end_idx + 1}"
                    range_surface = small_font.render(range_text, True, HIGHLIGHT_COLOR)
                    screen.blit(range_surface, (20, 50))

                    # 绘制标记范围的指示器
                    marker_width = image_area.width / len(tool.images)
                    start_x = image_area.x + start_idx * marker_width
                    end_x = image_area.x + (end_idx + 1) * marker_width
    
                    pygame.draw.rect(screen, HIGHLIGHT_COLOR,
                                        (start_x, image_area.y + image_area.height + 5,
                                         end_x -start_x, 5))

            except Exception as e:
                error_text = f"无法加载图片: {e}"
                text_surface = font.render(error_text, True, (255, 0, 0))
                screen.blit(text_surface, (image_area.centerx - text_surface.get_width() // 2, image_area.centery - text_surface.get_height() // 2))

        else:
            no_image_text = "没有图片可显示"
            text_surface = font.render(no_image_text, True, TEXT_COLOR)
            screen.blit(text_surface, (image_area.centerx - text_surface.get_width() // 2, image_area.centery - text_surface.get_height() // 2))

        # 显示连续标记状态
        if tool.continuous_mode:
            mode_text = f"连续标记模式已启动 - 标记类型: {'正样本' if tool.continuous_label == 'positive' else '负样本'}"
            text_surface = small_font.render(mode_text, True, HIGHLIGHT_COLOR)
            screen.blit(text_surface, (WINDOW_WIDTH - text_surface.get_width() - 20, 50))

        # 绘制导航按钮
        draw_button(screen, "上一张 (a)", nav_buttons["prev"], nav_buttons["prev"].collidepoint(mouse_pos))
        draw_button(screen, "下一张 (d)", nav_buttons["next"], nav_buttons["next"].collidepoint(mouse_pos))
        draw_button(screen, "上个文件夹 (z)", nav_buttons["prev_folder"], nav_buttons["prev_folder"].collidepoint(mouse_pos))
        draw_button(screen, "下个文件夹 (c)", nav_buttons["next_folder"], nav_buttons["next_folder"].collidepoint(mouse_pos))
        draw_button(screen, "撤销 (Ctrl+Z)", nav_buttons["undo"], nav_buttons["undo"].collidepoint(mouse_pos))

        # 绘制标注按钮
        draw_button(screen, "正样本 (w)", label_buttons["positive"], label_buttons["positive"].collidepoint(mouse_pos))
        draw_button(screen, "负样本 (s)", label_buttons["negative"], label_buttons["negative"].collidepoint(mouse_pos))
        draw_button(screen, "开始连续标(↑)", label_buttons["continuous_start"], label_buttons["continuous_start"].collidepoint(mouse_pos))
        draw_button(screen, "结束连续标(↓)", label_buttons["continuous_end"], label_buttons["continuous_end"].collidepoint(mouse_pos))
        draw_button(screen, "移动文件 (x)", label_buttons["move_files"], label_buttons["move_files"].collidepoint(mouse_pos))

        # 显示确认对话框
        if tool.show_confirm_dialog:
            draw_confirm_dialog(screen, tool.confirm_message)


        # 更新屏幕
        pygame.display.flip()
        clock.tick(30)

    # 退出前保存标记状态
    tool.save_labels()
    pygame.quit()
    sys.exit()
        
if __name__ == "__main__":
    main()

2.6 新增负样本模式

我想新增一个变量,用于控制将auto_neg_mode为1的时候,现在的逻辑是当auto_neg_mode为1的时候,第一个正样本之前的视为负样本,我想可以切换到以负样本为主的标注方式,即用户只标负样本,然后按x之后就会将负样本的视频帧区间变为视频片段放在0文件夹,但是在第一个负样本之前的图片帧不自动补标为正样本但是会将其从当前图片帧读取中丢弃

python 复制代码
import os
import pygame
import sys
import shutil
import time
import json
from pygame.locals import *
import cv2  # pip install opencv-python 

# 初始化pygame
pygame.init()

# 配置参数
SCREEN_WIDTH, SCREEN_HEIGHT = pygame.display.Info().current_w, pygame.display.Info().current_h   

WINDOW_WIDTH, WINDOW_HEIGHT = SCREEN_WIDTH - 100, SCREEN_HEIGHT - 100
BG_COLOR = (40, 44, 52)
TEXT_COLOR = (220, 220, 220)
HIGHLIGHT_COLOR = (97, 175, 239)
BUTTON_COLOR = (56, 58, 66)
BUTTON_HOVER_COLOR = (72, 74, 82)
WARNING_COLOR = (255, 152, 0)
CONFIRM_COLOR = (76, 175, 80)

# 创建窗口
screen = pygame.display.set_mode((WINDOW_WIDTH, WINDOW_HEIGHT))
pygame.display.set_caption("图像分类标注工具")

# 字体
font = pygame.font.SysFont("SimHei", 24)
small_font = pygame.font.SysFont("SimHei", 18)

class ImageLabelingTool:
    def __init__(self, root_path):
        self.root_path = root_path
        self.folders = []               # 所有含图片的文件夹绝对路径
        self.current_folder_index = 0   # 当前文件夹索引
        self.images = []                # 当前文件夹内所有图片绝对路径
        self.current_image_index = 0    # 当前图片索引
        self.labels = {}                # 路径 -> 'positive' / 'negative'

        # 自动补标负样本模式
        # 0=关闭 1=仅第一个正样本之前 2=全部未标
        self.auto_neg_mode = 1

        # 新增:标注模式
        # "positive" = 正样本为主模式(默认)
        # "negative" = 负样本为主模式
        self.labeling_mode = "negative"   

        self.convert_to_video = True  # 是否启用转视频模式
        self.video_fps = 10  # 视频帧率
        
        # 自动播放相关
        self.playing        = False      # 是否处于自动播放
        self.play_direction = 1          # 1 下一张,-1 上一张
        self.last_play_tick = 0          # 上一次翻片的时间
        self.play_interval  = 100        # 每 多少 毫秒 翻一张

        # 标记状态
        self.continuous_mode = False    # 是否处于连续标记模式
        self.continuous_label = None    # 连续标记时统一的标签
        self.continuous_start_index = None  # 连续标记起始索引

        # 键盘长按状态
        self.key_pressed = {"left": False, "right": False}
        self.last_key_time = 0          # 长按重复计时
        self.key_repeat_delay = 0.8  # 初始延迟增加到0.8秒
        self.key_repeat_interval = 0.15  # 重复间隔增加到0.15秒  

        # 操作历史(用于撤销)
        self.undo_stack = []
        self.max_undo_steps = 50
  
        # 确认对话框状态
        self.show_confirm_dialog = False
        self.confirm_message = ""
        self.confirm_action = ""        # 标记确认对话框触发动作

        # 获取所有包含图片的文件夹
        self.find_image_folders()

        # 加载当前文件夹的图片
        if self.folders:
            self.load_current_folder_images()
        
        # 加载保存的标记状态
        self.load_labels()          # 尝试加载历史标签

    def images_to_video(self, image_paths, output_path, fps=10):
        """将图片序列转为视频"""
        if not image_paths:
            return

        # 读取第一张图获取尺寸
        frame = cv2.imread(image_paths[0])
        h, w, _ = frame.shape

        # 初始化视频写入器
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

        for img_path in image_paths:
            frame = cv2.imread(img_path)
            out.write(frame)

        out.release()

        # 删除原图
        for img_path in image_paths:
            if os.path.exists(img_path):
                os.remove(img_path)
        
    def find_image_folders(self):
        """查找所有包含图片的文件夹"""
        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
        for root, dirs, files in os.walk(self.root_path):
            has_images = any(file.lower().endswith(image_extensions) for file in files)
            if has_images:
                self.folders.append(root)

    def load_current_folder_images(self):
        """加载当前文件夹中的所有图片"""
        folder_path = self.folders[self.current_folder_index]
        self.images = []

        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')

        for file in os.listdir(folder_path):
            if file.lower().endswith(image_extensions):
                self.images.append(os.path.join(folder_path, file))

        # 按文件名排序
        self.images.sort()
        self.current_image_index = 0

    def get_current_image(self):
        """获取当前图片"""
        if not self.images:
            return None
        return self.images[self.current_image_index]

    def next_image(self):
        """切换到下一张图片"""
        if self.current_image_index < len(self.images) - 1:
            self.save_state()  # 保存状态以便撤销
            self.current_image_index += 1
            return True
        return False
        
    def prev_image(self):
        """切换到上一张图片"""
        if self.current_image_index > 0:
            self.current_image_index -= 1
            return True
        return False

    def label_current_image(self, label):
        """标记当前图片"""
        current_image= self.get_current_image()
        if current_image:
            self.save_state()        # 保存状态以便撤销         
            self.labels[current_image] = label
            # 自动保存标记状态
            self.save_labels()

    def start_continuous_labeling(self):
        """开始连续标记"""
        current_image = self.get_current_image()
        if current_image:
            self.save_state()  # 保存状态以便撤销
            # 如果当前图片已经有标签,使用该标签
            if current_image in self.labels:
                self.continuous_label = self.labels[current_image]
            else:
                # 如果没有标签,默认为正样本
                self.continuous_label = "positive"
                self.labels[current_image] = self.continuous_label

            self.continuous_mode = True
            self.continuous_start_index = self.current_image_index
            # 自动保存标记状态
            self.save_labels()
            return True
        return False

    def end_continuous_labeling(self):
        """结束连续标记"""
        if self.continuous_mode and self.continuous_start_index is not None:
            self.save_state()  # 保存状态以便撤销
            start = min(self.continuous_start_index, self.current_image_index)
            end = max(self.continuous_start_index, self.current_image_index)

            for i in range(start, end + 1):
                self.labels[self.images[i]] = self.continuous_label

            self.continuous_mode = False
            self.continuous_start_index = None
            # 自动保存标记状态
            self.save_labels()
            return True
        return False

    def move_labeled_files(self, positive_dir, negative_dir):
        """移动已标记的文件到正负样本文件夹"""
        if not os.path.exists(positive_dir):
            os.makedirs(positive_dir)
        if not os.path.exists(negative_dir):
            os.makedirs(negative_dir)

        # 按标签分组
        from collections import defaultdict
        groups = defaultdict(list)

        # ---------- 自动补标负样本 ----------
        # ---------- 自动补标逻辑 ----------
        if self.labeling_mode == "positive":
            if self.auto_neg_mode:
                # 当前文件夹已标正样本的索引
                pos_idx = [i for i, p in enumerate(self.images)
                       if p in self.labels and self.labels[p] == 'positive']
                if pos_idx:                       # 有正样本才需要补
                    first_pos = min(pos_idx)      # 第一个正样本序号
                    end_idx   = len(self.images) if self.auto_neg_mode == 2 else first_pos
                    for i in range(0, end_idx):
                        p = self.images[i]
                        if p not in self.labels:   # 未标才补
                            self.labels[p] = 'negative'
        else:
            # 新逻辑:负样本为主模式
            if self.auto_neg_mode:
                neg_idx = [i for i, p in enumerate(self.images)
                       if p in self.labels and self.labels[p] == 'negative']
                if neg_idx:
                    first_neg = min(neg_idx)
                    # 丢弃第一个负样本之前的图片(不处理)
                    self.images = self.images[first_neg:]
                    self.current_image_index = 0
                
                    # 重新调整标签字典,只保留处理后的图片
                    new_labels = {}
                    for img_path in self.images:
                        if img_path in self.labels:
                            new_labels[img_path] = self.labels[img_path]
                    self.labels = new_labels
        # -----------------------------------

        # 按标签分组
        from collections import defaultdict
        groups = defaultdict(list)

        for img_path, label in self.labels.items():
            if label in ["positive", "negative"] and os.path.exists(img_path):
                groups[label].append(img_path)

        # 排序每组图片(按文件名)
        for label in groups:
            groups[label].sort()

        # 处理每组
        for label, image_paths in groups.items():
            if not image_paths:
                continue

            dest_dir = positive_dir if label == "positive" else negative_dir

            if self.convert_to_video:
                # 合成视频
                # 提取帧号范围
                def extract_index(path):
                    # 假设文件名格式为 frame_0010.jpg 或 0010.jpg
                    name = os.path.splitext(os.path.basename(path))[0]
                    try:
                        return int(name.split('_')[-1])
                    except ValueError:
                        return None

                start_idx = extract_index(image_paths[0])
                end_idx = extract_index(image_paths[-1])

                if start_idx is not None and end_idx is not None:
                    range_str = f"{start_idx}to{end_idx}"
                else:
                    range_str = "unknown_range"

                folder_name = os.path.basename(os.path.dirname(image_paths[0]))
                video_name = f"{folder_name}_{label}_{range_str}.mp4"

                video_path = os.path.join(dest_dir, video_name)
                self.images_to_video(image_paths, video_path, fps=self.video_fps)
            else:
                # 原逻辑:移动文件
                for img_path in image_paths:
                    filename = os.path.basename(img_path)
                    dest_path = os.path.join(dest_dir, filename)
                    shutil.move(img_path, dest_path)

            # 从标签中移除
            for img_path in image_paths:
                self.labels.pop(img_path, None)
        
        # 重新加载当前文件夹
        self.load_current_folder_images()
        self.save_labels()

    def next_folder(self):
        """切换到下一个文件夹"""
        if self.current_folder_index < len(self.folders) - 1:
            # 检查当前文件夹是否有未移动的标记文件
            current_folder = self.folders[self.current_folder_index]
            has_unmoved_labels = any(
                img_path.startswith(current_folder) and os.path.exists(img_path)
                for img_path in self.labels.keys()
            )

            if has_unmoved_labels:  
                # 显示确认对话框
                self.show_confirm_dialog = True
                self.confirm_action = "next_folder"
                self.confirm_message = "当前文件夹有未移动的标记文件,确定要切换到下一个文件夹吗?"
                return False
            else:
                # 直接切换文件夹
                self.current_folder_index += 1
                self.load_current_folder_images()
                return True
        return False

    def prev_folder(self):
        """切换到上一个文件夹"""
        if self.current_folder_index > 0:
            self.current_folder_index -= 1
            self.load_current_folder_images()
            return True
        return False

    def handle_key_repeats(self):
        """处理方向键长按"""   
        current_time = time.time() 

        # 检查是否需要触发按键重复
        if any(self.key_pressed.values()):
            # 如果是第一次按下,等待较长时间
            if self.last_key_time == 0:
                if current_time - self.key_pressed_time > self.key_repeat_delay:
                    if self.key_pressed["left"]:
                        self.prev_image()
                    elif self.key_pressed["right"]:
                        self.next_image()
                    self.last_key_time = current_time
            # 后续重复,使用较短的间隔  
            elif current_time - self.last_key_time > self.key_repeat_interval:
                if self.key_pressed["left"]:
                    self.prev_image()
                elif self.key_pressed["right"]:  
                    self.next_image()
                self.last_key_time = current_time

    def save_state(self):
        """保存当前状态以便撤销"""
        if len(self.undo_stack) >= self.max_undo_steps:
            self.undo_stack.pop(0)  # 移除最旧的状态

        state = {
            "current_image_index": self.current_image_index,
            "labels": self.labels.copy(),
            "continuous_mode": self.continuous_mode,
            "continuous_start_index": self.continuous_start_index,
            "continuous_label": self.continuous_label
        } 

        self.undo_stack.append(state)

    def undo(self):
        """撤销上一次操作"""
        if self.undo_stack:
            state = self.undo_stack.pop()
            self.current_image_index = state["current_image_index"]
            self.labels = state["labels"]
            self.continuous_mode = state["continuous_mode"]
            self.continuous_start_index = state["continuous_start_index"]
            self.continuous_label = state["continuous_label"]
            return True
        return False

    def save_labels(self):
        """保存标记状态到文件"""
        labels_file = os.path.join(self.root_path, "labels_backup.json")
        try:
            # 只保存仍然存在的文件的标记
            existing_labels = {k: v for k, v in self.labels.items() if os.path.exists(k)}
            with open(labels_file, 'w') as f:
                json.dump(existing_labels, f)
        except Exception as e:
            print(f"保存标记状态失败: {e}")

    def load_labels(self):
        """从文件加载标记状态"""
        labels_file = os.path.join(self.root_path, "labels_backup.json")
        if os.path.exists(labels_file):
            try:
                with open(labels_file, 'r') as f:
                    self.labels = json.load(f)
            except Exception as e:
                print(f"加载标记状态失败: {e}")

def draw_button(screen, text, rect, hover=False, color=None):
    """绘制按钮"""
    if color is None:
        color = BUTTON_HOVER_COLOR if hover else BUTTON_COLOR

    # 先画主体
    pygame.draw.rect(screen, color, rect, border_radius=5)
    # 再画边框
    pygame.draw.rect(screen, (100, 100, 100), rect, 2, border_radius=5)

    # 文字居中
    text_surface= small_font.render(text, True, TEXT_COLOR)
    txt_rect = text_surface.get_rect(center=rect.center)
    screen.blit(text_surface, txt_rect)

def draw_confirm_dialog(screen, message, width=400, height=200):
    """绘制确认对话框"""
    dialog_rect = pygame.Rect(
        (WINDOW_WIDTH - width) // 2,
        (WINDOW_HEIGHT - height) // 2, 
        width, height
    )

    # 绘制对话框背景
    pygame.draw.rect(screen, BG_COLOR, dialog_rect, border_radius=10)
    pygame.draw.rect(screen, TEXT_COLOR, dialog_rect, 2, border_radius=10)

    # 绘制消息
    lines = []
    words = message.split()
    current_line = ""

    for word in words:
        test_line = current_line + word + " "
        if small_font.size(test_line)[0] < width - 40:
            current_line = test_line
        else:
            lines.append(current_line)
            current_line = word + " "

    if current_line:
        lines.append(current_line)

    for i, line in enumerate(lines):
        text_surface = small_font.render(line, True, TEXT_COLOR)
    screen.blit(text_surface, (dialog_rect.x + 20, dialog_rect.y + 30 + i * 25))

    # 绘制按钮
    yes_button= pygame.Rect(dialog_rect.x + width // 2 - 100, dialog_rect.y + height - 50, 80, 30)
    no_button  = pygame.Rect(dialog_rect.x + width // 2 + 20, dialog_rect.y + height - 50, 80, 30)

    draw_button(screen, "是", yes_button, color=CONFIRM_COLOR)
    draw_button(screen, "否", no_button, color=WARNING_COLOR)

    return dialog_rect, yes_button, no_button

def main():
    # 假设的根路径,实际使用时需要修改
    root_path = r"D:\zero_track\mmaction2\input_videos\test1"

    # 创建标注工具实例
    tool = ImageLabelingTool(root_path)

    # 创建正负样本输出目录
    # positive_dir = os.path.join(root_path, "positive_samples")
    # negative_dir = os.path.join(root_path, "negative_samples")
    positive_dir = os.path.join(root_path, "1")
    negative_dir = os.path.join(root_path, "0")
    os.makedirs(positive_dir, exist_ok=True)
    os.makedirs(negative_dir, exist_ok=True)

    # 主循环
    running = True
    clock = pygame.time.Clock()

    # 按钮区域 - 分为两行
    button_height = 40
    button_width = 140
    button_margin =15
    button_row1_y = WINDOW_HEIGHT - button_height - button_margin
    button_row2_y = WINDOW_HEIGHT - 2 * button_height - 2 * button_margin
    
    # 第一行按钮(导航按钮)
    nav_buttons = {
        "prev": pygame.Rect(button_margin, button_row2_y, button_width, button_height),
        "next": pygame.Rect(button_margin * 2 + button_width, button_row2_y, button_width, button_height),
        "prev_folder": pygame.Rect(button_margin * 3 + button_width * 2, button_row2_y, button_width, button_height),
        "next_folder": pygame.Rect(button_margin * 4 + button_width * 3, button_row2_y, button_width, button_height),
        "undo": pygame.Rect(button_margin * 5 + button_width * 4, button_row2_y, button_width, button_height),
    }

    # 添加模式切换按钮(放在导航按钮和标注按钮之间)
    mode_button = pygame.Rect(button_margin * 6 + button_width * 5, button_row2_y, button_width, button_height)

    # 第二行按钮(标注按钮)
    label_buttons = {
        "positive": pygame.Rect(button_margin, button_row1_y, button_width, button_height),
        "negative": pygame.Rect(button_margin * 2 + button_width, button_row1_y, button_width, button_height),
        "continuous_start": pygame.Rect(button_margin * 3 + button_width * 2, button_row1_y, button_width, button_height),
        "continuous_end": pygame.Rect(button_margin * 4 + button_width * 3, button_row1_y, button_width, button_height),
        "move_files": pygame.Rect(button_margin * 5 + button_width * 4, button_row1_y, button_width, button_height),
    }

    # 图片显示区域
    image_area = pygame.Rect(50, 80, WINDOW_WIDTH - 100, WINDOW_HEIGHT - 220)


    # 添加按键按下时间记录
    tool.key_pressed_time = 0

    while running:
        mouse_pos = pygame.mouse.get_pos()

        # 处理按键重复
        tool.handle_key_repeats()

        # 自动播放逻辑
        if tool.playing:
            now = pygame.time.get_ticks()
            if now - tool.last_play_tick > tool.play_interval:
                if tool.play_direction == 1:
                    tool.next_image()
                else:
                    tool.prev_image()
                tool.last_play_tick = now

        for event in pygame.event.get():
            if event.type == QUIT:
                running = False
            elif event.type == KEYDOWN:
                if event.key == K_d:
                    tool.key_pressed["right"] = True
                    tool.key_pressed["left"] = False
                    tool.key_pressed_time = time.time()  # 记录按下时间
                    tool.next_image()  # 立即响应一次
                elif event.key == K_a:
                    tool.key_pressed["left"] = True
                    tool.key_pressed["right"] = False
                    tool.key_pressed_time = time.time()  # 记录按下时间
                    tool.prev_image()  # 立即响应一次

                elif event.key == K_RIGHT:          # 向后自动播放
                    tool.play_direction = 1
                    tool.playing        = True
                    tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_LEFT:           # 向前自动播放
                    tool.play_direction = -1
                    tool.playing        = True
                    tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_SPACE:          # 暂停/继续
                    tool.playing = not tool.playing
                    if tool.playing:
                        tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_v:
                    tool.convert_to_video = not tool.convert_to_video
                    print("转视频模式:" + ("开启" if tool.convert_to_video else "关闭"))
                elif event.key == K_w:  # 标记为正样本
                    tool.label_current_image("positive")
                elif event.key == K_s:  # 标记为负样本
                    tool.label_current_image("negative")
                elif event.key == K_UP: # 开始连续标记
                    if not tool.start_continuous_labeling():
                        print("无法开始连续标记")
                elif event.key == K_DOWN: # 结束连续标记
                    if not tool.end_continuous_labeling():
                        print("没有激活的连续标记")
                elif event.key == K_x: # 移动文件
                    moved = tool.move_labeled_files(positive_dir, negative_dir)
                    print(f"已移动 {moved} 个文件")
                elif event.key == K_c: # 下一个文件夹
                    tool.next_folder()  
                elif event.key == K_z:  # 上一个文件夹
                    tool.prev_folder()
                elif event.key == K_z and (pygame.key.get_mods() & KMOD_CTRL): # Ctrl+Z 撤销
                    if tool.undo():
                        print("已撤销上一次操作")
                    else:
                        print("没有可撤销的操作")
                elif event.key == K_ESCAPE:  # ESC 键取消确认对话框
                    if tool.show_confirm_dialog:
                        tool.show_confirm_dialog = False

            elif event.type == KEYUP:
                if event.key == K_d:
                    tool.key_pressed["right"] = False
                    tool.last_key_time = 0 # 重置重复计时
                elif event.key == K_a:
                    tool.key_pressed["left"] = False
                    tool.last_key_time =0  # 重置重复计时

            elif event.type == MOUSEBUTTONDOWN:
                if event.button == 1: # 左键点击
                    # 检查是否点击了确认对话框
                    if tool.show_confirm_dialog:
                        dialog_rect, yes_button, no_button = draw_confirm_dialog(screen, tool.confirm_message)
                        if yes_button.collidepoint(mouse_pos):
                            tool.show_confirm_dialog = False
                            if tool.confirm_action == "next_folder":
                                tool.current_folder_index += 1
                                tool.load_current_folder_images()
                        elif no_button.collidepoint(mouse_pos):
                            tool.show_confirm_dialog = Fasle
                    else:
                        # 导航按钮
                        if nav_buttons["prev"].collidepoint(mouse_pos):
                            tool.prev_image()
                        elif nav_buttons["next"].collidepoint(mouse_pos):
                            tool.next_image()
                        elif nav_buttons["prev_folder"].collidepoint(mouse_pos):
                            tool.prev_folder()
                        elif nav_buttons["next_folder"].collidepoint(mouse_pos):
                            tool.next_folder()
                        elif nav_buttons["undo"].collidepoint(mouse_pos):
                            if tool.undo():
                                print("已撤销上一次操作")
                            else:
                                print("没有可撤销的操作")

                        # 添加模式切换按钮检测
                        elif mode_button.collidepoint(mouse_pos):
                            tool.labeling_mode = "negative" if tool.labeling_mode == "positive" else "positive"
                            print(f"切换到{'负样本' if tool.labeling_mode == 'negative' else '正样本'}为主模式")

                        # 标注按钮
                        elif label_buttons["positive"].collidepoint(mouse_pos):
                            tool.label_current_image("positive")
                        elif label_buttons["negative"].collidepoint(mouse_pos):
                            tool.label_current_image("negative")
                        elif label_buttons["continuous_start"].collidepoint(mouse_pos):
                            if not tool.start_continuous_labeling():
                                print("无法开始连续标记")
                        elif label_buttons["continuous_end"].collidepoint(mouse_pos):
                            if not tool.end_continuous_labeling():
                                print("没有激活的连续标记")  
                        elif label_buttons["move_files"].collidepoint(mouse_pos):
                            moved = tool.move_labeled_files(positive_dir, negative_dir)
                            print("已移动 {moved} 个文件") 

        # 清屏
        screen.fill(BG_COLOR)
        
        # 显示文件信息
        if tool.folders:
            folder_text = f"当前文件夹: {os.path.basename(tool.folders[tool.current_folder_index])} ({tool.current_folder_index + 1}/{len(tool.folders)})"
            text_surface = small_font.render(folder_text, True, TEXT_COLOR)
            screen.blit(text_surface, (20, 20))

        # 显示当前图片
        current_image_path = tool.get_current_image()
        if current_image_path and os.path.exists(current_image_path):
            try:
                img = pygame.image.load(current_image_path)
                img_rect = img.get_rect()
        
                # 缩放图片以适应显示区域
                scale = min(image_area.width / img_rect.width, image_area.height / img_rect.height)
                new_size = (int(img_rect.width * scale), int(img_rect.height * scale))
                img = pygame.transform.smoothscale(img, new_size)
                img_rect = img.get_rect(center=image_area.center)

                screen.blit(img, img_rect)
    
                # 显示图片信息(在图片上方)  
                info_text = f"{os.path.basename(current_image_path)} ({tool.current_image_index + 1}/{len(tool.images)})"
                if current_image_path in tool.labels:
                    label = tool.labels[current_image_path]
                    info_text += f" - 已标记: {'正样本' if label == 'positive' else '负样本'}"
                text_surface = font.render(info_text, True, TEXT_COLOR)
                text_rect = text_surface.get_rect(center=(WINDOW_WIDTH // 2, image_area.y - 20))
                screen.blit(text_surface, text_rect)
            
                # 在连续标记模式下显示标记范围
                if tool.continuous_mode and tool.continuous_start_index is not None:
                    start_idx = min(tool.continuous_start_index, tool.current_image_index)
                    end_idx = max(tool.continuous_start_index, tool.current_image_index)
                
                    range_text = f"标记范围: {start_idx + 1} - {end_idx + 1}"
                    range_surface = small_font.render(range_text, True, HIGHLIGHT_COLOR)
                    screen.blit(range_surface, (20, 50))

                    # 绘制标记范围的指示器
                    marker_width = image_area.width / len(tool.images)
                    start_x = image_area.x + start_idx * marker_width
                    end_x = image_area.x + (end_idx + 1) * marker_width
    
                    pygame.draw.rect(screen, HIGHLIGHT_COLOR,
                                        (start_x, image_area.y + image_area.height + 5,
                                         end_x -start_x, 5))

            except Exception as e:
                error_text = f"无法加载图片: {e}"
                text_surface = font.render(error_text, True, (255, 0, 0))
                screen.blit(text_surface, (image_area.centerx - text_surface.get_width() // 2, image_area.centery - text_surface.get_height() // 2))

        else:
            no_image_text = "没有图片可显示"
            text_surface = font.render(no_image_text, True, TEXT_COLOR)
            screen.blit(text_surface, (image_area.centerx - text_surface.get_width() // 2, image_area.centery - text_surface.get_height() // 2))

        # 显示连续标记状态
        if tool.continuous_mode:
            mode_text = f"连续标记模式已启动 - 标记类型: {'正样本' if tool.continuous_label == 'positive' else '负样本'}"
            text_surface = small_font.render(mode_text, True, HIGHLIGHT_COLOR)
            screen.blit(text_surface, (WINDOW_WIDTH - text_surface.get_width() - 20, 50))

        # 绘制导航按钮
        draw_button(screen, "上一张 (a)", nav_buttons["prev"], nav_buttons["prev"].collidepoint(mouse_pos))
        draw_button(screen, "下一张 (d)", nav_buttons["next"], nav_buttons["next"].collidepoint(mouse_pos))
        draw_button(screen, "上个文件夹 (z)", nav_buttons["prev_folder"], nav_buttons["prev_folder"].collidepoint(mouse_pos))
        draw_button(screen, "下个文件夹 (c)", nav_buttons["next_folder"], nav_buttons["next_folder"].collidepoint(mouse_pos))
        draw_button(screen, "撤销 (Ctrl+Z)", nav_buttons["undo"], nav_buttons["undo"].collidepoint(mouse_pos))

        # 绘制模式切换按钮
        mode_text = f"{'负样本' if tool.labeling_mode == 'negative' else '正样本'}模式"
        draw_button(screen, mode_text, mode_button, mode_button.collidepoint(mouse_pos))

        # 绘制标注按钮
        draw_button(screen, "正样本 (w)", label_buttons["positive"], label_buttons["positive"].collidepoint(mouse_pos))
        draw_button(screen, "负样本 (s)", label_buttons["negative"], label_buttons["negative"].collidepoint(mouse_pos))
        draw_button(screen, "开始连续标(↑)", label_buttons["continuous_start"], label_buttons["continuous_start"].collidepoint(mouse_pos))
        draw_button(screen, "结束连续标(↓)", label_buttons["continuous_end"], label_buttons["continuous_end"].collidepoint(mouse_pos))
        draw_button(screen, "移动文件 (x)", label_buttons["move_files"], label_buttons["move_files"].collidepoint(mouse_pos))

        # 显示确认对话框
        if tool.show_confirm_dialog:
            draw_confirm_dialog(screen, tool.confirm_message)


        # 更新屏幕
        pygame.display.flip()
        clock.tick(30)

    # 退出前保存标记状态
    tool.save_labels()
    pygame.quit()
    sys.exit()
        
if __name__ == "__main__":
    main()

三、数据量的问题

3.1 mmaction2训练TSN模型脚踢/非脚踢 两个分类的动作识别,需要多少训练和验证样本、需要多少正负样本?各需要多少视频才能效果好

经验表明,用 MMAction2 训练 TSN 做"脚踢 / 非脚踢"二分类,想"效果比较好"至少需要满足以下数据量:

  1. 正样本(脚踢)

    100~200 段干净、已精准裁好的"脚踢"片段是"不爆"的底线;若场景复杂或要稳定 90%+ 的召回,建议 300 段左右。

  2. 负样本(非脚踢)

    按 1:1 ∼ 1:3 的比例配套,即 100 ∼ 600 段。背景越复杂、动作越多样,负样本越要向 1:3 靠。

  3. 训练 / 验证划分

    随机 8:2 或 9:1 切开即可。上面给出的数字是"训练+验证"的总量,验证集一般占 20% 左右。

  4. 总视频数

    把正负加在一起,300 ∼ 800 段是"起步价";低于 200 段几乎必然过拟合,1000 段以上收益逐渐递减但会更稳。

  5. 快速验证技巧

    如果标注成本高,可先收集几十段正例 + 等负例(共 ≈100 段)做微小微调,同时用 VideoMAE V2 之类自监督模型做 backbone,也能在 10-20 epoch 内收敛到可用模型,之后再逐步补充数据迭代

3.2 每段视频有时间的要求吗?

在 MMAction2 里,TSN 对"每段视频"本身没有强制时长,但模型采样方式决定了**"太短"和"太长"都会影响收敛和精度**。经验值如下,你可以直接当硬性规范来剪片:

  1. 最短

    至少包含 1 个完整动作周期。脚踢动作从抬腿到落脚通常 0.5--1 s;因此不要低于 8--10 帧 (30 fps 下 ≈0.3 s)。

    如果剪得太短,TSN 的稀疏采样容易采到全背景,标签就失效了。

  2. 推荐
    **1--3 s(30--90 帧)**是最甜点:

    • 能包住动作完整边界,又不会引入过多无关背景;

    • 与 Kinetics、Something-Something 等主流数据集的片段长度一致,复用预训练权重时域差异最小。

  3. 最长

    单段不要超过 5--6 s

    • TSN 只从整个视频里均匀抽 K 段(默认 K=8)各 1 帧,片段越长,抽到的帧越稀疏,动作占比被稀释,精度反而掉;

    • 如果原始视频很长,先用工具(trim_raw_video.py 或 MMAction2 的 trim.py)把脚踢事件精确裁剪成 1--3 s 的小段,再送进训练管道。

  4. 统一尺寸

    剪完后把帧率统一成 30 fps(或 25 fps),分辨率 256×256 或 340×256,这样后续做 RawFrameDecode + Resize 不会重复采样/掉帧。

一句话总结:
脚踢/非脚踢二分类,每段 1--3 s、30--90 帧、包住完整动作,即可让 TSN 在 MMAction2 上稳定收敛;短于 0.3 s 或长于 6 s 都会直接掉点。

四、转数据

(等我标完数据再写这里)

......

我们先新建一个文件夹my_data,然后里面新建文件夹0和1,因为我们这里要做一个踢脚的动作识别,所以1里面放踢脚的视频,0里面放不是踢脚的视频。然后写个脚本把它转换为跟kinetics400_tiny一样的格式。

python 复制代码
import os
import random
import shutil

# 原始数据路径:里面有 0,1,2,... 文件夹
src_root = "your_dataset"   # 修改为你的数据路径
dst_root = "kinetics_format"  # 输出路径

train_dir = os.path.join(dst_root, "train")
val_dir = os.path.join(dst_root, "val")
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)

train_txt = os.path.join(dst_root, "train_video.txt")
val_txt = os.path.join(dst_root, "val_video.txt")

# train/val 划分比例
train_ratio = 0.8

with open(train_txt, "w") as ftrain, open(val_txt, "w") as fval:
    # 遍历类别文件夹
    for cls_name in os.listdir(src_root):
        cls_path = os.path.join(src_root, cls_name)
        if not os.path.isdir(cls_path):
            continue

        videos = os.listdir(cls_path)
        random.shuffle(videos)

        split_idx = int(len(videos) * train_ratio)
        train_videos = videos[:split_idx]
        val_videos = videos[split_idx:]

        # 复制到 train/val 并写入 txt
        for v in train_videos:
            src = os.path.join(cls_path, v)
            dst = os.path.join(train_dir, v)
            shutil.copy(src, dst)
            ftrain.write(f"{v} {cls_name}\n")

        for v in val_videos:
            src = os.path.join(cls_path, v)
            dst = os.path.join(val_dir, v)
            shutil.copy(src, dst)
            fval.write(f"{v} {cls_name}\n")
相关推荐
天若有情6731 小时前
PyTorch与OpenCV 计算机视觉实战指南(入门篇)
pytorch·opencv·计算机视觉
棒棒的皮皮1 小时前
【OpenCV】Python图像处理之按位逻辑运算
图像处理·python·opencv·计算机视觉
ReinaXue1 小时前
快速认识图像生成算法:VAE、GAN 和 Diffusion Models
图像处理·人工智能·神经网络·算法·生成对抗网络·计算机视觉·语言模型
棒棒的皮皮2 小时前
【OpenCV】Python图像处理之图像加法运算
图像处理·python·opencv·计算机视觉
Y_Chime2 小时前
卷积到底是什么?卷积的底层实现原理
计算机视觉
CoovallyAIHub2 小时前
2025年值得关注的5款数据标注工具
深度学习·算法·计算机视觉
提娜米苏2 小时前
[论文笔记] 基于 LSTM 的端到端视觉语音识别 (End-to-End Visual Speech Recognition with LSTMs)
论文阅读·深度学习·计算机视觉·lstm·语音识别·视觉语音识别
CoovallyAIHub3 小时前
如何让机器看懂视觉世界?从图像匹配理解环境开始
深度学习·算法·计算机视觉
劈星斩月3 小时前
OpenCV是什么?
opencv·计算机视觉