前言
前篇 PyTorch SymNode 為何找不到方法實作?──sizes_strides_methods 動態安裝機制解析 介紹了 _make_node_sizes_strides 函數,而在該函數內,又定義了 sizes_strides_impl 和 sizes_strides_user 函數。
_make_node_sizes_strides 函數中的第一部份 sizes_strides_impl 是在為 SymNode 安裝成員函數;第二部份 sizes_strides_user 則是在為 torch.fx.experimental.symbolic_shapes 模組本身安裝模組層級的函數,本篇將會介紹第一部分 sizes_strides_impl 的實作細節及安裝過程。
安裝 sizes_strides_impl
因為 sizes_strides_impl 的定義位於 _make_node_sizes_strides 函數內部,所以 _make_node_sizes_strides 的參數 method, func 也對 sizes_strides_impl 可見,其中 method 代表函數的名稱,func 則為欲安裝的函數本身,詳見 sizes_strides_methods 安裝流程。
先來看 sizes_strides_impl 的安裝,利用以下這行:
python
setattr(SymNode, f"_{method}", sizes_strides_impl)
它會把 sizes_strides_impl 以 _method 的名稱安裝在 SymNode 上。
以 is_contiguous 為例。如果 _make_node_sizes_strides 的 method 參數是 is_contiguous,那麼它被安裝到 SymNode 後,名稱就會變為 _is_contiguous,即 SymNode._is_contiguous。
主程式遍歷 sizes_strides_methods 中的六個函數,將這些方法安裝到 SymNode 類別上,但會加上下底線前綴,變成:
_is_contiguous_is_channels_last_contiguous_2d_is_channels_last_contiguous_3d_is_channels_last_strides_2d_is_channels_last_strides_3d_is_non_overlapping_and_dense_indicator
那麼這類函數在哪裡被調用呢?其實只要看 跟 method 同名的 SymNode 的成員函數的定義 就知道了,在該函數中會調用帶下底線前綴的版本,那個帶下底線前綴的版本便是本篇的 sizes_strides_impl。
sizes_strides_impl
正式進入 sizes_strides_impl:
python
def sizes_strides_impl(self, sizes, strides):
參數
sizes_strides_impl 有三個參數:self, sizes, strides。
前面已經看到 sizes_strides_impl 將來會被安裝到 SymNode 類別上,所以可以知道,第一個 self 參數就是 SymNode 型別。
另外兩個參數 sizes, strides 則皆為 list of SymNode,這是怎麼看出來的呢?稍後我們會看到:程式會遍歷 sizes 和 strides,存取其元素的 expr 和 hint 成員變數,而 expr 和 hint 正分別是 SymNode 的成員變數和 property,參考 PyTorch動態形狀系統的基石 - SymNode。這可證明 sizes 和 strides 兩個 list 中存的確實是 SymNode。
op
python
op = getattr(sys.modules[__name__], method)
其中 sys.modules 是一個字典,key 是模組名稱字串,value 是對應的「已載入模組」物件。此處的 __name__ 是目前模組的完整名稱,也就是 torch.fx.experimental.symbolic_shapes,因此 sys.modules[__name__] 取到的就是這個模組物件。
getattr 是 Python 的內建函數,可以取出物件中名稱為指定名稱的屬性的值。此處用它來取出 torch.fx.experimental.symbolic_shapes 模組的 method 成員,如果傳入的 method 為 is_contiguous,則 op 會變成 torch.fx.experimental.symbolic_shapes.is_contiguous。
但如果我們實際在 symbolic_shapes.py 裡尋找,卻找不到名為 is_contiguous 的函數,這是因為它其實是稍後才被安裝到 torch.fx.experimental.symbolic_shapes 模組上的模組層級函數,這跟 sizes_strides_user 有關,將在下一篇介紹。
那麼為何此處非得用 op,也就是 sizes_strides_user,而不是直接呼叫 func 呢?
我們先來看看 op 會被用在什麼地方:
if SYM_FUNCTION_MODE分支內當作_handle_sym_dispatch函數的第一個參數- 計算
out_hint時呼叫op(size_hints, stride_hints)
在計算 out_hint 時會用到 op,但是為什麼必須用 op 而不能用 func 呢?其原因將在 hint 章節分曉。
SYM_FUNCTION_MODE
python
if SYM_FUNCTION_MODE:
return to_node(
self,
_handle_sym_dispatch(
op,
([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
{}
)
)
SYM_FUNCTION_MODE 是一個全域變數,用來表示「目前有沒有外部程式想要攔截符號運算」。如果它非 None,表示外部已經設好攔截器,此時就會把這次的運算(op 函數和入參 sizes、strides)交給攔截器處理,再用 to_node 把結果包回 SymNode 回傳。
expr
在這段 code 中會用到外層函數 _make_node_sizes_strides 的參數 func。
func 是用 list of sympy.Expr 做運算的,所以此處先取出 sizes 和 strides 列表中各 SymNode 物件的 expr 成員變數:
python
size_exprs = [s.expr for s in sizes]
stride_exprs = [s.expr for s in strides]
接著將它們丟到 func 裡做運算:
python
try:
out = func(size_exprs, stride_exprs)
except Exception:
log.warning(f"failed to eval {method}({size_exprs}, {stride_exprs})")
raise
得到的 out 是函數 func 的判斷結果,是個 symbolic boolean expression。
hint
以下這段會先宣告 out_hint 變數,然後遍歷 sizes 和 strides,嘗試為 out_hint 賦值。
python
size_hints = []
out_hint = None
for s in sizes:
if s.hint is None:
break
size_hints.append(s.hint)
else:
stride_hints = []
for s in strides:
if s.hint is None:
break
stride_hints.append(s.hint)
else:
out_hint = op(size_hints, stride_hints)
上面的程式碼中用到了 Python 的 for-else 語法,詳見 Python For Else。簡單說就是當 for loop 中的 break 沒有被觸發時,就會執行 else 中的程式碼。
所以這段 code 可以分為兩三種情況:
- 當
sizes列表中有任一個SymNode的hint值為空,就 break 且不進入else分支,out_hint保持為空 - 如果
sizes列表中每個SymNode的hint值皆非空,就進入else分支,這時會檢查strides列表中每個SymNode的hint是否非空。如果有任一個為空,就會 break 且不進入else分支,out_hint保持為空 - 如果
sizes和strides列表中每個SymNode的hint值皆非空,會再進入else分支,利用剛剛遍歷列表時順便填好的size_hints和stride_hints來計算out_hint
這段程式碼看起來很長,其實核心只有一句:就是只有在兩個列表中的所有 SymNode 的 hint 皆非空時,才做以下計算:out_hint = op(size_hints, stride_hints)。
此處計算 out_hint 時用的不是外層 _make_node_sizes_strides 的參數 func,而是稍早透過 getattr(sys.modules[__name__], method) 取出的 op,即 sizes_strides_user,為何要多此一舉呢?
如果查看 SymNode類別的建構子,就會發現 hint 的型別是 Optional[Union[int, float]],也就是 Python 中單純的數字,這代表稍早取出的 size_hints 和 stride_hints 裡面都是 Python 中的 int 或 float。這兩個 list of Python int 或 float 會被當作 op 的參數。
但 func 函數接受的是 smypy.Expr 而非 Python 中的數字,所以才會需要使用其它能接受 Python 數字的函數。
我們在下一篇會看到,op(即 sizes_strides_user)正好滿足這一點,它既接受 list of SymInt, sympy.Expr,還接受 list of Python int 為輸入。當輸入是 list of Python int 時,函數會回傳 Python int 或 bool。
op 函數的回傳值會被賦給 out_hint,out_hint 稍後會被當作 SymNode 建構子的參數,因為 Python bool 可以無痛被轉為 Python int,所以 SymNode 建構子也能接受。
pytype
最後是 pytype,表示的是輸出 SymNode 所包含的資料的型別。如果函數名稱以 _indicator 結尾,則指定底層資料的型別為 int,否則指定為 bool:
python
# NB: This is the indicator function, not the actual bool!
pytype: Type
if method.endswith("_indicator"):
pytype = int
else:
pytype = bool
回傳
函數最後一行會先建構一個 SymNode 物件再回傳,而 SymNode的建構子 需要 expr,shape_env,pytype,hint 等參數,我們剛才那麼辛苦地計算 out, pytype, 和 out_hint 就是為了此處建構 SymNode 所用。
python
return SymNode(out, self.shape_env, pytype, out_hint)
所以 sizes_strides_impl 就是一個輸入為 list of SymNode,輸出為 SymNode 的函數。
如果要再細分的話,is_contiguous, is_channels_last_* 這些函數最後回傳的會是 bool-typed SymNode;is_non_overlapping_and_dense_indicator 回傳的則會是 int-typed SymNode。
demo
is_contiguous
以 is_contiguous 為例,它是以 _is_contiguous 的名稱被安裝到 SymNode 上的,下例會直接呼叫它。
傳入的參數是兩個 list of SymNode,即 sizes 和 strides。以下程式會先建立一個 ShapeEnv 並用它創建 SymNode:
python
import sympy
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymNode
shape_env = ShapeEnv()
# SymNode 的 expr 欄位型別為 sympy.Expr,而 sympy.Symbol 是 sympy.Expr 的子類別,
# 所以這裡可以直接傳入 sympy.Symbol。使用 Symbol 是為了表示「未具體化的符號維度」
# (動態 shape);若要表示靜態維度可改用 sympy.Integer,或傳入由兩者組合的運算式。
def create_symnode(name, hint):
return SymNode(sympy.Symbol(name, integer=True), shape_env, int, hint)
# 對應形狀為 (2, 3, 4) 的 contiguous tensor
sizes = [create_symnode("s0", 2), create_symnode("s1", 3), create_symnode("s2", 4)]
strides = [create_symnode("t0", 12), create_symnode("t1", 4), create_symnode("t2", 1)]
result = sizes[0]._is_contiguous(sizes, strides)
print(result)
print(type(result))
print(result.pytype)
print(bool(result))
運行結果:
Eq(s0, 0) | Eq(s1, 0) | Eq(s2, 0) | ((Eq(s1, 1) | Eq(t1, s2)) & (Eq(s2, 1) | Eq(t2, 1)) & (Eq(s0, 1) | Eq(t0, s1*s2)))
<class 'torch.fx.experimental.symbolic_shapes.SymNode'>
<class 'bool'>
True
可以看到輸出結果是一個 SymNode,並且其 pytype 為 bool。
is_non_overlapping_and_dense_indicator
呼叫 _is_non_overlapping_and_dense_indicator:
python
result = sizes[0]._is_non_overlapping_and_dense_indicator(sizes, strides)
print(result)
print(type(result))
print(result.pytype)
print(bool(result))
print(result.hint)
運行結果:
IsNonOverlappingAndDenseIndicator(s0, s1, s2, t0, t1, t2)
<class 'torch.fx.experimental.symbolic_shapes.SymNode'>
<class 'int'>
True
1
輸出結果是一個 SymNode,其 pytype 為 int,而非剛才所見的 bool,hint 值則為 0 或 1 其中之一。
調用流程
以 SymNode.is_contiguous 方法為例,從 SymNode 不帶下底線的方法到 sympy_* 函數的調用流程如下:
text
SymNode.is_contiguous → SymNode._is_contiguous → sympy_is_contiguous
self + 兩個 list[SymNode] → self + 兩個 list[SymNode] → 兩個 list[sympy.Expr]
↓
SymNode(pytype=bool) ← SymNode(pytype=bool) ← SymPy Boolean expression