PyTorch動態形狀系統的基石 - SymNode
- PyTorch動態形狀系統
- SymNode
- [SymBool, SymInt, SymFloat](#SymBool, SymInt, SymFloat)
PyTorch動態形狀系統
在使用深度學習模型時,有時候輸入的形狀會不一樣:
- 推理服務在做動態 batching 時,每批次收到的請求數量不同,這導致了每批次的 batch size 不同
- NLP 模型的序列長度以該 batch 中的最長序列為主,這在每批次可能會不一樣
為了應付這種情況,PyTorch提供了動態形狀系統,其文檔位於Dynamic Shapes。動態形狀系統的核心部份包括:
- torch._dynamo:TorchDynamo(簡稱 Dynamo)是一個 Python 層級的即時(Just-In-Time, JIT)編譯器,旨在讓未經修改的 PyTorch 程式執行得更快。Dynamo 以即時方式運作,並會根據動態特性對計算圖進行特化。
- torch.fx.experimental.symbolic_shapes:它提供了一組介面,用於與PyTorch的「符號形狀(symbolic shapes)推理系統」互動。這個系統在 torch.compile 中被大量使用。雖然這些介面通常不被視為公開 API,但在撰寫 PyTorch 的框架層程式碼,或是為 PyTorch 撰寫擴充功能(例如自訂運算子實作)時,你可能需要使用這些 API,才能正確設定對動態形狀(dynamic shapes)的支援。
- torch.export:
torch.export.export()接收一個torch.nn.Module,並以AOT(Ahead-of-Time,提前)的方式產生一個僅包含 Tensor 計算的追蹤(traced)計算圖。這個計算圖之後可以用於不同的輸入或被序列化。預設情況下,torch.export會假設所有輸入的形狀(shapes)都是靜態的,並將匯出的程式特化(specialize)到這些維度。這會導致以下結果:在執行階段,如果輸入的形狀與匯出時不同,即使在 eager 模式下是合法的,匯出的程式也無法正常運作。
如果某些維度在不同執行之間是動態的(例如 batch 維度),就必須使用torch.export.Dim()API 來建立,並透過dynamic_shapes參數傳入torch.export.export(),以明確指定它們為動態維度。
建立動態形狀模型的流程如下:
- 當我們在 Dynamo 中開始編譯一個 frame 時,我們會配置一個
ShapeEnv(附加在 FakeTensorMode 上),用來追蹤符號形狀(symbolic shapes)的狀態。 - 當張量進入用戶定義的 Python 函式(user-defined function,如被
@torch.compile裝飾的函式)或torch.compile觸發的整體編譯過程時,PyTorch 會為它們分配符號化尺寸(symbolic sizes),而哪些維度是靜態或動態則取決於策略設定(可以透過一些參數調整)。 - 我們會讓 symbolic sizes 在運算子之間傳遞,同時維護以下內容:
- FX IR(intermediate representation, 中間語言):確保能忠實地匯出符號化計算。
- Sympy 表達式:代表各個尺寸變數,使我們能對它們進行推理。
- 當程式在依據 symbolic sizes 做條件判斷時(無論是 Dynamo 追蹤過程中(Dynamo tracing)或 Inductor 最佳化階段(Inductor optimization)),我們會根據條件加入對應的 guard。這些 guard 可以來自 Python 或 C++。
- 這些 guard 可能會觸發對符號變數(symbolic variables)的進一步簡化。例如,如果你斷言
s0 == 4,那麼我們之後就能將所有s0都替換成4。 - 當追蹤(tracing)與最佳化(optimizing)結束後,我們會將所有 guard 安裝到已編譯的程式中;只有當所有 guard 都評估為真時,已編譯的程式才能重複使用。
SymNode
以上動態形狀系統中提到了symbolic shape,其底層使用的便是PyTorch中的Python類別SymNode。
SymNode定義於torch/fx/experimental/symbolic_shapes.py,是一個「符號節點(symbolic node)」的抽象封裝,用於表示張量形狀中的單個維度(例如 batch size)或符號運算的結果。它可以讓模型的shape計算、控制流等部分「符號化(symbolic)」,即:不直接依賴當前執行的具體數值,而能以「符號」的方式記錄運算邏輯。
SymNode在未啟用動態形狀系統時也會用到,如果一個張量的形狀為一已知的常數,PyTorch仍會把該常數包裝為SymNode,這是為了維持統一的符號框架,避免重新編譯。
PyTorch中的注釋:
這是一個 類型抹消(type erased)的 SymInt/SymFloat ,我們用它來執行實際的運算。
最終使用者不會直接接觸到這個物件。在這個物件上 不會定義任何魔術方法(magic methods)。
成員變數
_expr
SymNode是一個符號節點的抽象封裝,expr 成員變數則是 SymNode 的核心屬性,它代表該符號的數學表達式(symbolic expression)。它通常是 SymPy 表達式(sympy.Expr) 或 PyTorch 自訂的簡化表達式形式。
包括以下幾個功能:
-
符號運算:所有加減乘除等操作其實都在操作 expr。這一點可以從magic methods中的
binary_magic_impl等方法中得到印證。關於程式中的符號運算,更多詳見SymPy - Introduction。 -
生成 guard:針對 expr 的關係生成guard條件,如:
expr1 == expr2。
SymPy文檔中關於Expr的介紹,sympy doc - Expr:
!
Expr
Expr is the superclass of all algebraic SymPy expressions. It is itself a subclass of Basic. SymPy expressions that can be in an Add, Mul, or Pow should be Expr subclasses. Not all SymPy classes are subclasses of Expr, for example, Boolean objects are Basic but not Expr, because boolean expressions do not make mathematical sense in classes like Add or Mul.
Expr 是所有代數 SymPy 表達式的超類別。它本身是 Basic 的子類別。能夠出現在 Add、Mul 或 Pow 中的 SymPy 表達式應為 Expr 的子類別。並非所有 SymPy 類別都是 Expr 的子類別,例如,布林物件是 Basic 但不是 Expr,因為布林表達式在 Add 或 Mul 等類別中不具備數學意義。
pytype
pytype儲存了SymNode底層資料的型別。注意pytype儲存的是Python中資料型別,如:bool, int, float等。
_hint
_hint表示該符號在「目前 trace 過程」觀察到的具體整數值。例如 x.size(0) 在這次執行中為 32,那麼 hint=32。
用途:
- 加速運算:不用每次都解
expr。 - guard 比對:下次執行時確認值是否一致。
PyTorch中的注釋:
指的是我們追蹤(tracing)的某一次執行過程中的 特定值 ,但它在下一次執行時可能會不同。
保留 hint 很有用,因為如果我們需要從 SymNode 取得具體值,我們會回傳這個 hint,並在產生它的表達式上加上 guard,以便下一次仍能得到相同的 hint。
不過,在以下情況 hint 不一定會存在:
- 如果你有一個 unbacked SymNode(沒有具體值支撐的符號節點),就不會有 hint。
- 如果它是某個依賴於張量計算的結果,但是因為那個張量運算尚未真正被執行,所以我們並不知道它的具體值實際上是多少。
_hint_expr
_hint_expr是能產生 hint 的符號表達式(symbolic expression)。當 hint 不可用時才存在。
當 SymNode 無法立即提供 concrete hint(如 tensor-dependent 計算尚未執行,或依賴其他 unbacked SymNode)時,PyTorch會記錄產生該 hint 的符號計算過程,例如 hint_expr = (S1 + 1) // 2,其中 S1 是另一符號變數。這允許後續延遲求值:透過 shape_env.replacements 替換已知符號值來重新評估。
PyTorch中的注釋:
hint_expr只有在沒有 hint 的情況下才會被設定。
它儲存的是一個包含 unbacked SymNode 的表達式(expression);
如果後來被約束(constrained),就可以讓這個表達式(expression)再次產生一個 hint。
constant
PyTorch中的注釋:
指的是在模型的多次調用(invocations of the model)過程中 保持不變的值 ;它永遠就是這個值。
我們只有在遇到「真正的字面常數」時才知道這個值(在將其包裝進 SymNode 時,我們會設定constant成員變數)。
在大多數情況下,constant 會是 None。
shape_env
ShapeEnv 是整個符號化張量形狀(symbolic shapes)的「上下文管理器」,用於管理符號、約束expr、化簡expr。
建構子
torch/fx/experimental/symbolic_shapes.py
python
# TODO: An incomplete list
# 1. Set variables to be equal when we do equality
# 2. Specialize on 0/1 when we do subtraction
class SymNode:
"""
This is a type erased SymInt/SymFloat which we use to do actual operations.
End users don't touch this. Magic methods are NOT defined on this object.
"""
def __init__(self, expr, shape_env, pytype, hint: Optional[Union[int, float]], constant=None):
self._expr = expr
self.shape_env = shape_env
self.pytype = pytype
# What's the difference between hint and constant?
#
# - A constant is known to be invariant across invocations of the model;
# it will always be this value. We only really know this when we
# encounter an honest-to-goodness literal (when wrapping it into
# a SymNode, we set constant.) Most of the time, constant is None
#
# - A hint is a *particular* value from the particular run we are
# tracing, but it may vary the next time around. It's useful to
# keep this around, as if we need a concrete value from a SymNode,
# we will return the hint and guard on the expression that produced
# it giving the same hint next time around. The hint is not
# guaranteed to be set either: if you have an unbacked SymNode,
# there won't be any hint; it was the result of some tensor-dependent
# computation, but we don't know what it actually is because we
# haven't actually run the tensor computation.
#
# hint_expr is only set if we don't have a hint. When it is set, it
# contains the expression which contains the unbacked symnodes that,
# if constrained, would allow this expression to be hinted again.
if hint is None:
self._hint_expr = self.expr.xreplace(shape_env.var_to_val)
self._hint = None
self._update_hint() # check if the replacement actually was enough
else:
self._hint_expr = None
self._hint = hint
self.constant: Optional[Union[int, float, bool]] = constant
建構子接受expr, shape_env, pytype, hint等四個必需參數,和constant一個可選參數。
其中hint的型別為Optional[Union[int, float]],表示它可以是int或float,也可以是None;constant的型別為Optional[Union[int, float, bool]],表示它可以是int、float、bool或None。
在建構子中會將_expr, shape_env, pytype, constant等成員變數設為對應的參數值。hint和hint_expr則會根據hint是否為None來決定是否設定:
- 如果入參
hint為None,則會由expr計算得到hint_expr,且_hint成員變數會被初始化為None。在_update_hint()函數中,會嘗試更新_hint和_hint_expr。 - 如果入參
hint不為None,則_hint_expr會被設定為None、_hint被設定為入參hint。
可以看出_hint和_hint_expr勢不兩立,如果其中一個非None,另一個就必須為None。
expr函數
python
@property
def expr(self):
self._update_expr()
return self._expr
python
def _update_expr(self):
self._expr = self.shape_env.replace(self._expr)
_update_expr函數中用到了ShapeEnv.replace函數,其文檔為PyTorch doc - ShapeEnv.replace:
replace(expr)[source]
Apply symbol replacements to any symbols in the given expression.
Return type
_SympyT
ShapeEnv.replace函數會對給定運算式中的所有符號,套用符號替換(symbol replacements)。
字串轉換函數
python
def str(self):
return f"{self.expr}"
def __str__(self):
return self.str()
def __repr__(self):
return self.str()
__str__和__repr__函數都是str函數的套殼,而str函數則會回傳expr成員變數的字串表示。
hint相關函數
python
# Check if we have replacements hint_expr that would allow us to
# simplify it into a hint
def _update_hint(self):
if self._hint_expr.free_symbols <= self.shape_env.replacements.keys():
new_hint = self.shape_env.replace(self._hint_expr)
# NB: unification constraints could result in a replacement that
# doesn't actually solve the hint! Check for this.
if new_hint.free_symbols:
self._hint_expr = new_hint
return
self._hint = self.pytype(new_hint)
self._hint_expr = None
@property
def hint(self):
if self._hint is None:
self._update_hint()
return self._hint
def has_hint(self):
return self._hint is not None
def require_hint(self):
if self._hint is None:
self._update_hint()
if self._hint is None:
raise self.shape_env._make_data_dependent_error(self._hint_expr, self.expr)
else:
return self._hint
else:
return self._hint
_update_hint函數會嘗試更新_hint和_hint_expr。
型別檢查函數
python
def is_int(self):
return self.pytype is int
def is_float(self):
return self.pytype is float
def is_bool(self):
return self.pytype is bool
is_int、is_float和is_bool函數會檢查pytype成員變數是否為Python中的int、float和bool型別。
wrap函數
python
def wrap_int(self, num):
assert type(num) is int
return SymNode(sympy.Integer(num), self.shape_env, int, num, constant=num)
def wrap_float(self, num):
assert type(num) is float
return SymNode(sympy.Float(num), self.shape_env, float, num, constant=num)
def wrap_bool(self, num):
assert type(num) is bool
return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, num, constant=num)
wrap函數首先檢查入參num的型別,如果是Python中的int、float或bool,則會調用SymNode建構子,傳入對應的expr、shape_env、pytype、hint,另外還會傳入可選參數constant,最後將新建的SymNode物件回傳。
在建構SymNode物件時,傳入的expr參數用到了sympy中的類別,可參考class sympy.core.expr.Expr(*args), class sympy.core.numbers.Integer(i), class sympy.core.numbers.Float(num, dps=None, precision=None), class sympy.logic.boolalg.BooleanTrue和class sympy.logic.boolalg.BooleanFalse。
sympy.Integer(num), sympy.Float(num), sympy.true和sympy.false等四個物件都是被當作expr參數傳入SymNode的建構子,但如果我們實際來檢查一下他們是否屬於sympy.Expr:
python
print(issubclass(sympy.Integer, sympy.Expr)) # True
print(issubclass(sympy.Float, sympy.Expr)) # True
print(isinstance(sympy.true, sympy.Expr)) # False
print(isinstance(sympy.false, sympy.Expr)) # False
會發現sympy.true和sympy.false並不屬於sympy.Expr。所以其實PyTorch並沒有嚴格要求SymNode的expr成員必須是sympy.Expr型別。
clone函數
python
def clone(self):
return self
clone函數直接回傳self。
四則運算函數
python
# These methods call the metaprogrammed methods, they're hand written
# here so we get good stack traces
def add(self, other) -> "SymNode": # noqa: F811
return self._add(other) # type: ignore[attr-defined]
def sub(self, other) -> "SymNode": # noqa: F811
return self._sub(other) # type: ignore[attr-defined]
def mul(self, other) -> "SymNode": # noqa: F811
return self._mul(other) # type: ignore[attr-defined]
def mod(self, other) -> "SymNode": # noqa: F811
return self._mod(other) # type: ignore[attr-defined]
def pow(self, other) -> "SymNode": # noqa: F811
return self._pow(other) # type: ignore[attr-defined]
python
def truediv(self, other) -> "SymNode": # noqa: F811
return self._truediv(other) # type: ignore[attr-defined]
def floordiv(self, other) -> "SymNode": # noqa: F811
return self._floordiv(other) # type: ignore[attr-defined]
SymNode類別中定義了add, sub, mul, mod, pow, truediv, floordiv等方法。而這些方法的實現則是呼叫對應的帶下底線的版本,例如add會呼叫_add,sub會呼叫_sub,以此類推。但是如果我們實際在SymNode類別尋找,卻找不到_add, _sub, _mul, _mod, _pow, _truediv, _floordiv這些帶下底線的方法的定義。
其實這些帶下底線的方法們是透過 metaprogramming(元編程) 自動生成的方法,會在稍後才被安裝,它們被稱作magic methods,具體詳見magic methods。
在安裝magic methods後,如果我們對一個SymNode物件呼叫了truediv方法,它會呼叫magic method:self._floordiv(other)。參考binary_magic_impl,在SymNode._truediv中,會先取出self和other兩個SymNode的expr成員變數(型別為sympy.Expr),對它們做TrueDiv操作,得到新的sympy.Expr。在SymNode._truediv的最後,會將運算得到的sympy.Expr當作SymNode建構子的expr參數,創建一個SymNode後回傳。注意此處創建出來的SymNode的pytype為float。
這整個過程的函數調用鏈,函數的入參和回傳可視化如下:
SymNode.truediv → SymNode._truediv → TrueDiv
兩個SymNode → 兩個SymNode → 兩個sympy.Expr
↓
SymNode ← SymNode ← sympy.Expr
bool運算函數
python
def and_(self, other) -> "SymNode": # noqa: F811
return self._and_(other) # type: ignore[attr-defined]
def or_(self, other) -> "SymNode": # noqa: F811
return self._or_(other) # type: ignore[attr-defined]
and_, or_方法會呼叫_and_和_or_方法,這兩個前面有下底線的方法同樣是透過metaprogramming自動生成的,稍後會被安裝到SymNode類別上。詳見magic methods。
大小比較函數
python
def eq(self, other) -> "SymNode": # noqa: F811
return self._eq(other) # type: ignore[attr-defined]
def ne(self, other) -> "SymNode": # noqa: F811
return self._ne(other) # type: ignore[attr-defined]
def gt(self, other) -> "SymNode": # noqa: F811
return self._gt(other) # type: ignore[attr-defined]
def lt(self, other) -> "SymNode": # noqa: F811
return self._lt(other) # type: ignore[attr-defined]
def le(self, other) -> "SymNode": # noqa: F811
return self._le(other) # type: ignore[attr-defined]
def ge(self, other) -> "SymNode": # noqa: F811
return self._ge(other) # type: ignore[attr-defined]
此處被調用的_eq, _ne, _gt, _lt, _le, _ge等方法同樣是透過metaprogramming自動生成的,稍後會被安裝到SymNode類別上。詳見SymBool, SymInt, SymFloat user magic method。
sym相關函數
python
def sym_not(self) -> "SymNode": # noqa: F811
return self._sym_not() # type: ignore[attr-defined]
python
def sym_float(self) -> "SymNode": # noqa: F811
return self._sym_float() # type: ignore[attr-defined]
def sym_int(self) -> "SymNode": # noqa: F811
return self._sym_int() # type: ignore[attr-defined]
python
def sym_min(self, other) -> "SymNode": # noqa: F811
return self._sym_min(other) # type: ignore[attr-defined]
def sym_max(self, other) -> "SymNode": # noqa: F811
return self._sym_max(other) # type: ignore[attr-defined]
def sym_sqrt(self) -> "SymNode": # noqa: F811
return self._sym_sqrt() # type: ignore[attr-defined]
此處被調用的_sym_not, _sym_float, _sym_int, _sym_min, _sym_max, _sym_sqrt等方法同樣是透過metaprogramming自動生成的,稍後會被安裝到SymNode類別上。詳見magic methods。
python
# Make C++ happy
def sym_or(self, other): # noqa: F811
return self.or_(other)
def sym_and(self, other): # noqa: F811
return self.and_(other)
這兩個函數調用的是稍早看到的or_和and_。
取整函數
python
def floor(self) -> "SymNode": # noqa: F811
return self._floor() # type: ignore[attr-defined]
python
def ceil(self) -> "SymNode": # noqa: F811
return self._ceil() # type: ignore[attr-defined]
此處被調用的_floor, _ceil等方法同樣是透過metaprogramming自動生成的,稍後會被安裝到SymNode類別上。詳見magic methods。
neg函數
python
def neg(self) -> "SymNode": # noqa: F811
return self._neg() # type: ignore[attr-defined]
此處被調用的_neg方法同樣是透過metaprogramming自動生成的,稍後會被安裝到SymNode類別上。詳見magic methods。
sizes, strides相關函數
python
def is_contiguous(self, sizes, strides) -> "SymNode": # noqa: F811
return self._is_contiguous(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode": # noqa: F811
return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode": # noqa: F811
return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode": # noqa: F811
return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode": # noqa: F811
return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined]
def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode": # noqa: F811
return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined]
python
def is_non_overlapping_and_dense(self, sizes, strides):
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined]
此處被調用的_is_contiguous, _is_channels_last_contiguous_2d, _is_channels_last_contiguous_3d, _is_channels_last_strides_2d, _is_channels_last_strides_3d, _is_non_overlapping_and_dense_indicator等方法同樣是透過metaprogramming自動生成的,稍後會被安裝到SymNode類別上。詳見SymNode sizes strides methods。
guard相關函數
python
def int_(self):
return self.guard_int("", 0) # NB: uses Python backtrace
# You can manually trigger a guard with this function
def guard_int(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(self.expr, self.hint)
try:
return int(r)
except Exception:
log.warning(f"Failed to convert to int: {r}")
raise
def guard_float(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(self.expr, self.hint)
try:
return float(r)
except Exception:
log.warning(f"Failed to convert to float: {r}")
raise
def guard_bool(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(self.expr, self.hint)
try:
return bool(r)
except Exception:
log.warning(f"Failed to convert to bool: {r}")
raise
def bool_(self):
return self.guard_bool("", 0)
guard_int, guard_float, guard_bool等三個函數都調用了shape_env.evaluate_expr,其文檔為PyTorch doc - ShapeEnv.evaluate_expr:
evaluate_expr(orig_expr, hint=None, fx_node=None, size_oblivious=False, fallback_value=None, *, forcing_spec=False)[source]
Given an expression, evaluates it, adding guards if necessary When fallback_value is not None the function return fallback_value instead of failing with data dependent error.
Return type
Basic
給定一個運算式(expression),evaluate_expr函式會對其進行求值(evaluate),並在必要時加入保護條件(guards)。
當 fallback_value 不是 None 時,函式不會拋出與資料相關的錯誤(data-dependent error),而是返回 fallback_value 作為替代結果。
SymBool, SymInt, SymFloat
torch/__init__.py中定義了與SymNode相關的類別SymBool, SymInt, SymFloat,不過它們並不是繼承關係:這三個類別的建構子都會接受一個node參數,並將node參數賦值給self.node成員變數。