手部检测 yolov5 实战笔记

手部检测

yolo-v5

https://github.com/XIAN-HHappy/yolo-v5

检测保存json:

python 复制代码
# -*- coding:utf-8 -*-
import warnings
warnings.filterwarnings("ignore")
import argparse
import os
import json
import random
import time
import cv2
from utils.datasets import *
from utils.utils import *

def detect(save_img=False):
    # --- 解析参数
    source, weights, half, imgsz = \
        opt.source, opt.weights, opt.half, opt.img_size

    # --- 初始化设备
    device = torch_utils.select_device(opt.device)

    # --- 加载模型
    ckpt = torch.load(weights, weights_only=False, map_location="cpu")
    model = ckpt['model'].float().to(device).eval()
    names = model.names if hasattr(model, 'names') else model.modules.names

    # --- 设置半精度
    half = half and device.type != 'cpu'
    if half:
        model.half()

    save_dir = source[:-4]
    # --- 初始化视频流
    cap = cv2.VideoCapture(source)
    if not cap.isOpened():
        raise Exception(f"无法打开视频: {source}")

    os.makedirs(save_dir, exist_ok=True)

    frame_count = 0
    t0 = time.time()

    while True:
        ret, img0 = cap.read()
        if not ret:
            break

        frame_count += 1
        h0, w0 = img0.shape[:2]

        # --- 预处理输入图像
        img = letterbox(img0, new_shape=imgsz)[0]
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR → RGB
        img = np.ascontiguousarray(img)
        img = torch.from_numpy(img).to(device)
        img = img.half() if half else img.float()
        img /= 255.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # --- 推理
        t1 = torch_utils.time_synchronized()
        pred = model(img, augment=opt.augment)[0]
        t2 = torch_utils.time_synchronized()

        if half:
            pred = pred.float()

        # --- NMS
        pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres,
                                   classes=opt.classes, agnostic=opt.agnostic_nms)

        # --- 遍历检测结果
        for i, det in enumerate(pred):
            if det is None or len(det) == 0:
                continue

            # --- 转换为原图坐标
            det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0.shape).round()

            shapes = []
            for *xyxy, conf, cls in det:
                x1, y1, x2, y2 = [float(x) for x in xyxy]
                label = f"{names[int(cls)]}"

                shapes.append({
                    "label": "hand",
                    "points": [[x1, y1], [x2, y2]],
                    "group_id": 0,
                    "description": "",
                    "shape_type": "rectangle",
                    "flags": {}
                })

                # # 画框可视化
                # plot_one_box(xyxy, img0, label=f"{label} {conf:.2f}",
                #              color=(0, 255, 0), line_thickness=2)

            # --- 生成 JSON
            json_dict = {
                "version": "5.3.1",
                "flags": {},
                "imageData": None,
                "imageHeight": h0,
                "imageWidth": w0,
                "imagePath": f"frame_{frame_count:04d}.jpg",
                "shapes": shapes
            }

            json_path = os.path.join(save_dir, f"frame_{frame_count:04d}.json")
            with open(json_path, "w", encoding="utf-8") as f:
                json.dump(json_dict, f, ensure_ascii=False, indent=4)

            img_path=os.path.join(save_dir, f"frame_{frame_count:04d}.jpg")
            cv2.imwrite(img_path, img0)

            print(f"[Frame {frame_count}] 检测到 {len(shapes)} 个目标 → {json_path}")

        # --- 显示结果
        cv2.imshow("YOLOv5 Detect Hands", img0)
        if cv2.waitKey(1) == 27:  # ESC退出
            break

    cap.release()
    cv2.destroyAllWindows()
    print(f"推理完成,共 {frame_count} 帧,耗时 {time.time() - t0:.2f}s")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default='hand_m.pt', help='model.pt path')
    # parser.add_argument('--source', type=str, default=r"D:\data\jiezhi\det_1201\20251201-201906.mp4", help='source')
    parser.add_argument('--source', type=str, default=r"D:\data\jiezhi\det_1201\20251201-201944.mp4", help='source')
    parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
    parser.add_argument('--fourcc', type=str, default='mp4v', help='output video codec')
    parser.add_argument('--half', default=False, help='half precision FP16 inference')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or cpu')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--augment', default=False, help='augmented inference')
    opt = parser.parse_args()
    print(opt)

    with torch.no_grad():
        detect(save_img=True)
相关推荐
寻星探路3 分钟前
【深度长文】万字攻克网络原理:从 HTTP 报文解构到 HTTPS 终极加密逻辑
java·开发语言·网络·python·http·ai·https
聆风吟º2 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
User_芊芊君子3 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
ValhallaCoder3 小时前
hot100-二叉树I
数据结构·python·算法·二叉树
智驱力人工智能3 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
人工不智能5773 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
猫头虎3 小时前
如何排查并解决项目启动时报错Error encountered while processing: java.io.IOException: closed 的问题
java·开发语言·jvm·spring boot·python·开源·maven
h64648564h4 小时前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
心疼你的一切4 小时前
解密CANN仓库:AIGC的算力底座、关键应用与API实战解析
数据仓库·深度学习·aigc·cann
八零后琐话4 小时前
干货:程序员必备性能分析工具——Arthas火焰图
开发语言·python