默认的 分辨率是 [160,240] ,基于 Transformer 的方法不能做高分辨率。
Dataloader
输入是 带有 pose 信息的 RGB 图像
eval datasets
bash
## 采样帧数目 = 20
num_max_future_frames = int(self.timespan * fps)
## 每次间隔多少个时间 timesteps 取一个context image
num_context_timesteps = 4
按照STORM
原来的 setting, future_frames = 20
context_image
每次间隔4帧,所以是 context_frame_idx = [0,5,10,15]
, 在 target_frame
包含了 从[0,20]的所有20帧。
以这样20 帧的 image 作为一个基本的 batch, 进行预测: 进入 model
所以,输入网络的 context_image
对应的 shape (1,4,3,160,240)
输入4个时刻帧的 frame, 每一个 frame 有 3个相机;对应的 context_camtoworlds
shape (1,4,3,4,4)
train datasets
第一帧 ID
随机采样, 之后的 context_image
每次间隔 5 帧,比如: [47 52 57 62]
: target·_frame_id
也是进行随机选取:
python
if self.equispaced:
context_frame_idx = np.arange(
context_frame_idx,
context_frame_idx + num_max_future_frames,
num_max_future_frames // self.num_context_timesteps,
)
随机在 num_future_id 里面 选择 self.num_target_timesteps
选择 4帧作为 target_image
的监督帧
Network
输入网络的 有3个 input: context_image, ray 和 time 的信息
- context_image: (1,4,3,3,160,240)
- Ray embedding (1,4,3,6,160,240)
- time_embedding (1,4,3)
- 将 image 和 ray_embedding 进行
concat
操作, 得到 x:(12,9,160,240):
python
x = rearrange(x, "b t v c h w -> (b t v) c h w")
plucker_embeds = rearrange(plucker_embeds, "b t v h w c-> (b t v) c h w")
x = torch.cat([x, plucker_embeds], dim=1) ## (12,9,160,240)
然后经过3个 embedding , 将这些 feature 映射成为 token:
python
x = self.patch_embed(x) # (b t v) h w c2
x = self._pos_embed(x) # (b t v) (h w) c2
x = self._time_embed(x, time, num_views=v)
得到 x.shape (12,600,768)
, 表示一共有12张图像,每个图象 是 600 个 token, 每个 token 的 channel 是768. 然后将这些 token concat 在一起 得到了 (7200,768) 的 feature;
给得到的 token 分别加上可学习的 motion_token, affine_token 和 sky_token. 连接方式都是 concat
这样得到的 feature 为 (7220,768)的 feature
python
if self.num_motion_tokens > 0:
motion_tokens = repeat(self.motion_tokens, "1 k d -> b k d", b=x.shape[0])
x = torch.cat([motion_tokens, x], dim=-2)
if self.use_affine_token:
affine_token = repeat(self.affine_token, "1 k d -> b k d", b=b)
x = torch.cat([affine_token, x], dim=-2)
if self.use_sky_token:
sky_token = repeat(self.sky_token, "1 1 d -> b 1 d", b=x.shape[0])
x = torch.cat([sky_token, x], dim=-2)
- 使用 Transformer 进行学习, 得到的 feature 维度不变。:
python
x = self.transformer(x)
x = self.norm(x) ## shape(7220,768)
运行完之后,可以将学习到的 token
提取出来:
python
if self.use_sky_token:
sky_token = x[:, :1] ## (1,1,768)
x = x[:, 1:]
if self.use_affine_token:
affine_tokens = x[:, : self.num_cams] ## (1,3,768)
x = x[:, self.num_cams :]
if self.num_motion_tokens > 0:
motion_tokens = x[:, : self.num_motion_tokens] ## (1,16,768)
x = x[:, self.num_motion_tokens :]
在 Transformer 内部,没有上采样层,也可以实现 这种 per-pixel feature 的学习。
对于 x 进行 GS 的预测,得到 pixel_align 的高斯。 对于每个 patch, 得到的 feature 是 (12,600,768)
, 通过一个CNN,虽然通道数没有变 (12,600,768)
, 但是之前 768 可以理解为全局的 语义, 之后的 768 为 一个patch 内部不同像素的语义,他们共享着 全局的 语义信息,但是每个pixel 却又不一样。 通过下面的 unpatchify
函数将将一个patch 的语义拆成 per-pixel 的语义,将每个768维token
展开为8×8
像素。
python
b, t, v, h, w, _ = origins.shape
## x_shape: (12,600,768)
x = rearrange(x, "b (t v hw) c -> (b t v) hw c", t=t, v=v)
## gs_params_shape: (12,600,768),这一步虽然通道没变,但其实是将一个 token 的全局 语义,映射成
## token 内部的像素级别的语义
gs_params = self.gs_pred(x)
## gs_params_shape: (12,12,160,240)
### 关键步骤:unpatchify将每个768维token展开为8×8像素
gs_params = self.unpatchify(gs_params, hw=(h, w), patch_size=self.unpatch_size)
根据 token
展开的 per-pixel feature, 进行3DGS 的属性预测
python
gs_params = rearrange(gs_params, "(b t v) c h w -> b t v h w c", t=t, v=v)
depth, scales, quats, opacitys, colors = gs_params.split([1, 3, 4, 1, self.gs_dim], dim=-1)
scales = self.scale_act_fn(scales)
opacitys = self.opacity_act_fn(opacitys)
depths = self.depth_act_fn(depth)
colors = self.rgb_act_fn(colors)
means = origins + directions * depths
除了3DGS 的一半属性之外, storm 还额外预测了其他的运动属性,包括:
其中: x: (1,7200,768)
代表 image_token
, motion_tokens 是(1,16,768)
代表 motion_token. 处理的大致思路是 motion_token
作为 query,
然后 image_token
映射的feature 作为 key,
去结合计算每一个 高斯的 moition_weights
和 moition_bases
python
gs_params = self.forward_motion_predictor(x, motion_tokens, gs_params)
其中:
forward_flow = torch.einsum(
"b t v h w k, b k c -> b t v h w c", motion_weights, motion_bases
)
moition_bases: shape: [1,16,3]
moition_weights: shape: [1,4,3,160,240,16]
forward_flow: shape: [1,4,3,160,240,3]:
是 weights 和bases 结合的结果
GS_param Rendering
- 取出高斯的各项属性,尤其是 means 和 速度
forward_v
: STORM 假设 在这 20帧是出于匀速直线运动
, 其速度时不变的,可能并不合理。我们的方法直接预测 BBX,可能更为准确。
python
means = rearrange(gs_params["means"], "b t v h w c -> b (t v h w) c")
scales = rearrange(gs_params["scales"], "b t v h w c -> b (t v h w) c")
quats = rearrange(gs_params["quats"], "b t v h w c -> b (t v h w) c")
opacities = rearrange(gs_params["opacities"], "b t v h w -> b (t v h w)")
colors = rearrange(gs_params["colors"], "b t v h w c -> b (t v h w) c")
forward_v = rearrange(gs_params["forward_flow"], "b t v h w c -> b (t v h w) c")
这里得到的 高斯的 mean
是全部由 context_image
得到的, shape (46800,3)
, 但这其实是 4个 时刻context_frame_idx = [0,5,10,15]
, 得到的高斯,并不处于同一时间刻度。
通过比较 target_time
和 context_time
之间的插值,去得到每一个 target_time
的 3D Gaussian 的坐标means_batched
:
python
if tgt_time.ndim == 3:
tdiff_forward = tgt_time.unsqueeze(2) - ctx_time.unsqueeze(1)
tdiff_forward = tdiff_forward.view(b * tgt_t, t * v, 1)
tdiff_forward_batched = tdiff_forward.repeat_interleave(h * w, dim=1)
else:
tdiff_forward = tgt_time.unsqueeze(-1) - ctx_time.unsqueeze(-2)
tdiff_forward = tdiff_forward.view(b * tgt_t, t, 1)
tdiff_forward_batched = tdiff_forward.repeat_interleave(v * h * w, dim=1)
forward_translation = forward_v_batched * tdiff_forward_batched
means_batched = means_batched + forward_translation ## (20,460800,3)
使用 gsplat
的 batch_rasterization
函数:
python
rendered_color, rendered_alpha, _ = rasterization(
means=means_batched.float(), ## (20,460800,3)
quats=quats_batched.float(),
scales=scales_batched.float(),
opacities=opacities_batched.float(),
colors=colors_batched.float(),
viewmats=viewmats_batched, ## (20,3,4,4)
Ks=Ks_batched, ## (20,3,3,3)
width=tgt_w,
height=tgt_h,
render_mode="RGB+ED",
near_plane=self.near,
far_plane=self.far,
packed=False,
radius_clip=radius_clip,
)
bug 记录:
当使用单个相机的时候,下面这段代码会把 维度搞错:
python
if self.use_affine_token:
affine = self.affine_linear(affine_tokens) # b v (gs_dim * (gs_dim + 1))
affine = rearrange(affine, "b v (p q) -> b v p q", p=self.gs_dim)
images = torch.einsum("b t v h w p, b v p q -> b t v h w p", images, affine)
gs_params["affine"] = affine