【3D-AICG 系列-11】Trellis 2 的 Shape VAE 训练流程梳理

系列文章目录


文章目录

  • 系列文章目录
    • [`shape_vae_next_dc_f16c32` 训练数据与流程全链路梳理](#shape_vae_next_dc_f16c32 训练数据与流程全链路梳理)
      • 0️⃣、整体数据流图
      • 一、总览
      • 二、数据准备(离线预处理)
      • 三、数据加载链路
        • [3.1 数据集初始化](#3.1 数据集初始化)
        • [3.2 单个样本读取 `get_instance`](#3.2 单个样本读取 get_instance)
        • [3.3 Collate 与 Batch Split](#3.3 Collate 与 Batch Split)
      • 四、模型前向传播链路
        • [4.1 Encoder: `FlexiDualGridVaeEncoder`](#4.1 Encoder: FlexiDualGridVaeEncoder)
        • [4.2 Decoder: `FlexiDualGridVaeDecoder`](#4.2 Decoder: FlexiDualGridVaeDecoder)
      • [五、Loss 计算链路](#五、Loss 计算链路)
        • [5.1 Direct Regression Loss](#5.1 Direct Regression Loss)
        • [5.2 Subdivision Prediction Loss](#5.2 Subdivision Prediction Loss)
        • [5.3 Rendering Loss(主要 loss)](#5.3 Rendering Loss(主要 loss))
        • [5.4 KL Regularization](#5.4 KL Regularization)
      • 六、训练循环

shape_vae_next_dc_f16c32 训练数据与流程全链路梳理

0️⃣、整体数据流图

复制代码
[离线] mesh.pickle + dual_grid.vxz
        ↓ FlexiDualGridDataset.get_instance()
        ↓ read_mesh → Mesh(vertices, faces)
        ↓ read_dual_grid → SparseTensor(vertices), SparseTensor(intersected)
        ↓ collate_fn → {vertices, intersected, mesh}
        ↓
[训练] ShapeVaeTrainer.training_losses(vertices, intersected, mesh)
        ↓
  Encoder: [vertices.feats-0.5, intersected.feats-0.5] → SparseUNet(16x下采样) → mean, logvar → z(32ch)
        ↓
  Decoder: z → SparseUNet(16x上采样, 预测subdivision) → 7ch output
           ├── ch0-2: pred_vertices (sigmoid)
           ├── ch3-5: pred_intersected (logits)
           └── ch6:   quad_lerp (softplus)
           → flexible_dual_grid_to_mesh(train=True) → recon Mesh
        ↓
  Loss =  λ_subdiv × BCE(subdivision)
        + λ_intersected × BCE(pred_intersected, gt_intersected)
        + λ_vertice × MSE(pred_vertices, gt_vertices)
        + λ_mask × L1(render_mask)
        + λ_depth × L1(render_depth)
        + λ_normal × [L1 + λ_ssim×SSIM + λ_lpips×LPIPS](render_normal)
        + λ_kl × KL(mean, logvar)

一、总览

配置文件:configs/scvae/shape_vae_next_dc_f16c32_fp16.json

涉及三大核心组件:

  • Dataset : FlexiDualGridDataset
  • Models : FlexiDualGridVaeEncoder + FlexiDualGridVaeDecoder
  • Trainer : ShapeVaeTrainer

二、数据准备(离线预处理)

根据 data_toolkit/README.md,训练所需数据需经过以下预处理:

  1. Step 4 : dump_mesh.py --- 将原始 3D 资产标准化为 mesh dump(pickle 文件)
  2. Step 5 : dual_grid.py --- 将 mesh 转为 O-Voxel 的 dual grid 表示(.vxz 文件)

训练启动命令中的 --data_dir 指定了数据位置:

json 复制代码
{
  "ObjaverseXL_sketchfab": {
    "base": "datasets/ObjaverseXL_sketchfab",
    "mesh_dump": "datasets/ObjaverseXL_sketchfab/mesh_dumps",
    "dual_grid": "datasets/ObjaverseXL_sketchfab/dual_grid_256",
    "asset_stats": "datasets/ObjaverseXL_sketchfab/asset_stats"
  }
}

三、数据加载链路

3.1 数据集初始化

train.py 第 70 行实例化数据集:

70:70:repo/TRELLIS.2/train.py 复制代码
    dataset = getattr(datasets, cfg.dataset.name)(cfg.data_dir, **cfg.dataset.args)

FlexiDualGridDataset(roots=data_dir_json, resolution=256, max_active_voxels=1000000, max_num_faces=1000000, min_aesthetic_score=4.5)

构造函数中(flexi_dual_grid.py),StandardDatasetBase.__init__ 解析 JSON 格式的 roots,读取每个数据集子目录下的 metadata.csv,并通过 filter_metadata 过滤:

92:104:repo/TRELLIS.2/trellis2/datasets/flexi_dual_grid.py 复制代码
    def filter_metadata(self, metadata):
        stats = {}
        metadata = metadata[metadata[f'dual_grid_converted'] == True]
        # ...
        metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
        metadata = metadata[metadata[f'dual_grid_size'] <= self.max_active_voxels]
        # ...
3.2 单个样本读取 get_instance
141:144:repo/TRELLIS.2/trellis2/datasets/flexi_dual_grid.py 复制代码
    def get_instance(self, root, instance):
        mesh = self.read_mesh(root['mesh_dump'], instance)
        dual_grid = self.read_dual_grid(root['dual_grid'], instance)
        return {**mesh, **dual_grid}

每个样本包含 三个字段

a) mesh --- 从 pickle 中读取并归一化到 [-0.5, 0.5]

106:126:repo/TRELLIS.2/trellis2/datasets/flexi_dual_grid.py 复制代码
    def read_mesh(self, root, instance):
        with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f:
            dump = pickle.load(f)
        # ... 合并多个 object 的 vertices/faces ...
        vertices = (vertices - center) * scale
        return {'mesh': [Mesh(vertices=vertices, faces=faces)]}

b) vertices --- 从 .vxz 文件读取的 dual grid 顶点偏移,封装为 SparseTensor(feats 为 3 通道,范围 [0,1]):

128:133:repo/TRELLIS.2/trellis2/datasets/flexi_dual_grid.py 复制代码
    def read_dual_grid(self, root, instance):
        coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4)
        vertices = sp.SparseTensor(
            (attr['vertices'] / 255.0).float(),
            torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1),
        )

c) intersected --- 每个 voxel 在 3 个轴方向上是否被面片穿过的布尔标志(3 通道),同样封装为 SparseTensor

134:139:repo/TRELLIS.2/trellis2/datasets/flexi_dual_grid.py 复制代码
        intersected = vertices.replace(torch.cat([
            attr['intersected'] % 2,
            attr['intersected'] // 2 % 2,
            attr['intersected'] // 4 % 2,
        ], dim=-1).bool())
3.3 Collate 与 Batch Split

collate_fn 将多个样本中的 SparseTensor 通过 sp.sparse_cat 拼接,同时支持 batch_split(梯度累积):

146:172:repo/TRELLIS.2/trellis2/datasets/flexi_dual_grid.py 复制代码
    @staticmethod
    def collate_fn(batch, split_size=None):
        # ... load_balanced_group_indices 做负载均衡分组 ...
        # SparseTensor 用 sparse_cat 拼接
        # Tensor 用 torch.stack
        # list 用 sum 拼接

配置中 batch_size_per_gpu=8, batch_split=2,即每个 GPU 取 8 个样本,分成 2 组各 4 个做梯度累积。


四、模型前向传播链路

4.1 Encoder: FlexiDualGridVaeEncoder

继承自 SparseUnetVaeEncoder,输入通道为 6(3 通道 vertices + 3 通道 intersected):

51:55:repo/TRELLIS.2/trellis2/models/sc_vaes/fdg_vae.py 复制代码
    def forward(self, vertices: sp.SparseTensor, intersected: sp.SparseTensor, sample_posterior=False, return_raw=False):
        x = vertices.replace(torch.cat([
            vertices.feats - 0.5,
            intersected.feats.float() - 0.5,
        ], dim=1))
        return super().forward(x, sample_posterior, return_raw)

网络结构:5 级稀疏 U-Net,通道 [64, 128, 256, 512, 1024],各级 block 数 [0, 4, 8, 16, 4],block 类型为 SparseConvNeXtBlock3dconv → layernorm → mlp 残差块),下采样为 SparseResBlockS2C3d(spatial-to-channel 2x 下采样)。

输出经 to_latent 线性层映射为 2 * 32 = 64 通道,拆分为 meanlogvar,采样得到 32 通道的 latent z

4.2 Decoder: FlexiDualGridVaeDecoder

继承自 SparseUnetVaeDecoder,输出通道为 7

87:108:repo/TRELLIS.2/trellis2/models/sc_vaes/fdg_vae.py 复制代码
    def forward(self, x: sp.SparseTensor, gt_intersected: sp.SparseTensor = None, **kwargs):
        decoded = super().forward(x, **kwargs)
        if self.training:
            h, subs_gt, subs = decoded
            vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
            intersected_logits = h.replace(h.feats[..., 3:6])
            quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
            mesh = [Mesh(*flexible_dual_grid_to_mesh(
                v.coords[:, 1:], v.feats, i.feats, q.feats,
                aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
                grid_size=self.resolution,
                train=True
            )) for v, i, q in zip(vertices, gt_intersected, quad_lerp)]
            return mesh, vertices, intersected_logits, subs_gt, subs

7 通道的含义:

  • 通道 0-2 : 顶点偏移(经 sigmoid 映射到 [-0.5, 1.5]
  • 通道 3-5: intersected 的 logits
  • 通道 6: quad 分割权重(经 softplus 确保正值)

训练时使用 GT 的 intersected 来做 mesh 提取(flexible_dual_grid_to_meshtrain=True 模式会在 quad 中心插入一个中间顶点,生成 4 个三角形以保持可微性)。

上采样使用 SparseResBlockC2S3d(channel-to-spatial 2x 上采样),每一级上采样时会预测 subdivision mask(to_subdiv 线性层),决定哪些 voxel 应被细分。


五、Loss 计算链路

ShapeVaeTrainer.training_losses 是核心 loss 计算方法:

199:239:repo/TRELLIS.2/trellis2/trainers/vae/shape_vae.py 复制代码
    def training_losses(self, vertices, intersected, mesh):
        z, mean, logvar = self.training_models['encoder'](vertices, intersected, sample_posterior=True, return_raw=True)
        recon, pred_vertice, pred_intersected, subs_gt, subs = self.training_models['decoder'](z, intersected)
        # ...

传入参数直接对应 collate_fn 输出的 dict 的 key(vertices, intersected, mesh)。

Loss 由 5 类 组成:

5.1 Direct Regression Loss
python 复制代码
# intersected BCE loss (λ=0.1)
F.binary_cross_entropy_with_logits(pred_intersected.feats.flatten(), intersected.feats.flatten().float())

# vertices MSE loss (λ=0.01)
F.mse_loss(pred_vertice.feats, vertices.feats)
5.2 Subdivision Prediction Loss

对每一级上采样产生的 subdivision 预测 subs 与 GT subs_gt 计算 BCE(λ=0.1):

python 复制代码
for i, (sub_gt, sub) in enumerate(zip(subs_gt, subs)):
    F.binary_cross_entropy_with_logits(sub.feats, sub_gt.float())

GT 的 subdivision mask 来自 decoder 上采样时从输入 SparseTensor 的 spatial cache 中获取。

5.3 Rendering Loss(主要 loss)

先随机采样相机参数,然后分别渲染 GT mesh 和重建 mesh:

python 复制代码
cameras = self._randomize_camera(len(mesh))
gt_renders = self._render_batch(mesh, **cameras, return_types=['mask', 'normal', 'depth'])
pred_renders = self._render_batch(recon, **cameras, return_types=['mask', 'normal', 'depth'])

相机随机化(_randomize_camera):在 1/r^2 空间均匀采样半径(范围 [2, 100]),计算对应 FOV,随机方向。

渲染 loss 包含:

  • Mask L1 (λ=1)
  • Depth L1 (λ=10)
  • Normal L1 (λ=1)
  • Normal SSIM (λ=0.2)
  • Normal LPIPS (λ=0.2)
python 复制代码
terms['loss'] += λ_mask * mask_l1 + λ_depth * depth_l1 + λ_normal * (normal_l1 + λ_ssim * normal_ssim + λ_lpips * normal_lpips)
5.4 KL Regularization

标准 VAE KL 散度(λ=1e-6,非常小):

python 复制代码
terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)

六、训练循环

BasicTrainer.run() 驱动:

复制代码
while step < max_steps:
    data_list = self.load_data()          # 从 data_iterator 取一批数据,移到 GPU
    step_log = self.run_step(data_list)   # 前向 + 反向 + 优化
    step += 1

run_step 中遍历 data_list(batch_split=2,两组):

  1. 对每组调用 self.training_losses(**mb_data) 计算 loss
  2. loss 除以 len(data_list) 后反向传播(fp16 inflat_all 模式下乘以 2^log_scale
  3. 梯度累积后执行 AdaptiveGradClipper(max_norm=1.0, clip_percentile=95)
  4. AdamW 优化器 step(lr=1e-4)
  5. 更新 EMA(rate=0.9999)
相关推荐
tuotali20261 小时前
氢气压缩机技术规范亲测案例分享
人工智能·python
Coder_Boy_2 小时前
Java(Spring AI)传统项目智能化改造——商业化真实案例(含完整核心代码+落地指南)
java·人工智能·spring boot·spring·微服务
lintax2 小时前
计算pi值-积分法
python·算法·计算π·积分法
CoderJia程序员甲2 小时前
GitHub 热榜项目 - 日榜(2026-02-23)
人工智能·ai·大模型·github·ai教程
你的冰西瓜2 小时前
C++ STL算法——排序和相关操作
开发语言·c++·算法·stl
冬奇Lab2 小时前
MCP 集成实战:连接外部世界
人工智能·ai编程·claude
罗政3 小时前
AI图片识别批量提取医疗器械铭牌信息实战
人工智能
冬奇Lab3 小时前
一天一个开源项目(第32篇):Edit-Banana - 让不可编辑的图表变成可编辑,SAM3+多模态大模型驱动
人工智能·开源·资讯
今儿敲了吗3 小时前
29| 高考志愿
c++·笔记·学习·算法