绘制人体3D关键点

一背景

最近学习了3D人体骨骼关键点检测算法。需要修改可视化3D,在此记录可视化3D骨骼点绘画思路以及代码实现。

二可视化画需求

希望在一张图显示,标签的3D结果,模型预测的3D结果,预测和标签一起的结果,以及对应的图像,并保存视频。

三代码实现

1 读取标签数据

python 复制代码
import os, sys, copy, cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import imageio
import io
import matplotlib.animation as animation
import matplotlib as mpl

def string_to_float(data):
    
    return list(map(lambda x:float(x), data))


def read_label(label_txt_file):
    
    if os.path.exists(label_txt_file) == False:
        print('find not file : ', label_txt_file)
        sys.exit(1)
    
    label_dict = {}
    class_id = [0, 1, 2, 3, 4, 5]
    
    with open(label_txt_file, 'r') as f:
        lines = f.readlines()
        
        #line = lines[0:1] + lines[3:9]
        line = lines
        line = list(map(lambda x:x.strip(), line))
        line = list(map(lambda x:x.split(' '), line))
        line = list(map(lambda x:string_to_float(x), line))
        line = np.array(line)
        
        point2d = line[1:, 2:4]
        point2d = np.array(point2d)
        
        line  = np.array(line[..., 4:])

        find_0_id = line[..., -1] == 0.0

        not_find_0_id = ~find_0_id
        line_temp = line[not_find_0_id]
        temp      = line_temp[..., 3] / 1000

        line_temp[..., 2]   = temp
        line[not_find_0_id] = line_temp
        
        data = line[..., 0:3] - line[0, 0:3]
        
    return data, point2d

标签中有2D坐标和3D坐标。

2 获取模型预测数据

python 复制代码
def pred_label(pred_txt_file):
    
    if os.path.exists(pred_txt_file) == False:
        print('find not file : ', pred_txt_file)
        sys.exit(1)
    
    all_point = []
    with open(pred_txt_file, 'r') as f:
        all_lines = f.readlines()
        line = all_lines[0:1] + all_lines[1:17]
        
        line = list(map(lambda x:x.strip(), line))
        line = list(map(lambda x:x.split(' '), line))
        line = list(map(lambda x:string_to_float(x), line))

        all_point.append(line)
    
    all_point = np.array(all_point)
    all_points = []
    all_p  = all_point[0, 1:, -3:]
    base_p = all_point[0, 0, -3:]
    
    all_points.append(base_p)
    for i in range(8):
        all_points.append(all_p[i*2 + 1].tolist())
        all_points.append(all_p[i*2].tolist())
    
    all_points = np.array(all_points)
    all_points -= base_p
    
    return all_points

注意,注意,注意 读取标签和模型预测时候,我都减去了根节点的坐标的。

3 绘制3D骨骼图

python 复制代码
#画骨骼点代码
def draw3Dpose(label_pose_3d, pred_pose_3d, ax1, ax2, ax3, label_total_ids, pred_total_ids, lcolor="r", rcolor="g", add_labels=False):  # blue, orange
"""
label_pose_3d : 标签3D坐标
pred_pose_3d : 模型预测的3D坐标
ax1, ax2, ax3 子图
label_total_ids : 标签关键点个点连接关系
pred_total_ids : 模型预测的关键点连接关系
""
    colors_keys = [
             '#FF0000',  # 红色
             '#00FF00',  # 绿色
             '#0000FF',  # 蓝色
             '#FFFF00',  # 黄色
             '#FF00FF',  # 洋红
             '#00FFFF',  # 青色
             '#FFA500',  # 橙色
             '#800080',  # 紫色
             '#008000',  # 深绿
             '#000080',  # 深蓝
             '#808000',  # 橄榄绿
             '#800000',  # 栗色
             '#008080',  # 青色
             '#808080',  # 灰色
             '#A52A2A',  # 棕色
             '#D2691E',  # 巧克力色
             '#00FFFF'
         ]
    
    
    for k in range(len(label_total_ids)):
        l_ids = label_total_ids[k]
        p_ids = pred_total_ids[k]
        lx, ly, lz = [np.array([label_pose_3d[l_ids[0], j], label_pose_3d[l_ids[1], j]]) for j in range(3)]
        px, py, pz = [np.array([pred_pose_3d[p_ids[0], j], pred_pose_3d[p_ids[1], j]]) for j in range(3)]
        if l_ids[2] == 3:
            color = 'b'
            ax1.plot(lx, ly, lz, lw=2, c=color)
            ax2.plot(lx, ly, lz, lw=2, c=color)
            ax3.plot(lx, ly, lz, lw=2, c=color)
        
        elif p_ids[2] == 3:
            color = 'b'
            ax1.plot(px, py, pz, lw=2, c=color)
            ax2.plot(px, py, pz, lw=2, c=color)
            ax3.plot(px, py, pz, lw=2, c=color)            
            
        else:
            ax1.plot(lx, ly, lz, lw=2, c=lcolor if l_ids[2] else rcolor)
            ax2.plot(px, py, pz, lw=2, c=lcolor if p_ids[2] else rcolor)            
            ax3.plot(lx, ly, lz, lw=2, c=lcolor if l_ids[2] else rcolor)
            ax3.plot(px, py, pz, lw=2, c=lcolor if p_ids[2] else rcolor)                   
        
        key_color = colors_keys[k]
        ax1.scatter(lx, ly, lz, color=key_color, marker='o', s=5)
        ax2.scatter(px, py, pz, color=key_color, marker='o', s=5)        
        ax3.scatter(lx, ly, lz, color=key_color, marker='o', s=5)
        ax3.scatter(px, py, pz, color=key_color, marker='o', s=5)  
        
    
    ax1.set_xlim3d([-100, 100])
    ax1.set_zlim3d([70, 200])
    ax1.set_ylim3d([-100, 100])    
    
    ax1.set_xlabel("x")
    ax1.set_ylabel("y")
    ax1.set_zlabel("z")
    
    
    ax2.set_xlim3d([-100, 100])
    ax2.set_zlim3d([70, 200])
    ax2.set_ylim3d([-100, 100])    
    
    ax2.set_xlabel("x")
    ax2.set_ylabel("y")
    ax2.set_zlabel("z")    
    
    
    ax3.set_xlim3d([-100, 100])
    ax3.set_zlim3d([70, 200])
    ax3.set_ylim3d([-100, 100])    
    
    ax3.set_xlabel("x")
    ax3.set_ylabel("y")
    ax3.set_zlabel("z")

#把fig转换成图片,用于保存视频.
def get_img_from_fig(fig, dpi=500):
    buf = io.BytesIO()
    #fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight', pad_inches=0)
    fig.savefig(buf, format='png', dpi=dpi, pad_inches=0.2)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGBA)
    
    return img

def draw_label_test(label_file, label_ids, pred_file, pred_ids, img_file):
    
    names = os.listdir(label_file)
    names = sorted(names)
    
    plt.rcParams['axes.unicode_minus'] = False
    
    #解决中文乱码问题
    font_path = '/home/xx/Downloads/chinese.simhei.ttf'
    mpl.font_manager.fontManager.addfont(font_path)  
    mpl.rc('font', family='SimHei')
    
    fig = plt.figure(figsize=(12, 8))
    ax1 = fig.add_subplot(141, projection='3d')
    ax2 = fig.add_subplot(142, projection='3d')
    ax3 = fig.add_subplot(143, projection='3d')
    ax4 = fig.add_subplot(144)
    
    ax1.view_init(elev=29, azim=-60)
    ax2.view_init(elev=16, azim=-75)
    ax3.view_init(elev=18, azim=-73)
         
    
    output_video = './3d_pose_animation.mp4'
    fps=10
    videwrite = imageio.get_writer(uri=output_video, fps=fps)
    
    plt.ion() 
    for name in names:
        na = name[:-3] + 'txt'
        ax1.cla()
        ax2.cla()
        ax3.cla()
        ax4.cla()
        
        ax1.title.set_text("标签结果结果")
        ax2.title.set_text("模型算法结果")
        ax3.title.set_text("标签和算法结果")
        ax4.title.set_text("原始图片")
        
        label_path = os.path.join(label_file, name)
        pred_path = os.path.join(pred_file, na)
        if os.path.exists(label_path) == False:
            continue
        
        l_data_3d, _ = read_label(label_path)
        p_data_3d = pred_label(pred_path)
        
        l_data_3d *= 100
        l_new_data_3d = l_data_3d[..., [0, 2, 1]]
        l_new_data_3d[..., 2] = 200 - l_new_data_3d[..., 2]
        
        p_data_3d *= 100
        p_new_data_3d = p_data_3d[..., [0, 2, 1]]
        p_new_data_3d[..., 2] = 200 - p_new_data_3d[..., 2]        
        
        img_name = name[:-3] + 'png'
        img_path = os.path.join(img_file, img_name)
        img = np.array(Image.open(img_path)) 
        draw3Dpose(l_new_data_3d, p_new_data_3d, ax1, ax2, ax3, label_ids, pred_ids)
        ax4.imshow(img)
        
        #plt.pause(0.01)
        frame_vis = get_img_from_fig(fig)
        videwrite.append_data(frame_vis)
                
    
    videwrite.close()
    plt.tight_layout()
 
    plt.ioff()
    print("save out video")
    plt.show()

if __name__ == "__main__":
	label_path = '/home/xx/Desktop/simcc_3d/temp/select_label_txt'
    pred_path  = '/home/xx/Desktop/simcc_3d/temp/out_txt'
    img_file   = '/home/xx/Desktop/simcc_3d/temp/val_img'
	
	label_ids = [[0, 1, 1], [1, 2, 1], [1, 3, 1], [1, 4, 1], [3, 5, 1], 
                 [5, 7, 1], [4, 6, 1], [6, 8, 1],
                 ]
    
    
    pred_ids = [[0, 6, 0], [6, 8, 0], [8, 10, 0], [0, 5, 0], [5, 7, 0], 
                [7, 9, 0], [0, 0, 0], [0, 0, 0],
                 ]
    
    draw_label_test(label_path, label_ids, pred_path, pred_ids, img_file)

四总结

以上代码都是只是演示,只适用于我自己的场景,其他场景需要修改标签数据,关键点连接关系,该代码仅供参考,不可照搬。

相关推荐
云空13 小时前
《PyQt6-3D应用开发技术文档》
3d·pyqt
鹧鸪云光伏19 小时前
光伏无人机3D建模:毫秒级精度设计
3d·无人机
杀生丸学AI21 小时前
【三维生成】FlashDreamer:基于扩散模型的单目图像到3D场景
人工智能·3d·大模型·aigc·蒸馏与迁移学习·扩散模型与生成模型
gis分享者1 天前
学习threejs,使用自定义GLSL 着色器,生成漂流的3D能量球
3d·threejs·着色器·glsl·shadermaterial·能量球
m0_743106462 天前
【论文笔记】BlockGaussian:巧妙解决大规模场景重建中的伪影问题
论文阅读·计算机视觉·3d·aigc·几何学
向宇it2 天前
【unity小技巧】在 Unity 中将 2D 精灵添加到 3D 游戏中,并实现阴影投射效果,实现类《八分旅人》《饥荒》等等的2.5D游戏效果
游戏·3d·unity·编辑器·游戏引擎·材质
荔枝味啊~2 天前
相机位姿估计
人工智能·计算机视觉·3d
在下胡三汉3 天前
什么是 3D 文件?
3d
点云登山者4 天前
登山第二十六梯:单目3D检测一切——一只眼看世界
3d·3d检测·检测一切·单目3d检测
xhload3d4 天前
智慧航天运载体系全生命周期监测 | 图扑数字孪生
物联网·3d·智慧城市·html5·webgl·数字孪生·可视化·工业互联网·三维建模·工控·航空航天·火箭升空·智慧航空·智慧航天·火箭发射·火箭回收