Nerfstudio Gaussian Splatting 相机姿态可微/求导的具体实现方式

源码

代码

python 复制代码
class _ProjectGaussians(Function):
    """Project 3D gaussians to 2D."""

    @staticmethod
    def forward(
        ctx,
        means3d: Float[Tensor, "*batch 3"],
        scales: Float[Tensor, "*batch 3"],
        glob_scale: float,
        quats: Float[Tensor, "*batch 4"],
        viewmat: Float[Tensor, "4 4"],
        fx: float,
        fy: float,
        cx: float,
        cy: float,
        img_height: int,
        img_width: int,
        block_width: int,
        clip_thresh: float = 0.01,
    ):
        num_points = means3d.shape[-2]
        if num_points < 1 or means3d.shape[-1] != 3:
            raise ValueError(f"Invalid shape for means3d: {means3d.shape}")

        (
            cov3d,
            xys,
            depths,
            radii,
            conics,
            compensation,
            num_tiles_hit,
        ) = _C.project_gaussians_forward(
            num_points,
            means3d,
            scales,
            glob_scale,
            quats,
            viewmat,
            fx,
            fy,
            cx,
            cy,
            img_height,
            img_width,
            block_width,
            clip_thresh,
        )

        # Save non-tensors.
        ctx.img_height = img_height
        ctx.img_width = img_width
        ctx.num_points = num_points
        ctx.glob_scale = glob_scale
        ctx.fx = fx
        ctx.fy = fy
        ctx.cx = cx
        ctx.cy = cy

        # Save tensors.
        ctx.save_for_backward(
            means3d,
            scales,
            quats,
            viewmat,
            cov3d,
            radii,
            conics,
            compensation,
        )

        return (xys, depths, radii, conics, compensation, num_tiles_hit, cov3d)

    @staticmethod
    def backward(
        ctx,
        v_xys,
        v_depths,
        v_radii,
        v_conics,
        v_compensation,
        v_num_tiles_hit,
        v_cov3d,
    ):
        (
            means3d,
            scales,
            quats,
            viewmat,
            cov3d,
            radii,
            conics,
            compensation,
        ) = ctx.saved_tensors

        (v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat) = _C.project_gaussians_backward(
            ctx.num_points,
            means3d,
            scales,
            ctx.glob_scale,
            quats,
            viewmat,
            ctx.fx,
            ctx.fy,
            ctx.cx,
            ctx.cy,
            ctx.img_height,
            ctx.img_width,
            cov3d,
            radii,
            conics,
            compensation,
            v_xys,
            v_depths,
            v_conics,
            v_compensation,
        )

        if viewmat.requires_grad:
            v_viewmat = torch.zeros_like(viewmat)
            R = viewmat[..., :3, :3]

            # Denote ProjectGaussians for a single Gaussian (mean3d, q, s)
            # viemwat = [R, t] as:
            #
            #   f(mean3d, q, s, R, t, intrinsics)
            #       = g(R @ mean3d + t,
            #           R @ cov3d_world(q, s) @ R^T ))
            #
            # Then, the Jacobian w.r.t., t is:
            #
            #   d f / d t = df / d mean3d @ R^T
            #
            # and, in the context of fine tuning camera poses, it is reasonable
            # to assume that
            #
            #   d f / d R_ij =~ \sum_l d f / d t_l * d (R @ mean3d)_l / d R_ij
            #                = d f / d_t_i * mean3d[j]
            #
            # Gradients for R and t can then be obtained by summing over
            # all the Gaussians.
            v_mean3d_cam = torch.matmul(v_mean3d, R.transpose(-1, -2))

            # gradient w.r.t. view matrix translation
            v_viewmat[..., :3, 3] = v_mean3d_cam.sum(-2)

            # gradent w.r.t. view matrix rotation
            for j in range(3):
                for l in range(3):
                    v_viewmat[..., j, l] = torch.dot(
                        v_mean3d_cam[..., j], means3d[..., l]
                    )
        else:
            v_viewmat = None

        # Return a gradient for each input.
        return (
            # means3d: Float[Tensor, "*batch 3"],
            v_mean3d,
            # scales: Float[Tensor, "*batch 3"],
            v_scale,
            # glob_scale: float,
            None,
            # quats: Float[Tensor, "*batch 4"],
            v_quat,
            # viewmat: Float[Tensor, "4 4"],
            v_viewmat,
            # fx: float,
            None,
            # fy: float,
            None,
            # cx: float,
            None,
            # cy: float,
            None,
            # img_height: int,
            None,
            # img_width: int,
            None,
            # block_width: int,
            None,
            # clip_thresh,
            None,
        )

解释

这段代码是用于计算视图矩阵(view matrix)的梯度,视图矩阵通常用于3D图形和计算机视觉中,用于将世界坐标转换为相机坐标。在深度学习中,视图矩阵通常用于表示相机的位置和朝向。

代码的目的是计算一个由多个高斯分布(每个分布有均值mean3d、协方差cov3d_world、以及其他参数qs)组成的投影高斯分布的梯度。这些高斯分布通过视图矩阵viewmat(包含旋转矩阵R和平移向量t)进行投影,并考虑了相机内参intrinsics

代码的逻辑如下:

  1. 检查视图矩阵是否需要梯度if viewmat.requires_grad: 检查视图矩阵是否需要计算梯度。如果需要,则继续执行后续代码。

  2. 初始化梯度张量v_viewmat = torch.zeros_like(viewmat) 创建一个与视图矩阵形状相同的全零张量,用于存储梯度。

  3. 提取旋转矩阵R = viewmat[..., :3, :3] 从视图矩阵中提取旋转矩阵部分。

  4. 计算均值在相机坐标系下的表示v_mean3d_cam = torch.matmul(v_mean3d, R.transpose(-1, -2)) 将世界坐标系下的均值通过旋转矩阵转换到相机坐标系。

  5. 计算相对于平移向量的梯度v_viewmat[..., :3, 3] = v_mean3d_cam.sum(-2) 计算所有高斯分布的梯度相对于平移向量的和。

  6. 计算相对于旋转矩阵的梯度 :通过两层循环,计算所有高斯分布的梯度相对于旋转矩阵的和。外层循环遍历旋转矩阵的行,内层循环遍历旋转矩阵的列。对于每个元素(j, l),计算v_mean3d_cam[..., j]means3d[..., l]的点积,并将结果赋给v_viewmat[..., j, l]

  7. 返回梯度 :最后,函数返回计算得到的梯度v_viewmat

总的来说,这段代码是用于计算视图矩阵的梯度,这些梯度可以用于优化相机的位置和朝向,以最小化投影高斯分布的某种损失函数。在实际应用中,这通常在训练神经网络时用于端到端的优化,以提高模型的性能。

相关推荐
梦云澜2 小时前
论文阅读(十二):全基因组关联研究中生物通路的图形建模
论文阅读·人工智能·深度学习
远洋录3 小时前
构建一个数据分析Agent:提升分析效率的实践
人工智能·ai·ai agent
IT古董4 小时前
【深度学习】常见模型-Transformer模型
人工智能·深度学习·transformer
沐雪架构师5 小时前
AI大模型开发原理篇-2:语言模型雏形之词袋模型
人工智能·语言模型·自然语言处理
python算法(魔法师版)5 小时前
深度学习深度解析:从基础到前沿
人工智能·深度学习
kakaZhui6 小时前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
struggle20257 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
佛州小李哥7 小时前
通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
云空8 小时前
《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》
运维·人工智能·web安全·网络安全·开源·网络攻击模型·安全威胁分析
AIGC大时代8 小时前
对比DeepSeek、ChatGPT和Kimi的学术写作关键词提取能力
论文阅读·人工智能·chatgpt·数据分析·prompt