绘制人体3D关键点

一背景

最近学习了3D人体骨骼关键点检测算法。需要修改可视化3D,在此记录可视化3D骨骼点绘画思路以及代码实现。

二可视化画需求

希望在一张图显示,标签的3D结果,模型预测的3D结果,预测和标签一起的结果,以及对应的图像,并保存视频。

三代码实现

1 读取标签数据

python 复制代码
import os, sys, copy, cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import imageio
import io
import matplotlib.animation as animation
import matplotlib as mpl

def string_to_float(data):
    
    return list(map(lambda x:float(x), data))


def read_label(label_txt_file):
    
    if os.path.exists(label_txt_file) == False:
        print('find not file : ', label_txt_file)
        sys.exit(1)
    
    label_dict = {}
    class_id = [0, 1, 2, 3, 4, 5]
    
    with open(label_txt_file, 'r') as f:
        lines = f.readlines()
        
        #line = lines[0:1] + lines[3:9]
        line = lines
        line = list(map(lambda x:x.strip(), line))
        line = list(map(lambda x:x.split(' '), line))
        line = list(map(lambda x:string_to_float(x), line))
        line = np.array(line)
        
        point2d = line[1:, 2:4]
        point2d = np.array(point2d)
        
        line  = np.array(line[..., 4:])

        find_0_id = line[..., -1] == 0.0

        not_find_0_id = ~find_0_id
        line_temp = line[not_find_0_id]
        temp      = line_temp[..., 3] / 1000

        line_temp[..., 2]   = temp
        line[not_find_0_id] = line_temp
        
        data = line[..., 0:3] - line[0, 0:3]
        
    return data, point2d

标签中有2D坐标和3D坐标。

2 获取模型预测数据

python 复制代码
def pred_label(pred_txt_file):
    
    if os.path.exists(pred_txt_file) == False:
        print('find not file : ', pred_txt_file)
        sys.exit(1)
    
    all_point = []
    with open(pred_txt_file, 'r') as f:
        all_lines = f.readlines()
        line = all_lines[0:1] + all_lines[1:17]
        
        line = list(map(lambda x:x.strip(), line))
        line = list(map(lambda x:x.split(' '), line))
        line = list(map(lambda x:string_to_float(x), line))

        all_point.append(line)
    
    all_point = np.array(all_point)
    all_points = []
    all_p  = all_point[0, 1:, -3:]
    base_p = all_point[0, 0, -3:]
    
    all_points.append(base_p)
    for i in range(8):
        all_points.append(all_p[i*2 + 1].tolist())
        all_points.append(all_p[i*2].tolist())
    
    all_points = np.array(all_points)
    all_points -= base_p
    
    return all_points

注意,注意,注意 读取标签和模型预测时候,我都减去了根节点的坐标的。

3 绘制3D骨骼图

python 复制代码
#画骨骼点代码
def draw3Dpose(label_pose_3d, pred_pose_3d, ax1, ax2, ax3, label_total_ids, pred_total_ids, lcolor="r", rcolor="g", add_labels=False):  # blue, orange
"""
label_pose_3d : 标签3D坐标
pred_pose_3d : 模型预测的3D坐标
ax1, ax2, ax3 子图
label_total_ids : 标签关键点个点连接关系
pred_total_ids : 模型预测的关键点连接关系
""
    colors_keys = [
             '#FF0000',  # 红色
             '#00FF00',  # 绿色
             '#0000FF',  # 蓝色
             '#FFFF00',  # 黄色
             '#FF00FF',  # 洋红
             '#00FFFF',  # 青色
             '#FFA500',  # 橙色
             '#800080',  # 紫色
             '#008000',  # 深绿
             '#000080',  # 深蓝
             '#808000',  # 橄榄绿
             '#800000',  # 栗色
             '#008080',  # 青色
             '#808080',  # 灰色
             '#A52A2A',  # 棕色
             '#D2691E',  # 巧克力色
             '#00FFFF'
         ]
    
    
    for k in range(len(label_total_ids)):
        l_ids = label_total_ids[k]
        p_ids = pred_total_ids[k]
        lx, ly, lz = [np.array([label_pose_3d[l_ids[0], j], label_pose_3d[l_ids[1], j]]) for j in range(3)]
        px, py, pz = [np.array([pred_pose_3d[p_ids[0], j], pred_pose_3d[p_ids[1], j]]) for j in range(3)]
        if l_ids[2] == 3:
            color = 'b'
            ax1.plot(lx, ly, lz, lw=2, c=color)
            ax2.plot(lx, ly, lz, lw=2, c=color)
            ax3.plot(lx, ly, lz, lw=2, c=color)
        
        elif p_ids[2] == 3:
            color = 'b'
            ax1.plot(px, py, pz, lw=2, c=color)
            ax2.plot(px, py, pz, lw=2, c=color)
            ax3.plot(px, py, pz, lw=2, c=color)            
            
        else:
            ax1.plot(lx, ly, lz, lw=2, c=lcolor if l_ids[2] else rcolor)
            ax2.plot(px, py, pz, lw=2, c=lcolor if p_ids[2] else rcolor)            
            ax3.plot(lx, ly, lz, lw=2, c=lcolor if l_ids[2] else rcolor)
            ax3.plot(px, py, pz, lw=2, c=lcolor if p_ids[2] else rcolor)                   
        
        key_color = colors_keys[k]
        ax1.scatter(lx, ly, lz, color=key_color, marker='o', s=5)
        ax2.scatter(px, py, pz, color=key_color, marker='o', s=5)        
        ax3.scatter(lx, ly, lz, color=key_color, marker='o', s=5)
        ax3.scatter(px, py, pz, color=key_color, marker='o', s=5)  
        
    
    ax1.set_xlim3d([-100, 100])
    ax1.set_zlim3d([70, 200])
    ax1.set_ylim3d([-100, 100])    
    
    ax1.set_xlabel("x")
    ax1.set_ylabel("y")
    ax1.set_zlabel("z")
    
    
    ax2.set_xlim3d([-100, 100])
    ax2.set_zlim3d([70, 200])
    ax2.set_ylim3d([-100, 100])    
    
    ax2.set_xlabel("x")
    ax2.set_ylabel("y")
    ax2.set_zlabel("z")    
    
    
    ax3.set_xlim3d([-100, 100])
    ax3.set_zlim3d([70, 200])
    ax3.set_ylim3d([-100, 100])    
    
    ax3.set_xlabel("x")
    ax3.set_ylabel("y")
    ax3.set_zlabel("z")

#把fig转换成图片,用于保存视频.
def get_img_from_fig(fig, dpi=500):
    buf = io.BytesIO()
    #fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight', pad_inches=0)
    fig.savefig(buf, format='png', dpi=dpi, pad_inches=0.2)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGBA)
    
    return img

def draw_label_test(label_file, label_ids, pred_file, pred_ids, img_file):
    
    names = os.listdir(label_file)
    names = sorted(names)
    
    plt.rcParams['axes.unicode_minus'] = False
    
    #解决中文乱码问题
    font_path = '/home/xx/Downloads/chinese.simhei.ttf'
    mpl.font_manager.fontManager.addfont(font_path)  
    mpl.rc('font', family='SimHei')
    
    fig = plt.figure(figsize=(12, 8))
    ax1 = fig.add_subplot(141, projection='3d')
    ax2 = fig.add_subplot(142, projection='3d')
    ax3 = fig.add_subplot(143, projection='3d')
    ax4 = fig.add_subplot(144)
    
    ax1.view_init(elev=29, azim=-60)
    ax2.view_init(elev=16, azim=-75)
    ax3.view_init(elev=18, azim=-73)
         
    
    output_video = './3d_pose_animation.mp4'
    fps=10
    videwrite = imageio.get_writer(uri=output_video, fps=fps)
    
    plt.ion() 
    for name in names:
        na = name[:-3] + 'txt'
        ax1.cla()
        ax2.cla()
        ax3.cla()
        ax4.cla()
        
        ax1.title.set_text("标签结果结果")
        ax2.title.set_text("模型算法结果")
        ax3.title.set_text("标签和算法结果")
        ax4.title.set_text("原始图片")
        
        label_path = os.path.join(label_file, name)
        pred_path = os.path.join(pred_file, na)
        if os.path.exists(label_path) == False:
            continue
        
        l_data_3d, _ = read_label(label_path)
        p_data_3d = pred_label(pred_path)
        
        l_data_3d *= 100
        l_new_data_3d = l_data_3d[..., [0, 2, 1]]
        l_new_data_3d[..., 2] = 200 - l_new_data_3d[..., 2]
        
        p_data_3d *= 100
        p_new_data_3d = p_data_3d[..., [0, 2, 1]]
        p_new_data_3d[..., 2] = 200 - p_new_data_3d[..., 2]        
        
        img_name = name[:-3] + 'png'
        img_path = os.path.join(img_file, img_name)
        img = np.array(Image.open(img_path)) 
        draw3Dpose(l_new_data_3d, p_new_data_3d, ax1, ax2, ax3, label_ids, pred_ids)
        ax4.imshow(img)
        
        #plt.pause(0.01)
        frame_vis = get_img_from_fig(fig)
        videwrite.append_data(frame_vis)
                
    
    videwrite.close()
    plt.tight_layout()
 
    plt.ioff()
    print("save out video")
    plt.show()

if __name__ == "__main__":
	label_path = '/home/xx/Desktop/simcc_3d/temp/select_label_txt'
    pred_path  = '/home/xx/Desktop/simcc_3d/temp/out_txt'
    img_file   = '/home/xx/Desktop/simcc_3d/temp/val_img'
	
	label_ids = [[0, 1, 1], [1, 2, 1], [1, 3, 1], [1, 4, 1], [3, 5, 1], 
                 [5, 7, 1], [4, 6, 1], [6, 8, 1],
                 ]
    
    
    pred_ids = [[0, 6, 0], [6, 8, 0], [8, 10, 0], [0, 5, 0], [5, 7, 0], 
                [7, 9, 0], [0, 0, 0], [0, 0, 0],
                 ]
    
    draw_label_test(label_path, label_ids, pred_path, pred_ids, img_file)

四总结

以上代码都是只是演示,只适用于我自己的场景,其他场景需要修改标签数据,关键点连接关系,该代码仅供参考,不可照搬。

相关推荐
地球资源数据云4 小时前
从 DEM 到 3D 渲染:R 语言 rayshader 地形可视化全指南
3d·数据分析·r语言
换日线°5 小时前
前端3D炫酷展开效果
前端·3d
Funny_AI_LAB6 小时前
RAD基准重新定义多视角异常检测,传统2D方法为何战胜前沿3D与VLM?
人工智能·目标检测·3d·ai
新启航光学频率梳8 小时前
储能电池极柱深孔孔深光学3D轮廓测量-激光频率梳3D轮廓技术
科技·3d·制造
军军君0121 小时前
Three.js基础功能学习十三:太阳系实例上
前端·javascript·vue.js·学习·3d·前端框架·three
CG_MAGIC1 天前
Substance Painter 高效出图:贴图导出与后期优化技巧
3d·贴图·maya·substance painter·渲云渲染·3d软件
图生生1 天前
基于AI的电商产品2D转3D,降低3D建模开发成本
3d·ai
mocoding2 天前
Flutter 3D 翻转动画flip_card三方库在鸿蒙版天气预报卡片中的实战教程
flutter·3d·harmonyos
2501_948120152 天前
3D虚拟衣服动画系统关键技术的研究与实现
3d
应用市场2 天前
基于上下文感知分层深度修复的3D照片生成技术详解
3d