apply() 是 PyTorch C++ 层的入口。它会创建一个 BackwardCFunction 节点(也就是 autograd 图上的反向节点),然后在该节点的上下文 ctx 下调用你定义的 forward,并把输出张量的 grad_fn 指向这个节点。
apply = "建反向节点 → 在该节点 ctx 下以 no_grad 调用 forward → 把输出的 grad_fn 指向该节点"。所以 forward 是被它调用的,但调用之外它还干了一堆把张量挂进 autograd 图的事,从而自动注册反向图。
1. apply 是 Function的元类方法(C++ 实现)
torch.autograd.Function 的元类是 FunctionMeta。
示例代码:
python
class _VocabParallelRKLDivergence(torch.autograd.Function):
@staticmethod
def forward(ctx, ...): ...
@staticmethod
def backward(ctx, ...): ...
FunctionMeta 会给这个类绑定一个 apply classmethod,对应到 C++ 的 THPFunction_apply(源码位于 torch/csrc/autograd/python_function.cpp),所以 apply 是框架提供的入口。
2. apply(*args) 内部大致做的事情
执行 _VocabParallelRKLDivergence.apply(student_logits, ...) 时,C++ 端依次做:
(1) 新建一个反向节点 :构造一个 PyNode(包装成 BackwardCFunction 实例),这就是未来出现在 autograd 图里的节点。ctx 就是这个节点暴露给 Python 的接口。
(2) 解析输入 :检查每个输入张量的 requires_grad,记录哪些需要在反向时回传梯度,并在 ctx 上记录 next_functions(指向输入张量的 grad_fn,用来串联计算图)。
(3) 以 no_grad 模式调用你的 forward:相当于
python
with torch.no_grad():
outputs = cls.forward(ctx, *args)
所以 forward 内部所有运算不会 再被 autograd 追踪------这正是必须手写 backward 的原因。ctx.save_for_backward(...)、ctx.mark_non_differentiable(...) 这些调用就是在向这个反向节点登记状态。
(4) 把输出挂到图上 :对每个返回的、可微的张量,把它的 grad_fn 设为这个新建的反向节点,output_nr 设为对应的输出位置。被 mark_non_differentiable 标记的张量则不会被挂上 grad_fn。
(5) 返回 forward 的输出给调用者。
3. 反向是怎么"自动注册"的
输出张量的 grad_fn 指针指向了刚才那个反向节点。后续当对最终 loss 调用 loss.backward() 时,autograd 引擎沿 grad_fn 反向遍历,遇到这个节点就会调用其 apply(C++ 端),它内部再回调你写的 Python @staticmethod backward(ctx, *grad_outputs)。ctx.saved_tensors / ctx.eps 等都从这里取出。
4. apply和直接调forward差异点
- 不直接调
forward:_VocabParallelRKLDivergence.forward(...)是合法的 Python 调用,使用 .apply(...) 时它会绕过 autograd ------不会有grad_fn,也不会触发backward。 forward在 no_grad 下运行 :内部的vocab_parallel_log_softmax、all_reduce等不会被记录,因此手写backward必须自己处理跨 TP 的梯度聚合。ctx是这个反向节点的"槽位" :save_for_backward存张量到节点、mark_non_differentiable通知引擎不要给这些输出生成反向边、ctx.next_functions自动由 C++ 填充。