系列文章目录
- 【3D AICG 系列-1】Trellis v1 和 Trellis v2 的区别和改进
- 【3D AICG 系列-2】Trellis 2 的O-voxel (上) Shape: Flexible Dual Grid
- 【3D AICG 系列-3】Trellis 2 的O-voxel (下) Material: Volumetric Surface Attributes
- 【3D AICG 系列-4】Trellis 2 的Shape SLAT Flow Matching DiT 训练流程
- 【3D AICG 系列-5】Trellis 2 的 Pipeline 推理流程的各个中间结果和形状
- 【3D AICG 系列-6】OmniPart 训练流程梳理
- 【3D AICG 系列-7】PartUV 代码流程深度解析
- 【3D AICG 系列-8】PartUV 流程图详解
- 【3D AICG 系列-9】Trellis2 推理流程图超详细介绍
- 【3D-AICG 系列-10】Trellis v2 只在 512/1024上训却能生成 1536
文章目录
- 系列文章目录
-
- [`shape_vae_next_dc_f16c32` 训练数据与流程全链路梳理](#
shape_vae_next_dc_f16c32训练数据与流程全链路梳理) -
- 0️⃣、整体数据流图
- 一、总览
- 二、数据准备(离线预处理)
- 三、数据加载链路
-
- [3.1 数据集初始化](#3.1 数据集初始化)
- [3.2 单个样本读取 `get_instance`](#3.2 单个样本读取
get_instance) - [3.3 Collate 与 Batch Split](#3.3 Collate 与 Batch Split)
- 四、模型前向传播链路
-
- [4.1 Encoder: `FlexiDualGridVaeEncoder`](#4.1 Encoder:
FlexiDualGridVaeEncoder) - [4.2 Decoder: `FlexiDualGridVaeDecoder`](#4.2 Decoder:
FlexiDualGridVaeDecoder)
- [4.1 Encoder: `FlexiDualGridVaeEncoder`](#4.1 Encoder:
- [五、Loss 计算链路](#五、Loss 计算链路)
-
- [5.1 Direct Regression Loss](#5.1 Direct Regression Loss)
- [5.2 Subdivision Prediction Loss](#5.2 Subdivision Prediction Loss)
- [5.3 Rendering Loss(主要 loss)](#5.3 Rendering Loss(主要 loss))
- [5.4 KL Regularization](#5.4 KL Regularization)
- 六、训练循环
- [`shape_vae_next_dc_f16c32` 训练数据与流程全链路梳理](#
shape_vae_next_dc_f16c32 训练数据与流程全链路梳理
0️⃣、整体数据流图
[离线] mesh.pickle + dual_grid.vxz
↓ FlexiDualGridDataset.get_instance()
↓ read_mesh → Mesh(vertices, faces)
↓ read_dual_grid → SparseTensor(vertices), SparseTensor(intersected)
↓ collate_fn → {vertices, intersected, mesh}
↓
[训练] ShapeVaeTrainer.training_losses(vertices, intersected, mesh)
↓
Encoder: [vertices.feats-0.5, intersected.feats-0.5] → SparseUNet(16x下采样) → mean, logvar → z(32ch)
↓
Decoder: z → SparseUNet(16x上采样, 预测subdivision) → 7ch output
├── ch0-2: pred_vertices (sigmoid)
├── ch3-5: pred_intersected (logits)
└── ch6: quad_lerp (softplus)
→ flexible_dual_grid_to_mesh(train=True) → recon Mesh
↓
Loss = λ_subdiv × BCE(subdivision)
+ λ_intersected × BCE(pred_intersected, gt_intersected)
+ λ_vertice × MSE(pred_vertices, gt_vertices)
+ λ_mask × L1(render_mask)
+ λ_depth × L1(render_depth)
+ λ_normal × [L1 + λ_ssim×SSIM + λ_lpips×LPIPS](render_normal)
+ λ_kl × KL(mean, logvar)
一、总览
配置文件:configs/scvae/shape_vae_next_dc_f16c32_fp16.json
涉及三大核心组件:
- Dataset :
FlexiDualGridDataset - Models :
FlexiDualGridVaeEncoder+FlexiDualGridVaeDecoder - Trainer :
ShapeVaeTrainer
二、数据准备(离线预处理)
根据 data_toolkit/README.md,训练所需数据需经过以下预处理:
- Step 4 :
dump_mesh.py--- 将原始 3D 资产标准化为 mesh dump(pickle 文件) - Step 5 :
dual_grid.py--- 将 mesh 转为 O-Voxel 的 dual grid 表示(.vxz文件)
训练启动命令中的 --data_dir 指定了数据位置:
json
{
"ObjaverseXL_sketchfab": {
"base": "datasets/ObjaverseXL_sketchfab",
"mesh_dump": "datasets/ObjaverseXL_sketchfab/mesh_dumps",
"dual_grid": "datasets/ObjaverseXL_sketchfab/dual_grid_256",
"asset_stats": "datasets/ObjaverseXL_sketchfab/asset_stats"
}
}
三、数据加载链路
3.1 数据集初始化
train.py 第 70 行实例化数据集:
70:70:repo/TRELLIS.2/train.py
dataset = getattr(datasets, cfg.dataset.name)(cfg.data_dir, **cfg.dataset.args)
即 FlexiDualGridDataset(roots=data_dir_json, resolution=256, max_active_voxels=1000000, max_num_faces=1000000, min_aesthetic_score=4.5)。
构造函数中(flexi_dual_grid.py),StandardDatasetBase.__init__ 解析 JSON 格式的 roots,读取每个数据集子目录下的 metadata.csv,并通过 filter_metadata 过滤:
92:104:repo/TRELLIS.2/trellis2/datasets/flexi_dual_grid.py
def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata[f'dual_grid_converted'] == True]
# ...
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
metadata = metadata[metadata[f'dual_grid_size'] <= self.max_active_voxels]
# ...
3.2 单个样本读取 get_instance
141:144:repo/TRELLIS.2/trellis2/datasets/flexi_dual_grid.py
def get_instance(self, root, instance):
mesh = self.read_mesh(root['mesh_dump'], instance)
dual_grid = self.read_dual_grid(root['dual_grid'], instance)
return {**mesh, **dual_grid}
每个样本包含 三个字段:
a) mesh --- 从 pickle 中读取并归一化到 [-0.5, 0.5]:
106:126:repo/TRELLIS.2/trellis2/datasets/flexi_dual_grid.py
def read_mesh(self, root, instance):
with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f:
dump = pickle.load(f)
# ... 合并多个 object 的 vertices/faces ...
vertices = (vertices - center) * scale
return {'mesh': [Mesh(vertices=vertices, faces=faces)]}
b) vertices --- 从 .vxz 文件读取的 dual grid 顶点偏移,封装为 SparseTensor(feats 为 3 通道,范围 [0,1]):
128:133:repo/TRELLIS.2/trellis2/datasets/flexi_dual_grid.py
def read_dual_grid(self, root, instance):
coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4)
vertices = sp.SparseTensor(
(attr['vertices'] / 255.0).float(),
torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1),
)
c) intersected --- 每个 voxel 在 3 个轴方向上是否被面片穿过的布尔标志(3 通道),同样封装为 SparseTensor:
134:139:repo/TRELLIS.2/trellis2/datasets/flexi_dual_grid.py
intersected = vertices.replace(torch.cat([
attr['intersected'] % 2,
attr['intersected'] // 2 % 2,
attr['intersected'] // 4 % 2,
], dim=-1).bool())
3.3 Collate 与 Batch Split
collate_fn 将多个样本中的 SparseTensor 通过 sp.sparse_cat 拼接,同时支持 batch_split(梯度累积):
146:172:repo/TRELLIS.2/trellis2/datasets/flexi_dual_grid.py
@staticmethod
def collate_fn(batch, split_size=None):
# ... load_balanced_group_indices 做负载均衡分组 ...
# SparseTensor 用 sparse_cat 拼接
# Tensor 用 torch.stack
# list 用 sum 拼接
配置中 batch_size_per_gpu=8, batch_split=2,即每个 GPU 取 8 个样本,分成 2 组各 4 个做梯度累积。
四、模型前向传播链路
4.1 Encoder: FlexiDualGridVaeEncoder
继承自 SparseUnetVaeEncoder,输入通道为 6(3 通道 vertices + 3 通道 intersected):
51:55:repo/TRELLIS.2/trellis2/models/sc_vaes/fdg_vae.py
def forward(self, vertices: sp.SparseTensor, intersected: sp.SparseTensor, sample_posterior=False, return_raw=False):
x = vertices.replace(torch.cat([
vertices.feats - 0.5,
intersected.feats.float() - 0.5,
], dim=1))
return super().forward(x, sample_posterior, return_raw)
网络结构:5 级稀疏 U-Net,通道 [64, 128, 256, 512, 1024],各级 block 数 [0, 4, 8, 16, 4],block 类型为 SparseConvNeXtBlock3d(conv → layernorm → mlp 残差块),下采样为 SparseResBlockS2C3d(spatial-to-channel 2x 下采样)。
输出经 to_latent 线性层映射为 2 * 32 = 64 通道,拆分为 mean 和 logvar,采样得到 32 通道的 latent z。
4.2 Decoder: FlexiDualGridVaeDecoder
继承自 SparseUnetVaeDecoder,输出通道为 7:
87:108:repo/TRELLIS.2/trellis2/models/sc_vaes/fdg_vae.py
def forward(self, x: sp.SparseTensor, gt_intersected: sp.SparseTensor = None, **kwargs):
decoded = super().forward(x, **kwargs)
if self.training:
h, subs_gt, subs = decoded
vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
intersected_logits = h.replace(h.feats[..., 3:6])
quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
mesh = [Mesh(*flexible_dual_grid_to_mesh(
v.coords[:, 1:], v.feats, i.feats, q.feats,
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
grid_size=self.resolution,
train=True
)) for v, i, q in zip(vertices, gt_intersected, quad_lerp)]
return mesh, vertices, intersected_logits, subs_gt, subs
7 通道的含义:
- 通道 0-2 : 顶点偏移(经 sigmoid 映射到
[-0.5, 1.5]) - 通道 3-5: intersected 的 logits
- 通道 6: quad 分割权重(经 softplus 确保正值)
训练时使用 GT 的 intersected 来做 mesh 提取(flexible_dual_grid_to_mesh 的 train=True 模式会在 quad 中心插入一个中间顶点,生成 4 个三角形以保持可微性)。
上采样使用 SparseResBlockC2S3d(channel-to-spatial 2x 上采样),每一级上采样时会预测 subdivision mask(to_subdiv 线性层),决定哪些 voxel 应被细分。
五、Loss 计算链路
ShapeVaeTrainer.training_losses 是核心 loss 计算方法:
199:239:repo/TRELLIS.2/trellis2/trainers/vae/shape_vae.py
def training_losses(self, vertices, intersected, mesh):
z, mean, logvar = self.training_models['encoder'](vertices, intersected, sample_posterior=True, return_raw=True)
recon, pred_vertice, pred_intersected, subs_gt, subs = self.training_models['decoder'](z, intersected)
# ...
传入参数直接对应 collate_fn 输出的 dict 的 key(vertices, intersected, mesh)。
Loss 由 5 类 组成:
5.1 Direct Regression Loss
python
# intersected BCE loss (λ=0.1)
F.binary_cross_entropy_with_logits(pred_intersected.feats.flatten(), intersected.feats.flatten().float())
# vertices MSE loss (λ=0.01)
F.mse_loss(pred_vertice.feats, vertices.feats)
5.2 Subdivision Prediction Loss
对每一级上采样产生的 subdivision 预测 subs 与 GT subs_gt 计算 BCE(λ=0.1):
python
for i, (sub_gt, sub) in enumerate(zip(subs_gt, subs)):
F.binary_cross_entropy_with_logits(sub.feats, sub_gt.float())
GT 的 subdivision mask 来自 decoder 上采样时从输入 SparseTensor 的 spatial cache 中获取。
5.3 Rendering Loss(主要 loss)
先随机采样相机参数,然后分别渲染 GT mesh 和重建 mesh:
python
cameras = self._randomize_camera(len(mesh))
gt_renders = self._render_batch(mesh, **cameras, return_types=['mask', 'normal', 'depth'])
pred_renders = self._render_batch(recon, **cameras, return_types=['mask', 'normal', 'depth'])
相机随机化(_randomize_camera):在 1/r^2 空间均匀采样半径(范围 [2, 100]),计算对应 FOV,随机方向。
渲染 loss 包含:
- Mask L1 (λ=1)
- Depth L1 (λ=10)
- Normal L1 (λ=1)
- Normal SSIM (λ=0.2)
- Normal LPIPS (λ=0.2)
python
terms['loss'] += λ_mask * mask_l1 + λ_depth * depth_l1 + λ_normal * (normal_l1 + λ_ssim * normal_ssim + λ_lpips * normal_lpips)
5.4 KL Regularization
标准 VAE KL 散度(λ=1e-6,非常小):
python
terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
六、训练循环
BasicTrainer.run() 驱动:
while step < max_steps:
data_list = self.load_data() # 从 data_iterator 取一批数据,移到 GPU
step_log = self.run_step(data_list) # 前向 + 反向 + 优化
step += 1
run_step 中遍历 data_list(batch_split=2,两组):
- 对每组调用
self.training_losses(**mb_data)计算 loss - loss 除以
len(data_list)后反向传播(fp16 inflat_all 模式下乘以2^log_scale) - 梯度累积后执行
AdaptiveGradClipper(max_norm=1.0, clip_percentile=95) AdamW优化器 step(lr=1e-4)- 更新 EMA(rate=0.9999)