前言
前篇 為什麼這個 Tensor 算 dense?PyTorch _eval_is_non_overlapping_and_dense 深入解析 介紹了 _eval_is_non_overlapping_and_dense 函數,該函數位於 torch/fx/experimental/symbolic_shapes.py,實作了判斷張量的記憶體佈局是否「非重疊且稠密」(non-overlapping and dense)的邏輯,是個純 Python 函數。
本篇介紹的 eval_is_non_overlapping_and_dense 函數則是對 _eval_is_non_overlapping_and_dense 函數的包裝,與它同樣是一個模組層級的函數,並開始跟 PyTorch 的編譯系統產生關聯。
注:從本節開始,會牽涉到 PyTorch 中的動態形狀系統,建議先閱讀 PyTorch 的官方文檔:Dynamic Shapes Core Concepts。
eval_is_non_overlapping_and_dense
eval_is_non_overlapping_and_dense 位於 torch/fx/experimental/symbolic_shapes.py,實作如下:
python
# TODO: Deduplicate this with torch/_prims_common/__init__.py
def eval_is_non_overlapping_and_dense(sizes, strides):
return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides)))
可以看到它的參數跟 _eval_is_non_overlapping_and_dense 一樣是 sizes 和 strides,呼叫 _eval_is_non_overlapping_and_dense 後對其回傳值作了兩層包裝,第一層是 guard_bool,第二層則是 int。
為何要做這兩層包裝呢?回顧前篇 為什麼這個 Tensor 算 dense?PyTorch _eval_is_non_overlapping_and_dense 深入解析,當中提到了 _eval_is_non_overlapping_and_dense 函數回傳的型別在一般情況下是 bool,但也有回傳 SymBool 的情況,所以此處才在它回傳後多包一層 guard_bool 做保險,即使 _eval_is_non_overlapping_and_dense 真的回傳了 SymBool,guard_bool 也能將它轉成 bool 型別。
第二層包裝則是將 bool 轉為 int 給後續函數做使用。
調用端
IsNonOverlappingAndDenseIndicator
torch/fx/experimental/symbolic_shapes.py
eval_is_non_overlapping_and_dense 會在 IsNonOverlappingAndDenseIndicator.eval 方法中被調用:
python
class IsNonOverlappingAndDenseIndicator(sympy.Function):
# ...
@classmethod
def eval(cls, *args):
# ...
return eval_is_non_overlapping_and_dense(
[int(a) for a in size_args],
[int(a) for a in stride_args]
)
return None
而 IsNonOverlappingAndDenseIndicator 正是我們下一篇的主題。
SYMPY_INTERP
torch/fx/experimental/symbolic_shapes.py
python
SYMPY_INTERP = {
'Eq': operator.eq,
'Ne': operator.ne,
'Gt': operator.gt,
'Lt': operator.lt,
'Le': operator.le,
'Ge': operator.ge,
'Min': min,
'Max': max,
'Mod': operator.mod,
'FloorDiv': operator.floordiv,
'TrueDiv': operator.truediv,
'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
'floor': math.floor,
'ceiling': math.ceil,
}
在 SYMPY_INTERP 字典中,作為 'IsNonOverlappingAndDenseIndicator' 鍵的值,在執行期驗證時會用到。
sizes_strides_user
_make_node_sizes_strides 函數同樣位於 torch/fx/experimental/symbolic_shapes.py,是一個模組層級的函數,當中調用了 eval_is_non_overlapping_and_dense:
python
def _make_node_sizes_strides(method, func):
# ...
# TODO: This is technically hotpath, but in the ideal end state
# guards on this will resolve at a higher level so you never
# spend time in this code
def sizes_strides_user(sizes, strides):
for a in itertools.chain(sizes, strides):
if isinstance(a, SymInt):
return wrap_node(getattr(a.node, method)(
[to_node(a.node, b) for b in sizes],
[to_node(a.node, b) for b in strides],
))
if method == "is_non_overlapping_and_dense_indicator":
return eval_is_non_overlapping_and_dense(sizes, strides)
else:
# TODO: this is an awful implementation
return bool(func(
[sympy.sympify(a) for a in sizes],
[sympy.sympify(a) for a in strides],
))
PyTorch SymNode 為何找不到方法實作?──sizes_strides_methods 動態安裝機制解析 中有對 sizes_strides_methods 的詳細介紹。