动作识别12——yolo26s-pose+PoseC3D第1篇之标注工具升级2.0

目录

一、前言

二、动作识别标注工具2.0

[2.1 标注工具](#2.1 标注工具)

[2.2 生成的json格式](#2.2 生成的json格式)

三、基于annotations.json生成pkl

[3.1 回顾一下pkl 文件格式](#3.1 回顾一下pkl 文件格式)

[3.2 基于annotations.json生成pkl脚本](#3.2 基于annotations.json生成pkl脚本)


一、前言

我在Github上也上传了相应的代码:

https://github.com/guanlinyi/Action-Recognition/tree/main

针对多人骨架错乱问题,将标注工具1.0升级为2.0,功能更新如下:

1.预先使用yolo26进行人物目标检测、跟踪和姿态估计,若视频帧不存在人则自动删除视频帧

2.支持通过Q和E切换不同的人物进行动作标注。

3.显示连续标记范围的时候,增加显示当前帧与方向键上键标记的开始连续帧之间已经过了多少帧,如此就能在按方向键下键结束连续标记之前知道生成的视频帧有多少张。

4.当按x将连续标记的视频帧区间转为视频后,该区间的视频帧不会被删除,而是在视频帧图片信息后面新增已标注的信息,比如: (1/1428) - 已标记人物动作{1: 0, 2:1}, 其中{1: 0, 2:1}是一个字典,它表示人物1被标记为动作0,人物2倍标记为动作1。这种设计意味着用户可以回到已转换为视频的视频帧区间内标注该区间的另一个人物动作(通过Q和E切换不同的人物),而且标注以后,上面的字典信息也会更新,比如新增人物3动作0,变成 (1/1428) - 已标记人物动作{1: 0, 2:1, 3:0}

  1. 在build_video_name函数中写了生成mp4名称的逻辑,名称中除了有连续标记的区间range_str,还新增了P{person_id}用于标记人物id,suffix则是原来的动作类别id

suffix = "A001" if label == "positive" else "A002"

return f"{folder_name}_{range_str}P{person_id}{suffix}"

6.新增用于裁剪视频后生成annotations.json,基于该json可以更容易转换成能够输入训练的pkl格式。

7.新增独立脚本generate_ntu60_pkl.py用于根据标注工具生成的annotations.json生成输入训练的pkl。

标注工具生成的文件结果如下

复制代码
long_video1
├──ok/
│  ├── positive/         # 存放所有裁剪出来的正样本动作视频
│      ├── xxx_P1_A001.mp4
│      └── ..
│  ├── negative/         # 存放所有裁剪出来的负样本动作视频
│      ├── xxx_P2_A002.mp4
│      └── ...
├──000000.jpg         # 原始的视频帧
├──000001.jpg
├──......
├──yolo_pose_results.json         # 记录yolo推理结果的文件
├──yolo_pose_results_stats.json   # 记录yolo推理结果状态的文件
├──labels_backup.json             # 记录标注工具标注进度的文件
└── annotations.pkl               # 记录标注信息的文件,可用于转成输入训练的pkl文件

long_video2同理
├──...

二、动作识别标注工具2.0

2.1 标注工具

D:\zero_track\mmaction2\my_tools\auto_label_tool5_ok.py

python 复制代码
import os
import pygame
import sys
import shutil
import time
import json
import subprocess
import numpy as np
from pygame.locals import *
import cv2

ROOT_PATH = r"D:\zero_track\mmaction2\my_code\input\fisheye_back_2026-01-29-14-13-15_30fps"

# 初始化pygame
pygame.init()

# 配置参数
SCREEN_WIDTH, SCREEN_HEIGHT = pygame.display.Info().current_w, pygame.display.Info().current_h   
WINDOW_WIDTH, WINDOW_HEIGHT = SCREEN_WIDTH - 100, SCREEN_HEIGHT - 100
BG_COLOR = (40, 44, 52)
TEXT_COLOR = (220, 220, 220)
HIGHLIGHT_COLOR = (97, 175, 239)
BUTTON_COLOR = (56, 58, 66)
BUTTON_HOVER_COLOR = (72, 74, 82)
WARNING_COLOR = (255, 152, 0)
CONFIRM_COLOR = (76, 175, 80)
BOX_COLORS = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), 
              (255, 0, 255), (0, 255, 255), (255, 128, 0), (128, 255, 0)]

# 创建窗口
screen = pygame.display.set_mode((WINDOW_WIDTH, WINDOW_HEIGHT))
pygame.display.set_caption("图像分类标注工具 - YOLO-Pose 多人标注")

# 字体
font = pygame.font.SysFont("SimHei", 24)
small_font = pygame.font.SysFont("SimHei", 18)

def run_yolo_pose_inference(image_folder, output_json_path, device='cuda', conf_threshold=0.5):
    """
    运行YOLO-pose推理,逐帧处理并使用跟踪模式保持ID一致性
    返回:处理后的图片列表和推理结果
    """
    print(f"开始对文件夹 {image_folder} 进行YOLO-pose逐帧跟踪推理...")
    print(f"设备: {device}, 置信度阈值: {conf_threshold}")
    
    # 定义检查连续性的辅助函数
    def check_continuity(frame_numbers):
        """检查帧号是否连续"""
        if not frame_numbers or len(frame_numbers) <= 1:
            return True
        
        sorted_frames = sorted(frame_numbers)
        for i in range(len(sorted_frames) - 1):
            if sorted_frames[i] + 1 != sorted_frames[i + 1]:
                return False
        return True
    
    try:
        from ultralytics import YOLO
        import torch
        
        # 检查CUDA是否可用
        if device == 'cuda' and not torch.cuda.is_available():
            print("CUDA不可用,将使用CPU进行推理")
            device = 'cpu'
        
        # 加载YOLO-pose模型
        print("正在加载YOLO-pose模型...")
        model_path = "checkpoints/yolo26s-pose.pt"  # 请确保模型路径正确
        
        # 检查模型文件是否存在
        if not os.path.exists(model_path):
            print(f"警告: 模型文件不存在: {model_path}")
            # 尝试使用默认的yolo-pose模型
            model_path = 'yolov8s-pose.pt'
            print(f"尝试使用默认模型: {model_path}")
        
        model = YOLO(model_path)
        model.to(device)
        
        # 获取文件夹中所有图片并排序
        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
        image_files = []
        for f in os.listdir(image_folder):
            if f.lower().endswith(image_extensions):
                image_files.append(f)
        image_files.sort()  # 确保按文件名顺序处理
        
        if not image_files:
            print(f"文件夹 {image_folder} 中没有找到图片")
            return [], {}
        
        inference_results = {}
        valid_images = []
        deleted_count = 0
        
        print(f"开始逐帧处理 {len(image_files)} 张图片...")
        
        # 逐帧处理,使用persist=True保持跟踪状态
        for i, img_file in enumerate(image_files):
            if (i + 1) % 10 == 0 or (i + 1) == len(image_files):
                print(f"处理进度: {i+1}/{len(image_files)}...")
            
            img_path = os.path.join(image_folder, img_file)
            
            try:
                # 使用track模式,persist=True保持跟踪状态
                # 这是关键:逐帧处理,但persist=True会保持跟踪器状态
                if i == 0:
                    # 第一帧,开始新的跟踪
                    results = model.track(
                        source=img_path,
                        conf=conf_threshold,
                        persist=True,  # 开始跟踪并保持状态
                        verbose=False,
                        tracker="bytetrack.yaml" if i == 0 else None  # 第一帧初始化跟踪器
                    )
                else:
                    # 后续帧,继续之前的跟踪
                    results = model.track(
                        source=img_path,
                        conf=conf_threshold,
                        persist=True,  # 继续跟踪,使用之前的状态
                        verbose=False
                    )
                
                if not results or len(results) == 0:
                    # 没有检测结果,删除图片
                    os.remove(img_path)
                    deleted_count += 1
                    continue
                
                result = results[0]  # 单张图片,只有一个结果
                
                # 获取关键点和边界框
                keypoints = result.keypoints
                boxes = result.boxes
                
                # 检查是否检测到人物
                if (keypoints is not None and len(keypoints) > 0 and 
                    boxes is not None and len(boxes) > 0):
                    
                    persons = []
                    
                    # 获取跟踪ID(如果有)
                    track_ids = None
                    if hasattr(boxes, 'id') and boxes.id is not None:
                        track_ids = boxes.id.cpu().numpy().astype(int)
                    
                    # 处理每个检测到的人物
                    for person_idx in range(len(keypoints)):
                        kpt_data = keypoints.data[person_idx].cpu().numpy()
                        box_data = boxes.data[person_idx].cpu().numpy()
                        
                        # 过滤低置信度的检测
                        if box_data[5] < conf_threshold:
                            continue
                        
                        # 获取边界框
                        bbox = box_data[:4].tolist()
                        
                        # 获取跟踪ID
                        if track_ids is not None and person_idx < len(track_ids):
                            person_id = int(track_ids[person_idx])
                        else:
                            # 如果没有跟踪ID,使用索引
                            person_id = person_idx
                        
                        # 处理关键点数据
                        keypoints_list = []
                        for point in kpt_data:
                            x, y, v = float(point[0]), float(point[1]), float(point[2])
                            # 只保留置信度足够高的关键点
                            if v > 0.3:  # 关键点可见性阈值
                                keypoints_list.extend([x, y, v])
                            else:
                                keypoints_list.extend([0.0, 0.0, 0.0])
                        
                        person_info = {
                            "class_id": int(boxes.cls[person_idx].item()),
                            "id": person_id,
                            "bbox": bbox,
                            "keypoints": keypoints_list,
                            "confidence": float(box_data[5]),
                        }
                        persons.append(person_info)
                    
                    if persons:
                        # 按ID排序,便于查看
                        persons.sort(key=lambda x: x["id"])
                        inference_results[img_file] = persons
                        valid_images.append(img_path)
                        
                        # 可选:显示当前帧的跟踪信息
                        if (i + 1) % 20 == 0:
                            current_ids = [p["id"] for p in persons]
                            print(f"  帧 {i+1}: 检测到 {len(persons)} 人, IDs: {sorted(current_ids)}")
                    else:
                        # 没有检测到有效人物,删除图片
                        os.remove(img_path)
                        deleted_count += 1
                else:
                    # 没有检测到人物,删除图片
                    os.remove(img_path)
                    deleted_count += 1
                    
            except torch.cuda.OutOfMemoryError:
                print(f"警告: 处理 {img_file} 时显存不足,跳过此图片")
                # 清理CUDA缓存
                torch.cuda.empty_cache()
                # 保留图片,不删除
                valid_images.append(img_path)
                continue
            except Exception as e:
                print(f"处理图片 {img_file} 时出错: {e}")
                # 出错时保留图片,不删除
                valid_images.append(img_path)
        
        # 推理完成后清理显存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        print(f"YOLO-pose逐帧跟踪推理完成。")
        print(f"原始图片数: {len(image_files)}")
        print(f"保留有人的图片: {len(valid_images)}")
        print(f"删除无人的图片: {deleted_count}")
        
        # 分析ID变化情况
        if inference_results:
            all_ids = set()
            id_transitions = {}
            
            # 收集所有ID
            for img_file, persons in inference_results.items():
                for person in persons:
                    all_ids.add(person["id"])
            
            # 分析每个ID的出现情况
            for person_id in sorted(all_ids):
                frames_with_id = []
                for img_file, persons in inference_results.items():
                    if any(p["id"] == person_id for p in persons):
                        frames_with_id.append(img_file)
                
                if frames_with_id:
                    # 提取帧号(假设文件名中有数字)
                    frame_numbers = []
                    for f in frames_with_id:
                        import re
                        numbers = re.findall(r'\d+', f)
                        if numbers:
                            frame_numbers.append(int(numbers[-1]))
                        else:
                            # 如果没有数字,使用索引
                            frame_numbers.append(frames_with_id.index(f) + 1)
                    
                    if len(frame_numbers) > 1:
                        id_transitions[person_id] = {
                            "start_frame": min(frame_numbers),
                            "end_frame": max(frame_numbers),
                            "frame_count": len(frame_numbers),
                            "frames": sorted(frame_numbers)
                        }
            
            print(f"检测到 {len(all_ids)} 个不同的人物ID")
            print("ID追踪情况:")
            for person_id, info in sorted(id_transitions.items()):
                continuity = "连续" if check_continuity(info["frames"]) else "不连续"
                print(f"  ID {person_id}: 出现在 {info['frame_count']} 帧中 ({continuity}), "
                      f"从帧 {info['start_frame']} 到 {info['end_frame']}")
        
        # 保存推理结果到JSON
        if inference_results:
            # 转换为可序列化的格式
            json_data = {}
            for img_file, persons in inference_results.items():
                json_persons = []
                for person in persons:
                    # 转换numpy类型为Python原生类型
                    json_person = {
                        "id": int(person["id"]),
                        "bbox": [float(coord) for coord in person["bbox"]],
                        "keypoints": [float(val) for val in person["keypoints"]],
                        "confidence": float(person["confidence"]),
                        "class_id": int(person["class_id"])
                    }
                    json_persons.append(json_person)
                json_data[img_file] = json_persons
            
            # 确保输出目录存在
            os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
            
            with open(output_json_path, 'w', encoding='utf-8') as f:
                json.dump(json_data, f, indent=2, ensure_ascii=False)
            
            print(f"推理结果已保存到: {output_json_path}")
            
            # 保存统计信息
            stats = {
                "total_images": len(image_files),
                "valid_images": len(valid_images),
                "deleted_images": deleted_count,
                "detected_persons_total": sum(len(persons) for persons in inference_results.values()),
                "unique_person_ids": len(all_ids) if inference_results else 0,
                "tracking_mode": True,
                "id_transitions": id_transitions
            }
            
            stats_path = output_json_path.replace(".json", "_stats.json")
            with open(stats_path, 'w', encoding='utf-8') as f:
                json.dump(stats, f, indent=2, ensure_ascii=False)
            print(f"统计信息已保存到: {stats_path}")
        else:
            print("警告: 没有检测到任何人物")
        
        return valid_images, inference_results
        
    except ImportError as e:
        print(f"无法导入ultralytics库: {e}")
        print("请安装: pip install ultralytics")
        print("正在使用备用方案...")
        
        # 备用方案:返回所有图片,不进行推理
        all_images = [os.path.join(image_folder, f) for f in os.listdir(image_folder) 
                     if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
        all_images.sort()
        return all_images, {}
        
    except Exception as e:
        print(f"YOLO-pose推理失败: {e}")
        import traceback
        traceback.print_exc()
        
        # 出错时返回所有图片作为备用
        all_images = [os.path.join(image_folder, f) for f in os.listdir(image_folder) 
                     if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
        all_images.sort()
        return all_images, {}

class ImageLabelingTool:
    def __init__(self, root_path):
        self.root_path = root_path
        self.folders = []               # 所有含图片的文件夹绝对路径
        self.current_folder_index = 0   # 当前文件夹索引
        self.images = []                # 当前文件夹内所有图片绝对路径
        self.current_image_index = 0    # 当前图片索引
        self.labels = {}                # 路径 -> {person_id: 'positive'/'negative'}
        self.inference_data = {}        # 图片名称 -> [人物列表]
        self.selected_person_id = 0     # 当前选定的人物ID
        self.person_id_list = []        # 当前图片中所有人物ID列表
        
        # 显示设置
        self.show_bboxes = True         # 是否显示边界框
        self.show_keypoints = True      # 是否显示关键点
        
        self.convert_to_video = True    # 是否启用转视频模式
        self.video_fps = 30             # 视频帧率
        
        # 自动播放相关
        self.playing = False
        self.play_direction = 1
        self.last_play_tick = 0
        self.play_interval = 100
        
        # 标记状态
        self.continuous_mode = False
        self.continuous_label = None
        self.continuous_start_index = None
        
        # 键盘长按状态
        self.key_pressed = {"left": False, "right": False}
        self.last_key_time = 0
        self.key_repeat_delay = 0.8
        self.key_repeat_interval = 0.15
        
        # 操作历史
        self.undo_stack = []
        self.max_undo_steps = 50
        
        # 不再删除原图,改为移动到ok文件夹
        self.discard_before_last_action = False  # 禁用删除原图
        
        # 确认对话框状态
        self.show_confirm_dialog = False
        self.confirm_message = ""
        self.confirm_action = ""
        
        # 获取所有包含图片的文件夹
        self.find_image_folders()
        
        # 加载当前文件夹的图片和推理结果
        if self.folders:
            self.load_current_folder_images()
        
        # 加载保存的标记状态
        self.load_labels()

        # 新增:annotations 文件路径及数据存储
        self.annotations_path = os.path.join(self.root_path, "annotations.json")
        self.annotations_data = {"samples": []}
        self.load_annotations()

    def cancel_current_person_label(self):
        """取消当前图片中选定人物的标记"""
        current_image = self.get_current_image()
        if current_image and current_image in self.labels:
            if self.selected_person_id in self.labels[current_image]:
                self.save_state()  # 保存状态以便撤销
                del self.labels[current_image][self.selected_person_id]
                self.save_labels()
                return True
        return False

    def images_to_video(self, image_paths, output_path, fps=10):
        """将图片序列转为视频"""
        if not image_paths:
            return False
        
        print(f"正在合成视频: {output_path}")
        print(f"图片数量: {len(image_paths)},帧率: {fps}")
        
        try:
            # 读取第一张图获取尺寸
            first_image_path = image_paths[0]
            if not os.path.exists(first_image_path):
                print(f"第一张图片不存在: {first_image_path}")
                return False
            
            frame = cv2.imread(first_image_path)
            if frame is None:
                print(f"无法读取第一张图片: {first_image_path}")
                return False
            
            h, w, _ = frame.shape
            
            # 初始化视频写入器
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
            
            if not out.isOpened():
                print(f"无法创建视频写入器: {output_path}")
                return False
            
            # 逐帧写入图片
            success_count = 0
            for i, img_path in enumerate(image_paths):
                if not os.path.exists(img_path):
                    print(f"警告: 图片不存在,跳过: {img_path}")
                    continue
                
                frame = cv2.imread(img_path)
                if frame is None:
                    print(f"警告: 无法读取图片,跳过: {img_path}")
                    continue
                
                # 如果图片尺寸与第一张不同,调整尺寸
                if frame.shape[0] != h or frame.shape[1] != w:
                    print(f"警告: 图片尺寸不匹配,调整尺寸: {img_path}")
                    frame = cv2.resize(frame, (w, h))
                
                out.write(frame)
                success_count += 1
                
                if (i + 1) % 50 == 0:
                    print(f"  已处理 {i + 1}/{len(image_paths)} 张图片")
            
            out.release()
            cv2.destroyAllWindows()
            
            if success_count > 0:
                print(f"视频合成成功: {output_path} ({success_count} 张图片)")
                return True
            else:
                print(f"视频合成失败: 没有成功处理的图片")
                if os.path.exists(output_path):
                    os.remove(output_path)
                return False
                
        except Exception as e:
            print(f"视频合成失败: {e}")
            import traceback
            traceback.print_exc()
            
            # 清理未完成的视频文件
            if os.path.exists(output_path):
                try:
                    os.remove(output_path)
                except:
                    pass
            return False

    def analyze_id_tracking(self):
        """分析ID跟踪情况"""
        if not self.inference_data:
            print("没有推理数据,无法分析ID跟踪")
            return
        
        id_frames = {}  # ID -> [出现的图片列表]
        
        # 收集每个ID出现的图片
        for img_name, persons in self.inference_data.items():
            for person in persons:
                person_id = person.get("id", 0)
                if person_id not in id_frames:
                    id_frames[person_id] = []
                id_frames[person_id].append(img_name)
        
        print("ID跟踪分析:")
        print("-" * 50)
        
        # 按ID出现的频率排序
        sorted_ids = sorted(id_frames.items(), key=lambda x: len(x[1]), reverse=True)
        
        for person_id, frames in sorted_ids:
            frame_numbers = []
            for f in frames:
                # 提取帧号
                import re
                numbers = re.findall(r'\d+', f)
                if numbers:
                    frame_numbers.append(int(numbers[-1]))
            
            if frame_numbers:
                continuity = "连续" if self._check_continuity(frame_numbers) else "不连续"
                
                print(f"ID {person_id}: 出现在 {len(frames)} 帧中, {continuity}")
                print(f"      帧范围: {min(frame_numbers)} - {max(frame_numbers)}")
        
        print("-" * 50)
        print(f"总共检测到 {len(sorted_ids)} 个不同ID")

    def preprocess_folder_with_yolo(self, folder_path):
        """对文件夹进行YOLO-pose预处理"""
        json_path = os.path.join(folder_path, "yolo_pose_results.json")
        
        # 检查是否已有预处理结果
        if os.path.exists(json_path):
            print(f"加载已有的YOLO-pose推理结果: {json_path}")
            with open(json_path, 'r') as f:
                inference_data = json.load(f)
            
            # 获取有效的图片列表
            valid_images = []
            for img_name in os.listdir(folder_path):
                if img_name.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                    if img_name in inference_data:
                        valid_images.append(os.path.join(folder_path, img_name))
            
            valid_images.sort()
            return valid_images, inference_data
        else:
            # 运行YOLO-pose推理
            return run_yolo_pose_inference(folder_path, json_path)
    
    def find_image_folders(self):
        """查找所有包含图片的文件夹"""
        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
        for root, dirs, files in os.walk(self.root_path):
            has_images = any(file.lower().endswith(image_extensions) for file in files)
            if has_images:
                self.folders.append(root)
    
    def load_current_folder_images(self):
        """加载当前文件夹中的所有图片和YOLO-pose推理结果"""
        folder_path = self.folders[self.current_folder_index]
        
        # 预处理文件夹(YOLO-pose推理)
        self.images, self.inference_data = self.preprocess_folder_with_yolo(folder_path)
        
        # 重置指针
        self.current_image_index = 0 if self.images else -1
        self.selected_person_id = 0
        self.update_person_id_list()
    
    def update_person_id_list(self):
        """更新当前图片的人物ID列表"""
        self.person_id_list = []
        current_image = self.get_current_image()
        if current_image:
            img_name = os.path.basename(current_image)
            if img_name in self.inference_data:
                persons = self.inference_data[img_name]
                self.person_id_list = [person.get("id", i) for i, person in enumerate(persons)]
                if self.person_id_list and self.selected_person_id not in self.person_id_list:
                    self.selected_person_id = self.person_id_list[0]
    
    def get_current_image(self):
        """获取当前图片"""
        if not self.images or self.current_image_index < 0:
            return None
        return self.images[self.current_image_index]
    
    def next_image(self):
        """切换到下一张图片"""
        if self.current_image_index < len(self.images) - 1:
            self.save_state()
            self.current_image_index += 1
            self.update_person_id_list()
            return True
        return False
    
    def prev_image(self):
        """切换到上一张图片"""
        if self.current_image_index > 0:
            self.current_image_index -= 1
            self.update_person_id_list()
            return True
        return False
    
    def label_current_image(self, label):
        """标记当前图片中的选定人物"""
        current_image = self.get_current_image()
        if current_image:
            self.save_state()
        
            # 确保labels字典中有当前图片的条目
            if current_image not in self.labels:
                self.labels[current_image] = {}
        
            # 标记选定的人物
            self.labels[current_image][self.selected_person_id] = label
        
            # 自动保存标记状态
            self.save_labels()
    
    def start_continuous_labeling(self):
        """开始连续标记"""
        current_image = self.get_current_image()
        if current_image:
            self.save_state()
            
            # 如果当前图片的选定人物已经有标签,使用该标签
            if (current_image in self.labels and 
                self.selected_person_id in self.labels[current_image]):
                self.continuous_label = self.labels[current_image][self.selected_person_id]
            else:
                # 如果没有标签,默认为正样本
                self.continuous_label = "positive"
                if current_image not in self.labels:
                    self.labels[current_image] = {}
                self.labels[current_image][self.selected_person_id] = self.continuous_label
            
            self.continuous_mode = True
            self.continuous_start_index = self.current_image_index
            self.save_labels()
            return True
        return False
    
    def end_continuous_labeling(self):
        """结束连续标记"""
        if self.continuous_mode and self.continuous_start_index is not None:
            self.save_state()
            start = min(self.continuous_start_index, self.current_image_index)
            end = max(self.continuous_start_index, self.current_image_index)
            
            # 为范围内的所有图片的选定人物标记相同的标签
            for i in range(start, end + 1):
                img_path = self.images[i]
                if img_path not in self.labels:
                    self.labels[img_path] = {}
                self.labels[img_path][self.selected_person_id] = self.continuous_label
            
            self.continuous_mode = False
            self.continuous_start_index = None
            self.save_labels()
            return True
        return False
    
    def get_current_label(self):
        """获取当前图片选定人物的标签"""
        current_image = self.get_current_image()
        if current_image and current_image in self.labels:
            person_labels = self.labels[current_image]
            if isinstance(person_labels, dict):
                return person_labels.get(self.selected_person_id)
            else:
                # 旧格式,假设人物ID为0
                return person_labels if self.selected_person_id == 0 else None
        return None

    def get_next_unprocessed_index(self, start_idx):
        """从 start_idx 开始向后查找第一个未标记为 __processed__ 的图片索引,找不到返回 None"""
        for i in range(start_idx, len(self.images)):
            img_path = self.images[i]
            if img_path in self.labels and "__processed__" in self.labels[img_path]:
                continue
            return i
        return None

    def get_prev_unprocessed_index(self, start_idx):
        """从 start_idx 开始向前查找第一个未标记为 __processed__ 的图片索引,找不到返回 None"""
        for i in range(start_idx, -1, -1):
            img_path = self.images[i]
            if img_path in self.labels and "__processed__" in self.labels[img_path]:
                continue
            return i
        return None

    def move_labeled_files(self, positive_dir, negative_dir):

        """
        移动已标记的文件到 ok 文件夹,并按连续片段生成视频 + annotation
        修改后:不复制原图,仅生成视频,并在 labels 中标记已处理帧
        """
        if not self.folders:
            return "无文件夹可处理"

        current_folder = self.folders[self.current_folder_index]
        ok_folder = os.path.join(current_folder, "ok")
        positive_ok = os.path.join(ok_folder, "positive")
        negative_ok = os.path.join(ok_folder, "negative")

        # 创建输出目录(视频目录)
        if self.convert_to_video:
            for folder in [positive_ok, negative_ok]:
                os.makedirs(folder, exist_ok=True)

        from collections import defaultdict
        groups = defaultdict(list)   # (label, person_id) -> list of img_paths

        # 收集所有正负标记
        for img_path, person_labels in list(self.labels.items()):
            if not os.path.exists(img_path):
                continue
            for person_id, label in person_labels.items():
                if person_id == "__processed__":  # 跳过已处理标记
                    continue
                if label in ["positive", "negative"]:
                    groups[(label, person_id)].append(img_path)

        if not groups:
            return "没有找到已标记的文件"

        video_count = 0
        processed_frames = 0
        new_samples = []
        last_processed_frame_index = -1   # 记录最后一个被处理的帧的索引

        # 处理每个 (标签, 人物ID) 组
        for (label, person_id), img_paths in groups.items():
            # 按帧号排序
            img_paths.sort(key=lambda p: self.extract_frame_number(p))

            # 切分成连续帧段
            segments = []
            current_segment = [img_paths[0]]
            for i in range(1, len(img_paths)):
                prev_frame = self.extract_frame_number(img_paths[i-1])
                curr_frame = self.extract_frame_number(img_paths[i])
                if curr_frame == prev_frame + 1:
                    current_segment.append(img_paths[i])
                else:
                    segments.append(current_segment)
                    current_segment = [img_paths[i]]
            segments.append(current_segment)

            # 处理每个连续片段
            for seg_paths in segments:
                if len(seg_paths) < 2:
                    continue   # 单帧忽略,也可根据需要保留

                # --- 生成视频(仅当开启转视频模式)---
                video_name = self.build_video_name(seg_paths, person_id, label)
                video_path = None
                if self.convert_to_video:
                    dest_dir = positive_ok if label == "positive" else negative_ok
                    video_path = os.path.join(dest_dir, video_name + ".mp4")
                    if self.images_to_video(seg_paths, video_path, fps=self.video_fps):
                        video_count += 1
                        processed_frames += len(seg_paths)

                        # --- 生成 annotation ---
                        try:
                            annotation = self.build_annotation(
                                seg_paths, person_id, label, video_name
                            )
                            self.annotations_data["samples"].append(annotation)
                            new_samples.append(annotation)
                            print(f"  生成 annotation: {video_name} 共 {len(seg_paths)} 帧")
                        except Exception as e:
                            print(f"  生成 annotation 失败 {video_name}: {e}")

                # --- 标记已处理帧(无论是否生成视频,都标记为已处理)---
                for img_path in seg_paths:
                    # 添加 __processed__ 标记,记录视频信息
                    if img_path not in self.labels:
                        self.labels[img_path] = {}
                    if "__processed__" not in self.labels[img_path]:
                        self.labels[img_path]["__processed__"] = []
                    self.labels[img_path]["__processed__"].append({
                        "person_id": person_id,
                        "label": label,
                        "video": video_name + ".mp4" if video_path else None,
                        "timestamp": time.time()
                    })

                    # 删除该人物 ID 的原标记(已完成)
                    # if person_id in self.labels[img_path]:
                    #     del self.labels[img_path][person_id]
                        # 如果该图片没有其他标记且没有其他 ID 待处理,则保留空字典,不删除整个条目
                        # 也可以选择清理空字典,但保留方便下次判断
                    # 记录最后处理的帧索引,用于跳转
                    try:
                        idx = self.images.index(img_path)
                        if idx > last_processed_frame_index:
                            last_processed_frame_index = idx
                    except ValueError:
                        pass

        # 保存更新后的 labels 和 annotations
        self.save_labels()
        if new_samples:
            self.save_annotations()

        # 构造状态消息
        result_msg = f"已生成 {video_count} 个视频,处理 {processed_frames} 帧"
        if new_samples:
            result_msg += f",新增 {len(new_samples)} 个 annotation"
        self.status_message = result_msg
        self.status_message_timer = pygame.time.get_ticks()  # 记录显示起始时间

        print(result_msg)

        # 自动跳转到最后一个被处理帧的下一帧(未处理)
        if last_processed_frame_index != -1:
            next_idx = self.get_next_unprocessed_index(last_processed_frame_index + 1)
            if next_idx is not None:
                self.current_image_index = next_idx
                self.update_person_id_list()
                print(f"已跳转到下一未处理帧: 索引 {next_idx+1}")
            else:
                # 如果没有下一帧,留在当前帧或第一帧
                print("所有帧均已处理")

        return result_msg

    def draw_detections(self, screen, img_rect, current_image_path):
        """在图片上绘制检测结果"""
        if not current_image_path or not self.show_bboxes:
            return
        
        img_name = os.path.basename(current_image_path)
        if img_name not in self.inference_data:
            return
        
        persons = self.inference_data[img_name]
        if not persons:
            return
        
        # 获取图片显示的实际尺寸和位置
        img_display_rect = img_rect
        
        # 读取图片原始尺寸
        original_img = pygame.image.load(current_image_path)
        orig_width, orig_height = original_img.get_size()
        
        # 计算缩放比例
        scale_x = img_display_rect.width / orig_width
        scale_y = img_display_rect.height / orig_height
        
        for i, person in enumerate(persons):
            person_id = person.get("id", i)
            bbox = person.get("bbox", [])
            
            if len(bbox) >= 4:
                # 缩放边界框坐标
                x1 = img_display_rect.x + bbox[0] * scale_x
                y1 = img_display_rect.y + bbox[1] * scale_y
                x2 = img_display_rect.x + bbox[2] * scale_x
                y2 = img_display_rect.y + bbox[3] * scale_y
                
                width = x2 - x1
                height = y2 - y1
                
                # 选择颜色
                color_idx = person_id % len(BOX_COLORS)
                color = BOX_COLORS[color_idx]
                
                # 绘制边界框
                if person_id == self.selected_person_id:
                    # 选中的边界框加粗
                    pygame.draw.rect(screen, color, (x1, y1, width, height), 3)
                    
                    # 在框上方显示ID
                    id_text = f"ID: {person_id}"
                    text_surface = small_font.render(id_text, True, color)
                    screen.blit(text_surface, (x1, y1 - 25))
                    
                    # 显示置信度
                    conf = person.get("confidence", 0)
                    conf_text = f"{conf:.2f}"
                    conf_surface = small_font.render(conf_text, True, color)
                    screen.blit(conf_surface, (x1, y1 - 45))
                else:
                    pygame.draw.rect(screen, color, (x1, y1, width, height), 2)
                
                # 绘制关键点
                if self.show_keypoints:
                    keypoints = person.get("keypoints", [])
                    for j in range(0, len(keypoints), 3):
                        if j + 2 < len(keypoints):
                            kx = img_display_rect.x + keypoints[j] * scale_x
                            ky = img_display_rect.y + keypoints[j+1] * scale_y
                            confidence = keypoints[j+2]
                            
                            if confidence > 0.5:  # 只显示置信度高的关键点
                                pygame.draw.circle(screen, color, (int(kx), int(ky)), 3)
    
    def select_person_by_click(self, mouse_pos, img_rect):
        """通过鼠标点击选择人物"""
        current_image = self.get_current_image()
        if not current_image:
            return False
        
        img_name = os.path.basename(current_image)
        if img_name not in self.inference_data:
            return False
        
        persons = self.inference_data[img_name]
        if not persons:
            return False
        
        # 读取图片原始尺寸
        original_img = pygame.image.load(current_image)
        orig_width, orig_height = original_img.get_size()
        
        # 计算缩放比例
        scale_x = img_rect.width / orig_width
        scale_y = img_rect.height / orig_height
        
        # 检查鼠标是否在某个边界框内
        for i, person in enumerate(persons):
            person_id = person.get("id", i)
            bbox = person.get("bbox", [])
            
            if len(bbox) >= 4:
                # 计算缩放后的边界框
                x1 = img_rect.x + bbox[0] * scale_x
                y1 = img_rect.y + bbox[1] * scale_y
                x2 = img_rect.x + bbox[2] * scale_x
                y2 = img_rect.y + bbox[3] * scale_y
                
                if (x1 <= mouse_pos[0] <= x2 and y1 <= mouse_pos[1] <= y2):
                    self.selected_person_id = person_id
                    print(f"选中人物 ID: {person_id}")
                    return True
        
        return False
    
    def next_person_id(self):
        """切换到下一个可选的人物ID"""
        if self.person_id_list:
            current_idx = self.person_id_list.index(self.selected_person_id) if self.selected_person_id in self.person_id_list else -1
            if current_idx < len(self.person_id_list) - 1:
                self.selected_person_id = self.person_id_list[current_idx + 1]
                return True
        return False
    
    def prev_person_id(self):
        """切换到上一个可选的人物ID"""
        if self.person_id_list:
            current_idx = self.person_id_list.index(self.selected_person_id) if self.selected_person_id in self.person_id_list else -1
            if current_idx > 0:
                self.selected_person_id = self.person_id_list[current_idx - 1]
                return True
        return False
    
    def handle_key_repeats(self):
        """处理方向键长按"""   
        current_time = time.time() 
        
        if any(self.key_pressed.values()):
            if self.last_key_time == 0:
                if current_time - self.key_pressed_time > self.key_repeat_delay:
                    if self.key_pressed["left"]:
                        self.prev_image()
                    elif self.key_pressed["right"]:
                        self.next_image()
                    self.last_key_time = current_time
            elif current_time - self.last_key_time > self.key_repeat_interval:
                if self.key_pressed["left"]:
                    self.prev_image()
                elif self.key_pressed["right"]:  
                    self.next_image()
                self.last_key_time = current_time
    
    def save_state(self):
        """保存当前状态以便撤销"""
        if len(self.undo_stack) >= self.max_undo_steps:
            self.undo_stack.pop(0)
    
        # 深度复制labels,确保能够正确处理嵌套结构
        labels_copy = {}
        for img_path, person_labels in self.labels.items():
            if isinstance(person_labels, dict):
                # 新格式:{图片路径: {人物ID: 标签}}
                labels_copy[img_path] = person_labels.copy()
            else:
                # 旧格式:{图片路径: 标签字符串} - 转换为新格式
                labels_copy[img_path] = {0: person_labels}
    
        state = {
            "current_image_index": self.current_image_index,
            "labels": labels_copy,
            "selected_person_id": self.selected_person_id,
            "continuous_mode": self.continuous_mode,
            "continuous_start_index": self.continuous_start_index,
            "continuous_label": self.continuous_label
        }
    
        self.undo_stack.append(state)
    
    def undo(self):
        """撤销上一次操作"""
        if self.undo_stack:
            state = self.undo_stack.pop()
            self.current_image_index = state["current_image_index"]
            self.labels = state["labels"]
            self.selected_person_id = state["selected_person_id"]
            self.continuous_mode = state["continuous_mode"]
            self.continuous_start_index = state["continuous_start_index"]
            self.continuous_label = state["continuous_label"]
            self.update_person_id_list()
            return True
        return False
    
    def load_labels(self):
        """从文件加载标记状态,支持新旧格式"""
        labels_file = os.path.join(self.root_path, "labels_backup.json")
        if os.path.exists(labels_file):
            try:
                with open(labels_file, 'r') as f:
                    loaded = json.load(f)
                
                # 转换旧格式到新格式
                self.labels = {}
                for img_path, label_data in loaded.items():
                    if isinstance(label_data, dict):
                        # 新格式:已经是 {人物ID: 标签}
                        self.labels[img_path] = label_data
                    else:
                        # 旧格式:标签字符串,转换为新格式,默认人物ID为0
                        self.labels[img_path] = {0: label_data}
                    
                print(f"成功加载标签,共 {len(self.labels)} 张图片的标签")
            except Exception as e:
                print(f"加载标记状态失败: {e}")
                self.labels = {}
        else:
            print(f"标签文件不存在: {labels_file}")
            self.labels = {}

    def save_labels(self):
        """保存标记状态到文件"""
        labels_file = os.path.join(self.root_path, "labels_backup.json")
        try:
            existing_labels = {}
            for img_path, person_labels in self.labels.items():
                if os.path.exists(img_path) and person_labels:
                    # 确保保存新格式
                    existing_labels[img_path] = person_labels
        
            with open(labels_file, 'w') as f:
                json.dump(existing_labels, f, indent=2)
        
            print(f"标签已保存到: {labels_file}")
        except Exception as e:
            print(f"保存标记状态失败: {e}")

    def load_annotations(self):
        """加载全局 annotations.json"""
        if os.path.exists(self.annotations_path):
            try:
                with open(self.annotations_path, 'r', encoding='utf-8') as f:
                    self.annotations_data = json.load(f)
                print(f"已加载 {len(self.annotations_data.get('samples', []))} 个标注片段")
            except Exception as e:
                print(f"加载 annotations 失败: {e}")
                self.annotations_data = {"samples": []}
        else:
            self.annotations_data = {"samples": []}

    def save_annotations(self):
        """保存 annotations.json"""
        try:
            with open(self.annotations_path, 'w', encoding='utf-8') as f:
                json.dump(self.annotations_data, f, indent=2, ensure_ascii=False)
            print(f"annotations 已保存到: {self.annotations_path}")
        except Exception as e:
            print(f"保存 annotations 失败: {e}")

    @staticmethod
    def extract_frame_number(img_path):
        """从文件名中提取最后一个数字串作为帧号"""
        import re
        name = os.path.splitext(os.path.basename(img_path))[0]
        numbers = re.findall(r'\d+', name)
        return int(numbers[-1]) if numbers else 0

    def build_video_name(self, segment_paths, person_id, label):
        """构建视频文件名(不含扩展名)"""
        folder_name = os.path.basename(self.folders[self.current_folder_index])
        start_frame = self.extract_frame_number(segment_paths[0])
        end_frame = self.extract_frame_number(segment_paths[-1])
        range_str = f"{start_frame:06d}to{end_frame:06d}"
        suffix = "A001" if label == "positive" else "A002"
        return f"{folder_name}_{range_str}_P{person_id}_{suffix}"

    def build_annotation(self, segment_paths, person_id, label, video_name):
        """
        根据连续帧片段构建 annotation 字典
        segment_paths : list of absolute image paths (已排序)
        person_id     : int
        label         : str ('positive' / 'negative')
        video_name    : str (不含扩展名)
        """
        # 1. 读取第一帧获取图像尺寸
        first_img = cv2.imread(segment_paths[0])
        if first_img is None:
            raise ValueError(f"无法读取图片: {segment_paths[0]}")
        h, w = first_img.shape[:2]

        # 2. 逐帧提取该人物的关键点
        keypoint_seq = []   # T x 17 x 2
        score_seq = []      # T x 17

        for img_path in segment_paths:
            img_name = os.path.basename(img_path)
            persons = self.inference_data.get(img_name, [])
            target_person = None
            for p in persons:
                if p.get("id") == person_id:
                    target_person = p
                    break

            if target_person is None:
                # 理论上不应该发生,但若缺失则用全零填充
                keypoint_seq.append([[0.0, 0.0]] * 17)
                score_seq.append([0.0] * 17)
                continue

            kps = target_person.get("keypoints", [])
            # 解析 keypoints 列表 (x,y,v) 共 51 个值
            coords = []
            scores = []
            for i in range(0, len(kps), 3):
                x = kps[i] if i < len(kps) else 0.0
                y = kps[i+1] if i+1 < len(kps) else 0.0
                v = kps[i+2] if i+2 < len(kps) else 0.0
                coords.append([x, y])
                scores.append(v)
            # 如果某些帧点数不足17,补齐
            while len(coords) < 17:
                coords.append([0.0, 0.0])
                scores.append(0.0)
            keypoint_seq.append(coords[:17])
            score_seq.append(scores[:17])

        # 3. 构建 annotation 字典
        annotation = {
            "frame_dir": video_name,
            "label": 0 if label == "positive" else 1,   # 正样本→0,负样本→1
            "img_shape": [h, w],
            "original_shape": [h, w],
            "total_frames": len(segment_paths),
            "keypoint": keypoint_seq,      # T x 17 x 2
            "keypoint_score": score_seq    # T x 17
        }
        return annotation

def draw_button(screen, text, rect, hover=False, color=None):
    """绘制按钮"""
    if color is None:
        color = BUTTON_HOVER_COLOR if hover else BUTTON_COLOR
    
    pygame.draw.rect(screen, color, rect, border_radius=5)
    pygame.draw.rect(screen, (100, 100, 100), rect, 2, border_radius=5)
    
    text_surface = small_font.render(text, True, TEXT_COLOR)
    txt_rect = text_surface.get_rect(center=rect.center)
    screen.blit(text_surface, txt_rect)

def draw_confirm_dialog(screen, message, width=400, height=200):
    """绘制确认对话框"""
    dialog_rect = pygame.Rect(
        (WINDOW_WIDTH - width) // 2,
        (WINDOW_HEIGHT - height) // 2, 
        width, height
    )
    
    pygame.draw.rect(screen, BG_COLOR, dialog_rect, border_radius=10)
    pygame.draw.rect(screen, TEXT_COLOR, dialog_rect, 2, border_radius=10)
    
    lines = []
    words = message.split()
    current_line = ""
    
    for word in words:
        test_line = current_line + word + " "
        if small_font.size(test_line)[0] < width - 40:
            current_line = test_line
        else:
            lines.append(current_line)
            current_line = word + " "
    
    if current_line:
        lines.append(current_line)
    
    for i, line in enumerate(lines):
        text_surface = small_font.render(line, True, TEXT_COLOR)
        screen.blit(text_surface, (dialog_rect.x + 20, dialog_rect.y + 30 + i * 25))
    
    yes_button = pygame.Rect(dialog_rect.x + width // 2 - 100, dialog_rect.y + height - 50, 80, 30)
    no_button = pygame.Rect(dialog_rect.x + width // 2 + 20, dialog_rect.y + height - 50, 80, 30)
    
    draw_button(screen, "是", yes_button, color=CONFIRM_COLOR)
    draw_button(screen, "否", no_button, color=WARNING_COLOR)
    
    return dialog_rect, yes_button, no_button

def main():
    root_path = ROOT_PATH
    
    tool = ImageLabelingTool(root_path)
    
    positive_dir = os.path.join(root_path, "1")
    negative_dir = os.path.join(root_path, "0")
    
    running = True
    clock = pygame.time.Clock()
    
    # 按钮布局
    button_height = 40
    button_width = 140
    button_margin = 15
    
    # 行坐标(从底部向上数)
    row1_y = WINDOW_HEIGHT - button_height - button_margin                     # 最下面一行(第0行)
    row2_y = WINDOW_HEIGHT - 2*button_height - 2*button_margin                # 第1行
    row3_y = WINDOW_HEIGHT - 3*button_height - 3*button_margin                # 第2行
    row4_y = WINDOW_HEIGHT - 4*button_height - 4*button_margin                # 第3行

    # 分析按钮(第4行)
    analysis_buttons = {
        "analyze_ids": pygame.Rect(button_margin, row4_y, button_width, button_height),
    }

    # 人物选择按钮(第3行)
    person_buttons = {
        "prev_person": pygame.Rect(button_margin, row3_y, button_width, button_height),
        "next_person": pygame.Rect(button_margin * 2 + button_width, row3_y, button_width, button_height),
        "toggle_bbox": pygame.Rect(button_margin * 3 + button_width * 2, row3_y, button_width, button_height),
        "toggle_kpts": pygame.Rect(button_margin * 4 + button_width * 3, row3_y, button_width, button_height),
        "cancel_label": pygame.Rect(button_margin * 5 + button_width * 4, row3_y, button_width, button_height),
    }

    # 导航按钮(第2行)
    nav_buttons = {
        "prev": pygame.Rect(button_margin, row2_y, button_width, button_height),
        "next": pygame.Rect(button_margin * 2 + button_width, row2_y, button_width, button_height),
        "prev_folder": pygame.Rect(button_margin * 3 + button_width * 2, row2_y, button_width, button_height),
        "next_folder": pygame.Rect(button_margin * 4 + button_width * 3, row2_y, button_width, button_height),
        "undo": pygame.Rect(button_margin * 5 + button_width * 4, row2_y, button_width, button_height),
    }

    # 标注按钮(第1行)
    label_buttons = {
        "positive": pygame.Rect(button_margin, row1_y, button_width, button_height),
        "negative": pygame.Rect(button_margin * 2 + button_width, row1_y, button_width, button_height),
        "continuous_start": pygame.Rect(button_margin * 3 + button_width * 2, row1_y, button_width, button_height),
        "continuous_end": pygame.Rect(button_margin * 4 + button_width * 3, row1_y, button_width, button_height),
        "move_files": pygame.Rect(button_margin * 5 + button_width * 4, row1_y, button_width, button_height),
    }

    # 图片显示区域
    image_area = pygame.Rect(50, 80, WINDOW_WIDTH - 100, WINDOW_HEIGHT - 340)
    
    tool.key_pressed_time = 0
    
    while running:
        mouse_pos = pygame.mouse.get_pos()
        
        tool.handle_key_repeats()
        
        if tool.playing:
            now = pygame.time.get_ticks()
            if now - tool.last_play_tick > tool.play_interval:
                if tool.play_direction == 1:
                    tool.next_image()
                else:
                    tool.prev_image()
                tool.last_play_tick = now
        
        for event in pygame.event.get():
            if event.type == QUIT:
                running = False
            elif event.type == KEYDOWN:
                if event.key == K_d:
                    tool.key_pressed["right"] = True
                    tool.key_pressed["left"] = False
                    tool.key_pressed_time = time.time()
                    tool.next_image()
                elif event.key == K_a:
                    tool.key_pressed["left"] = True
                    tool.key_pressed["right"] = False
                    tool.key_pressed_time = time.time()
                    tool.prev_image()
                elif event.key == K_q:  # 切换到上一个人物
                    tool.prev_person_id()
                elif event.key == K_e:  # 切换到下一个人物
                    tool.next_person_id()
                elif event.key == K_b:  # 切换边界框显示
                    tool.show_bboxes = not tool.show_bboxes
                elif event.key == K_k:  # 切换关键点显示
                    tool.show_keypoints = not tool.show_keypoints
                elif event.key == K_RIGHT:
                    tool.play_direction = 1
                    tool.playing = True
                    tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_LEFT:
                    tool.play_direction = -1
                    tool.playing = True
                    tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_SPACE:
                    tool.playing = not tool.playing
                    if tool.playing:
                        tool.last_play_tick = pygame.time.get_ticks()
                elif event.key == K_v:
                    tool.convert_to_video = not tool.convert_to_video
                    print("转视频模式:" + ("开启" if tool.convert_to_video else "关闭"))
                elif event.key == K_w:  # 标记选定人物为正样本
                    tool.label_current_image("positive")
                elif event.key == K_s:  # 标记选定人物为负样本
                    tool.label_current_image("negative")
                elif event.key == K_UP:
                    if not tool.start_continuous_labeling():
                        print("无法开始连续标记")
                elif event.key == K_DOWN:
                    if not tool.end_continuous_labeling():
                        print("没有激活的连续标记")
                elif event.key == K_x:
                    tool.move_labeled_files(positive_dir, negative_dir)
                elif event.key == K_c:
                    tool.next_folder()
                elif event.key == K_r:  # 按 R 取消当前选定人物的标记
                    if tool.cancel_current_person_label():
                        print(f"已取消人物 {tool.selected_person_id} 的标记")
                    else:
                        print("无法取消标记(没有标记或图片)")
                elif event.key == K_z:
                    if pygame.key.get_mods() & KMOD_CTRL:
                        # Ctrl+Z: 撤销
                        if tool.undo():
                            print("已撤销上一次操作")
                        else:
                            print("没有可撤销的操作")
                    else:
                        # 单独按 Z: 上一个文件夹
                        tool.prev_folder()
                elif event.key == K_ESCAPE:
                    if tool.show_confirm_dialog:
                        tool.show_confirm_dialog = False
            
            elif event.type == KEYUP:
                if event.key == K_d:
                    tool.key_pressed["right"] = False
                    tool.last_key_time = 0
                elif event.key == K_a:
                    tool.key_pressed["left"] = False
                    tool.last_key_time = 0
            
            elif event.type == MOUSEBUTTONDOWN:
                if event.button == 1:
                    if tool.show_confirm_dialog:
                        dialog_rect, yes_button, no_button = draw_confirm_dialog(screen, tool.confirm_message)
                        if yes_button.collidepoint(mouse_pos):
                            tool.show_confirm_dialog = False
                            if tool.confirm_action == "next_folder":
                                tool.current_folder_index += 1
                                tool.load_current_folder_images()
                        elif no_button.collidepoint(mouse_pos):
                            tool.show_confirm_dialog = False
                    else:
                        # 分析按钮
                        if analysis_buttons["analyze_ids"].collidepoint(mouse_pos):
                            tool.analyze_id_tracking()

                        # 人物选择按钮
                        if person_buttons["prev_person"].collidepoint(mouse_pos):
                            tool.prev_person_id()
                        elif person_buttons["next_person"].collidepoint(mouse_pos):
                            tool.next_person_id()
                        elif person_buttons["toggle_bbox"].collidepoint(mouse_pos):
                            tool.show_bboxes = not tool.show_bboxes
                        elif person_buttons["toggle_kpts"].collidepoint(mouse_pos):
                            tool.show_keypoints = not tool.show_keypoints
                        elif person_buttons["cancel_label"].collidepoint(mouse_pos):
                            if tool.cancel_current_person_label():
                                print(f"已取消人物 {tool.selected_person_id} 的标记")
                            else:
                                print("无法取消标记")
                        
                        # 导航按钮
                        elif nav_buttons["prev"].collidepoint(mouse_pos):
                            tool.prev_image()
                        elif nav_buttons["next"].collidepoint(mouse_pos):
                            tool.next_image()
                        elif nav_buttons["prev_folder"].collidepoint(mouse_pos):
                            tool.prev_folder()
                        elif nav_buttons["next_folder"].collidepoint(mouse_pos):
                            tool.next_folder()
                        elif nav_buttons["undo"].collidepoint(mouse_pos):
                            if tool.undo():
                                print("已撤销上一次操作")
                            else:
                                print("没有可撤销的操作")
                        
                        # 标注按钮
                        elif label_buttons["positive"].collidepoint(mouse_pos):
                            tool.label_current_image("positive")
                        elif label_buttons["negative"].collidepoint(mouse_pos):
                            tool.label_current_image("negative")
                        elif label_buttons["continuous_start"].collidepoint(mouse_pos):
                            if not tool.start_continuous_labeling():
                                print("无法开始连续标记")
                        elif label_buttons["continuous_end"].collidepoint(mouse_pos):
                            if not tool.end_continuous_labeling():
                                print("没有激活的连续标记")
                        elif label_buttons["move_files"].collidepoint(mouse_pos):
                            tool.move_labeled_files(positive_dir, negative_dir)
                        
                        # 检查是否点击了图片区域来选择人物
                        elif image_area.collidepoint(mouse_pos):
                            tool.select_person_by_click(mouse_pos, image_area)
        
        # 清屏
        screen.fill(BG_COLOR)
        
        # 显示文件信息
        if tool.folders:
            folder_text = f"当前文件夹: {os.path.basename(tool.folders[tool.current_folder_index])} ({tool.current_folder_index + 1}/{len(tool.folders)})"
            text_surface = small_font.render(folder_text, True, TEXT_COLOR)
            screen.blit(text_surface, (20, 20))
        
        # 显示当前图片
        current_image_path = tool.get_current_image()
        if current_image_path and os.path.exists(current_image_path):
            try:
                img = pygame.image.load(current_image_path)
                img_rect = img.get_rect()
                
                scale = min(image_area.width / img_rect.width, image_area.height / img_rect.height)
                new_size = (int(img_rect.width * scale), int(img_rect.height * scale))
                img = pygame.transform.smoothscale(img, new_size)
                img_rect = img.get_rect(center=image_area.center)
                
                screen.blit(img, img_rect)
                
                # 显示检测结果
                tool.draw_detections(screen, img_rect, current_image_path)
                
                # 显示图片信息
                info_text = f"{os.path.basename(current_image_path)} ({tool.current_image_index + 1}/{len(tool.images)})"

                # 显示当前图片所有已标记人物的动作类别(正样本 → 0,负样本 → 1)
                if current_image_path in tool.labels:
                    person_labels = tool.labels[current_image_path]
                    # 过滤出真正的人物ID(排除 __processed__ 等非数字键)
                    valid_labels = {pid: label for pid, label in person_labels.items()
                                    if isinstance(pid, int) or (isinstance(pid, str) and pid.isdigit())}
                    if valid_labels:
                        label_map = {'positive': 0, 'negative': 1}
                        labeled_dict = {int(pid): label_map[label] for pid, label in valid_labels.items()
                                        if label in label_map}
                        if labeled_dict:
                            # 格式化为 {1:0,2:1} 紧凑形式
                            dict_str = str(labeled_dict).replace(' ', '')
                            info_text += f" - 已标记人物动作{dict_str}"

                text_surface = font.render(info_text, True, TEXT_COLOR)
                text_rect = text_surface.get_rect(center=(WINDOW_WIDTH // 2, image_area.y - 20))
                screen.blit(text_surface, text_rect)
                
                # 显示选定人物ID
                person_text = f"选定人物ID: {tool.selected_person_id} (Q/E切换, 点击框选择)"
                person_surface = small_font.render(person_text, True, HIGHLIGHT_COLOR)
                screen.blit(person_surface, (20, 50))
                
                # 显示连续标记范围
                if tool.continuous_mode and tool.continuous_start_index is not None:
                    start_idx = min(tool.continuous_start_index, tool.current_image_index)
                    end_idx = max(tool.continuous_start_index, tool.current_image_index)
                    frame_diff = abs(tool.current_image_index - tool.continuous_start_index)
                    
                    # 第一行:连续标记范围
                    range_text1 = f"连续标记范围: {start_idx + 1} - {end_idx + 1}"
                    range_surface1 = small_font.render(range_text1, True, HIGHLIGHT_COLOR)
                    screen.blit(range_surface1, (20, 75))
                    
                    # 第二行:已过帧数
                    range_text2 = f"已过 {frame_diff} 帧 (人物 {tool.selected_person_id})"
                    range_surface2 = small_font.render(range_text2, True, HIGHLIGHT_COLOR)
                    screen.blit(range_surface2, (20, 100))
                    
                    marker_width = image_area.width / len(tool.images)
                    start_x = image_area.x + start_idx * marker_width
                    end_x = image_area.x + (end_idx + 1) * marker_width
                    
                    pygame.draw.rect(screen, HIGHLIGHT_COLOR,
                                    (start_x, image_area.y + image_area.height + 5,
                                    end_x - start_x, 5))
            
            except Exception as e:
                error_text = f"无法加载图片: {e}"
                text_surface = font.render(error_text, True, (255, 0, 0))
                screen.blit(text_surface, (image_area.centerx - text_surface.get_width() // 2, image_area.centery - text_surface.get_height() // 2))
        
        else:
            no_image_text = "没有图片可显示"
            text_surface = font.render(no_image_text, True, TEXT_COLOR)
            screen.blit(text_surface, (image_area.centerx - text_surface.get_width() // 2, image_area.centery - text_surface.get_height() // 2))
        
        # 显示显示设置状态
        display_status = f"显示: 边界框({'开' if tool.show_bboxes else '关'}/B) 关键点({'开' if tool.show_keypoints else '关'}/K)"
        status_surface = small_font.render(display_status, True, TEXT_COLOR)
        screen.blit(status_surface, (WINDOW_WIDTH - status_surface.get_width() - 20, 50))

        # 绘制分析按钮
        draw_button(screen, "分析ID跟踪", analysis_buttons["analyze_ids"], analysis_buttons["analyze_ids"].collidepoint(mouse_pos))
        
        # 绘制人物选择按钮
        draw_button(screen, "上个ID (Q)", person_buttons["prev_person"], person_buttons["prev_person"].collidepoint(mouse_pos))
        draw_button(screen, "下个ID (E)", person_buttons["next_person"], person_buttons["next_person"].collidepoint(mouse_pos))
        draw_button(screen, f"边界框(B):{'开' if tool.show_bboxes else '关'}", person_buttons["toggle_bbox"], person_buttons["toggle_bbox"].collidepoint(mouse_pos))
        draw_button(screen, f"关键点(K):{'开' if tool.show_keypoints else '关'}", person_buttons["toggle_kpts"], person_buttons["toggle_kpts"].collidepoint(mouse_pos))
        draw_button(screen, "取消标记 (R)", person_buttons["cancel_label"], person_buttons["cancel_label"].collidepoint(mouse_pos)) 

        # 绘制导航按钮
        draw_button(screen, "上一张 (a)", nav_buttons["prev"], nav_buttons["prev"].collidepoint(mouse_pos))
        draw_button(screen, "下一张 (d)", nav_buttons["next"], nav_buttons["next"].collidepoint(mouse_pos))
        draw_button(screen, "上个文件夹 (z)", nav_buttons["prev_folder"], nav_buttons["prev_folder"].collidepoint(mouse_pos))
        draw_button(screen, "下个文件夹 (c)", nav_buttons["next_folder"], nav_buttons["next_folder"].collidepoint(mouse_pos))
        draw_button(screen, "撤销 (Ctrl+Z)", nav_buttons["undo"], nav_buttons["undo"].collidepoint(mouse_pos))
        
        # 绘制标注按钮
        draw_button(screen, "正样本 (w)", label_buttons["positive"], label_buttons["positive"].collidepoint(mouse_pos))
        draw_button(screen, "负样本 (s)", label_buttons["negative"], label_buttons["negative"].collidepoint(mouse_pos))
        draw_button(screen, "开始连续标(↑)", label_buttons["continuous_start"], label_buttons["continuous_start"].collidepoint(mouse_pos))
        draw_button(screen, "结束连续标(↓)", label_buttons["continuous_end"], label_buttons["continuous_end"].collidepoint(mouse_pos))
        draw_button(screen, "移动文件 (x)", label_buttons["move_files"], label_buttons["move_files"].collidepoint(mouse_pos))
        

        # 显示确认对话框
        if tool.show_confirm_dialog:
            draw_confirm_dialog(screen, tool.confirm_message)
        
        # 更新屏幕
        pygame.display.flip()
        clock.tick(30)
    
    # 退出前保存
    tool.save_annotations()   # 若已有修改,确保保存
    # 退出前保存标记状态
    tool.save_labels()
    pygame.quit()
    sys.exit()

if __name__ == "__main__":
    main()

2.2 生成的json格式

生成的annotations.json格式如下:

python 复制代码
{
    "samples": [
        {   # P1的1是指用户1
            # A001的001是指用户标记真实的动作类别是第1个动作,A后面的数字是动作id
            'frame_dir': 'fisheye_back_2026-01-29-14-16-15_358to389_P1_A001',
            'label': 0,  # 真实的动作类别索引,第1个动作类别对应索引0
            'img_shape': (1080, 1920),
            'original_shape': (1080, 1920),
            'total_frames': 103,
            'keypoint': (1, 103, 17, 2),
            'keypoint_score': (1, 103, 17),
        },
        { 
            'frame_dir': 'fisheye_back_2026-01-29-14-16-15_358to389_P2_A002',
            'label': 1,  # 真实的动作类别索引,第2个动作类别对应索引1
            'img_shape': (1080, 1920),
            'original_shape': (1080, 1920),
            'total_frames': 158,
            'keypoint': (1, 158, 17, 2),
            'keypoint_score': (1, 158, 17),
        },
        ...... 
     ]
}

三、基于annotations.json生成pkl

3.1 回顾一下pkl 文件格式

python 复制代码
data:{
    'split':{
                 'train': ['frame_dir名称', 'xxx', ...],  
                 'val':  ['xxx', 'xxx', ...],  
     }

    'annotations':{  
        {   # P1的1是指用户1
            # A001的001是指用户标记真实的动作类别是第1个动作,A后面的数字是动作id
            'frame_dir': 'fisheye_back_2026-01-29-14-16-15_358to389_P1_A001',
            'label': 0,  # 真实的动作类别索引,第1个动作类别对应索引0
            'img_shape': (1080, 1920),
            'original_shape': (1080, 1920),
            'total_frames': 103,
            'keypoint': (1, 103, 17, 2),
            'keypoint_score': (1, 103, 17),
        },
        { 
            'frame_dir': 'fisheye_back_2026-01-29-14-16-15_358to389_P2_A002',
            'label': 1,  # 真实的动作类别索引,第2个动作类别对应索引1
            'img_shape': (1080, 1920),
            'original_shape': (1080, 1920),
            'total_frames': 158,
            'keypoint': (1, 158, 17, 2),
            'keypoint_score': (1, 158, 17),
        },
        ......
    }
}

3.2 基于annotations.json生成pkl脚本

用于递归扫描标注工具生成的 annotations.json,合并所有样本,按 8:2 随机划分为训练集/验证集,并输出可直接用于动作识别模型训练的标准 pkl 文件。

D:\zero_track\mmaction2\my_tools\generate_ntu60_pkl.py

python 复制代码
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# generate_ntu60_pkl.py

"""
生成 MMAction2 风格的关键点动作识别数据集 pkl 文件
用法:
    python generate_pkl.py --root /path/to/annotations_root --output my_ntu60_2d.pkl
"""

import os
import json
import pickle
import numpy as np
import argparse
from glob import glob
from sklearn.model_selection import train_test_split

def find_all_annotations(root_path):
    """递归查找 root_path 下所有的 annotations.json 文件"""
    pattern = os.path.join(root_path, "**", "annotations.json")
    return glob(pattern, recursive=True)

def load_and_merge_samples(json_files):
    """加载多个 annotations.json 文件,合并所有 samples 到一个列表"""
    all_samples = []
    for f in json_files:
        try:
            with open(f, 'r', encoding='utf-8') as fp:
                data = json.load(fp)
                samples = data.get('samples', [])
                print(f"  {f}: {len(samples)} 个样本")
                all_samples.extend(samples)
        except Exception as e:
            print(f"  警告: 无法加载 {f} - {e}")
    return all_samples

def convert_to_numpy(sample):
    """
    将 sample 中的 keypoint 和 keypoint_score 从 list 转为 numpy 数组,
    并 reshape 为 (1, T, 17, 2) 和 (1, T, 17)
    """
    kp = np.array(sample['keypoint'], dtype=np.float32)          # (T, 17, 2)
    score = np.array(sample['keypoint_score'], dtype=np.float32) # (T, 17)
    T = kp.shape[0]

    # 添加人物维度 (1, T, 17, 2)
    kp = kp.reshape(1, T, 17, 2)
    score = score.reshape(1, T, 17)

    return {
        'frame_dir': sample['frame_dir'],
        'label': sample['label'],
        'img_shape': tuple(sample['img_shape']),
        'original_shape': tuple(sample['original_shape']),
        'total_frames': sample['total_frames'],
        'keypoint': kp,
        'keypoint_score': score
    }

def main():
    parser = argparse.ArgumentParser(description='生成动作识别 pkl 文件')
    input_path = r"D:\zero_track\mmaction2\my_code\input"
    parser.add_argument('--root', type=str, default=input_path, 
                        help='标注工具根目录(会递归查找所有 annotations.json)')
    parser.add_argument('--output', type=str, default=input_path + os.sep + 'my_ntu60_2d.pkl', help='输出 pkl 文件名')
    parser.add_argument('--test-size', type=float, default=0.2,
                        help='验证集比例,默认 0.2')
    parser.add_argument('--seed', type=int, default=42,
                        help='随机种子')
    args = parser.parse_args()

    print(f"查找根目录下的 annotations.json: {args.root}")
    json_files = find_all_annotations(args.root)
    if not json_files:
        print("错误: 未找到任何 annotations.json 文件")
        return

    print(f"找到 {len(json_files)} 个 annotations.json 文件,开始合并...")
    samples = load_and_merge_samples(json_files)
    print(f"共合并 {len(samples)} 个样本")

    if len(samples) == 0:
        print("错误: 没有样本数据")
        return

    # 提取所有 frame_dir 用于划分
    frame_dirs = [s['frame_dir'] for s in samples]

    # 随机划分训练集/验证集
    train_names, val_names = train_test_split(
        frame_dirs,
        test_size=args.test_size,
        random_state=args.seed
    )
    print(f"训练集样本数: {len(train_names)}")
    print(f"验证集样本数: {len(val_names)}")

    # 构建 split 字典
    split = {
        'train': train_names,
        'val': val_names
    }

    # 转换所有样本为 numpy 格式
    annotations = []
    for s in samples:
        try:
            ann = convert_to_numpy(s)
            annotations.append(ann)
        except Exception as e:
            print(f"  跳过样本 {s.get('frame_dir', 'unknown')}: {e}")

    # 最终数据结构
    output_data = {
        'split': split,
        'annotations': annotations
    }

    # 保存 pkl
    with open(args.output, 'wb') as f:
        pickle.dump(output_data, f)

    print(f"\n成功生成 pkl 文件: {args.output}")
    print(f"  split 包含训练集 {len(train_names)} 个, 验证集 {len(val_names)} 个")
    print(f"  annotations 包含 {len(annotations)} 个样本")

if __name__ == '__main__':
    main()
相关推荐
才兄说1 小时前
机器人租售配合齐不齐?多人同步操作
python
好家伙VCC2 小时前
# 发散创新:基于Python的TTS语音合成实战与优化策略 在人工智能加速落地的今天,**文本转
java·开发语言·人工智能·python
沃达德软件2 小时前
模糊图像复原技术解析
图像处理·人工智能·深度学习·目标检测·机器学习·计算机视觉·目标跟踪
来两个炸鸡腿2 小时前
【Datawhale组队学习202602】Easy-Vibe task05 完整项目实战
人工智能·学习·大模型·vibe coding
烤鸭的世界我们不懂2 小时前
开箱即用:OpenCLaw集成飞书实用部署与使用心得
人工智能·ai编程·openclaw
沪漂阿龙2 小时前
第二章:RAG系统技术架构设计
人工智能·系统架构
货拉拉技术2 小时前
文本大模型评测实践
人工智能·深度学习·算法
yuankoudaodaokou2 小时前
如何让三维数字化技术落地?思看科技三级认证培训体系赋能用户成长
python·科技