【3D-AICG 系列-12】Trellis 2 的 Shape VAE 的设计细节 Sparse Residual Autoencoding Layer

系列文章目录


文章目录

  • 系列文章目录
    • [论文公式 → 代码对应关系](#论文公式 → 代码对应关系)
      • [一、公式 (4):下采样(Spatial → Channel)](#一、公式 (4):下采样(Spatial → Channel))
      • [二、公式 (5):上采样(Channel → Spatial)](#二、公式 (5):上采样(Channel → Spatial))
      • 三、总结对照表

本文为论文中3.2.1. Network Architecture 中的 Sparse Residual Autoencoding Layer 段的 论文公式 → 代码对应关系 解读。

论文公式 → 代码对应关系

一、公式 (4):下采样(Spatial → Channel)

论文描述:

底层操作 对应 SparseSpatial2Channel

16:55:repo/TRELLIS.2/trellis2/modules/sparse/spatial/spatial2channel.py 复制代码
    def forward(self, x: SparseTensor) -> SparseTensor:
        DIM = x.coords.shape[-1] - 1
        cache = x.get_spatial_cache(f'spatial2channel_{self.factor}')
        if cache is None:
            coord = list(x.coords.unbind(dim=-1))
            for i in range(DIM):
                coord[i+1] = coord[i+1] // self.factor
            subidx = x.coords[:, 1:] % self.factor
            subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)])
            # ... unique → new_coords, idx ...

        new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype)
        new_feats[idx * self.factor ** DIM + subidx] = x.feats

        out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, ...)
        # ...

这段做的就是论文中的 stack 操作:

  • 将每个 fine voxel 的坐标整除 2 得到 coarse 坐标
  • subidx 计算子 voxel 在 8 个 children 中的编号(0~7)
  • 初始化全零 tensor,按 [idx * 8 + subidx] 填入(缺失 voxel 自然为 0)
  • reshape 成 [N_coarse, 8*C] --- 这就是 (F_{\text{coarse}}^{\text{raw}} \in \mathbb{R}^{8C})

残差 shortcut(avg_groups) 对应 SparseResBlockS2C3d 中的 skip_connection

195:196:repo/TRELLIS.2/trellis2/models/sc_vaes/sparse_unet_vae.py 复制代码
        self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1))
        self.updown = sp.SparseSpatial2Channel(2)

这里 reshape(N, C', 8C/C').mean(dim=-1) 就是论文中的 avg_groups:将 (8C) 通道分成 (C') 组,每组取均值,得到 (F_{\text{coarse}} \in \mathbb{R}^{C'})。

完整下采样 forward 流程(主路径 + shortcut 相加):

198:208:repo/TRELLIS.2/trellis2/models/sc_vaes/sparse_unet_vae.py 复制代码
    def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
        h = x.replace(self.norm1(x.feats))
        h = h.replace(F.silu(h.feats))
        h = self.conv1(h)              # C → C'/8, 卷积后通道变小
        h = self.updown(h)             # spatial2channel: 8*(C'/8) = C'
        x = self.updown(x)             # shortcut 也做 spatial2channel: C → 8C
        h = h.replace(self.norm2(h.feats))
        h = h.replace(F.silu(h.feats))
        h = self.conv2(h)
        h = h + self.skip_connection(x)  # avg_groups(8C) → C',然后残差相加
        return h

二、公式 (5):上采样(Channel → Spatial)

论文描述:

底层操作 对应 SparseChannel2Spatial

67:93:repo/TRELLIS.2/trellis2/modules/sparse/spatial/spatial2channel.py 复制代码
    def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor:
        # ...
        x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1)
        new_feats = x_feats[idx * self.factor ** DIM + subidx]
        out = SparseTensor(new_feats, new_coords, ...)

这里将 [N_coarse, 8*(C'/8)] reshape 成 [N_coarse*8, C'/8],然后按 subdivision mask 选出活跃子 voxel 的 feature --- 这就是论文的 unstack 操作,得到 (F_{\text{fine}}^{\text{raw}} \in \mathbb{R}^{C'/8})。

残差 shortcut(dup_groups) 对应 SparseResBlockC2S3d 中的 skip_connection

235:238:repo/TRELLIS.2/trellis2/models/sc_vaes/sparse_unet_vae.py 复制代码
        self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
        # ...
        self.updown = sp.SparseChannel2Spatial(2)

repeat_interleave 就是论文中的 dup_groups:将 (C'/8) 通道中的每个元素重复 (C \cdot 8 / C') 次,扩展到目标通道数 (C)。

完整上采样 forward 流程

240:256:repo/TRELLIS.2/trellis2/models/sc_vaes/sparse_unet_vae.py 复制代码
    def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
        if self.pred_subdiv:
            subdiv = self.to_subdiv(x)          # 预测 subdivision mask
        h = x.replace(self.norm1(x.feats))
        h = h.replace(F.silu(h.feats))
        h = self.conv1(h)                       # C → C_out*8, 卷积扩大通道
        subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
        h = self.updown(h, subdiv_binarized)    # channel2spatial: unstack → C_out
        x = self.updown(x, subdiv_binarized)    # shortcut 也做 channel2spatial: C → C/8
        h = h.replace(self.norm2(h.feats))
        h = h.replace(F.silu(h.feats))
        h = self.conv2(h)
        h = h + self.skip_connection(x)         # dup_groups(C/8) → C_out,残差相加
        if self.pred_subdiv:
            return h, subdiv
        else:
            return h

三、总结对照表

论文概念 代码位置 说明
stack (Eq.4 第1行) SparseSpatial2Channel.forward (spatial2channel.py:39-42) 8 个子 voxel 特征按位置拼入全零 tensor,reshape 成 [N, 8C]
avg_groups (Eq.4 第2行) SparseResBlockS2C3d.__init__skip_connection (sparse_unet_vae.py:195) reshape(N, C', 8C/C').mean(-1)
unstack (Eq.5 第1行) SparseChannel2Spatial.forward (spatial2channel.py:87-88) reshape 成 [N*8, C'/8],按 subdivision mask 索引取出
dup_groups (Eq.5 第2行) SparseResBlockC2S3d.__init__skip_connection (sparse_unet_vae.py:235) repeat_interleave 复制通道
non-parametric residual shortcut 两个 block 的 h = h + self.skip_connection(x) shortcut 路径没有可学习参数,纯 reshape/mean/repeat

核心设计思想:下采样时 shortcut 用 avg_groups8C 通道压缩到 C',上采样时 shortcut 用 dup_groupsC'/8 通道扩展到 C,整个过程不引入任何可学习参数 (non-parametric),仅通过 空间维度与通道维度的重排列 来传递信息。

即正文所说:We adapt the Residual Autoencoding principle from DC-AE [1] to sparse voxel data by introducing non-parametric residual short-cuts within downsampling and upsampling blocks.

相关推荐
qq_24218863321 小时前
金融AI反欺诈系统构建指南
人工智能·笔记·金融·课程设计
新加坡内哥谈技术1 小时前
Claude C 编译器:它揭示了软件未来的什么
人工智能
予枫的编程笔记2 小时前
【Kafka进阶篇】Kafka消息重复消费?Exactly-Once语义落地指南,PID+事务消息吃透
人工智能·kafka·消息队列·exactly-once·分布式消息·kafka幂等性·kafka事务消息
踢足球09292 小时前
寒假打卡:2026-2-23
数据结构·算法
Loo国昌2 小时前
【AI应用开发实战】09_Prompt工程与模板管理:构建可演进的LLM交互层
大数据·人工智能·后端·python·自然语言处理·prompt
新缸中之脑2 小时前
Wellows:生成式AI搜索优化平台
人工智能·chatgpt
aiAIman2 小时前
OpenClaw 使用和管理 MCP 完全指南
人工智能·语言模型·开源
lusasky2 小时前
对比ZeroClaw 和 OpenClaw
人工智能
Clarence Liu2 小时前
用大白话讲解人工智能(16) 强化学习:教AI“玩游戏“学决策
人工智能·玩游戏