torch.autograd.Function的apply()方法作用

python 复制代码
import os
from typing import NamedTuple
import torch.nn as nn
import torch
from rtr_gs_rasterization import _C


def cpu_deep_copy_tuple(input_tuple):
    copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
    return tuple(copied_tensors)


def rasterize_gaussians(
        means3D,
        means2D,
        features,
        sh,
        colors_precomp,
        opacities,
        scales,
        rotations,
        cov3Ds_precomp,
        raster_settings,
):
    return _RasterizeGaussians.apply(
        means3D,
        means2D,
        features,
        sh,
        colors_precomp,
        opacities,
        scales,
        rotations,
        cov3Ds_precomp,
        raster_settings,
    )


class _RasterizeGaussians(torch.autograd.Function):
    @staticmethod
    def forward(
            ctx,
            means3D,
            means2D,
            features,
            sh,
            colors_precomp,
            opacities,
            scales,
            rotations,
            cov3Ds_precomp,
            raster_settings,
    ):

        # Restructure arguments the way that the C++ lib expects them
        args = (
            raster_settings.bg,
            means3D,
            features,
            colors_precomp,
            opacities,
            scales,
            rotations,
            raster_settings.scale_modifier,
            cov3Ds_precomp,
            raster_settings.viewmatrix,
            raster_settings.projmatrix,
            raster_settings.tanfovx,
            raster_settings.tanfovy,
            raster_settings.cx,
            raster_settings.cy,
            raster_settings.image_height,
            raster_settings.image_width,
            sh,
            raster_settings.sh_degree,
            raster_settings.campos,
            raster_settings.prefiltered,
            raster_settings.computer_pseudo_normal,
            raster_settings.debug
        )

        # Invoke C++/CUDA rasterizer
        if raster_settings.debug:
            cpu_args = cpu_deep_copy_tuple(args)  # Copy them before they can be corrupted
            try:
                num_rendered, num_contrib, color, opacity, depth, feature, normal, surface_xyz, weights, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(
                    *args)
            except Exception as ex:
                torch.save(cpu_args, "snapshot_fw.dump")
                print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
                raise ex
        else:
            num_rendered, num_contrib, color, opacity, depth, feature, normal, surface_xyz, weights, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(
                *args)
        
        # Keep relevant tensors for backward
        ctx.raster_settings = raster_settings
        ctx.num_rendered = num_rendered
        ctx.save_for_backward(colors_precomp, means3D, features, scales, rotations, cov3Ds_precomp,
                              radii, sh, geomBuffer, binningBuffer, imgBuffer)
        return num_rendered, num_contrib, color, opacity, depth, feature, normal, surface_xyz, weights, radii

    @staticmethod
    def backward(ctx, grad_num_rendered, grad_num_contrib, grad_out_color, grad_out_opacity, grad_out_depth,
                 grad_out_feature, grad_out_normal, grad_out_surface_xyz, grad_out_weights, grad_out_radii):
        # Restore necessary values from context
        num_rendered = ctx.num_rendered
        raster_settings = ctx.raster_settings
        colors_precomp, means3D, features, scales, rotations, cov3Ds_precomp, radii, sh, \
            geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors

        # Restructure args as C++ method expects them
        args = (raster_settings.bg,
                means3D,
                features,
                radii,
                colors_precomp,
                scales,
                rotations,
                raster_settings.scale_modifier,
                cov3Ds_precomp,
                raster_settings.viewmatrix,
                raster_settings.projmatrix,
                raster_settings.tanfovx,
                raster_settings.tanfovy,
                grad_out_color,
                grad_out_opacity,
                grad_out_depth,
                grad_out_feature,
                sh,
                raster_settings.sh_degree,
                raster_settings.campos,
                geomBuffer,
                num_rendered,
                binningBuffer,
                imgBuffer,
                raster_settings.backward_geometry,
                raster_settings.debug)
        # Compute gradients for relevant tensors by invoking backward method
        if raster_settings.debug:
            cpu_args = cpu_deep_copy_tuple(args)  # Copy them before they can be corrupted
            try:
                grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_features, grad_cov3Ds_precomp, \
                    grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
            except Exception as ex:
                torch.save(cpu_args, "snapshot_bw.dump")
                print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
                raise ex
        else:
            grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_features, grad_cov3Ds_precomp, \
                grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)

        grads = (
            grad_means3D,
            grad_means2D,
            grad_features,
            grad_sh,
            grad_colors_precomp,
            grad_opacities,
            grad_scales,
            grad_rotations,
            grad_cov3Ds_precomp,
            None,
        )

        return grads


class GaussianRasterizationSettings(NamedTuple):
    image_height: int
    image_width: int
    tanfovx: float
    tanfovy: float
    cx: float
    cy: float
    bg: torch.Tensor
    scale_modifier: float
    viewmatrix: torch.Tensor
    projmatrix: torch.Tensor
    sh_degree: int
    campos: torch.Tensor
    prefiltered: bool
    backward_geometry: bool
    computer_pseudo_normal: bool
    debug: bool


class GaussianRasterizer(nn.Module):
    def __init__(self, raster_settings):
        super().__init__()
        self.raster_settings = raster_settings

    def markVisible(self, positions):
        # Mark visible points (based on frustum culling for camera) with a boolean
        with torch.no_grad():
            raster_settings = self.raster_settings
            visible = _C.mark_visible(
                positions,
                raster_settings.viewmatrix,
                raster_settings.projmatrix)

        return visible

    def forward(self, means3D, means2D, opacities, shs=None, colors_precomp=None,
                scales=None, rotations=None, cov3D_precomp=None, features=None):

        raster_settings = self.raster_settings

        if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
            raise Exception('Please provide excatly one of either SHs or precomputed colors!')

        if ((scales is None or rotations is None) and cov3D_precomp is None) or (
                (scales is not None or rotations is not None) and cov3D_precomp is not None):
            raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')

        if shs is None:
            shs = torch.Tensor([])
        if colors_precomp is None:
            colors_precomp = torch.Tensor([])

        if scales is None:
            scales = torch.Tensor([])
        if rotations is None:
            rotations = torch.Tensor([])
        if cov3D_precomp is None:
            cov3D_precomp = torch.Tensor([])

        if features is None:
            features = torch.empty_like(means3D[..., :0])

        # Invoke C++/CUDA rasterization routine
        return rasterize_gaussians(
            means3D,
            means2D,
            features,
            shs,
            colors_precomp,
            opacities,
            scales,
            rotations,
            cov3D_precomp,
            raster_settings,
        )

在你提供的这段代码中,_RasterizeGaussians.apply(...)torch.autograd.Function 的类方法,它的核心作用是:触发自定义算子的前向传播(Forward),并将该操作注册到 PyTorch 的自动求导图(Autograd Graph)中。


1. 核心功能:连接 Python 和 C++/CUDA

在 3D Gaussian Splatting (3DGS) 的实现中,渲染过程非常复杂,无法直接用 PyTorch 的原生算子堆叠而成。因此,开发者编写了高效的 CUDA 内核

_RasterizeGaussians.apply(...) 实际执行了以下操作:

  1. 调用 forward 静态方法 :它会将你传入的参数(means3D, opacities 等)传递给 _RasterizeGaussians 类中的 forward 函数。
  2. 进入 C++ 接口 :在 forward 内部,通过 _C.rasterize_gaussians(*args) 调用底层的 CUDA 渲染代码。
  3. 保存上下文 (Context) :通过 ctx.save_for_backward(...),它会将反向传播(梯度计算)时需要的 Tensor(如 means3D, radii 等)暂存起来。
  4. 建立计算图节点 :它在 PyTorch 的动态计算图中留下一个"钩子"。当你稍后对生成的图像 loss.backward() 时,PyTorch 会顺着这个钩子找到 _RasterizeGaussians.backward 静态方法。

2. 为什么要用 .apply 而不是直接调用函数?

如果你直接调用 forward(),代码虽然能运行并得到渲染图像,但 梯度无法回传

  • 普通函数: 只是数值计算,计算图在这里就断了。
  • .apply() 告诉 PyTorch:"这是一个可导的算子,前向传播用 C++ 跑,反向传播请参考我定义的 backward 逻辑。"

3. 参数流向解析

当你执行:

python 复制代码
return _RasterizeGaussians.apply(means3D, means2D, ..., raster_settings)
  1. 输入: 所有的 3D 高斯参数和相机设置。
  2. 中间过程: * forward 接收这些输入。
  • 通过 _C (通常是 torch.utils.cpp_extension 编译的模块) 将数据发送到 GPU 显存。
  • CUDA 线程执行投影、排序、Alpha 合成。
  1. 输出: 渲染出的颜色图 (color)、深度图 (depth) 以及用于反向传播的辅助缓冲区(geomBuffer 等)。

4. 总结:两种 apply 的区别

特性 nn.Module.apply(fn) autograd.Function.apply(*args)
所属对象 模型实例(如 nn.Sequential 自定义算子类(继承自 Function
主要目的 递归修改子模块(如初始化权重) 执行计算并建立自动求导链接
调用频率 通常模型创建后调一次 每次 forward 传播都会调用
你的代码中 不属于此类 正是此类:执行 CUDA 渲染核心逻辑

一句话总结:这里的 .apply 是把一堆 Tensor 扔进 CUDA 渲染引擎,并告诉 PyTorch 待会儿怎么根据渲染出的图像误差来更新这些 Tensor 的位置和颜色。

你想深入了解这个渲染器对应的 C++/CUDA 源码部分,还是想知道如何调整这些输入参数来优化渲染效果?

相关推荐
ywfwyht2 小时前
AlpamayoR1 VLA自动驾驶模型解析
人工智能·机器学习·自动驾驶
CoderJia程序员甲2 小时前
GitHub 热榜项目 - 日榜(2026-03-09)
人工智能·ai·大模型·github·ai教程
小超同学你好2 小时前
Transformer 7. Decoder:架构选择、Teacher Forcing 与并行计算
人工智能·语言模型·transformer
Learn Beyond Limits2 小时前
固定窗口神经语言模型|A fixed-window neural Language Model
人工智能·语言模型·自然语言处理
时光瀚海2 小时前
跨境旅游出海指南:如何通过国际短信(SMS)精准营销美国华裔银发族?
人工智能·经验分享·旅游·跨境电商·短信营销
来两个炸鸡腿2 小时前
【Datawhale2603】happy-llm task04 Encoder-Decoder预训练
人工智能·学习·大模型
石工记2 小时前
OpenClaw AI 助手 Docker Compose 一键部署文档(MacBook Pro 2020 专属版,可下载)
人工智能·docker·容器
Lab_AI2 小时前
京博控股集团科研管理的智慧创新之道
人工智能·项目管理·电子实验记录本·仪器管理·科研管理·研发数字化
QBoson2 小时前
量子启发 AI 破解 PDE 难题:QIDNNF 让流体、波动力学模拟更稳更准
人工智能·量子计算·深度神经网络