sizes_strides_methods
PyTorch 模型支援動態形狀的輸入。在 PyTorch 的動態形狀系統中,會使用 torch.SymInt 來表示張量的尺寸(sizes)及步長(strides),而 torch.SymInt 的底層便是 torch.SymNode。
torch.SymNode 類別中定義了 sizes, strides相關函數,可以看到它們的參數皆為 self, sizes, strides:
python
def is_contiguous(self, sizes, strides) -> "SymNode": # noqa: F811
return self._is_contiguous(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode": # noqa: F811
return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode": # noqa: F811
return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode": # noqa: F811
return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode": # noqa: F811
return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined]
def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode": # noqa: F811
return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined]
如果查看這些函數的定義,會發現它們各自的實作內容都只是呼叫了一個帶下底線前綴的版本,但如果我們試圖進一步去尋找那些函數,會發現 SymNode 類別根本沒有定義那些函數。
其實這些方法不是沒有實作,而是稍後會在 torch/fx/experimental/symbolic_shapes.py 中安裝到 SymNode 類別和 torch.fx.experimental.symbolic_shapes 模組上。這些方法在 PyTorch 中被稱為 sizes strides methods。
sizes_strides_methods 的定義位於 torch/fx/experimental/symbolic_shapes.py。它是一個 將方法的名稱對應到 lambda 函數 的字典,其中 key 代表方法的名字,value 則為該方法的實現:
python
sizes_strides_methods = {
# TODO: These could also be done with indicators, maybe it is better
# for reasoning to do it that way
'is_contiguous': lambda sizes, strides: sympy_is_contiguous(sizes, strides),
'is_channels_last_contiguous_2d': lambda sizes, strides: sympy_is_channels_last_contiguous_2d(sizes, strides),
'is_channels_last_contiguous_3d': lambda sizes, strides: sympy_is_channels_last_contiguous_3d(sizes, strides),
'is_channels_last_strides_2d': lambda sizes, strides: sympy_is_channels_last_strides_2d(sizes, strides),
'is_channels_last_strides_3d': lambda sizes, strides: sympy_is_channels_last_strides_3d(sizes, strides),
'is_non_overlapping_and_dense_indicator': lambda sizes, strides: IsNonOverlappingAndDenseIndicator(*sizes, *strides),
}
如果想要它們各自在做什麼,可參考以下文章:
sympy_is_contiguous:PyTorch 中的張量連續性檢查sympy_is_channels_last_contiguous_2d,sympy_is_channels_last_contiguous_3d,sympy_is_channels_last_strides_2d,sympy_is_channels_last_strides_3d:PyTorch 如何知道張量是 NCHW 還是 Channels Last 的?--sympy_is_channels_last_strides_generic函數解析IsNonOverlappingAndDenseIndicator:為什麼這個 Tensor 算 dense?PyTorch _eval_is_non_overlapping_and_dense 深入解析
sizes_strides_methods 的輸入和輸出
前五個函數所接受的參數:
- sequence of
int,sympy.Expr,sympy.Symbol,sympy.Integer或 sympy 組合式 - 完全具體化的情況:sequence of
int,sympy.Integer
回傳值:
- symbolic boolean expression,如
sympy.Eq,sympy.Ne等的組合 - 完全具體化的情況:
sympy.true或sympy.false。注意即使入參是 Python int,回傳值也不是 python bool,因為它們的回傳值是基於sympy.true去做運算的
IsNonOverlappingAndDenseIndicator 所接受的參數:
- 未完全具體化的情況:多個
int,sympy.Expr,sympy.Symbol,sympy.Integer或 sympy 組合式 - 完全具體化的情況:多個
int,smypy.Integer
回傳值:
- 未完全具體化的情況:
IsNonOverlappingAndDenseIndicator的eval函數回傳None表示無法進一步化簡;外層IsNonOverlappingAndDenseIndicator(...)回傳的最終結果是未化簡的 symbolicIsNonOverlappingAndDenseIndicatorexpr。IsNonOverlappingAndDenseIndicator繼承自sympy.Function,具體可參考 Defining Automatic Evaluation with eval() - 完全具體化的情況:Python int 0 或 1
sizes_strides_methods 安裝流程
為 SymNode 類別和 torch.fx.experimental.symbolic_shapes 模組安裝 sizes strides methods 的程式碼位於 torch/fx/experimental/symbolic_shapes.py。
主程式
以下是安裝 sizes strides methods 的主程式:
python
for method, func in sizes_strides_methods.items():
_make_node_sizes_strides(method, func)
它會遍歷 sizes_strides_methods,對(函數名稱 method,lambda 函數 func)的 pair 呼叫 _make_node_sizes_strides 函數。
_make_node_sizes_strides
_make_node_sizes_strides 函數的作用是為 SymNode 類別和 torch.fx.experimental.symbolic_shapes 模組安裝指定的 sizes strides method:
python
def _make_node_sizes_strides(method, func):
# NB: don't LRU cache, lots of arguments
def sizes_strides_impl(self, sizes, strides):
op = getattr(sys.modules[__name__], method)
if SYM_FUNCTION_MODE:
return to_node(
self,
_handle_sym_dispatch(
op,
([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
{}
)
)
size_exprs = [s.expr for s in sizes]
stride_exprs = [s.expr for s in strides]
try:
out = func(size_exprs, stride_exprs)
except Exception:
log.warning(f"failed to eval {method}({size_exprs}, {stride_exprs})")
raise
# bool is never expandable
size_hints = []
out_hint = None
for s in sizes:
if s.hint is None:
break
size_hints.append(s.hint)
else:
stride_hints = []
for s in strides:
if s.hint is None:
break
stride_hints.append(s.hint)
else:
out_hint = op(size_hints, stride_hints)
# NB: This is the indicator function, not the actual bool!
pytype: Type
if method.endswith("_indicator"):
pytype = int
else:
pytype = bool
return SymNode(out, self.shape_env, pytype, out_hint)
setattr(SymNode, f"_{method}", sizes_strides_impl)
# TODO: This is technically hotpath, but in the ideal end state
# guards on this will resolve at a higher level so you never
# spend time in this code
def sizes_strides_user(sizes, strides):
for a in itertools.chain(sizes, strides):
if isinstance(a, SymInt):
return wrap_node(getattr(a.node, method)(
[to_node(a.node, b) for b in sizes],
[to_node(a.node, b) for b in strides],
))
if method == "is_non_overlapping_and_dense_indicator":
return eval_is_non_overlapping_and_dense(sizes, strides)
else:
# TODO: this is an awful implementation
return bool(func(
[sympy.sympify(a) for a in sizes],
[sympy.sympify(a) for a in strides],
))
# Skip for is_non_overlapping_and_dense_indicator
if not hasattr(sys.modules[__name__], method):
setattr(sys.modules[__name__], method, sizes_strides_user)
它接受 method 和 func 兩個參數,其中 method 代表函數的名稱。函數會以 _method(帶下底線前綴)的名稱被安裝在 SymNode 上,以 method 的名稱被安裝在 torch.fx.experimental.symbolic_shapes 上。
func 則為欲安裝的函數本身,也就是剛剛在 sizes_strides_methods 裡看到的 sympy_is_contiguous, sympy_is_channels_last_contiguous_2d, sympy_is_channels_last_contiguous_3d, sympy_is_channels_last_strides_2d, sympy_is_channels_last_strides_3d 和 IsNonOverlappingAndDenseIndicator 等函數。
可以看到,在這個函數裡又定義了兩個函數 sizes_strides_impl 和 sizes_strides_user,並且還有兩行 setattr。它的意思是將 func 包裝成 sizes_strides_impl 和 sizes_strides_user 後,再分別將這兩個函數安裝到 SymNode 和 torch.fx.experimental.symbolic_shapes 上。
這個函數可以分兩部份看:第一部份 sizes_strides_impl 是在為 SymNode 安裝成員函數;第二部份 sizes_strides_user 則是在為 torch.fx.experimental.symbolic_shapes 模組本身安裝模組層級的函數,接下來的兩篇文章中將會細究 sizes_strides_impl 和 sizes_strides_user 的實作細節及安裝過程。