前言
在 PyTorch SymNode 為何找不到方法實作?──sizes_strides_methods 動態安裝機制解析 這篇文章中介紹了 _make_node_sizes_strides 函數,而在該函數內,又定義了 sizes_strides_impl 和 sizes_strides_user 函數。
關於 sizes_strides_impl 函數,我們已在 PyTorch SymNode 的 _is_contiguous 從何而來?──sizes_strides_impl 實作詳解 中看過;sizes_strides_user 則是將來會被安裝在 torch.fx.experimental.symbolic_shapes 模組上的函數,它正是本篇的主題。
安裝 sizes_strides_user
因為 sizes_strides_user 的定義位於 _make_node_sizes_strides 函數內部,所以 _make_node_sizes_strides 的參數 method, func 也對 sizes_strides_user 可見,其中 method 代表函數的名稱,func 則為欲安裝的函數本身,詳見 sizes_strides_methods 安裝流程。
來看 sizes_strides_user 的安裝,以下兩行程式碼會將它安裝到 sys.modules[__name__] 上面:
python
# Skip for is_non_overlapping_and_dense_indicator
if not hasattr(sys.modules[__name__], method):
setattr(sys.modules[__name__], method, sizes_strides_user)
setattr 是 Python 的內建函數,可以將值指派給物件的屬性,在該屬性不存在的情況下則會自動為指定物件新增該屬性。此處用它來會 torch.fx.experimental.symbolic_shapes 模組新增 method 成員,並賦值成 sizes_strides_user,將來就可以透過 torch.fx.experimental.symbolic_shapes.method 來調用 sizes_strides_user 函數。
主程式遍歷 sizes_strides_methods 中的六個函數,將這些方法安裝到 torch.fx.experimental.symbolic_shapes 模組上:
is_contiguousis_channels_last_contiguous_2dis_channels_last_contiguous_3dis_channels_last_strides_2dis_channels_last_strides_3dis_non_overlapping_and_dense_indicator
注釋中說:
python
# Skip for is_non_overlapping_and_dense_indicator
但 code 裡沒有做相關的特殊處理,不確定這行注釋是 TODO 還是已過時了。
sizes_strides_user
正式進入 sizes_strides_user 函數,從它名字中的 user 可以看出來,它是對外暴露給使用者的入口函數:
python
def sizes_strides_user(sizes, strides):
參數
跟 sizes_strides_impl 不同,sizes_strides_user 少了 self 參數,只有 sizes 和 strides 兩個參數,因為它不像 sizes_strides_impl 是某個類別的成員函數,而是一個模組層級的函數。這兩個參數可以是 list of SymInt、list of sympy.Expr、list of sympy.Integer,或 list of Python int。函數會根據輸入元素的型別走不同的路徑。
SymInt 路徑
python
for a in itertools.chain(sizes, strides):
if isinstance(a, SymInt):
return wrap_node(getattr(a.node, method)(
[to_node(a.node, b) for b in sizes],
[to_node(a.node, b) for b in strides],
))
這段 code 中用到了 itertools.chain(*iterables),其作用如下:
Make an iterator that returns elements from the first iterable until it is exhausted, then proceeds to the next iterable, until all of the iterables are exhausted. This combines multiple data sources into a single iterator.
建立一個迭代器,會先依序回傳第一個可迭代物件(iterable)的元素,直到耗盡為止,接著再處理下一個 iterable,直到所有 iterable 都被耗盡。這可以將多個資料來源合併成一個單一的迭代器。
簡單來說就是 itertools.chain 會將它接受的多個可迭代物件(iterable)參數合併成一個。
所以 for a in itertools.chain(sizes, strides) 會遍歷由 sizes 和 strides 併在一起的大 list。
在遍歷的過程中,如果其中一個元素是 SymInt,則使用 getattr(a.node, method) 得到 SymNode 的 method 方法,該方法是 SymNode 的成員函數,接受 self, sizes 和 strides 三個參數,回傳 SymNode 物件。
SymNode.method 方法內部會調用帶下底線的 SymNode._method 方法,而 SymNode._method 方法其實就是 PyTorch SymNode 的 _is_contiguous 從何而來?──sizes_strides_impl 實作詳解 中介紹的 sizes_strides_impl 方法。
因為 SymNode.method 接受的 sizes 和 strides 參數都是 list of SymNode,而此處接收到的 sizes 和 strides 則是 list of SymInt,所以需要做轉換。
這個轉換由 to_node 函數負責,參考 torch.fx.experimental.symbolic_shapes.to_node,它可以將 SymInt 物件轉成 SymNode 物件。[to_node(a.node, b) for b in sizes] 和 [to_node(a.node, b) for b in strides] 這兩行的作用就是將 sizes 和 strides 兩個 list of SymInt 轉換為 list of SymNode。
得到兩個 list of SymNode 後,將它們傳入 SymNode.method 方法可得到一個 SymNode 物件,該物件會被傳入 wrap_node 函數。wrap_node 函數會根據 SymNode 物件的 pytype 屬性來決定要將該物件包裝成什麼:如果 pytype 是 bool,則 SymNode 會被包裝成 SymBool;如果 pytype 是 int,則 SymNode 會被包裝成 SymInt。
因此在這個路徑下,函數接受的是 list of SymInt,回傳的則是 SymInt 或 SymBool。
如果 sizes 和 strides 中包含了 SymInt,會進入 SymInt 路徑,在 for 迴圈內就會直接回傳;反之,則代表 sizes 和 strides 裡包含的都是 sympy.Expr, sympy.Integer 或是 Python int,會進入另一條「sympy.Expr 路徑」。
sympy.Expr 路徑
在這種情況下,會繼續做以下檢查:如果函數名稱是 is_non_overlapping_and_dense_indicator,因為 torch/fx/experimental/symbolic_shapes.py 中有一個接受 list of sympy.Expr 或 list of Python int 為參數的 eval_is_non_overlapping_and_dense 函數,所以可以直接調用它,並且該函數回傳的是 Python int 0 或 1,所以得到結果後也可以直接回傳。
python
if method == "is_non_overlapping_and_dense_indicator":
return eval_is_non_overlapping_and_dense(sizes, strides)
否則會用到外層的 func 參數,func 函數接受 list of int, sympy.Expr, sympy.Symbol, sympy.Integer 或 sympy 組合式為輸入,輸出則為 symbolic boolean expression。
以下程式碼先用 sympy.sympify 將 sizes 和 strides 中的各元素轉換為可供SymPy表達式使用的形式,傳入 func 函數後得到 symbolic boolean expression,再將它轉成 Python bool 後回傳。
python
else:
# TODO: this is an awful implementation
return bool(func(
[sympy.sympify(a) for a in sizes],
[sympy.sympify(a) for a in strides],
))
兩條路徑的差異
sizes_strides_user 中 SymInt 路徑和 sympy.Expr 路徑的核心差別在於:前者接受 SymInt 為輸入,並且會保留 symbolic 包裝,回傳 SymBool / SymInt;後者則接受 sympy.Expr 為輸入,會立即執行判斷邏輯,不保留 symbolic 包裝,直接回傳 Python bool 或 int。
demo
SymInt 路徑
以 SymInt 為輸入,會走 SymInt 路徑,呼叫對應 SymNode 成員函數,最後用 wrap_node 包裝成 SymInt 或 SymBool:
python
import sympy
import torch
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymNode
import torch.fx.experimental.symbolic_shapes as ss
shape_env = ShapeEnv()
s0 = torch.SymInt(SymNode(sympy.Symbol("s0", integer=True), shape_env, int, None))
result = ss.is_contiguous([s0, 3, 4], [12, 4, 1])
print(result)
print(type(result))
print(isinstance(result, bool))
這裡直接用 torch.SymInt(...) 建構子,把 SymNode 包成一個 SymInt。其中 SymNode 的 expr 是 sympy.Symbol("s0", integer=True),pytype 是 int,hint則是None`,表示它目前沒有對應的具體值。
因為 s0 的型別是 SymInt,ss.is_contiguous([s0, 3, 4], [12, 4, 1]) 會進入 SymInt 路徑,轉而呼叫對應的 SymNode 成員函數,最後回傳 SymBool。
運行結果:
True
<class 'torch.SymBool'>
False
第一行是 SymBool 印出其底層 expr 後的結果;第二、三行則說明回傳物件的型別其實是 SymBool,不是 Python bool。
sympy.Expr 路徑
以 sympy.Expr 或是 Python int 為輸入,則會走 sympy.Expr 路徑,最後回傳 Python bool:
python
import torch.fx.experimental.symbolic_shapes as ss
print(ss.is_contiguous([2, 3, 4], [12, 4, 1]))
print(ss.is_contiguous([2, 3, 4], [1, 2, 6]))
# NCHW: (N=2, C=3, H=4, W=5)
print(ss.is_channels_last_strides_2d([2, 3, 4, 5], [60, 1, 15, 3]))
print(ss.is_channels_last_strides_2d([2, 3, 4, 5], [60, 20, 5, 1]))
print(type(ss.is_channels_last_strides_2d([2, 3, 4, 5], [60, 20, 5, 1])))
運行結果:
True
False
True
False
<class 'bool'>
調用流程
以 torch.fx.experimental.symbolic_shapes.is_contiguous 函數為例,根據輸入型別會走兩條不同的路徑:
SymInt 路徑
text
ss.is_contiguous → SymNode.is_contiguous → SymNode._is_contiguous → sympy_is_contiguous
兩個 list[SymInt] → self + 兩個 list[SymNode] → self + 兩個 list[SymNode] → 兩個 list[sympy.Expr]
↓
SymBool (wrap_node後) ← SymNode(pytype=bool) ← SymNode(pytype=bool) ← SymPy Boolean expression
sympy.Expr 路徑
text
ss.is_contiguous → sympy_is_contiguous
兩個 list[sympy.Expr/int] → 兩個 list[sympy.Expr]
↓
Python bool ← SymPy Boolean expression