绘制人体3D关键点

一背景

最近学习了3D人体骨骼关键点检测算法。需要修改可视化3D,在此记录可视化3D骨骼点绘画思路以及代码实现。

二可视化画需求

希望在一张图显示,标签的3D结果,模型预测的3D结果,预测和标签一起的结果,以及对应的图像,并保存视频。

三代码实现

1 读取标签数据

python 复制代码
import os, sys, copy, cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import imageio
import io
import matplotlib.animation as animation
import matplotlib as mpl

def string_to_float(data):
    
    return list(map(lambda x:float(x), data))


def read_label(label_txt_file):
    
    if os.path.exists(label_txt_file) == False:
        print('find not file : ', label_txt_file)
        sys.exit(1)
    
    label_dict = {}
    class_id = [0, 1, 2, 3, 4, 5]
    
    with open(label_txt_file, 'r') as f:
        lines = f.readlines()
        
        #line = lines[0:1] + lines[3:9]
        line = lines
        line = list(map(lambda x:x.strip(), line))
        line = list(map(lambda x:x.split(' '), line))
        line = list(map(lambda x:string_to_float(x), line))
        line = np.array(line)
        
        point2d = line[1:, 2:4]
        point2d = np.array(point2d)
        
        line  = np.array(line[..., 4:])

        find_0_id = line[..., -1] == 0.0

        not_find_0_id = ~find_0_id
        line_temp = line[not_find_0_id]
        temp      = line_temp[..., 3] / 1000

        line_temp[..., 2]   = temp
        line[not_find_0_id] = line_temp
        
        data = line[..., 0:3] - line[0, 0:3]
        
    return data, point2d

标签中有2D坐标和3D坐标。

2 获取模型预测数据

python 复制代码
def pred_label(pred_txt_file):
    
    if os.path.exists(pred_txt_file) == False:
        print('find not file : ', pred_txt_file)
        sys.exit(1)
    
    all_point = []
    with open(pred_txt_file, 'r') as f:
        all_lines = f.readlines()
        line = all_lines[0:1] + all_lines[1:17]
        
        line = list(map(lambda x:x.strip(), line))
        line = list(map(lambda x:x.split(' '), line))
        line = list(map(lambda x:string_to_float(x), line))

        all_point.append(line)
    
    all_point = np.array(all_point)
    all_points = []
    all_p  = all_point[0, 1:, -3:]
    base_p = all_point[0, 0, -3:]
    
    all_points.append(base_p)
    for i in range(8):
        all_points.append(all_p[i*2 + 1].tolist())
        all_points.append(all_p[i*2].tolist())
    
    all_points = np.array(all_points)
    all_points -= base_p
    
    return all_points

注意,注意,注意 读取标签和模型预测时候,我都减去了根节点的坐标的。

3 绘制3D骨骼图

python 复制代码
#画骨骼点代码
def draw3Dpose(label_pose_3d, pred_pose_3d, ax1, ax2, ax3, label_total_ids, pred_total_ids, lcolor="r", rcolor="g", add_labels=False):  # blue, orange
"""
label_pose_3d : 标签3D坐标
pred_pose_3d : 模型预测的3D坐标
ax1, ax2, ax3 子图
label_total_ids : 标签关键点个点连接关系
pred_total_ids : 模型预测的关键点连接关系
""
    colors_keys = [
             '#FF0000',  # 红色
             '#00FF00',  # 绿色
             '#0000FF',  # 蓝色
             '#FFFF00',  # 黄色
             '#FF00FF',  # 洋红
             '#00FFFF',  # 青色
             '#FFA500',  # 橙色
             '#800080',  # 紫色
             '#008000',  # 深绿
             '#000080',  # 深蓝
             '#808000',  # 橄榄绿
             '#800000',  # 栗色
             '#008080',  # 青色
             '#808080',  # 灰色
             '#A52A2A',  # 棕色
             '#D2691E',  # 巧克力色
             '#00FFFF'
         ]
    
    
    for k in range(len(label_total_ids)):
        l_ids = label_total_ids[k]
        p_ids = pred_total_ids[k]
        lx, ly, lz = [np.array([label_pose_3d[l_ids[0], j], label_pose_3d[l_ids[1], j]]) for j in range(3)]
        px, py, pz = [np.array([pred_pose_3d[p_ids[0], j], pred_pose_3d[p_ids[1], j]]) for j in range(3)]
        if l_ids[2] == 3:
            color = 'b'
            ax1.plot(lx, ly, lz, lw=2, c=color)
            ax2.plot(lx, ly, lz, lw=2, c=color)
            ax3.plot(lx, ly, lz, lw=2, c=color)
        
        elif p_ids[2] == 3:
            color = 'b'
            ax1.plot(px, py, pz, lw=2, c=color)
            ax2.plot(px, py, pz, lw=2, c=color)
            ax3.plot(px, py, pz, lw=2, c=color)            
            
        else:
            ax1.plot(lx, ly, lz, lw=2, c=lcolor if l_ids[2] else rcolor)
            ax2.plot(px, py, pz, lw=2, c=lcolor if p_ids[2] else rcolor)            
            ax3.plot(lx, ly, lz, lw=2, c=lcolor if l_ids[2] else rcolor)
            ax3.plot(px, py, pz, lw=2, c=lcolor if p_ids[2] else rcolor)                   
        
        key_color = colors_keys[k]
        ax1.scatter(lx, ly, lz, color=key_color, marker='o', s=5)
        ax2.scatter(px, py, pz, color=key_color, marker='o', s=5)        
        ax3.scatter(lx, ly, lz, color=key_color, marker='o', s=5)
        ax3.scatter(px, py, pz, color=key_color, marker='o', s=5)  
        
    
    ax1.set_xlim3d([-100, 100])
    ax1.set_zlim3d([70, 200])
    ax1.set_ylim3d([-100, 100])    
    
    ax1.set_xlabel("x")
    ax1.set_ylabel("y")
    ax1.set_zlabel("z")
    
    
    ax2.set_xlim3d([-100, 100])
    ax2.set_zlim3d([70, 200])
    ax2.set_ylim3d([-100, 100])    
    
    ax2.set_xlabel("x")
    ax2.set_ylabel("y")
    ax2.set_zlabel("z")    
    
    
    ax3.set_xlim3d([-100, 100])
    ax3.set_zlim3d([70, 200])
    ax3.set_ylim3d([-100, 100])    
    
    ax3.set_xlabel("x")
    ax3.set_ylabel("y")
    ax3.set_zlabel("z")

#把fig转换成图片,用于保存视频.
def get_img_from_fig(fig, dpi=500):
    buf = io.BytesIO()
    #fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight', pad_inches=0)
    fig.savefig(buf, format='png', dpi=dpi, pad_inches=0.2)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGBA)
    
    return img

def draw_label_test(label_file, label_ids, pred_file, pred_ids, img_file):
    
    names = os.listdir(label_file)
    names = sorted(names)
    
    plt.rcParams['axes.unicode_minus'] = False
    
    #解决中文乱码问题
    font_path = '/home/xx/Downloads/chinese.simhei.ttf'
    mpl.font_manager.fontManager.addfont(font_path)  
    mpl.rc('font', family='SimHei')
    
    fig = plt.figure(figsize=(12, 8))
    ax1 = fig.add_subplot(141, projection='3d')
    ax2 = fig.add_subplot(142, projection='3d')
    ax3 = fig.add_subplot(143, projection='3d')
    ax4 = fig.add_subplot(144)
    
    ax1.view_init(elev=29, azim=-60)
    ax2.view_init(elev=16, azim=-75)
    ax3.view_init(elev=18, azim=-73)
         
    
    output_video = './3d_pose_animation.mp4'
    fps=10
    videwrite = imageio.get_writer(uri=output_video, fps=fps)
    
    plt.ion() 
    for name in names:
        na = name[:-3] + 'txt'
        ax1.cla()
        ax2.cla()
        ax3.cla()
        ax4.cla()
        
        ax1.title.set_text("标签结果结果")
        ax2.title.set_text("模型算法结果")
        ax3.title.set_text("标签和算法结果")
        ax4.title.set_text("原始图片")
        
        label_path = os.path.join(label_file, name)
        pred_path = os.path.join(pred_file, na)
        if os.path.exists(label_path) == False:
            continue
        
        l_data_3d, _ = read_label(label_path)
        p_data_3d = pred_label(pred_path)
        
        l_data_3d *= 100
        l_new_data_3d = l_data_3d[..., [0, 2, 1]]
        l_new_data_3d[..., 2] = 200 - l_new_data_3d[..., 2]
        
        p_data_3d *= 100
        p_new_data_3d = p_data_3d[..., [0, 2, 1]]
        p_new_data_3d[..., 2] = 200 - p_new_data_3d[..., 2]        
        
        img_name = name[:-3] + 'png'
        img_path = os.path.join(img_file, img_name)
        img = np.array(Image.open(img_path)) 
        draw3Dpose(l_new_data_3d, p_new_data_3d, ax1, ax2, ax3, label_ids, pred_ids)
        ax4.imshow(img)
        
        #plt.pause(0.01)
        frame_vis = get_img_from_fig(fig)
        videwrite.append_data(frame_vis)
                
    
    videwrite.close()
    plt.tight_layout()
 
    plt.ioff()
    print("save out video")
    plt.show()

if __name__ == "__main__":
	label_path = '/home/xx/Desktop/simcc_3d/temp/select_label_txt'
    pred_path  = '/home/xx/Desktop/simcc_3d/temp/out_txt'
    img_file   = '/home/xx/Desktop/simcc_3d/temp/val_img'
	
	label_ids = [[0, 1, 1], [1, 2, 1], [1, 3, 1], [1, 4, 1], [3, 5, 1], 
                 [5, 7, 1], [4, 6, 1], [6, 8, 1],
                 ]
    
    
    pred_ids = [[0, 6, 0], [6, 8, 0], [8, 10, 0], [0, 5, 0], [5, 7, 0], 
                [7, 9, 0], [0, 0, 0], [0, 0, 0],
                 ]
    
    draw_label_test(label_path, label_ids, pred_path, pred_ids, img_file)

四总结

以上代码都是只是演示,只适用于我自己的场景,其他场景需要修改标签数据,关键点连接关系,该代码仅供参考,不可照搬。

相关推荐
工业3D_大熊3 天前
3D Web轻量引擎HOOPS赋能BIM/工程施工:实现超大模型的轻量化加载与高效浏览!
3d·3d web轻量化·3d模型可视化·3d渲染引擎·工业3d开发引擎·bim模型可视化
研梦非凡3 天前
ICCV 2025|从粗到细:用于高效3D高斯溅射的可学习离散小波变换
人工智能·深度学习·学习·3d
fanged3 天前
Cesium4--地形(OSGB到3DTiles)
3d·gis·web
桃花键神3 天前
从传统到智能:3D 建模流程的演进与 AI 趋势 —— 以 Blender 为例
人工智能·3d·blender
东风西巷4 天前
全能的3D创作平台,Blender(免费开源3D建模工具)
学习·3d·开源·blender·软件需求
查里王4 天前
AI 3D 生成工具知识库:当前产品格局与测评总结
人工智能·3d
Hello123网站4 天前
Champ-基于3D的人物图像到动画视频生成框架
3d·ai工具
Magnum Lehar4 天前
3d wpf游戏引擎的导入文件功能c++的.h实现
3d·游戏引擎·wpf
博图光电4 天前
Motioncam Color S + 蓝激光:3D 视觉革新,重塑工业与科研应用新格局
3d
新启航光学频率梳4 天前
[新启航]深孔加工尺寸精度检测方法 - 激光频率梳 3D 轮廓测量
科技·3d·制造