读Shape-Guided代码③训练

关于训练,是先把pretrain也做成npz

bash 复制代码
                            BS: 32
                            LR: 0.0001
                     POINT_NUM: 500
                     ckpt_path: ./checkpoint_debug/2024-03-11-16-38-30-knn500
                       classes: ['*']
                 datasets_path: None
              dict_n_component: 3
                         epoch: 1000
                     grid_path: data_new/
                     group_mul: 5
                    image_size: 224
                      k_number: 1
                   method_name: None
                    output_dir: None
                    rgb_method: Dict
                  sampled_size: 20
                           viz: False

这个训练batch还必须满足batch_size,如果不满足直接跳过

于是每epoch的每个batch,有points和samples,二者都是32,500,3

输入point作为gt,sample作为inputs_points

python 复制代码
SDF_Model(
  (encoder): encoder_BN(
    (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
    (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
    (fc1): Linear(in_features=1024, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=128, bias=True)
    (relu): ReLU()
    (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  把gt经过encoder得到32,128的feature
  再把它tile成32,500,128 point_feature
  然后point_feature和inputs_points被输入给NIL的get_gradient
  (NIF): local_NIF(
    (fc1): Linear(in_features=128, out_features=512, bias=True)
    (fc2): Linear(in_features=3, out_features=512, bias=True)
    (fc3): Linear(in_features=1024, out_features=512, bias=True)
    (fc40): Linear(in_features=512, out_features=512, bias=True)
    (fc41): Linear(in_features=512, out_features=512, bias=True)
    (fc42): Linear(in_features=512, out_features=512, bias=True)
    (fc43): Linear(in_features=512, out_features=512, bias=True)
    (fc44): Linear(in_features=512, out_features=512, bias=True)
    (fc45): Linear(in_features=512, out_features=512, bias=True)
    (fc46): Linear(in_features=512, out_features=512, bias=True)
    (fc5): Linear(in_features=512, out_features=1, bias=True)
  )
)

get_gradient方法在local_NIF类中用于计算输入点相对于输出SDF(有符号距离字段)值的梯度,并使用这个梯度来更新点的位置。下面是数据流向和处理步骤的详细分析:

输入梯度设置:input_points.requires_grad_(True)确保input_points可以计算梯度。这是为了使得在SDF输出值forward方法中关于input_points的梯度能够被计算。

前向传播:使用forward方法计算SDF值(32,500,1)。这个方法先通过两个独立的全连接层处理points_feature和input_points,然后将这两个特征合并,并通过一系列额外的全连接层进一步处理。

计算梯度(一个32,500,3组成的tuple):torch.autograd.grad用于计算SDF输出相对于input_points的梯度。这里使用torch.ones_like(sdf)作为梯度的权重,因为我们对SDF本身的梯度感兴趣,而不是对它的某个函数的梯度。

梯度归一化:计算梯度的模长normal_p_length,并用它来归一化梯度,得到单位向量grad_norm。这一步确保了梯度向量在每个方向上具有相同的长度,这在计算新的点位置时是有用的。pytorch_safe_norm函数用于计算向量的模长,避免除零错误,通过在求和后添加一个很小的数(epsilon)来实现。

计算更新的点位置:使用梯度信息更新input_points的位置。g_point = input_points - sdf * grad_norm根据梯度方向和大小调整每个点的位置,这里sdf * grad_norm计算每个点沿梯度方向的位移,然后从原始位置减去这个位移得到新的位置(32,500,3这也就是网络的最终输出了)

整体来看,get_gradient方法通过计算SDF相对于输入点的梯度,并使用这个梯度来更新点的位置,这在许多应用中是有用的,例如形状优化、网格重构等。

计算g_point和points的l2范数作为损失

相关推荐
棒棒的皮皮6 分钟前
【Python】Open3d用于3D测高项目
python·3d·open3d
CV实验室1 天前
CV论文速递:覆盖视频生成与理解、3D视觉与运动迁移、多模态与跨模态智能、专用场景视觉技术等方向 (11.17-11.21)
人工智能·计算机视觉·3d·论文·音视频·视频生成
Highcharts.js2 天前
使用 Highcharts 3D图表入门
3d·highcharts·使用文档·3d图表·交互图表·三维图表·3d 可视化
O***p6042 天前
C++在游戏中的Ogre3D
游戏·3d·ogre
sdjnled2292 天前
山东裸眼3D立体LED显示屏专业服务商
人工智能·3d
徒慕风流2 天前
GeoSight:基于 Open3D 与 PySide6 的参数化 3D 模型处理与实时点云监控工具
计算机视觉·3d·信号处理
三条猫3 天前
将3D CAD 模型结构树转换为图结构,用于训练CAD AI的思路
人工智能·3d·ai·cad·模型训练·图结构·结构树
二川bro3 天前
第59节:常见问题汇编 - 60个典型问题解答
javascript·3d·threejs
zhangfeng11334 天前
aigc 从2d 到 3d的形式转变,李飞飞在介绍WorldLabs的Marble平台,会围绕“空间智能“的核心理念,自动驾驶就是2d形式
3d·自动驾驶·aigc
da_vinci_x5 天前
PS 3D Viewer (Beta):概念美术的降维打击,白模直接在PS里转光打影出5张大片
人工智能·游戏·3d·prompt·aigc·材质·游戏美术