pytorch 中_call_impl()函数

记录pytorch 版本中的 nn.Module() 重要函数

1. _call_impl()

1.1 torch1.7.1 版本

python 复制代码
    def _call_impl(self, *input, **kwargs):
        for hook in itertools.chain(
                _global_forward_pre_hooks.values(),
                self._forward_pre_hooks.values()):
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in itertools.chain(
                _global_forward_hooks.values(),
                self._forward_hooks.values()):
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
        if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in itertools.chain(
                        _global_backward_hooks.values(),
                        self._backward_hooks.values()):
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result

以上的函数的功能作用解释如下:

提供的代码是 PyTorch 模块方法的 _call_impl 实现。当模块用作可调用对象时,通常使用输入数据调用它时,将调用此方法。让我们逐步分解代码以了解其功能:
for hook in itertools.chain(_global_forward_pre_hooks.values(), self._forward_pre_hooks.values()):

此循环遍历两个钩子集合:

_global_forward_pre_hooks 和 _forward_pre_hooks 。

钩子是可以注册为在神经网络向前或向后传递期间在特定点执行的函数。

这些预置挂钩将在模块的实际前向传递之前执行。

result = hook(self, input):

对于每个钩子,使用 self (模块)和 input 参数调用 hook 该函数。

if result is not None:

如果钩子返回非 None 值,则表示钩子修改了输入数据,并且此修改后的数据将成为循环中下一个挂钩的新输入。

if not isinstance(result, tuple): result = (result,)

钩子的结果将转换为元组(如果它还没有元组)。这是为了处理钩子可能返回单个值而不是元组的情况。

input = result

修改后的输入将成为循环中下一个钩子的新输入。
if torch._C._get_tracing_state(): result = self._slow_forward(*input, **kwargs)

此条件块检查是否正在跟踪当前正向传递。如果是,则使用修改后的输入数据调用该方_slow_forward 。

else: result = self.forward(*input, **kwargs)

如果未跟踪正向传递,则使用修改后的输入数据调用模块的正常 forward 方法。
for hook in itertools.chain(_global_forward_hooks.values(), self._forward_hooks.values()):

此循环遍历两个前向钩子集合: _global_forward_hooks 和 _forward_hooks 。这些钩子在模块的正向传递之后执行。

hook_result = hook(self, input, result):

对于每个钩子,使用 self 、 input 和 result 作为参数调用 hook 函数。

if hook_result is not None: result = hook_result

如果钩子返回非 None 值,则表示钩子修改了正向传递的结果,并且此修改后的结果将成为循环中下一个挂钩的新结果。
if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):

此条件块检查是否有任何全局或为此特定模块注册的向后钩子。

var = result

正向传递的结果存储在变量 var 中。
while not isinstance(var, torch.Tensor):

此循环迭代 till 是 var 的 torch.Tensor 实例。

if isinstance(var, dict): var = next((v for v in var.values() if isinstance(v, torch.Tensor)))

如果是一个字典,它 var 查找字典中的第一个值,即 torch.Tensor .

else: var = var[0]

如果不是字典,它 var 假定它是一个序列(例如,列表,元组)并获取其第一个元素。

grad_fn = var.grad_fn

grad_fn torch.Tensor 实例的属性被分配给变量 grad_fn 。此属性表示在反向传播期间计算张量梯度的函数。
if grad_fn is not None:

如果 不是 grad_fn None,则表示张量参与了需要梯度的计算,我们需要向其附加向后钩子。

for hook in itertools.chain(_global_backward_hooks.values(), self._backward_hooks.values()):

此循环遍历两个向后钩子集合: _global_backward_hooks 和 _backward_hooks 。这些钩子在模块的向后传递期间执行。

wrapper = functools.partial(hook, self)

对于每个向后钩子,通过将钩子函数与模块 self 作为参数部分应用来创建一个新函数。这样做是为了确保钩子函数可以访问模块。

functools.update_wrapper(wrapper, hook)

包装器函数使用原始钩子函数中的信息进行更新,例如其名称和文档字符串。

grad_fn.register_hook(wrapper)

包装器函数注册为挂接到 grad_fn .这意味着在反向传播期间计算梯度时,将执行钩子以对梯度执行其他操作。

return result

返回前向传递的最终结果。

总之,该方法 _call_impl 执行前向挂钩(前向前和后向前),执行前向传递,执行向后挂钩(如果需要),并返回向前传递的结果。它还处理钩子修改输入或结果数据的情况,并确保向后钩子附加到相关张量,以便在反向传播期间进行梯度计算。

相关推荐
许泽宇的技术分享4 分钟前
「让AI大脑直连Windows桌面」:深度解析Windows-MCP,开启操作系统下一代智能交互
人工智能·windows·交互
crushqqi12 分钟前
【跨服务器的数据自动化下载--安装公钥,免密下载】
服务器·python·自动化
Baihai_IDP20 分钟前
许多 AI 智能体评测基准并不可靠
人工智能·面试·llm
qq_4639448621 分钟前
如何将新建的Anaconda虚拟环境导入Juputer内核中?
linux·windows·python
钢铁男儿29 分钟前
PyTorch基础(使用Tensor及Antograd实现机器学习)
人工智能·pytorch·机器学习
Shun_Tianyou40 分钟前
Python Day28 HTML 与 CSS 核心知识点 及例题分析
开发语言·前端·css·python·算法·html
Hcoco_me1 小时前
【8】Transformers快速入门:Decoder 分支和统计语言模型区别?
人工智能·语言模型·自然语言处理
_oP_i1 小时前
Model Context Protocol (MCP)标准化应用程序向大型语言模型 (LLM) 提供上下文协议
人工智能·语言模型·自然语言处理
平行绳1 小时前
智能体一键生成火遍全网的火柴人视频,工作流详细搭建教程来了
人工智能·coze
chenchihwen1 小时前
腾讯codebuddy.ai 安装实测【从零开始开发在线五子棋游戏:完整开发记录】
人工智能·codebuddy