系列文章目录
- 【3D AICG 系列-1】Trellis v1 和 Trellis v2 的区别和改进
- 【3D AICG 系列-2】Trellis 2 的O-voxel (上) Shape: Flexible Dual Grid
- 【3D AICG 系列-3】Trellis 2 的O-voxel (下) Material: Volumetric Surface Attributes
- 【3D AICG 系列-4】Trellis 2 的Shape SLAT Flow Matching DiT 训练流程
- 【3D AICG 系列-5】Trellis 2 的 Pipeline 推理流程的各个中间结果和形状
- 【3D AICG 系列-6】OmniPart 训练流程梳理
- 【3D AICG 系列-7】PartUV 代码流程深度解析
- 【3D AICG 系列-8】PartUV 流程图详解
- 【3D AICG 系列-9】Trellis2 推理流程图超详细介绍
- 【3D-AICG 系列-10】Trellis v2 只在 512/1024上训却能生成 1536
- 【3D-AICG 系列-11】Trellis 2 的 Shape VAE 训练流程梳理
文章目录
- 系列文章目录
-
- [论文公式 → 代码对应关系](#论文公式 → 代码对应关系)
-
- [一、公式 (4):下采样(Spatial → Channel)](#一、公式 (4):下采样(Spatial → Channel))
- [二、公式 (5):上采样(Channel → Spatial)](#二、公式 (5):上采样(Channel → Spatial))
- 三、总结对照表

本文为论文中3.2.1. Network Architecture 中的 Sparse Residual Autoencoding Layer 段的 论文公式 → 代码对应关系 解读。
论文公式 → 代码对应关系
一、公式 (4):下采样(Spatial → Channel)
论文描述:

底层操作 对应 SparseSpatial2Channel:
16:55:repo/TRELLIS.2/trellis2/modules/sparse/spatial/spatial2channel.py
def forward(self, x: SparseTensor) -> SparseTensor:
DIM = x.coords.shape[-1] - 1
cache = x.get_spatial_cache(f'spatial2channel_{self.factor}')
if cache is None:
coord = list(x.coords.unbind(dim=-1))
for i in range(DIM):
coord[i+1] = coord[i+1] // self.factor
subidx = x.coords[:, 1:] % self.factor
subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)])
# ... unique → new_coords, idx ...
new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype)
new_feats[idx * self.factor ** DIM + subidx] = x.feats
out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, ...)
# ...
这段做的就是论文中的 stack 操作:
- 将每个 fine voxel 的坐标整除 2 得到 coarse 坐标
subidx计算子 voxel 在 8 个 children 中的编号(0~7)- 初始化全零 tensor,按
[idx * 8 + subidx]填入(缺失 voxel 自然为 0) - reshape 成
[N_coarse, 8*C]--- 这就是 (F_{\text{coarse}}^{\text{raw}} \in \mathbb{R}^{8C})
残差 shortcut(avg_groups) 对应 SparseResBlockS2C3d 中的 skip_connection:
195:196:repo/TRELLIS.2/trellis2/models/sc_vaes/sparse_unet_vae.py
self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1))
self.updown = sp.SparseSpatial2Channel(2)
这里 reshape(N, C', 8C/C').mean(dim=-1) 就是论文中的 avg_groups:将 (8C) 通道分成 (C') 组,每组取均值,得到 (F_{\text{coarse}} \in \mathbb{R}^{C'})。
完整下采样 forward 流程(主路径 + shortcut 相加):
198:208:repo/TRELLIS.2/trellis2/models/sc_vaes/sparse_unet_vae.py
def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = self.conv1(h) # C → C'/8, 卷积后通道变小
h = self.updown(h) # spatial2channel: 8*(C'/8) = C'
x = self.updown(x) # shortcut 也做 spatial2channel: C → 8C
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x) # avg_groups(8C) → C',然后残差相加
return h
二、公式 (5):上采样(Channel → Spatial)
论文描述:

底层操作 对应 SparseChannel2Spatial:
67:93:repo/TRELLIS.2/trellis2/modules/sparse/spatial/spatial2channel.py
def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor:
# ...
x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1)
new_feats = x_feats[idx * self.factor ** DIM + subidx]
out = SparseTensor(new_feats, new_coords, ...)
这里将 [N_coarse, 8*(C'/8)] reshape 成 [N_coarse*8, C'/8],然后按 subdivision mask 选出活跃子 voxel 的 feature --- 这就是论文的 unstack 操作,得到 (F_{\text{fine}}^{\text{raw}} \in \mathbb{R}^{C'/8})。
残差 shortcut(dup_groups) 对应 SparseResBlockC2S3d 中的 skip_connection:
235:238:repo/TRELLIS.2/trellis2/models/sc_vaes/sparse_unet_vae.py
self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
# ...
self.updown = sp.SparseChannel2Spatial(2)
repeat_interleave 就是论文中的 dup_groups:将 (C'/8) 通道中的每个元素重复 (C \cdot 8 / C') 次,扩展到目标通道数 (C)。
完整上采样 forward 流程:
240:256:repo/TRELLIS.2/trellis2/models/sc_vaes/sparse_unet_vae.py
def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
if self.pred_subdiv:
subdiv = self.to_subdiv(x) # 预测 subdivision mask
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = self.conv1(h) # C → C_out*8, 卷积扩大通道
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
h = self.updown(h, subdiv_binarized) # channel2spatial: unstack → C_out
x = self.updown(x, subdiv_binarized) # shortcut 也做 channel2spatial: C → C/8
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x) # dup_groups(C/8) → C_out,残差相加
if self.pred_subdiv:
return h, subdiv
else:
return h
三、总结对照表
| 论文概念 | 代码位置 | 说明 |
|---|---|---|
| stack (Eq.4 第1行) | SparseSpatial2Channel.forward (spatial2channel.py:39-42) |
8 个子 voxel 特征按位置拼入全零 tensor,reshape 成 [N, 8C] |
| avg_groups (Eq.4 第2行) | SparseResBlockS2C3d.__init__ 的 skip_connection (sparse_unet_vae.py:195) |
reshape(N, C', 8C/C').mean(-1) |
| unstack (Eq.5 第1行) | SparseChannel2Spatial.forward (spatial2channel.py:87-88) |
reshape 成 [N*8, C'/8],按 subdivision mask 索引取出 |
| dup_groups (Eq.5 第2行) | SparseResBlockC2S3d.__init__ 的 skip_connection (sparse_unet_vae.py:235) |
repeat_interleave 复制通道 |
| non-parametric residual shortcut | 两个 block 的 h = h + self.skip_connection(x) |
shortcut 路径没有可学习参数,纯 reshape/mean/repeat |
核心设计思想:下采样时 shortcut 用 avg_groups 将 8C 通道压缩到 C',上采样时 shortcut 用 dup_groups 将 C'/8 通道扩展到 C,整个过程不引入任何可学习参数 (non-parametric),仅通过 空间维度与通道维度的重排列 来传递信息。
即正文所说:We adapt the Residual Autoencoding principle from DC-AE [1] to sparse voxel data by introducing non-parametric residual short-cuts within downsampling and upsampling blocks.