PyTorch symbolic_shapes 模組的 is_contiguous 從哪來?── sizes_strides_user 安裝與實作解析

前言

PyTorch SymNode 為何找不到方法實作?──sizes_strides_methods 動態安裝機制解析 這篇文章中介紹了 _make_node_sizes_strides 函數,而在該函數內,又定義了 sizes_strides_implsizes_strides_user 函數。

關於 sizes_strides_impl 函數,我們已在 PyTorch SymNode 的 _is_contiguous 從何而來?──sizes_strides_impl 實作詳解 中看過;sizes_strides_user 則是將來會被安裝在 torch.fx.experimental.symbolic_shapes 模組上的函數,它正是本篇的主題。

安裝 sizes_strides_user

因為 sizes_strides_user 的定義位於 _make_node_sizes_strides 函數內部,所以 _make_node_sizes_strides 的參數 method, func 也對 sizes_strides_user 可見,其中 method 代表函數的名稱,func 則為欲安裝的函數本身,詳見 sizes_strides_methods 安裝流程


來看 sizes_strides_user 的安裝,以下兩行程式碼會將它安裝到 sys.modules[__name__] 上面:

python 复制代码
    # Skip for is_non_overlapping_and_dense_indicator
    if not hasattr(sys.modules[__name__], method):
        setattr(sys.modules[__name__], method, sizes_strides_user)

setattr 是 Python 的內建函數,可以將值指派給物件的屬性,在該屬性不存在的情況下則會自動為指定物件新增該屬性。此處用它來會 torch.fx.experimental.symbolic_shapes 模組新增 method 成員,並賦值成 sizes_strides_user,將來就可以透過 torch.fx.experimental.symbolic_shapes.method 來調用 sizes_strides_user 函數。

主程式遍歷 sizes_strides_methods 中的六個函數,將這些方法安裝到 torch.fx.experimental.symbolic_shapes 模組上:

  • is_contiguous
  • is_channels_last_contiguous_2d
  • is_channels_last_contiguous_3d
  • is_channels_last_strides_2d
  • is_channels_last_strides_3d
  • is_non_overlapping_and_dense_indicator

注釋中說:

python 复制代码
# Skip for is_non_overlapping_and_dense_indicator

但 code 裡沒有做相關的特殊處理,不確定這行注釋是 TODO 還是已過時了。

sizes_strides_user

正式進入 sizes_strides_user 函數,從它名字中的 user 可以看出來,它是對外暴露給使用者的入口函數:

python 复制代码
    def sizes_strides_user(sizes, strides):

參數

sizes_strides_impl 不同,sizes_strides_user 少了 self 參數,只有 sizesstrides 兩個參數,因為它不像 sizes_strides_impl 是某個類別的成員函數,而是一個模組層級的函數。這兩個參數可以是 list of SymInt、list of sympy.Expr、list of sympy.Integer,或 list of Python int。函數會根據輸入元素的型別走不同的路徑。

SymInt 路徑

python 复制代码
        for a in itertools.chain(sizes, strides):
            if isinstance(a, SymInt):
                return wrap_node(getattr(a.node, method)(
                    [to_node(a.node, b) for b in sizes],
                    [to_node(a.node, b) for b in strides],
                ))

這段 code 中用到了 itertools.chain(*iterables),其作用如下:

复制代码
Make an iterator that returns elements from the first iterable until it is exhausted, then proceeds to the next iterable, until all of the iterables are exhausted. This combines multiple data sources into a single iterator.

建立一個迭代器,會先依序回傳第一個可迭代物件(iterable)的元素,直到耗盡為止,接著再處理下一個 iterable,直到所有 iterable 都被耗盡。這可以將多個資料來源合併成一個單一的迭代器。

簡單來說就是 itertools.chain 會將它接受的多個可迭代物件(iterable)參數合併成一個。

所以 for a in itertools.chain(sizes, strides) 會遍歷由 sizesstrides 併在一起的大 list。

在遍歷的過程中,如果其中一個元素是 SymInt,則使用 getattr(a.node, method) 得到 SymNodemethod 方法,該方法是 SymNode 的成員函數,接受 self, sizesstrides 三個參數,回傳 SymNode 物件。

SymNode.method 方法內部會調用帶下底線的 SymNode._method 方法,而 SymNode._method 方法其實就是 PyTorch SymNode 的 _is_contiguous 從何而來?──sizes_strides_impl 實作詳解 中介紹的 sizes_strides_impl 方法。

因為 SymNode.method 接受的 sizesstrides 參數都是 list of SymNode,而此處接收到的 sizesstrides 則是 list of SymInt,所以需要做轉換。

這個轉換由 to_node 函數負責,參考 torch.fx.experimental.symbolic_shapes.to_node,它可以將 SymInt 物件轉成 SymNode 物件。[to_node(a.node, b) for b in sizes][to_node(a.node, b) for b in strides] 這兩行的作用就是將 sizesstrides 兩個 list of SymInt 轉換為 list of SymNode

得到兩個 list of SymNode 後,將它們傳入 SymNode.method 方法可得到一個 SymNode 物件,該物件會被傳入 wrap_node 函數wrap_node 函數會根據 SymNode 物件的 pytype 屬性來決定要將該物件包裝成什麼:如果 pytypebool,則 SymNode 會被包裝成 SymBool;如果 pytypeint,則 SymNode 會被包裝成 SymInt

因此在這個路徑下,函數接受的是 list of SymInt,回傳的則是 SymIntSymBool


如果 sizesstrides 中包含了 SymInt,會進入 SymInt 路徑,在 for 迴圈內就會直接回傳;反之,則代表 sizesstrides 裡包含的都是 sympy.Expr, sympy.Integer 或是 Python int,會進入另一條「sympy.Expr 路徑」。

sympy.Expr 路徑

在這種情況下,會繼續做以下檢查:如果函數名稱是 is_non_overlapping_and_dense_indicator,因為 torch/fx/experimental/symbolic_shapes.py 中有一個接受 list of sympy.Expr 或 list of Python int 為參數的 eval_is_non_overlapping_and_dense 函數,所以可以直接調用它,並且該函數回傳的是 Python int 0 或 1,所以得到結果後也可以直接回傳。

python 复制代码
        if method == "is_non_overlapping_and_dense_indicator":
            return eval_is_non_overlapping_and_dense(sizes, strides)

否則會用到外層的 func 參數,func 函數接受 list of int, sympy.Expr, sympy.Symbol, sympy.Integer 或 sympy 組合式為輸入,輸出則為 symbolic boolean expression。

以下程式碼先用 sympy.sympifysizesstrides 中的各元素轉換為可供SymPy表達式使用的形式,傳入 func 函數後得到 symbolic boolean expression,再將它轉成 Python bool 後回傳。

python 复制代码
        else:
            # TODO: this is an awful implementation
            return bool(func(
                [sympy.sympify(a) for a in sizes],
                [sympy.sympify(a) for a in strides],
            ))

兩條路徑的差異

sizes_strides_userSymInt 路徑和 sympy.Expr 路徑的核心差別在於:前者接受 SymInt 為輸入,並且會保留 symbolic 包裝,回傳 SymBool / SymInt;後者則接受 sympy.Expr 為輸入,會立即執行判斷邏輯,不保留 symbolic 包裝,直接回傳 Python bool 或 int。

demo

SymInt 路徑

SymInt 為輸入,會走 SymInt 路徑,呼叫對應 SymNode 成員函數,最後用 wrap_node 包裝成 SymIntSymBool

python 复制代码
import sympy
import torch
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymNode
import torch.fx.experimental.symbolic_shapes as ss

shape_env = ShapeEnv()
s0 = torch.SymInt(SymNode(sympy.Symbol("s0", integer=True), shape_env, int, None))
result = ss.is_contiguous([s0, 3, 4], [12, 4, 1])

print(result)
print(type(result))
print(isinstance(result, bool))

這裡直接用 torch.SymInt(...) 建構子,把 SymNode 包成一個 SymInt。其中 SymNodeexprsympy.Symbol("s0", integer=True)pytypeint,hint則是None`,表示它目前沒有對應的具體值。

因為 s0 的型別是 SymIntss.is_contiguous([s0, 3, 4], [12, 4, 1]) 會進入 SymInt 路徑,轉而呼叫對應的 SymNode 成員函數,最後回傳 SymBool

運行結果:

复制代码
True
<class 'torch.SymBool'>
False

第一行是 SymBool 印出其底層 expr 後的結果;第二、三行則說明回傳物件的型別其實是 SymBool,不是 Python bool

sympy.Expr 路徑

sympy.Expr 或是 Python int 為輸入,則會走 sympy.Expr 路徑,最後回傳 Python bool:

python 复制代码
import torch.fx.experimental.symbolic_shapes as ss

print(ss.is_contiguous([2, 3, 4], [12, 4, 1]))
print(ss.is_contiguous([2, 3, 4], [1, 2, 6]))

# NCHW: (N=2, C=3, H=4, W=5)
print(ss.is_channels_last_strides_2d([2, 3, 4, 5], [60, 1, 15, 3]))
print(ss.is_channels_last_strides_2d([2, 3, 4, 5], [60, 20, 5, 1]))

print(type(ss.is_channels_last_strides_2d([2, 3, 4, 5], [60, 20, 5, 1])))

運行結果:

复制代码
True
False
True
False
<class 'bool'>

調用流程

torch.fx.experimental.symbolic_shapes.is_contiguous 函數為例,根據輸入型別會走兩條不同的路徑:

SymInt 路徑

text 复制代码
ss.is_contiguous            → SymNode.is_contiguous      → SymNode._is_contiguous      → sympy_is_contiguous
兩個 list[SymInt]            → self + 兩個 list[SymNode]   → self + 兩個 list[SymNode]   → 兩個 list[sympy.Expr]
                                                                                                ↓
SymBool (wrap_node後)        ← SymNode(pytype=bool)       ← SymNode(pytype=bool)        ← SymPy Boolean expression

sympy.Expr 路徑

text 复制代码
ss.is_contiguous            → sympy_is_contiguous
兩個 list[sympy.Expr/int]    → 兩個 list[sympy.Expr]
                                       ↓
Python bool                  ← SymPy Boolean expression
相关推荐
大模型推理1 小时前
Nano-vLLM 源码解读 - 7. Continuous Batching
深度学习·自然语言处理·vllm
MXsoft6181 小时前
**智能运维如何实现全栈监控与****AI****告警?****——****一体化平台实战解析**
运维·人工智能
C137的本贾尼1 小时前
别怕异步:`async` 和 `await` 的简单理解
开发语言·python
__log1 小时前
ComfyUI 集成技术方案分析报告
javascript·python·django
想你依然心痛1 小时前
HarmonyOS 6(API 23)实战:基于悬浮导航、沉浸光感与HMAF的“代码哨兵“——AI智能体代码安全审计平台
人工智能·安全·harmonyos·智能体
云安全助手1 小时前
谁能定义云安全AI时代?——具有“安全原生”的聚合与防护平台
人工智能·ai·claude
梅西库里RNG2 小时前
AI学习纪要——基础篇
人工智能·学习
梦想的颜色2 小时前
2026最新Claude Code 规范文件 CLAUDE.md 全面解析与超全模板
人工智能·小程序