PyTorch SymNode 的 _is_contiguous 從何而來?──sizes_strides_impl 實作詳解

前言

前篇 PyTorch SymNode 為何找不到方法實作?──sizes_strides_methods 動態安裝機制解析 介紹了 _make_node_sizes_strides 函數,而在該函數內,又定義了 sizes_strides_implsizes_strides_user 函數。

_make_node_sizes_strides 函數中的第一部份 sizes_strides_impl 是在為 SymNode 安裝成員函數;第二部份 sizes_strides_user 則是在為 torch.fx.experimental.symbolic_shapes 模組本身安裝模組層級的函數,本篇將會介紹第一部分 sizes_strides_impl 的實作細節及安裝過程。

安裝 sizes_strides_impl

因為 sizes_strides_impl 的定義位於 _make_node_sizes_strides 函數內部,所以 _make_node_sizes_strides 的參數 method, func 也對 sizes_strides_impl 可見,其中 method 代表函數的名稱,func 則為欲安裝的函數本身,詳見 sizes_strides_methods 安裝流程


先來看 sizes_strides_impl 的安裝,利用以下這行:

python 复制代码
    setattr(SymNode, f"_{method}", sizes_strides_impl)

它會把 sizes_strides_impl_method 的名稱安裝在 SymNode 上。

is_contiguous 為例。如果 _make_node_sizes_stridesmethod 參數是 is_contiguous,那麼它被安裝到 SymNode 後,名稱就會變為 _is_contiguous,即 SymNode._is_contiguous

主程式遍歷 sizes_strides_methods 中的六個函數,將這些方法安裝到 SymNode 類別上,但會加上下底線前綴,變成:

  • _is_contiguous
  • _is_channels_last_contiguous_2d
  • _is_channels_last_contiguous_3d
  • _is_channels_last_strides_2d
  • _is_channels_last_strides_3d
  • _is_non_overlapping_and_dense_indicator

那麼這類函數在哪裡被調用呢?其實只要看 跟 method 同名的 SymNode 的成員函數的定義 就知道了,在該函數中會調用帶下底線前綴的版本,那個帶下底線前綴的版本便是本篇的 sizes_strides_impl

sizes_strides_impl

正式進入 sizes_strides_impl

python 复制代码
    def sizes_strides_impl(self, sizes, strides):

參數

sizes_strides_impl 有三個參數:self, sizes, strides

前面已經看到 sizes_strides_impl 將來會被安裝到 SymNode 類別上,所以可以知道,第一個 self 參數就是 SymNode 型別。

另外兩個參數 sizes, strides 則皆為 list of SymNode,這是怎麼看出來的呢?稍後我們會看到:程式會遍歷 sizesstrides,存取其元素的 exprhint 成員變數,而 exprhint 正分別是 SymNode 的成員變數和 property,參考 PyTorch動態形狀系統的基石 - SymNode。這可證明 sizesstrides 兩個 list 中存的確實是 SymNode

op

python 复制代码
        op = getattr(sys.modules[__name__], method)

其中 sys.modules 是一個字典,key 是模組名稱字串,value 是對應的「已載入模組」物件。此處的 __name__ 是目前模組的完整名稱,也就是 torch.fx.experimental.symbolic_shapes,因此 sys.modules[__name__] 取到的就是這個模組物件。

getattr 是 Python 的內建函數,可以取出物件中名稱為指定名稱的屬性的值。此處用它來取出 torch.fx.experimental.symbolic_shapes 模組的 method 成員,如果傳入的 methodis_contiguous,則 op 會變成 torch.fx.experimental.symbolic_shapes.is_contiguous

但如果我們實際在 symbolic_shapes.py 裡尋找,卻找不到名為 is_contiguous 的函數,這是因為它其實是稍後才被安裝到 torch.fx.experimental.symbolic_shapes 模組上的模組層級函數,這跟 sizes_strides_user 有關,將在下一篇介紹。

那麼為何此處非得用 op,也就是 sizes_strides_user,而不是直接呼叫 func 呢?

我們先來看看 op 會被用在什麼地方:

  • if SYM_FUNCTION_MODE 分支內當作 _handle_sym_dispatch 函數的第一個參數
  • 計算 out_hint 時呼叫 op(size_hints, stride_hints)

在計算 out_hint 時會用到 op,但是為什麼必須用 op 而不能用 func 呢?其原因將在 hint 章節分曉。

SYM_FUNCTION_MODE

python 复制代码
        if SYM_FUNCTION_MODE:
            return to_node(
                self,
                _handle_sym_dispatch(
                    op,
                    ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
                    {}
                )
            )

SYM_FUNCTION_MODE 是一個全域變數,用來表示「目前有沒有外部程式想要攔截符號運算」。如果它非 None,表示外部已經設好攔截器,此時就會把這次的運算(op 函數和入參 sizesstrides)交給攔截器處理,再用 to_node 把結果包回 SymNode 回傳。

expr

在這段 code 中會用到外層函數 _make_node_sizes_strides 的參數 func

func 是用 list of sympy.Expr 做運算的,所以此處先取出 sizesstrides 列表中各 SymNode 物件的 expr 成員變數:

python 复制代码
        size_exprs = [s.expr for s in sizes]
        stride_exprs = [s.expr for s in strides]

接著將它們丟到 func 裡做運算:

python 复制代码
        try:
            out = func(size_exprs, stride_exprs)
        except Exception:
            log.warning(f"failed to eval {method}({size_exprs}, {stride_exprs})")
            raise

得到的 out 是函數 func 的判斷結果,是個 symbolic boolean expression。

hint

以下這段會先宣告 out_hint 變數,然後遍歷 sizesstrides,嘗試為 out_hint 賦值。

python 复制代码
        size_hints = []
        out_hint = None
        for s in sizes:
            if s.hint is None:
                break
            size_hints.append(s.hint)
        else:
            stride_hints = []
            for s in strides:
                if s.hint is None:
                    break
                stride_hints.append(s.hint)
            else:
                out_hint = op(size_hints, stride_hints)

上面的程式碼中用到了 Python 的 for-else 語法,詳見 Python For Else。簡單說就是當 for loop 中的 break 沒有被觸發時,就會執行 else 中的程式碼。

所以這段 code 可以分為兩三種情況:

  • sizes 列表中有任一個 SymNodehint 值為空,就 break 且不進入 else 分支,out_hint 保持為空
  • 如果 sizes 列表中每個 SymNodehint 值皆非空,就進入 else 分支,這時會檢查 strides 列表中每個 SymNodehint 是否非空。如果有任一個為空,就會 break 且不進入 else 分支,out_hint 保持為空
  • 如果 sizesstrides 列表中每個 SymNodehint 值皆非空,會再進入 else 分支,利用剛剛遍歷列表時順便填好的 size_hintsstride_hints 來計算 out_hint

這段程式碼看起來很長,其實核心只有一句:就是只有在兩個列表中的所有 SymNodehint 皆非空時,才做以下計算:out_hint = op(size_hints, stride_hints)

此處計算 out_hint 時用的不是外層 _make_node_sizes_strides 的參數 func,而是稍早透過 getattr(sys.modules[__name__], method) 取出的 op,即 sizes_strides_user,為何要多此一舉呢?

如果查看 SymNode類別的建構子,就會發現 hint 的型別是 Optional[Union[int, float]],也就是 Python 中單純的數字,這代表稍早取出的 size_hintsstride_hints 裡面都是 Python 中的 int 或 float。這兩個 list of Python int 或 float 會被當作 op 的參數。

func 函數接受的是 smypy.Expr 而非 Python 中的數字,所以才會需要使用其它能接受 Python 數字的函數。

我們在下一篇會看到,op(即 sizes_strides_user)正好滿足這一點,它既接受 list of SymInt, sympy.Expr,還接受 list of Python int 為輸入。當輸入是 list of Python int 時,函數會回傳 Python int 或 bool。

op 函數的回傳值會被賦給 out_hintout_hint 稍後會被當作 SymNode 建構子的參數,因為 Python bool 可以無痛被轉為 Python int,所以 SymNode 建構子也能接受。

pytype

最後是 pytype,表示的是輸出 SymNode 所包含的資料的型別。如果函數名稱以 _indicator 結尾,則指定底層資料的型別為 int,否則指定為 bool

python 复制代码
        # NB: This is the indicator function, not the actual bool!
        pytype: Type
        if method.endswith("_indicator"):
            pytype = int
        else:
            pytype = bool

回傳

函數最後一行會先建構一個 SymNode 物件再回傳,而 SymNode的建構子 需要 exprshape_envpytypehint 等參數,我們剛才那麼辛苦地計算 out, pytype, 和 out_hint 就是為了此處建構 SymNode 所用。

python 复制代码
        return SymNode(out, self.shape_env, pytype, out_hint)

所以 sizes_strides_impl 就是一個輸入為 list of SymNode,輸出為 SymNode 的函數。

如果要再細分的話,is_contiguous, is_channels_last_* 這些函數最後回傳的會是 bool-typed SymNodeis_non_overlapping_and_dense_indicator 回傳的則會是 int-typed SymNode

demo

is_contiguous

is_contiguous 為例,它是以 _is_contiguous 的名稱被安裝到 SymNode 上的,下例會直接呼叫它。

傳入的參數是兩個 list of SymNode,即 sizesstrides。以下程式會先建立一個 ShapeEnv 並用它創建 SymNode

python 复制代码
import sympy
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymNode

shape_env = ShapeEnv()

# SymNode 的 expr 欄位型別為 sympy.Expr,而 sympy.Symbol 是 sympy.Expr 的子類別,
# 所以這裡可以直接傳入 sympy.Symbol。使用 Symbol 是為了表示「未具體化的符號維度」
# (動態 shape);若要表示靜態維度可改用 sympy.Integer,或傳入由兩者組合的運算式。
def create_symnode(name, hint):
    return SymNode(sympy.Symbol(name, integer=True), shape_env, int, hint)

# 對應形狀為 (2, 3, 4) 的 contiguous tensor
sizes = [create_symnode("s0", 2), create_symnode("s1", 3), create_symnode("s2", 4)]
strides = [create_symnode("t0", 12), create_symnode("t1", 4), create_symnode("t2", 1)]

result = sizes[0]._is_contiguous(sizes, strides)
print(result)
print(type(result))
print(result.pytype)
print(bool(result))

運行結果:

复制代码
Eq(s0, 0) | Eq(s1, 0) | Eq(s2, 0) | ((Eq(s1, 1) | Eq(t1, s2)) & (Eq(s2, 1) | Eq(t2, 1)) & (Eq(s0, 1) | Eq(t0, s1*s2)))
<class 'torch.fx.experimental.symbolic_shapes.SymNode'>
<class 'bool'>
True

可以看到輸出結果是一個 SymNode,並且其 pytype 為 bool。

is_non_overlapping_and_dense_indicator

呼叫 _is_non_overlapping_and_dense_indicator

python 复制代码
result = sizes[0]._is_non_overlapping_and_dense_indicator(sizes, strides)
print(result)
print(type(result))
print(result.pytype)
print(bool(result))
print(result.hint)

運行結果:

复制代码
IsNonOverlappingAndDenseIndicator(s0, s1, s2, t0, t1, t2)
<class 'torch.fx.experimental.symbolic_shapes.SymNode'>
<class 'int'>
True
1

輸出結果是一個 SymNode,其 pytypeint,而非剛才所見的 boolhint 值則為 0 或 1 其中之一。

調用流程

SymNode.is_contiguous 方法為例,從 SymNode 不帶下底線的方法到 sympy_* 函數的調用流程如下:

text 复制代码
SymNode.is_contiguous      → SymNode._is_contiguous      → sympy_is_contiguous
self + 兩個 list[SymNode]   → self + 兩個 list[SymNode]   → 兩個 list[sympy.Expr]
                                                                    ↓
SymNode(pytype=bool)        ← SymNode(pytype=bool)       ← SymPy Boolean expression
相关推荐
测试员周周1 小时前
【Appium 系列】第02节-环境搭建 — Android + iOS 双平台环境配置
开发语言·人工智能·功能测试·appium·自动化·测试用例·web app
imbackneverdie1 小时前
AI PPT工具实测分享
人工智能·ai作画·aigc·ppt·ai工具·aippt
AI搅拌机1 小时前
【一键安装】 Qwen3-TTS语音克隆三合一工作流!
人工智能
踏着七彩祥云的小丑1 小时前
AI——Dify数据备份与迁移
人工智能·ai
2603_954708311 小时前
微电网分布式电源接入技术:光伏、风电的适配设计
人工智能·分布式·物联网·架构·系统架构·能源
手写码匠1 小时前
手写 AI 智能路由系统:从零构建多模型调度与负载均衡
人工智能·深度学习·算法·aigc
七牛云行业应用1 小时前
MCP 服务器本地部署实战【2026】:Python/Node.js 搭建 + Claude/Cursor/TRAE
服务器·python·node.js
Web极客码1 小时前
Python Deque:构建实时滑动窗口与高性能缓存的“隐藏高手”
java·python·缓存
AI科技星1 小时前
全域数学·体积与表面积通项定理【乖乖数学】
人工智能·算法·数学建模·数据挖掘·机器人