torch.autograd.Function.apply()

apply() 是 PyTorch C++ 层的入口。它会创建一个 BackwardCFunction 节点(也就是 autograd 图上的反向节点),然后在该节点的上下文 ctx 下调用你定义的 forward,并把输出张量的 grad_fn 指向这个节点。

apply = "建反向节点 → 在该节点 ctx 下以 no_grad 调用 forward → 把输出的 grad_fn 指向该节点"。所以 forward 是被它调用的,但调用之外它还干了一堆把张量挂进 autograd 图的事,从而自动注册反向图。

1. apply 是 Function的元类方法(C++ 实现)

torch.autograd.Function 的元类是 FunctionMeta

示例代码:

python 复制代码
class _VocabParallelRKLDivergence(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ...): ...
    @staticmethod
    def backward(ctx, ...): ...
复制代码
FunctionMeta 会给这个类绑定一个 apply classmethod,对应到 C++ 的 THPFunction_apply(源码位于 torch/csrc/autograd/python_function.cpp),所以 apply 是框架提供的入口。

2. apply(*args) 内部大致做的事情

执行 _VocabParallelRKLDivergence.apply(student_logits, ...) 时,C++ 端依次做:

(1) 新建一个反向节点 :构造一个 PyNode(包装成 BackwardCFunction 实例),这就是未来出现在 autograd 图里的节点。ctx 就是这个节点暴露给 Python 的接口。

(2) 解析输入 :检查每个输入张量的 requires_grad,记录哪些需要在反向时回传梯度,并在 ctx 上记录 next_functions(指向输入张量的 grad_fn,用来串联计算图)。

(3) 以 no_grad 模式调用你的 forward:相当于

python 复制代码
with torch.no_grad():
    outputs = cls.forward(ctx, *args)

所以 forward 内部所有运算不会 再被 autograd 追踪------这正是必须手写 backward 的原因。ctx.save_for_backward(...)ctx.mark_non_differentiable(...) 这些调用就是在向这个反向节点登记状态。

(4) 把输出挂到图上 :对每个返回的、可微的张量,把它的 grad_fn 设为这个新建的反向节点,output_nr 设为对应的输出位置。被 mark_non_differentiable 标记的张量则不会被挂上 grad_fn

(5) 返回 forward 的输出给调用者。

3. 反向是怎么"自动注册"的

输出张量的 grad_fn 指针指向了刚才那个反向节点。后续当对最终 loss 调用 loss.backward() 时,autograd 引擎沿 grad_fn 反向遍历,遇到这个节点就会调用其 apply(C++ 端),它内部再回调你写的 Python @staticmethod backward(ctx, *grad_outputs)ctx.saved_tensors / ctx.eps 等都从这里取出。

4. apply和直接调forward差异点

  • 不直接调 forward_VocabParallelRKLDivergence.forward(...) 是合法的 Python 调用,使用 .apply(...) 时它会绕过 autograd ------不会有 grad_fn,也不会触发 backward
  • forward 在 no_grad 下运行 :内部的 vocab_parallel_log_softmaxall_reduce 等不会被记录,因此手写 backward 必须自己处理跨 TP 的梯度聚合。
  • ctx 是这个反向节点的"槽位"save_for_backward 存张量到节点、mark_non_differentiable 通知引擎不要给这些输出生成反向边、ctx.next_functions 自动由 C++ 填充。
相关推荐
AI科技星1 小时前
《数术工坊:非欧射影录》类型:硬核光影·几何本源
c语言·开发语言·网络·量子计算·agi
花间相见1 小时前
【LeetCode01】—— 无重复字符的最长子串:滑动窗口经典题详解
python·算法·leetcode
何以解忧,唯有..1 小时前
Python 中的继承机制:从基础到高级用法详解
java·开发语言·python
try2find2 小时前
agent环境安装spacy
python·智能体
ellenwan20262 小时前
期货程序化开平标志错了总拒单:天勤 last_msg 排查思路
python
绵绵细雨中的乡音2 小时前
监控显示一切正常,可用户根本打不开网站——Blackbox Exporter帮我找到了真相(1)
开发语言·php
c++之路2 小时前
CMake 系列教程(五):进阶技巧
c语言·开发语言·c++
踏着七彩祥云的小丑2 小时前
Go学习第5天:变量作用域 + 数组 + 指针
开发语言·学习·golang·go
装不满的克莱因瓶2 小时前
自动微分的原理:计算图与前向传播
人工智能·pytorch·python·数学·ai·微积分·计算图