torch.autograd.Function.apply()

apply() 是 PyTorch C++ 层的入口。它会创建一个 BackwardCFunction 节点(也就是 autograd 图上的反向节点),然后在该节点的上下文 ctx 下调用你定义的 forward,并把输出张量的 grad_fn 指向这个节点。

apply = "建反向节点 → 在该节点 ctx 下以 no_grad 调用 forward → 把输出的 grad_fn 指向该节点"。所以 forward 是被它调用的,但调用之外它还干了一堆把张量挂进 autograd 图的事,从而自动注册反向图。

1. apply 是 Function的元类方法(C++ 实现)

torch.autograd.Function 的元类是 FunctionMeta

示例代码:

python 复制代码
class _VocabParallelRKLDivergence(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ...): ...
    @staticmethod
    def backward(ctx, ...): ...
复制代码
FunctionMeta 会给这个类绑定一个 apply classmethod,对应到 C++ 的 THPFunction_apply(源码位于 torch/csrc/autograd/python_function.cpp),所以 apply 是框架提供的入口。

2. apply(*args) 内部大致做的事情

执行 _VocabParallelRKLDivergence.apply(student_logits, ...) 时,C++ 端依次做:

(1) 新建一个反向节点 :构造一个 PyNode(包装成 BackwardCFunction 实例),这就是未来出现在 autograd 图里的节点。ctx 就是这个节点暴露给 Python 的接口。

(2) 解析输入 :检查每个输入张量的 requires_grad,记录哪些需要在反向时回传梯度,并在 ctx 上记录 next_functions(指向输入张量的 grad_fn,用来串联计算图)。

(3) 以 no_grad 模式调用你的 forward:相当于

python 复制代码
with torch.no_grad():
    outputs = cls.forward(ctx, *args)

所以 forward 内部所有运算不会 再被 autograd 追踪------这正是必须手写 backward 的原因。ctx.save_for_backward(...)ctx.mark_non_differentiable(...) 这些调用就是在向这个反向节点登记状态。

(4) 把输出挂到图上 :对每个返回的、可微的张量,把它的 grad_fn 设为这个新建的反向节点,output_nr 设为对应的输出位置。被 mark_non_differentiable 标记的张量则不会被挂上 grad_fn

(5) 返回 forward 的输出给调用者。

3. 反向是怎么"自动注册"的

输出张量的 grad_fn 指针指向了刚才那个反向节点。后续当对最终 loss 调用 loss.backward() 时,autograd 引擎沿 grad_fn 反向遍历,遇到这个节点就会调用其 apply(C++ 端),它内部再回调你写的 Python @staticmethod backward(ctx, *grad_outputs)ctx.saved_tensors / ctx.eps 等都从这里取出。

4. apply和直接调forward差异点

  • 不直接调 forward_VocabParallelRKLDivergence.forward(...) 是合法的 Python 调用,使用 .apply(...) 时它会绕过 autograd ------不会有 grad_fn,也不会触发 backward
  • forward 在 no_grad 下运行 :内部的 vocab_parallel_log_softmaxall_reduce 等不会被记录,因此手写 backward 必须自己处理跨 TP 的梯度聚合。
  • ctx 是这个反向节点的"槽位"save_for_backward 存张量到节点、mark_non_differentiable 通知引擎不要给这些输出生成反向边、ctx.next_functions 自动由 C++ 填充。
相关推荐
Agent_大师12 分钟前
WebSocket 行情重连成功,K线缺口不会自动消失
python
荣码12 分钟前
LLM结构化输出:让AI返回JSON而不是废话,我踩了4个坑
java·python
copyer_xyf25 分钟前
FastAPI 如何连接 MySQL
后端·python
apocelipes14 小时前
常用编程语言和库的正则表达式性能对比
c语言·c++·python·性能优化·golang·开发工具和环境
用户83562907805115 小时前
使用 Python 在 PDF 中创建与管理书签
后端·python
MeixianAgent20 小时前
Python 回测数据入口怎么验?历史 K 线入库前先做 5 个检查
后端·python
咕白m6251 天前
用 Python 实现一键批量查找与替换 Excel 数据
后端·python
SelectDB2 天前
Apache Doris Python UDF:让 SQL 直接调用 Python 生态,支撑 Agent 时代复杂业务逻辑
大数据·数据库·python
荣码2 天前
GraphRAG:普通RAG只能回答"点"的问题,我踩了4个坑才搞懂
java·python
金銀銅鐵2 天前
[Python] 基于欧几里得算法,实现分数约分计算器
python·数学