绘制人体3D关键点

一背景

最近学习了3D人体骨骼关键点检测算法。需要修改可视化3D,在此记录可视化3D骨骼点绘画思路以及代码实现。

二可视化画需求

希望在一张图显示,标签的3D结果,模型预测的3D结果,预测和标签一起的结果,以及对应的图像,并保存视频。

三代码实现

1 读取标签数据

python 复制代码
import os, sys, copy, cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import imageio
import io
import matplotlib.animation as animation
import matplotlib as mpl

def string_to_float(data):
    
    return list(map(lambda x:float(x), data))


def read_label(label_txt_file):
    
    if os.path.exists(label_txt_file) == False:
        print('find not file : ', label_txt_file)
        sys.exit(1)
    
    label_dict = {}
    class_id = [0, 1, 2, 3, 4, 5]
    
    with open(label_txt_file, 'r') as f:
        lines = f.readlines()
        
        #line = lines[0:1] + lines[3:9]
        line = lines
        line = list(map(lambda x:x.strip(), line))
        line = list(map(lambda x:x.split(' '), line))
        line = list(map(lambda x:string_to_float(x), line))
        line = np.array(line)
        
        point2d = line[1:, 2:4]
        point2d = np.array(point2d)
        
        line  = np.array(line[..., 4:])

        find_0_id = line[..., -1] == 0.0

        not_find_0_id = ~find_0_id
        line_temp = line[not_find_0_id]
        temp      = line_temp[..., 3] / 1000

        line_temp[..., 2]   = temp
        line[not_find_0_id] = line_temp
        
        data = line[..., 0:3] - line[0, 0:3]
        
    return data, point2d

标签中有2D坐标和3D坐标。

2 获取模型预测数据

python 复制代码
def pred_label(pred_txt_file):
    
    if os.path.exists(pred_txt_file) == False:
        print('find not file : ', pred_txt_file)
        sys.exit(1)
    
    all_point = []
    with open(pred_txt_file, 'r') as f:
        all_lines = f.readlines()
        line = all_lines[0:1] + all_lines[1:17]
        
        line = list(map(lambda x:x.strip(), line))
        line = list(map(lambda x:x.split(' '), line))
        line = list(map(lambda x:string_to_float(x), line))

        all_point.append(line)
    
    all_point = np.array(all_point)
    all_points = []
    all_p  = all_point[0, 1:, -3:]
    base_p = all_point[0, 0, -3:]
    
    all_points.append(base_p)
    for i in range(8):
        all_points.append(all_p[i*2 + 1].tolist())
        all_points.append(all_p[i*2].tolist())
    
    all_points = np.array(all_points)
    all_points -= base_p
    
    return all_points

注意,注意,注意 读取标签和模型预测时候,我都减去了根节点的坐标的。

3 绘制3D骨骼图

python 复制代码
#画骨骼点代码
def draw3Dpose(label_pose_3d, pred_pose_3d, ax1, ax2, ax3, label_total_ids, pred_total_ids, lcolor="r", rcolor="g", add_labels=False):  # blue, orange
"""
label_pose_3d : 标签3D坐标
pred_pose_3d : 模型预测的3D坐标
ax1, ax2, ax3 子图
label_total_ids : 标签关键点个点连接关系
pred_total_ids : 模型预测的关键点连接关系
""
    colors_keys = [
             '#FF0000',  # 红色
             '#00FF00',  # 绿色
             '#0000FF',  # 蓝色
             '#FFFF00',  # 黄色
             '#FF00FF',  # 洋红
             '#00FFFF',  # 青色
             '#FFA500',  # 橙色
             '#800080',  # 紫色
             '#008000',  # 深绿
             '#000080',  # 深蓝
             '#808000',  # 橄榄绿
             '#800000',  # 栗色
             '#008080',  # 青色
             '#808080',  # 灰色
             '#A52A2A',  # 棕色
             '#D2691E',  # 巧克力色
             '#00FFFF'
         ]
    
    
    for k in range(len(label_total_ids)):
        l_ids = label_total_ids[k]
        p_ids = pred_total_ids[k]
        lx, ly, lz = [np.array([label_pose_3d[l_ids[0], j], label_pose_3d[l_ids[1], j]]) for j in range(3)]
        px, py, pz = [np.array([pred_pose_3d[p_ids[0], j], pred_pose_3d[p_ids[1], j]]) for j in range(3)]
        if l_ids[2] == 3:
            color = 'b'
            ax1.plot(lx, ly, lz, lw=2, c=color)
            ax2.plot(lx, ly, lz, lw=2, c=color)
            ax3.plot(lx, ly, lz, lw=2, c=color)
        
        elif p_ids[2] == 3:
            color = 'b'
            ax1.plot(px, py, pz, lw=2, c=color)
            ax2.plot(px, py, pz, lw=2, c=color)
            ax3.plot(px, py, pz, lw=2, c=color)            
            
        else:
            ax1.plot(lx, ly, lz, lw=2, c=lcolor if l_ids[2] else rcolor)
            ax2.plot(px, py, pz, lw=2, c=lcolor if p_ids[2] else rcolor)            
            ax3.plot(lx, ly, lz, lw=2, c=lcolor if l_ids[2] else rcolor)
            ax3.plot(px, py, pz, lw=2, c=lcolor if p_ids[2] else rcolor)                   
        
        key_color = colors_keys[k]
        ax1.scatter(lx, ly, lz, color=key_color, marker='o', s=5)
        ax2.scatter(px, py, pz, color=key_color, marker='o', s=5)        
        ax3.scatter(lx, ly, lz, color=key_color, marker='o', s=5)
        ax3.scatter(px, py, pz, color=key_color, marker='o', s=5)  
        
    
    ax1.set_xlim3d([-100, 100])
    ax1.set_zlim3d([70, 200])
    ax1.set_ylim3d([-100, 100])    
    
    ax1.set_xlabel("x")
    ax1.set_ylabel("y")
    ax1.set_zlabel("z")
    
    
    ax2.set_xlim3d([-100, 100])
    ax2.set_zlim3d([70, 200])
    ax2.set_ylim3d([-100, 100])    
    
    ax2.set_xlabel("x")
    ax2.set_ylabel("y")
    ax2.set_zlabel("z")    
    
    
    ax3.set_xlim3d([-100, 100])
    ax3.set_zlim3d([70, 200])
    ax3.set_ylim3d([-100, 100])    
    
    ax3.set_xlabel("x")
    ax3.set_ylabel("y")
    ax3.set_zlabel("z")

#把fig转换成图片,用于保存视频.
def get_img_from_fig(fig, dpi=500):
    buf = io.BytesIO()
    #fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight', pad_inches=0)
    fig.savefig(buf, format='png', dpi=dpi, pad_inches=0.2)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGBA)
    
    return img

def draw_label_test(label_file, label_ids, pred_file, pred_ids, img_file):
    
    names = os.listdir(label_file)
    names = sorted(names)
    
    plt.rcParams['axes.unicode_minus'] = False
    
    #解决中文乱码问题
    font_path = '/home/xx/Downloads/chinese.simhei.ttf'
    mpl.font_manager.fontManager.addfont(font_path)  
    mpl.rc('font', family='SimHei')
    
    fig = plt.figure(figsize=(12, 8))
    ax1 = fig.add_subplot(141, projection='3d')
    ax2 = fig.add_subplot(142, projection='3d')
    ax3 = fig.add_subplot(143, projection='3d')
    ax4 = fig.add_subplot(144)
    
    ax1.view_init(elev=29, azim=-60)
    ax2.view_init(elev=16, azim=-75)
    ax3.view_init(elev=18, azim=-73)
         
    
    output_video = './3d_pose_animation.mp4'
    fps=10
    videwrite = imageio.get_writer(uri=output_video, fps=fps)
    
    plt.ion() 
    for name in names:
        na = name[:-3] + 'txt'
        ax1.cla()
        ax2.cla()
        ax3.cla()
        ax4.cla()
        
        ax1.title.set_text("标签结果结果")
        ax2.title.set_text("模型算法结果")
        ax3.title.set_text("标签和算法结果")
        ax4.title.set_text("原始图片")
        
        label_path = os.path.join(label_file, name)
        pred_path = os.path.join(pred_file, na)
        if os.path.exists(label_path) == False:
            continue
        
        l_data_3d, _ = read_label(label_path)
        p_data_3d = pred_label(pred_path)
        
        l_data_3d *= 100
        l_new_data_3d = l_data_3d[..., [0, 2, 1]]
        l_new_data_3d[..., 2] = 200 - l_new_data_3d[..., 2]
        
        p_data_3d *= 100
        p_new_data_3d = p_data_3d[..., [0, 2, 1]]
        p_new_data_3d[..., 2] = 200 - p_new_data_3d[..., 2]        
        
        img_name = name[:-3] + 'png'
        img_path = os.path.join(img_file, img_name)
        img = np.array(Image.open(img_path)) 
        draw3Dpose(l_new_data_3d, p_new_data_3d, ax1, ax2, ax3, label_ids, pred_ids)
        ax4.imshow(img)
        
        #plt.pause(0.01)
        frame_vis = get_img_from_fig(fig)
        videwrite.append_data(frame_vis)
                
    
    videwrite.close()
    plt.tight_layout()
 
    plt.ioff()
    print("save out video")
    plt.show()

if __name__ == "__main__":
	label_path = '/home/xx/Desktop/simcc_3d/temp/select_label_txt'
    pred_path  = '/home/xx/Desktop/simcc_3d/temp/out_txt'
    img_file   = '/home/xx/Desktop/simcc_3d/temp/val_img'
	
	label_ids = [[0, 1, 1], [1, 2, 1], [1, 3, 1], [1, 4, 1], [3, 5, 1], 
                 [5, 7, 1], [4, 6, 1], [6, 8, 1],
                 ]
    
    
    pred_ids = [[0, 6, 0], [6, 8, 0], [8, 10, 0], [0, 5, 0], [5, 7, 0], 
                [7, 9, 0], [0, 0, 0], [0, 0, 0],
                 ]
    
    draw_label_test(label_path, label_ids, pred_path, pred_ids, img_file)

四总结

以上代码都是只是演示,只适用于我自己的场景,其他场景需要修改标签数据,关键点连接关系,该代码仅供参考,不可照搬。

相关推荐
Tian_Hang33 分钟前
eclipse ditto 学习笔记
运维·服务器·开发语言·javascript·3d
AI视觉网奇4 小时前
BambuStudio 编译实战 2026
3d
AI前沿资讯4 小时前
AI3D角色生产如何减少返工?用 V2Fun 前移建模与动画流程
人工智能·3d
蓝速科技4 小时前
蓝速科技视觉 3D 全息舱 AI 数字人一体机带灯与无灯款深度评测
人工智能·科技·3d
林恒smileZAZ21 小时前
Three.js 3D 地图特效与材质实现指南
javascript·3d·材质
蓝速科技1 天前
蓝速科技 3D 全息舱 AI 数字人博物馆导览效果实录
人工智能·科技·3d
chaoyuanl1 天前
沉浸式飞行影院进场安装前期筹备事项
大数据·科技·3d·xr·娱乐
探物 AI18 天前
【3D·感知】从PointNet到PointPillars:如何让自动驾驶汽车“实时“看见3D世界?
3d·自动驾驶·汽车
苏州邦恩精密18 天前
GOM三维扫描在制造中的真实价值:让“修模”从经验动作变成数据动作
人工智能·科技·机器学习·3d·自动化·制造
YHHLAI18 天前
CSS 3D 硬核解析:四个属性手写旋转立方体
前端·css·3d