PyTorch SymNode 為何找不到方法實作?──sizes_strides_methods 動態安裝機制解析

sizes_strides_methods

PyTorch 模型支援動態形狀的輸入。在 PyTorch 的動態形狀系統中,會使用 torch.SymInt 來表示張量的尺寸(sizes)及步長(strides),而 torch.SymInt 的底層便是 torch.SymNode

torch.SymNode 類別中定義了 sizes, strides相關函數,可以看到它們的參數皆為 self, sizes, strides

python 复制代码
    def is_contiguous(self, sizes, strides) -> "SymNode":  # noqa: F811
        return self._is_contiguous(sizes, strides)  # type: ignore[attr-defined]

    def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode":  # noqa: F811
        return self._is_channels_last_contiguous_2d(sizes, strides)  # type: ignore[attr-defined]

    def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode":  # noqa: F811
        return self._is_channels_last_contiguous_3d(sizes, strides)  # type: ignore[attr-defined]

    def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode":  # noqa: F811
        return self._is_channels_last_strides_2d(sizes, strides)  # type: ignore[attr-defined]

    def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode":  # noqa: F811
        return self._is_channels_last_strides_3d(sizes, strides)  # type: ignore[attr-defined]

    def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode":  # noqa: F811
        return self._is_non_overlapping_and_dense_indicator(sizes, strides)  # type: ignore[attr-defined]

如果查看這些函數的定義,會發現它們各自的實作內容都只是呼叫了一個帶下底線前綴的版本,但如果我們試圖進一步去尋找那些函數,會發現 SymNode 類別根本沒有定義那些函數。

其實這些方法不是沒有實作,而是稍後會在 torch/fx/experimental/symbolic_shapes.py 中安裝到 SymNode 類別和 torch.fx.experimental.symbolic_shapes 模組上。這些方法在 PyTorch 中被稱為 sizes strides methods

sizes_strides_methods 的定義位於 torch/fx/experimental/symbolic_shapes.py。它是一個 將方法的名稱對應到 lambda 函數 的字典,其中 key 代表方法的名字,value 則為該方法的實現:

python 复制代码
sizes_strides_methods = {
    # TODO: These could also be done with indicators, maybe it is better
    # for reasoning to do it that way
    'is_contiguous': lambda sizes, strides: sympy_is_contiguous(sizes, strides),
    'is_channels_last_contiguous_2d': lambda sizes, strides: sympy_is_channels_last_contiguous_2d(sizes, strides),
    'is_channels_last_contiguous_3d': lambda sizes, strides: sympy_is_channels_last_contiguous_3d(sizes, strides),
    'is_channels_last_strides_2d': lambda sizes, strides: sympy_is_channels_last_strides_2d(sizes, strides),
    'is_channels_last_strides_3d': lambda sizes, strides: sympy_is_channels_last_strides_3d(sizes, strides),
    'is_non_overlapping_and_dense_indicator': lambda sizes, strides: IsNonOverlappingAndDenseIndicator(*sizes, *strides),
}

如果想要它們各自在做什麼,可參考以下文章:

sizes_strides_methods 的輸入和輸出

前五個函數所接受的參數:

  • sequence of int, sympy.Expr, sympy.Symbol, sympy.Integer 或 sympy 組合式
  • 完全具體化的情況:sequence of int, sympy.Integer

回傳值:

  • symbolic boolean expression,如 sympy.Eq, sympy.Ne 等的組合
  • 完全具體化的情況:sympy.truesympy.false。注意即使入參是 Python int,回傳值也不是 python bool,因為它們的回傳值是基於 sympy.true 去做運算的

IsNonOverlappingAndDenseIndicator 所接受的參數:

  • 未完全具體化的情況:多個 int, sympy.Expr, sympy.Symbol, sympy.Integer 或 sympy 組合式
  • 完全具體化的情況:多個 int, smypy.Integer

回傳值:

  • 未完全具體化的情況:IsNonOverlappingAndDenseIndicatoreval 函數回傳 None 表示無法進一步化簡;外層 IsNonOverlappingAndDenseIndicator(...) 回傳的最終結果是未化簡的 symbolic IsNonOverlappingAndDenseIndicator expr。IsNonOverlappingAndDenseIndicator 繼承自 sympy.Function,具體可參考 Defining Automatic Evaluation with eval()
  • 完全具體化的情況:Python int 0 或 1

sizes_strides_methods 安裝流程

SymNode 類別和 torch.fx.experimental.symbolic_shapes 模組安裝 sizes strides methods 的程式碼位於 torch/fx/experimental/symbolic_shapes.py

主程式

以下是安裝 sizes strides methods 的主程式:

python 复制代码
for method, func in sizes_strides_methods.items():
    _make_node_sizes_strides(method, func)

它會遍歷 sizes_strides_methods,對(函數名稱 method,lambda 函數 func)的 pair 呼叫 _make_node_sizes_strides 函數。

_make_node_sizes_strides

_make_node_sizes_strides 函數的作用是為 SymNode 類別和 torch.fx.experimental.symbolic_shapes 模組安裝指定的 sizes strides method:

python 复制代码
def _make_node_sizes_strides(method, func):
    # NB: don't LRU cache, lots of arguments

    def sizes_strides_impl(self, sizes, strides):
        op = getattr(sys.modules[__name__], method)
        if SYM_FUNCTION_MODE:
            return to_node(
                self,
                _handle_sym_dispatch(
                    op,
                    ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
                    {}
                )
            )
        size_exprs = [s.expr for s in sizes]
        stride_exprs = [s.expr for s in strides]
        try:
            out = func(size_exprs, stride_exprs)
        except Exception:
            log.warning(f"failed to eval {method}({size_exprs}, {stride_exprs})")
            raise
        # bool is never expandable

        size_hints = []
        out_hint = None
        for s in sizes:
            if s.hint is None:
                break
            size_hints.append(s.hint)
        else:
            stride_hints = []
            for s in strides:
                if s.hint is None:
                    break
                stride_hints.append(s.hint)
            else:
                out_hint = op(size_hints, stride_hints)

        # NB: This is the indicator function, not the actual bool!
        pytype: Type
        if method.endswith("_indicator"):
            pytype = int
        else:
            pytype = bool
        return SymNode(out, self.shape_env, pytype, out_hint)

    setattr(SymNode, f"_{method}", sizes_strides_impl)

    # TODO: This is technically hotpath, but in the ideal end state
    # guards on this will resolve at a higher level so you never
    # spend time in this code
    def sizes_strides_user(sizes, strides):
        for a in itertools.chain(sizes, strides):
            if isinstance(a, SymInt):
                return wrap_node(getattr(a.node, method)(
                    [to_node(a.node, b) for b in sizes],
                    [to_node(a.node, b) for b in strides],
                ))
        if method == "is_non_overlapping_and_dense_indicator":
            return eval_is_non_overlapping_and_dense(sizes, strides)
        else:
            # TODO: this is an awful implementation
            return bool(func(
                [sympy.sympify(a) for a in sizes],
                [sympy.sympify(a) for a in strides],
            ))

    # Skip for is_non_overlapping_and_dense_indicator
    if not hasattr(sys.modules[__name__], method):
        setattr(sys.modules[__name__], method, sizes_strides_user)

它接受 methodfunc 兩個參數,其中 method 代表函數的名稱。函數會以 _method(帶下底線前綴)的名稱被安裝在 SymNode 上,以 method 的名稱被安裝在 torch.fx.experimental.symbolic_shapes 上。

func 則為欲安裝的函數本身,也就是剛剛在 sizes_strides_methods 裡看到的 sympy_is_contiguous, sympy_is_channels_last_contiguous_2d, sympy_is_channels_last_contiguous_3d, sympy_is_channels_last_strides_2d, sympy_is_channels_last_strides_3dIsNonOverlappingAndDenseIndicator 等函數。

可以看到,在這個函數裡又定義了兩個函數 sizes_strides_implsizes_strides_user,並且還有兩行 setattr。它的意思是將 func 包裝成 sizes_strides_implsizes_strides_user 後,再分別將這兩個函數安裝到 SymNodetorch.fx.experimental.symbolic_shapes 上。

這個函數可以分兩部份看:第一部份 sizes_strides_impl 是在為 SymNode 安裝成員函數;第二部份 sizes_strides_user 則是在為 torch.fx.experimental.symbolic_shapes 模組本身安裝模組層級的函數,接下來的兩篇文章中將會細究 sizes_strides_implsizes_strides_user 的實作細節及安裝過程。

相关推荐
Traving Yu1 小时前
向量数据库Milvus
数据库·人工智能·milvus
苏生十一_Nojambot1 小时前
AI浏览器——Tabbit使用教程
人工智能
AI科技星1 小时前
【无标题】
人工智能·决策树·机器学习·数据挖掘·机器人
2501_901006471 小时前
golang如何使用DTM分布式事务框架_golang DTM分布式事务框架使用方法
jvm·数据库·python
一点一木1 小时前
2026 终端 AI 编码 Agent 六大工具深度横评
前端·人工智能·claude
qq_411262421 小时前
四博 AI 双目智能音箱方案:四路触控、震动反馈、姿态感应、语音克隆和专属知识库全拉满
人工智能·智能音箱
2501_901200531 小时前
Golang如何做Clean Architecture_Golang整洁架构教程【详解】
jvm·数据库·python
沪漂阿龙1 小时前
面试题:卷积神经网络(CNN)是什么?核心层、卷积核、池化、1×1 卷积、VGG、ResNet 一文讲透
人工智能·神经网络·cnn
weixin_459753941 小时前
Go 中嵌入类型字段在派生结构体字面量中的初始化规则详解
jvm·数据库·python