Python开源工具库使用之运动姿势追踪库mediapipe

文章目录

  • 前言
  • 一、姿势估计
    • [1.1 姿态关键点](#1.1 姿态关键点)
    • [1.2 旧版 solution API](#1.2 旧版 solution API)
    • [1.3 新版 solution API](#1.3 新版 solution API)
    • [1.4 俯卧撑计数](#1.4 俯卧撑计数)
  • 二、手部追踪
    • [2.1 手部姿态](#2.1 手部姿态)
    • [2.2 API 使用](#2.2 API 使用)
    • [2.3 识别手势含义](#2.3 识别手势含义)
  • 参考

前言

Mediapipe 是谷歌出品的一种开源框架,旨在为开发者提供一种简单而强大的工具,用于实现各种视觉和感知应用程序。它包括一系列预训练的机器学习模型和用于处理多媒体数据的工具,可以用于姿势估计、手部追踪、人脸检测与跟踪、面部标志、对象检测、图片分割和语言检测等任务

Mediapipe 是支持跨平台的,可以部署在手机端(Android, iOS), web, desktop, edge devices, IoT 等各种平台,编程语言也支持C++, Python, Java, Swift, Objective-C, Javascript等

在本文中,我们将通过Python实现 Mediapipe 在姿势估计和手部追踪不同领域的应用

一、姿势估计

1.1 姿态关键点

序号 部位 Pose Landmark
0 鼻子 PoseLandmark.NOSE
1 左眼(内侧) PoseLandmark.LEFT_EYE_INNER
2 左眼 PoseLandmark.LEFT_EYE
3 左眼(外侧) PoseLandmark.LEFT_EYE_OUTER
4 右眼(内侧) PoseLandmark.RIGHT_EYE_INNER
5 右眼 PoseLandmark.RIGHT_EYE
6 右眼(外侧) PoseLandmark.RIGHT_EYE_OUTER
7 左耳 PoseLandmark.LEFT_EAR
8 右耳 PoseLandmark.RIGHT_EAR
9 嘴巴(左侧) PoseLandmark.MOUTH_LEFT
10 嘴巴(右侧) PoseLandmark.MOUTH_RIGHT
11 左肩 PoseLandmark.LEFT_SHOULDER
12 右肩 PoseLandmark.RIGHT_SHOULDER
13 左肘 PoseLandmark.LEFT_ELBOW
14 右肘 PoseLandmark.RIGHT_ELBOW
15 左腕 PoseLandmark.LEFT_WRIST
16 右腕 PoseLandmark.RIGHT_WRIST
17 左小指 PoseLandmark.LEFT_PINKY
18 右小指 PoseLandmark.RIGHT_PINKY
19 左食指 PoseLandmark.LEFT_INDEX
20 右食指 PoseLandmark.RIGHT_INDEX
21 左拇指 PoseLandmark.LEFT_THUMB
22 右拇指 PoseLandmark.RIGHT_THUMB
23 左臀 PoseLandmark.LEFT_HIP
24 右臀 PoseLandmark.RIGHT_HIP
25 左膝 PoseLandmark.LEFT_KNEE
26 右膝 PoseLandmark.RIGHT_KNEE
27 左踝 PoseLandmark.LEFT_ANKLE
28 右踝 PoseLandmark.RIGHT_ANKLE
29 左脚跟 PoseLandmark.LEFT_HEEL
30 右脚跟 PoseLandmark.RIGHT_HEEL
31 左脚趾 PoseLandmark.LEFT_FOOT_INDEX
32 右脚趾 PoseLandmark.RIGHT_FOOT_INDEX

1.2 旧版 solution API

Mediapipe 提供 solution API 来实现快速检测, 不过这种方式在2023年5月10日停止更新了,不过目前还可以使用,可通过 mediapose.solutions.pose.Pose 来实现,配置参数如下

选项 含义 值范围 默认值
static_image_mode 如果设置为 False,会将输入图像视为视频流。它将尝试检测第一张图像中最突出的人,并在成功检测后进一步定位姿势。在随后的图像中,它只是跟踪这些标记,而不调用另一个检测,直到它失去跟踪,从而减少计算和延迟。如果设置为 True,则人员检测将运行每个输入图像,非常适合处理一批静态(可能不相关的)图像 Boolean False
model_complexity 模型的复杂度,准确性和推理延迟通常随着模型复杂性的增加而增加 {0,1,2} 1
smooth_landmarks 如果设置为 True,则solution 过滤器会在不同的输入图像中设置标记以减少抖动,但如果 static_image_mode 也设置为 True,则忽略该筛选器 Boolean True
enable_segmentation 如果设置为 True,则除了姿态标记外,还会生成分割蒙版 Boolean False
smooth_segmentation 如果设置为 True,则会过滤不同输入图像中的分割掩码,以减少抖动。如果enable_segmentation为 false 或 static_image_mode为 True,则忽略 Boolean True
min_detection_confidence 人员检测模型的最小置信度值 ,用于将检测视为成功 Float [0.0,1.0] 0.5
min_tracking_confidence 来自姿态跟踪模型的最小置信度值 , 用于将姿态标记视为成功跟踪,否则将在下一个输入图像上自动调用人员检测。将其设置为更高的值可以提高解决方案的可靠性,但代价是延迟更高。如果static_image_mode为 True,则忽略,其中人员检测仅对每个图像运行。 Float [0.0,1.0] 0.5
python 复制代码
import cv2
import numpy as np
import mediapipe as mp

def main():
    FILE_PATH = 'data/1.png'
    img = cv2.imread(FILE_PATH)
    mp_pose = mp.solutions.pose
    pose = mp_pose.Pose(static_image_mode=True,
                        min_detection_confidence=0.5, min_tracking_confidence=0.5)

    res = pose.process(img)
    img_copy = img.copy()

    if res.pose_landmarks is not None:
        mp_drawing = mp.solutions.drawing_utils
        # mp_drawing.draw_landmarks(
        #     img_copy, res.pose_landmarks, mp.solutions.pose.POSE_CONNECTIONS)
        mp_drawing.draw_landmarks(
            img_copy,
            res.pose_landmarks,
            mp_pose.POSE_CONNECTIONS,  # frozenset,定义了哪些关键点要连接
            mp_drawing.DrawingSpec(color=(255, 255, 255),  # 姿态关键点
                                   thickness=2,
                                   circle_radius=2),
            mp_drawing.DrawingSpec(color=(174, 139, 45),   # 连线颜色
                                   thickness=2,
                                   circle_radius=2),
        )

    cv2.imshow('MediaPipe Pose Estimation', img_copy)
    cv2.waitKey(0)


if __name__ == '__main__':
    main()
python 复制代码
import cv2
import numpy as np
import mediapipe as mp


def video():
    # 读取摄像头
    # cap = cv2.VideoCapture(0)
    # 读取视频
    cap = cv2.VideoCapture('data/1.mp4')
    mp_pose = mp.solutions.pose
    pose = mp_pose.Pose(static_image_mode=False,
                        min_detection_confidence=0.5, min_tracking_confidence=0.5)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
            # 摄像头
            # continue

        # 将 BGR 图像转换为 RGB
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # 进行姿势估计
        results = pose.process(rgb_frame)

        if results.pose_landmarks is not None:
            # 绘制关键点和连接线
            mp_drawing = mp.solutions.drawing_utils
            mp_drawing.draw_landmarks(
                frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)

        # 显示结果
        cv2.imshow('MediaPipe Pose Estimation', frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # 释放资源
    cap.release()
    cv2.destroyAllWindows()

if __name__ == '__main__':
    video()

1.3 新版 solution API

旧版 API 并不能检测多个姿态,新版 API 可以实现多个姿态检测

选项 含义 值范围 默认值
running_mode 设置任务的运行模式,有三种模式可选: IMAGE: 单一照片输入. VIDEO: 视频. LIVE_STREAM: 输入数据(例如来自摄像机)为实时流。在此模式下,必须调用 resultListener 来设置侦听器以异步接收结果. {IMAGE, VIDEO, LIVE_STREAM} IMAGE
num_poses 姿势检测器可以检测到的最大姿势数 Integer > 0 1
min_pose_detection_confidence 姿势检测被认为是成功的最小置信度得分 Float [0.0,1.0] 0.5
min_pose_presence_confidence 姿态检测中的姿态存在分数的最小置信度分数 Float [0.0,1.0] 0.5
min_tracking_confidence 姿势跟踪被视为成功的最小置信度分数 Float [0.0,1.0] 0.5
output_segmentation_masks 是否为检测到的姿势输出分割掩码 Boolean False
result_callback 将结果侦听器设置为在Pose Landmark处于LIVE_STREAM模式时异步接收Landmark结果。仅当运行模式设置为LIVE_STREAM时才能使用 ResultListener N/A
python 复制代码
from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2
import cv2
import numpy as np
import mediapipe as mp

mp_drawing = mp.solutions.drawing_utils
mp_pose = mp.solutions.pose

def draw_landmarks_on_image(rgb_image, detection_result):
    pose_landmarks_list = detection_result.pose_landmarks
    annotated_image = np.copy(rgb_image)

    # Loop through the detected poses to visualize.
    for idx in range(len(pose_landmarks_list)):
        pose_landmarks = pose_landmarks_list[idx]

        # Draw the pose landmarks.
        pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
        pose_landmarks_proto.landmark.extend([
            landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in pose_landmarks
        ])
        solutions.drawing_utils.draw_landmarks(
            annotated_image,
            pose_landmarks_proto,
            solutions.pose.POSE_CONNECTIONS,
            solutions.drawing_styles.get_default_pose_landmarks_style())
    return annotated_image


def newSolution():
    BaseOptions = mp.tasks.BaseOptions
    PoseLandmarker = mp.tasks.vision.PoseLandmarker
    PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
    VisionRunningMode = mp.tasks.vision.RunningMode
    model_path = 'data/pose_landmarker_heavy.task'
    options = PoseLandmarkerOptions(
        base_options=BaseOptions(model_asset_path=model_path),
        running_mode=VisionRunningMode.IMAGE,
        num_poses=10)

    FILE_PATH = 'data/4.jpg'
    image = cv2.imread(FILE_PATH)
    img = mp.Image.create_from_file(FILE_PATH)
    with PoseLandmarker.create_from_options(options) as detector:
        res = detector.detect(img)
        image = draw_landmarks_on_image(image, res)

    cv2.imshow('MediaPipe Pose Estimation', image)
    cv2.waitKey(0)


if __name__ == '__main__':
    newSolution()

1.4 俯卧撑计数

通过计算胳膊弯曲角度来判断状态,并计算俯卧撑个数

python 复制代码
import cv2
import mediapipe as mp
import numpy as np

mp_drawing = mp.solutions.drawing_utils
mp_pose = mp.solutions.pose


def calculate_angle(a, b, c):
    radians = np.arctan2(c.y - b.y, c.x - b.x) - \
        np.arctan2(a.y - b.y, a.x - b.x)
    angle = np.abs(np.degrees(radians))
    return angle if angle <= 180 else 360 - angle


def angle_of_arm(landmarks, shoulder, elbow, wrist):
    shoulder_coord = landmarks[mp_pose.PoseLandmark[shoulder].value]
    elbow_coord = landmarks[mp_pose.PoseLandmark[elbow].value]
    wrist_coord = landmarks[mp_pose.PoseLandmark[wrist].value]
    return calculate_angle(shoulder_coord, elbow_coord, wrist_coord)


def count_push_up(landmarks, counter, status):
    left_arm_angle = angle_of_arm(
        landmarks, "LEFT_SHOULDER", "LEFT_ELBOW", "LEFT_WRIST")
    right_arm_angle = angle_of_arm(
        landmarks, "RIGHT_SHOULDER", "RIGHT_ELBOW", "RIGHT_WRIST")
    avg_arm_angle = (left_arm_angle + right_arm_angle) // 2

    if status:
        if avg_arm_angle < 70:
            counter += 1
            status = False
    else:
        if avg_arm_angle > 160:
            status = True
    return counter, status


def main():
    cap = cv2.VideoCapture('data/test.mp4')
    counter = 0
    status = False
    with mp_pose.Pose(min_detection_confidence=0.7, min_tracking_confidence=0.7) as pose:
        while cap.isOpened():
            success, image = cap.read()
            if not success:
                print("empty camera")
                break
            result = pose.process(image)
            if result.pose_landmarks:
                mp_drawing.draw_landmarks(
                    image, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
                counter, status = count_push_up(
                    result.pose_landmarks.landmark, counter, status)
            cv2.putText(image, text=str(counter), org=(100, 100), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                        fontScale=4, color=(255, 255, 255), thickness=2, lineType=cv2.LINE_AA)
            cv2.imshow("push-up counter", image)
            key = cv2.waitKey(1)
            if key == ord('q'):
                break
    cap.release()


if __name__ == '__main__':
    main()

二、手部追踪

2.1 手部姿态

2.2 API 使用

照片

选项 含义 值范围 默认值
static_image_mode 如果设置为 False,会将输入图像视为视频流。它将尝试在第一个输入图像中检测手,并在成功检测后进一步定位手部标志。在随后的图像中,一旦检测到所有 max_num_hands 手并定位了相应的手部标志,它就会简单地跟踪这些标志,而不会调用其他检测,直到它失去对任何手的跟踪。这减少了延迟,是处理视频帧的理想选择。如果设置为 True,则对每个输入图像运行手动检测,非常适合处理一批静态(可能不相关的)图像 Boolean False
max_num_hands 要检测的最大手数 Integer 2
model_complexity 模型的复杂度,准确性和推理延迟通常随着模型复杂性的增加而增加 {0,1} 1
min_detection_confidence 检测模型的最小置信度值 ,用于将检测视为成功 Float [0.0,1.0] 0.5
min_tracking_confidence 来自手部跟踪模型的最小置信度值 , 用于将手部标记视为成功跟踪,否则将在下一个输入图像上自动调用检测。将其设置为更高的值可以提高解决方案的可靠性,但代价是延迟更高。如果static_image_mode为 True,则忽略,其中手部检测仅对每个图像运行。 Float [0.0,1.0] 0.5
python 复制代码
import cv2
import mediapipe as mp

mp_hands = mp.solutions.hands


def main():

    cv2.namedWindow("MediaPipe Hand", cv2.WINDOW_NORMAL)
    hands = mp_hands.Hands(static_image_mode=False, max_num_hands=2,
                           min_detection_confidence=0.5, min_tracking_confidence=0.5)
    img = cv2.imread('data/finger/1.jpg')
    rgb_frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # 进行手部追踪
    results = hands.process(rgb_frame)

    if results.multi_hand_landmarks:
        # 绘制手部关键点和连接线
        for hand_landmarks in results.multi_hand_landmarks:
            mp_drawing = mp.solutions.drawing_utils
            mp_drawing.draw_landmarks(
                img, hand_landmarks, mp_hands.HAND_CONNECTIONS)

    # 显示结果
    cv2.imshow('MediaPipe Hand', img)
    cv2.waitKey(0)
    
if __name__ == '__main__':
    main()
python 复制代码
import cv2
import mediapipe as mp

mp_hands = mp.solutions.hands

def video():

    hands = mp_hands.Hands(static_image_mode=False, max_num_hands=2,
                           min_detection_confidence=0.4, min_tracking_confidence=0.4)

    # 读取视频
    cap = cv2.VideoCapture('data/hand.mp4')

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # 将 BGR 图像转换为 RGB
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # 进行手部追踪
        results = hands.process(rgb_frame)

        if results.multi_hand_landmarks:
            # 绘制手部关键点和连接线
            for hand_landmarks in results.multi_hand_landmarks:
                mp_drawing = mp.solutions.drawing_utils
                mp_drawing.draw_landmarks(
                    frame, hand_landmarks, mp_hands.HAND_CONNECTIONS)

        # 显示结果
        cv2.imshow('MediaPipe Hand Tracking', frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # 释放资源
    cap.release()
    cv2.destroyAllWindows()


if __name__ == '__main__':
    video()

2.3 识别手势含义

使用 KNN 对手势进行预测

python 复制代码
import mediapipe as mp
import numpy as np
import cv2
from mediapipe.framework.formats.landmark_pb2 import NormalizedLandmarkList
from sklearn.neighbors import KNeighborsClassifier

mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_hands = mp.solutions.hands

# 压缩特征点
class Embedder(object):
    def __init__(self):
        self._landmark_names = mp.solutions.hands.HandLandmark

    def __call__(self, landmarks):
        # modify the call func can both handle a 3-dim dataset and a single referencing result.
        if isinstance(landmarks, np.ndarray):
            if landmarks.ndim == 3:  # for dataset
                embeddings = []
                for lmks in landmarks:
                    embedding = self.__call__(lmks)
                    embeddings.append(embedding)
                return np.array(embeddings)
            elif landmarks.ndim == 2:  # for inference
                assert landmarks.shape[0] == len(list(
                    self._landmark_names)), 'Unexpected number of landmarks: {}'.format(landmarks.shape[0])
                # Normalize landmarks.
                landmarks = self._normalize_landmarks(landmarks)
                # Get embedding.
                embedding = self._get_embedding(landmarks)
                return embedding
            else:
                print('ERROR: Can NOT embedding the data you provided !')
        else:
            if isinstance(landmarks, list):  # for dataset
                embeddings = []
                for lmks in landmarks:
                    embedding = self.__call__(lmks)
                    embeddings.append(embedding)
                return np.array(embeddings)
            elif isinstance(landmarks, NormalizedLandmarkList):  # for inference
                # Normalize landmarks.
                landmarks = np.array([[lmk.x, lmk.y, lmk.z]
                                     for lmk in landmarks.landmark], dtype=np.float32)
                assert landmarks.shape[0] == len(list(
                    self._landmark_names)), 'Unexpected number of landmarks: {}'.format(landmarks.shape[0])
                landmarks = self._normalize_landmarks(landmarks)
                # Get embedding.
                embedding = self._get_embedding(landmarks)
                return embedding
            else:
                print('ERROR: Can NOT embedding the data you provided !')

    def _get_center(self, landmarks):
        # MIDDLE_FINGER_MCP:9
        return landmarks[9]

    def _get_size(self, landmarks):
        landmarks = landmarks[:, :2]
        max_dist = np.max(np.linalg.norm(
            landmarks - self._get_center(landmarks), axis=1))
        return max_dist * 2

    def _normalize_landmarks(self, landmarks):
        landmarks = np.copy(landmarks)
        # Normalize
        center = self._get_center(landmarks)
        size = self._get_size(landmarks)
        landmarks = (landmarks - center) / size
        landmarks *= 100  # optional, but makes debugging easier.
        return landmarks

    def _get_embedding(self, landmarks):
        # we can add and delete any embedding features
        test = np.array([
            np.dot((landmarks[2]-landmarks[0]),
                   (landmarks[3]-landmarks[4])),   # thumb bent
            np.dot((landmarks[5]-landmarks[0]), (landmarks[6]-landmarks[7])),
            np.dot((landmarks[9]-landmarks[0]), (landmarks[10]-landmarks[11])),
            np.dot((landmarks[13]-landmarks[0]),
                   (landmarks[14]-landmarks[15])),
            np.dot((landmarks[17]-landmarks[0]), (landmarks[18]-landmarks[19]))
        ]).flatten()
        return test


def init_knn(file='data/dataset_embedded.npz'):
    npzfile = np.load(file)
    X = npzfile['X']
    y = npzfile['y']

    neigh = KNeighborsClassifier(n_neighbors=5)
    neigh.fit(X, y)
    return neigh


def hand_pose_recognition(stream_img):
    # For static images:
    stream_img = cv2.cvtColor(stream_img, cv2.COLOR_BGR2RGB)
    embedder = Embedder()
    neighbors = init_knn()
    with mp_hands.Hands(
            static_image_mode=True,
            max_num_hands=2,
            min_detection_confidence=0.5) as hands:

        results = hands.process(stream_img)
        if not results.multi_hand_landmarks:
            return ['no_gesture'], stream_img
        else:
            annotated_image = stream_img.copy()

            multi_landmarks = results.multi_hand_landmarks
            # KNN inference
            embeddings = embedder(multi_landmarks)
            hand_class = neighbors.predict(embeddings)

            # hand_class_prob = neighbors.predict_proba(embeddings)
            # print(hand_class_prob)

            for landmarks in results.multi_hand_landmarks:
                mp_drawing.draw_landmarks(annotated_image,
                                          landmarks,
                                          mp_hands.HAND_CONNECTIONS,
                                          mp_drawing_styles.get_default_hand_landmarks_style(),
                                          mp_drawing_styles.get_default_hand_connections_style())
            return hand_class, annotated_image


# 手势有10种,数字有8种,1-10之间7和9没有,还有两种是OK手势,和蜘蛛侠spide手势
# `eight_sign`, `five_sign`, `four_sign`, `ok`, `one_sign`, `six_sign`, `spider`, `ten_sign`, `three_sign`, `two_sign`

def image():
    FILE_PATH = 'data/ok.png'
    img = cv2.imread(FILE_PATH)
    handclass, img_final = hand_pose_recognition(img)
    cv2.putText(img_final, text=handclass[0], org=(200, 50), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                fontScale=2, color=(255, 255, 255), thickness=2, lineType=cv2.LINE_AA)
    cv2.imshow('test', cv2.cvtColor(img_final, cv2.COLOR_RGB2BGR))
    cv2.waitKey(0)


def video():
    cap = cv2.VideoCapture('data/ok.mp4')
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        handclass, img_final = hand_pose_recognition(frame)
        cv2.putText(img_final, text=handclass[0], org=(50, 50), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                    fontScale=2, color=(255, 0, 0), thickness=2, lineType=cv2.LINE_AA)
        cv2.imshow('test', cv2.cvtColor(img_final, cv2.COLOR_RGB2BGR))
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break


if __name__ == '__main__':
    video()

参考

  1. https://developers.google.cn/mediapipe/solutions/
  2. https://github.com/googlesamples/mediapipe
  3. https://github.com/Furkan-Gulsen/Sport-With-AI
  4. https://github.com/Chuanfang-Neptune/DLAV-G9
相关推荐
----云烟----1 小时前
QT中QString类的各种使用
开发语言·qt
lsx2024061 小时前
SQL SELECT 语句:基础与进阶应用
开发语言
小二·1 小时前
java基础面试题笔记(基础篇)
java·笔记·python
开心工作室_kaic2 小时前
ssm161基于web的资源共享平台的共享与开发+jsp(论文+源码)_kaic
java·开发语言·前端
向宇it2 小时前
【unity小技巧】unity 什么是反射?反射的作用?反射的使用场景?反射的缺点?常用的反射操作?反射常见示例
开发语言·游戏·unity·c#·游戏引擎
武子康2 小时前
Java-06 深入浅出 MyBatis - 一对一模型 SqlMapConfig 与 Mapper 详细讲解测试
java·开发语言·数据仓库·sql·mybatis·springboot·springcloud
转世成为计算机大神2 小时前
易考八股文之Java中的设计模式?
java·开发语言·设计模式
宅小海3 小时前
scala String
大数据·开发语言·scala
小喵要摸鱼3 小时前
Python 神经网络项目常用语法
python
qq_327342733 小时前
Java实现离线身份证号码OCR识别
java·开发语言