STORM代码阅读笔记

默认的 分辨率是 [160,240] ,基于 Transformer 的方法不能做高分辨率。

Dataloader

输入是 带有 pose 信息的 RGB 图像

eval datasets

bash 复制代码
## 采样帧数目 = 20
num_max_future_frames = int(self.timespan * fps) 
## 每次间隔多少个时间 timesteps 取一个context image
num_context_timesteps  = 4

按照STORM 原来的 setting, future_frames = 20 context_image 每次间隔4帧,所以是 context_frame_idx = [0,5,10,15], 在 target_frame 包含了 从[0,20]的所有20帧。

以这样20 帧的 image 作为一个基本的 batch, 进行预测: 进入 model

所以,输入网络的 context_image 对应的 shape (1,4,3,160,240) 输入4个时刻帧的 frame, 每一个 frame 有 3个相机;对应的 context_camtoworlds shape (1,4,3,4,4)

train datasets

第一帧 ID 随机采样, 之后的 context_image 每次间隔 5 帧,比如: [47 52 57 62]target·_frame_id 也是进行随机选取:

python 复制代码
  if self.equispaced:
    context_frame_idx = np.arange(
           context_frame_idx,
           context_frame_idx + num_max_future_frames,
           num_max_future_frames // self.num_context_timesteps,
       )

随机在 num_future_id 里面 选择 self.num_target_timesteps 选择 4帧作为 target_image 的监督帧

Network

输入网络的 有3个 input: context_image, ray 和 time 的信息

  • context_image: (1,4,3,3,160,240)
  • Ray embedding (1,4,3,6,160,240)
  • time_embedding (1,4,3)
  • 将 image 和 ray_embedding 进行 concat 操作, 得到 x:(12,9,160,240)
python 复制代码
 x = rearrange(x, "b t v c h w -> (b t v) c h w")
plucker_embeds = rearrange(plucker_embeds, "b t v h w c-> (b t v) c h w")
x = torch.cat([x, plucker_embeds], dim=1) ## (12,9,160,240)

然后经过3个 embedding , 将这些 feature 映射成为 token:

python 复制代码
x = self.patch_embed(x)  # (b t v) h w c2
 x = self._pos_embed(x)  # (b t v) (h w) c2
 x = self._time_embed(x, time, num_views=v)

得到 x.shape (12,600,768), 表示一共有12张图像,每个图象 是 600 个 token, 每个 token 的 channel 是768. 然后将这些 token concat 在一起 得到了 (7200,768) 的 feature;

给得到的 token 分别加上可学习的 motion_token, affine_token 和 sky_token. 连接方式都是 concat

这样得到的 feature 为 (7220,768)的 feature

python 复制代码
if self.num_motion_tokens > 0:
    motion_tokens = repeat(self.motion_tokens, "1 k d -> b k d", b=x.shape[0])
    x = torch.cat([motion_tokens, x], dim=-2)
if self.use_affine_token:
    affine_token = repeat(self.affine_token, "1 k d -> b k d", b=b)
    x = torch.cat([affine_token, x], dim=-2)
if self.use_sky_token:
    sky_token = repeat(self.sky_token, "1 1 d -> b 1 d", b=x.shape[0])
    x = torch.cat([sky_token, x], dim=-2)
  • 使用 Transformer 进行学习, 得到的 feature 维度不变。:
python 复制代码
 x = self.transformer(x)
 x = self.norm(x) ## shape(7220,768)

运行完之后,可以将学习到的 token提取出来:

python 复制代码
if self.use_sky_token:
     sky_token = x[:, :1] ## (1,1,768)
     x = x[:, 1:]

 if self.use_affine_token:
     affine_tokens = x[:, : self.num_cams] ## (1,3,768)
     x = x[:, self.num_cams :]

 if self.num_motion_tokens > 0:
     motion_tokens = x[:, : self.num_motion_tokens]  ## (1,16,768)
     x = x[:, self.num_motion_tokens :]

在 Transformer 内部,没有上采样层,也可以实现 这种 per-pixel feature 的学习。

对于 x 进行 GS 的预测,得到 pixel_align 的高斯。 对于每个 patch, 得到的 feature 是 (12,600,768), 通过一个CNN,虽然通道数没有变 (12,600,768), 但是之前 768 可以理解为全局的 语义, 之后的 768 为 一个patch 内部不同像素的语义,他们共享着 全局的 语义信息,但是每个pixel 却又不一样。 通过下面的 unpatchify 函数将将一个patch 的语义拆成 per-pixel 的语义,将每个768维token展开为8×8像素。

python 复制代码
 b, t, v, h, w, _ = origins.shape
 ## x_shape: (12,600,768)
 x = rearrange(x, "b (t v hw) c -> (b t v) hw c", t=t, v=v)
 ## gs_params_shape: (12,600,768),这一步虽然通道没变,但其实是将一个 token 的全局 语义,映射成
 ## token 内部的像素级别的语义
 gs_params = self.gs_pred(x)
 ## gs_params_shape: (12,12,160,240)
 ### 关键步骤:unpatchify将每个768维token展开为8×8像素
 gs_params = self.unpatchify(gs_params, hw=(h, w), patch_size=self.unpatch_size)

根据 token展开的 per-pixel feature, 进行3DGS 的属性预测

python 复制代码
gs_params = rearrange(gs_params, "(b t v) c h w -> b t v h w c", t=t, v=v)
depth, scales, quats, opacitys, colors = gs_params.split([1, 3, 4, 1, self.gs_dim], dim=-1)
scales = self.scale_act_fn(scales)
opacitys = self.opacity_act_fn(opacitys)
depths = self.depth_act_fn(depth)
colors = self.rgb_act_fn(colors)
means = origins + directions * depths

除了3DGS 的一半属性之外, storm 还额外预测了其他的运动属性,包括:

其中: x: (1,7200,768) 代表 image_token, motion_tokens 是(1,16,768)代表 motion_token. 处理的大致思路是 motion_token 作为 query, 然后 image_token 映射的feature 作为 key, 去结合计算每一个 高斯的 moition_weightsmoition_bases

python 复制代码
gs_params = self.forward_motion_predictor(x, motion_tokens, gs_params)
其中:
forward_flow = torch.einsum(
                "b t v h w k, b k c -> b t v h w c", motion_weights, motion_bases
            )

moition_bases: shape: [1,16,3]
moition_weights: shape: [1,4,3,160,240,16]
forward_flow: shape: [1,4,3,160,240,3]: 是 weights 和bases 结合的结果

GS_param Rendering

  • 取出高斯的各项属性,尤其是 means 和 速度 forward_v: STORM 假设 在这 20帧是出于 匀速直线运动, 其速度时不变的,可能并不合理。我们的方法直接预测 BBX,可能更为准确。
python 复制代码
means = rearrange(gs_params["means"], "b t v h w c -> b (t v h w) c")
scales = rearrange(gs_params["scales"], "b t v h w c -> b (t v h w) c")
quats = rearrange(gs_params["quats"], "b t v h w c -> b (t v h w) c")
opacities = rearrange(gs_params["opacities"], "b t v h w -> b (t v h w)")
colors = rearrange(gs_params["colors"], "b t v h w c -> b (t v h w) c")
forward_v = rearrange(gs_params["forward_flow"], "b t v h w c -> b (t v h w) c")

这里得到的 高斯的 mean 是全部由 context_image 得到的, shape (46800,3), 但这其实是 4个 时刻context_frame_idx = [0,5,10,15], 得到的高斯,并不处于同一时间刻度。

通过比较 target_timecontext_time 之间的插值,去得到每一个 target_time 的 3D Gaussian 的坐标means_batched

python 复制代码
  if tgt_time.ndim == 3:
      tdiff_forward = tgt_time.unsqueeze(2) - ctx_time.unsqueeze(1)
      tdiff_forward = tdiff_forward.view(b * tgt_t, t * v, 1)
      tdiff_forward_batched = tdiff_forward.repeat_interleave(h * w, dim=1)
  else:
      tdiff_forward = tgt_time.unsqueeze(-1) - ctx_time.unsqueeze(-2)
      tdiff_forward = tdiff_forward.view(b * tgt_t, t, 1)
      tdiff_forward_batched = tdiff_forward.repeat_interleave(v * h * w, dim=1)
  forward_translation = forward_v_batched * tdiff_forward_batched
  means_batched = means_batched + forward_translation ## (20,460800,3) 

使用 gsplatbatch_rasterization 函数:

python 复制代码
  rendered_color, rendered_alpha, _ = rasterization(
                        means=means_batched.float(),  ## (20,460800,3)
                        quats=quats_batched.float(),
                        scales=scales_batched.float(),
                        opacities=opacities_batched.float(),
                        colors=colors_batched.float(),
                        viewmats=viewmats_batched,  ## (20,3,4,4)
                        Ks=Ks_batched,  ## (20,3,3,3)
                        width=tgt_w,
                        height=tgt_h,
                        render_mode="RGB+ED",
                        near_plane=self.near,
                        far_plane=self.far,
                        packed=False,
                        radius_clip=radius_clip,
                    )

bug 记录:

当使用单个相机的时候,下面这段代码会把 维度搞错:

python 复制代码
  if self.use_affine_token:
    affine = self.affine_linear(affine_tokens)  # b v (gs_dim * (gs_dim + 1))
    affine = rearrange(affine, "b v (p q) -> b v p q", p=self.gs_dim)
    images = torch.einsum("b t v h w p, b v p q -> b t v h w p", images, affine)
    gs_params["affine"] = affine
相关推荐
【上下求索】2 小时前
学习笔记090——Ubuntu 中 UFW 防火墙的使用
笔记·学习·ubuntu
UQWRJ2 小时前
菜鸟教程Linux ViVimYumApt笔记
linux·运维·笔记
程序员Xu2 小时前
【OD机试题解法笔记】查找接口成功率最优时间段
笔记·算法
wb1894 小时前
企业WEB应用服务器TOMCAT
运维·前端·笔记·tomcat·云计算
Se_ren_di_pity5 小时前
CS231n2017-Lecture9经典CNN架构笔记
人工智能·笔记·cnn
Yueeyuee_5 小时前
【C#学习Day15笔记】拆箱装箱、 Equals与== 、文件读取IO
笔记·学习·c#
এ旧栎6 小时前
Gitee
笔记·gitee·学习方法
kfepiza6 小时前
vim的`:q!` 与 `ZQ` 笔记250729
linux·笔记·编辑器·vim