user magic methods
- [user magic methods](#user magic methods)
- [user magic methods安裝流程](#user magic methods安裝流程)
-
- 主程式
- _make_user_magic
-
- [\_make\_user\_magic - method_attr](#_make_user_magic - method_attr)
- [\_make\_user\_magic - unary\_magic\_impl](#_make_user_magic - unary_magic_impl)
- [\_make\_user\_magic - binary\_magic\_impl](#_make_user_magic - binary_magic_impl)
- [\_make\_user\_magic - rbinary\_magic\_impl](#_make_user_magic - rbinary_magic_impl)
- [安裝user magic method](#安裝user magic method)
- 調用流程
user magic methods
PyTorch模型支援動態形狀的輸入。在PyTorch的動態形狀系統中,除了之前看過的torch.SymNode之外,還會用到torch.SymInt, torch.SymFloat和torch.SymBool這三個類別。其中torch.SymInt和torch.SymFloat用於表示計算形狀期間產生的symbolic sizes;另外計算形狀期間有可能會需要做邏輯判斷,這時便會用到torch.SymBool來表示symbolic的邏輯值。
但如果我們去查看torch.SymBool, torch.SymInt, torch.SymFloat等類別的定義,卻會發現很多未實作的方法,如__eq__, __lt__等。
其實這些方法不是沒有實作,而是稍後會由torch.fx.experimental.symbolic_shapes模組安裝到torch.SymBool, torch.SymInt, torch.SymFloat上,這些方法被稱為user magic methods。
magic_methods
在PyTorch SymNode 的設計之謎:為何magic methods「看起來沒實作」?處我們已經看過magic methods的定義,以及它們是如何被安裝到torch.SymNode上的。
如果回去查看magic_methods的定義,可以知道magic methods包含了unary magic methods和binary magic methods兩個子集合。
在binary magic methods子集合中,有所謂的reflectable_magic_methods,來看看它的定義。
reflectable_magic_methods
reflectable_magic_methods的定義位於torch.fx.experimental.symbolic_shapes.py。它是一個 將方法的名稱對應到lambda函數 的字典,其中key代表方法的名字,value則為該方法的實現:
python
# Methods that have a `__foo__` as well as `__rfoo__`
reflectable_magic_methods = {
'add': lambda a, b: a + b,
'sub': lambda a, b: a - b,
'mul': lambda a, b: a * b,
'mod': lambda a, b: a % b,
'pow': lambda a, b: Pow(a, b),
'and': lambda a, b: a & b,
'or': lambda a, b: a | b,
'truediv': lambda a, b: TrueDiv(a, b),
'floordiv': lambda a, b: FloorDiv(a, b),
}
注意以上方法的入參和回傳值皆為sympy.Expr。
待會在[user magic methods安裝流程](#user magic methods安裝流程)章節會看到,如果一個binary method本來的名稱是foo,則它會被以__foo__的名稱安裝到SymInt, SymFloat或SymBool上。
例如sub方法會被以__sub__的名稱安裝到SymInt和SymFloat上,之後使用者便可以透過SymInt.__sub__(other)或SymFloat.__sub__(other)來調用這個方法。
如果一個binary method屬於reflectable_magic_methods,那麼除了SymInt.__sub__和SymFloat.__sub__之外,還會多安裝一個__rsub__方法。
那麼__sub__和__rsub__有何不同之處呢?
SymInt.__sub__(other)是由自己減去對方,即由self._expr減去other._expr;SymInt.__rsub__(other)則反過來,是由對方減去自己,即由other._expr減去self._expr。
reflectable_magic_methods中大部份方法在做什麼都一目瞭然,只有pow, truediv, floordir三個方法用到了PyTorch中自定義的類別Pow, TrueDiv和FloorDiv,讓我們來看看它們的定義。
Pow
reflectable_magic_methods中的pow方法對應到PyTorch中自定義的Pow類別,其定義如下:
python
# Overloaded to be compatible with regular Python.
# https://github.com/pytorch/pytorch/issues/90900
class Pow(sympy.Function):
@classmethod
def eval(cls, base, exp):
if exp.is_zero:
return sympy.Integer(1)
elif base.is_zero and exp < 0:
raise ZeroDivisionError(f"{base} cannot be raised to a negative power")
else:
return base ** exp
可以看到Pow類別繼承了sympy.Function,並且定義了class method eval方法,這實際上是在按照sympy的規則來撰寫自定義函數,詳見Creating a Custom Function。
Pow.eval函數的入參cls是Pow類別本身,底數base和指數exp則皆為sympy.Expr。
Pow.eval函數用於指數運算,分以下幾種情況:
- 當指數
exp是0時:直接回傳1 - 當底數
base是0且指數exp為負時:不合法,raiseZeroDivisionError錯誤。數學細節詳見Are exponents with base 0 even defined? - 在正常情況下則會回傳
base的exp次方
可以看出,Pow.eval函數對的核心是sympy的**運算子,PyTorch中為了處理指數為0的特殊情況和底數為0且指數為負的錯誤,才對sympy的**運算子進行了包裝。
TrueDiv
reflectable_magic_methods中的true_div方法對應到TrueDiv類別,其定義如下:
python
# Overloaded to be compatible with regular Python.
# https://github.com/pytorch/pytorch/issues/90900
class TrueDiv(sympy.Function):
@classmethod
def eval(cls, base, divisor):
if divisor.is_zero:
raise ZeroDivisionError("division by zero")
else:
return base / divisor
此處TrueDiv繼承自sympy.Function,並定義了eval方法,可知TrueDiv也是按照sympy規則來撰寫的自定義函數。
TrueDiv.eval函數的入參cls是TrueDiv類別本身,分子base和分母divisor則皆為sympy.Expr。
TrueDiv.eval函數用於除法運算,分以下幾種情況:
- 在分母
divisor為0的情況下,會raiseZeroDivisionError錯誤 - 否則其行為跟
/運算子一樣
這裡也可以看出,TrueDiv.eval函數的核心是sympy的/運算子,PyTorch中為了處理分母為0的錯誤,才對sympy的/運算子進行了包裝。
FloorDiv
reflectable_magic_methods中的floor_div方法對應到FloorDiv類別,其定義如下:
python
class FloorDiv(sympy.Function):
"""
We maintain this so that:
1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
"""
nargs = (2,)
precedence = 50 # precedence of mul # noqa: F811
# Default return type for SymPy assumptions.
# https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers
is_real = True
@property
def base(self):
return self.args[0]
@property
def divisor(self):
return self.args[1]
def _sympystr(self, printer):
base = printer.parenthesize(self.base, self.precedence)
divisor = printer.parenthesize(self.divisor, self.precedence)
return f"{base}//{divisor}"
# SymPy assumptions based on argument types.
def _eval_is_real(self):
return fuzzy_or([self.base.is_real, self.divisor.is_real])
def _eval_is_integer(self):
return fuzzy_and([self.base.is_integer, self.divisor.is_integer])
# Automatic evaluation.
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
@classmethod
def eval(cls, base, divisor):
def check_supported_type(x):
if (x.is_integer is False and x.is_real is False and x.is_complex) or x.is_Boolean:
raise TypeError(
f"unsupported operand type(s) for //: "
f"'{type(base).__name__}' and '{type(divisor).__name__}'"
f", expected integer or real")
check_supported_type(base)
check_supported_type(divisor)
# We don't provide the same error message as in Python because SymPy
# makes it difficult to check the types.
if divisor.is_zero:
raise ZeroDivisionError("division by zero")
if base.is_zero:
return sympy.S.Zero
if base.is_integer and divisor == 1:
return base
if base.is_real and divisor == 1:
return sympy.floor(base)
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
return base // divisor
if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance(divisor, (sympy.Integer, sympy.Float)):
return sympy.floor(base / divisor)
if isinstance(base, FloorDiv):
return FloorDiv(base.args[0], base.args[1] * divisor)
if isinstance(base, sympy.Add):
for a in base.args:
gcd = sympy.gcd(a, divisor)
if gcd == divisor:
return FloorDiv(base - a, divisor) + a / gcd
gcd = sympy.gcd(base, divisor)
if gcd != 1:
return FloorDiv(
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
)
FloorDiv.eval函數的入參cls是FloorDiv類別本身,分子base和分母divisor則皆為sympy.Expr。
FloorDiv.eval函數會做除法後取floor,分以下幾種情況:
- 在分母
divisor為0的情況下,會raiseZeroDivisionError錯誤 - 在分子
base為0的情況下,會回傳0 - 在分子
base是整數且分母divisor是1的情況下,不需做除法也不需取floor,會直接回傳分子base - 在分子
base是實數且分母divisor是1的情況下,不需做除法,只需要取floor,回傳sympy.floor(base) - 在分子
base是整數且分母divisor是整數的情況下,會直接做整數除法,回傳base // divisor - 在分子
base是sympy.Integer或sympy.Float且分母divisor亦是sympy.Integer或sympy.Float的情況下,會真的做除法,然後取floor,回傳sympy.floor(base / divisor) - 在分子
base是FloorDiv的情況下,會將分母divisor與base.args[1]相乘,當作新的分母,然後回傳FloorDiv(base.args[0], base.args[1] * divisor) - 在分子
base是sympy.Add的情況下,會遍歷base.args(加數、b)中的每個元素,嘗試找出與分母divisor的最大公因數gcd與divisor相等的元素a。如果找到了,則先單獨計算該元素與divisor的商,即a / gcd,因為a是gcd的倍數,所以a / gcd是一個整數。然後再加上其它元素之和與divisor的商的floor值,即FloorDiv(base - a, divisor),最後回傳 - 最後,如果分子
base與分母divisor的最大公因數gcd不等於1,則會先將分子分母各自除以gcd、各自簡化後再做FloorDiv運算。回傳值為:FloorDiv(sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)) - 如果最後的if條件不成立不會return值回去?
bool_magic_methods
在magic_methods的子集合中,以下方法屬於bool magic methods,會被安裝到SymBool上:
python
bool_magic_methods = {"and", "or", "sym_not"}
magic methods中不屬於bool magic methods者則將會同時被安裝到SymInt和SymFloat兩個類別上。
wrap_node
在等一下會看到的_make_user_magic函數中會大量用到wrap_node函數,其定義位於 torch/fx/experimental/symbolic_shapes.py:
python
def wrap_node(x):
# TODO: let C++ also take advantage of this
if isinstance(x, SymNode) and x.constant is not None:
return x.constant
if x.is_int():
return SymInt(x)
elif x.is_float():
return SymFloat(x)
elif x.is_bool():
return SymBool(x)
else:
raise AssertionError(f"unrecognized return type {x}")
wrap_node接受的參數x為一SymNode物件,而SymNode有個constant成員變數:
python
self.constant: Optional[Union[int, float, bool]] = constant
wrap_node函數會檢查SymNode x的constant成員變數,分以下幾種情況:
- 在
SymNode x的constant成員變數非空的情況下,wrap_node會取出constant回傳,而constant的型別是int,float或bool其中之一。 - 如果
constant成員變數為空,則wrap_node會依據型別檢查函數回傳的結果,決定將SymNode包裝成SymInt,SymFloat或者是SymBool,最後將包裝後的物件回傳。 - 如果
SymNode x不滿足上述情況,則會將x的型別視為不合法,會raiseAssertionError錯誤。
to_node
在等一下會看到的binary_magic_impl函數中會用到to_node函數,其定義位於 torch/fx/experimental/symbolic_shapes.py:
python
def to_node(self, num):
if isinstance(num, SymTypes):
return num.node
elif type(num) is bool:
return self.wrap_bool(num)
elif type(num) is int:
return self.wrap_int(num)
elif type(num) is float:
return self.wrap_float(num)
else:
# NotImplemented is important so that Python tries the
# other magic method
return NotImplemented
to_node函數接受以下參數:
self:SymNode物件num:可能是SymTypes(包含SymInt,SymFloat和SymBool)或Python中的bool, float或int
首先檢查num是否為SymTypes:
python
SymTypes = (SymInt, SymFloat, SymBool)
如果num是SymTypes,則to_node函數會取出其node成員變數(型別為SymNode)回傳;如果是Python中的bool, float或int,則會調用對應的wrap函數,將其包裝成SymNode後回傳。
在wrap_bool函數中會檢查入參是否為Python中的bool,如果是,便將sympy.true或sympy.false當作expr參數傳入SymNode的建構子,重新建構一個SymNode後回傳。
wrap_int、wrap_float也類似,分別會將Python中的int和float包裝成SymNode後回傳。
另外要注意的一點是,在wrap_bool, wrap_int和wrap_float等函數回傳的SymNode物件中會將constant成員變數設定為num。
user magic methods安裝流程
為SymInt, SymFloat或SymBool安裝user magic methods的程式碼位於torch/fx/experimental/symbolic_shapes.py。
主程式
在安裝user magic methods的主程式中,會遍歷magic_methods,一一對 函數名稱method和lambda函數func的pair 呼叫_make_user_magic函數:
python
for method, func in magic_methods.items():
if method in bool_magic_methods:
_make_user_magic(method, SymBool)
else:
_make_user_magic(method, SymInt)
_make_user_magic(method, SymFloat)
這段程式碼會為SymBool安裝magic_methods中屬於bool_magic_methods的方法,包括__and__, __or__, __sym__not__。
為SymInt和SymFloat安裝magic_methods中不屬於bool_magic_methods的方法,包括__add__, __sub__, __mul__, __mod__, __pow__, __and__, __or__, __truediv__, __floordiv__, __sym__not__, __eq__, __ne__, __gt__, __lt__, __le__, __ge__, __floor__, __sym__float__, __ceil__, __neg__, __sym__min__, __sym__max__, __sym__sqrt。
最後還會為SymInt和SymFloat安裝reflectable magic methods,包括__radd__, __rsub__, __rmul__, __rmod__, __rpow__, __rand__, __ror__, __rtruediv__, __rfloordiv__。
_make_user_magic
_make_user_magic函數的作用是為SymInt, SymFloat或SymBool安裝名為__method__的方法,等一下會看到,其實__method__方法就是我們在PyTorch SymNode 的設計之謎:為何magic methods「看起來沒實作」?見過的SymNode._method_attr方法的包裝。
_make_user_magic接受method和user_type兩個參數:
method代表方法的名稱,方法會以__method__的名稱被安裝user_type:要為哪個類別安裝方法,可以是SymInt,SymFloat或SymBool其中之一
python
def _make_user_magic(method, user_type):
# User magic takes care of wrapping the other operand into a node,
# so that our internal logic can assume everything is nodes
if method in magic_methods_on_operator_with_trailing_underscore:
method_attr = f"{method}_"
else:
method_attr = method
def unary_magic_impl(self):
return wrap_node(getattr(self.node, method_attr)())
def binary_magic_impl(self, other):
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
return wrap_node(getattr(self.node, method_attr)(other_node))
def rbinary_magic_impl(self, other):
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
return wrap_node(getattr(other_node, method_attr)(self.node))
if method in unary_magic_methods:
setattr(user_type, f"__{method}__", unary_magic_impl)
else:
setattr(user_type, f"__{method}__", binary_magic_impl)
if method in reflectable_magic_methods:
setattr(user_type, f"__r{method}__", rbinary_magic_impl)
這個函數比較長,我們可以將它拆成五個部份來看:
- 一開始定義了
method_attr變數 - 中間定義了
unary_magic_impl子函數 - 接著定義了
binary_magic_impl子函數 - 接著定義了
rbinary_magic_impl子函數 - 最後則是實際把
unary_magic_impl,binary_magic_impl或rbinary_magic_impl安裝在user_type上
_make_user_magic - method_attr
如果入參 method 在 magic_methods_on_operator_with_trailing_underscore (包括and和or)中,則會將method_attr變數設為method + "_",否則將method_attr設為method。
python
if method in magic_methods_on_operator_with_trailing_underscore:
method_attr = f"{method}_"
else:
method_attr = method
待會會看到,unary_magic_impl, binary_magic_impl和rbinary_magic_impl都是對SymNode方法的包裝,而它們調用的SymNode方法名稱即為method_attr。
_make_user_magic - unary_magic_impl
因為unary_magic_impl函數即將被安裝在的SymInt, SymFloat或SymBool身上,可知unary_magic_impl的參數self就是SymInt, SymFloat或SymBool其中之一:
python
def unary_magic_impl(self):
return wrap_node(getattr(self.node, method_attr)())
SymBool, SymInt或SymFloat都有一個SymNode型別的成員變數node,此處的self.node將取出其node成員變數。
getattr(self.node, method_attr)會獲取SymNode的method_attr方法。
參考四則運算函數,在SymNode.method_attr方法中會調用SymNode._method_attr方法。
注意到SymNode._method_attr方法就是PyTorch SymNode 的設計之謎:為何magic methods「看起來沒實作」? - magic method安裝流程處由_make_node_magic安裝的unary_magic_impl,它會接受此處傳入的self.node(型別為SymNode)作為參數。
在_make_node_magic裡的unary_magic_impl(它與此處_make_user_magic裡的unary_magic_impl是兩個不同的函數)中,會透過SymNode.expr方法把SymNode的_expr成員變數取出,把它當作參數傳入magic methods字典裡的lambda函數,lambda函數用輸入的sympy.Expr做運算後會回傳另一個sympy.Expr物件。_make_node_magic裡的unary_magic_impl會把回傳的sympy.Expr做包裝,得到一個SymNode之後回傳。
在此處_make_user_magic裡的unary_magic_impl函數中,會將這個SymNode用wrap_node函數包起來。
wrap_node函數在入參SymNode的constant成員變數非空的情況下會回傳底層的Python int, float或bool;如果constant成員變數為空,則會將入參SymNode包裝成SymInt, SymFloat或SymBool後回傳。
因為在_make_node_magic - unary_magic_impl處創建的SymNode並未設定constant成員變數,所以此處wrap_node函數回傳的是SymInt, SymFloat或SymBool其中之一。
總結一下,unary_magic_impl接受SymInt, SymFloat或SymBool為參數,調用底層的SymNode magic method對它做運算,最後同樣回傳SymInt, SymFloat或者是SymBool。
_make_user_magic - binary_magic_impl
比起unary_magic_impl,binary_magic_impl多了一個參數other:
python
def binary_magic_impl(self, other):
在無法保證other之型別的情況下,需要做以下前處理:
python
other_node = to_node(self.node, other)
to_node函數的作用依照other的型別有所不同:
- 在
other屬於SymTypes = (SymInt, SymFloat, SymBool)的情況下,to_node會直接取出它們的node成員變數(也就是SymNode)回傳 - 在
other屬於int, float, bool的情況下,to_node會用wrap_int, wrap_float或wrap_bool函數對它們做包裝,得到一個SymNode後回傳
接著透過()運算子調用self.node.method_attr方法,傳入other_node,得到一個SymNode物件。最後用wrap_node函數將SymNode包裝成SymInt, SymFloat或者是SymBool後回傳:
python
return wrap_node(getattr(self.node, method_attr)(other_node))
_make_user_magic - rbinary_magic_impl
在method屬於reflectable_magic_methods的情況下,會額外安裝所謂的rbinary函數。
rbinary_magic_impl函數與binary_magic_impl大體相同,不同之處在於以下這行程式碼:
python
getattr(other_node, method_attr)(self.node)
binary_magic_impl是在self.node上調用method_attr,並把other_node當作參數傳入;rbinary_magic_impl則反過來,在other_node上調用method_attr,把self.node當作參數傳入。
以pow函數為例,binary版本是以self.expr為底數,other.expr為指數;rbinary版本則是相反,以other.expr為底數,self.expr為指數。
安裝user magic method
定義完必要的子函數後,在_make_user_magic的最後會檢查入參method是否屬於unary_magic_methods、將它們分為unary和binary兩類。
對於unary method,就用unary_magic_impl包裝,然後安裝到user_type上;對於binary method,則用binary_magic_impl做包裝,同樣安裝到user_type上。
在binary method底下,還有個子集合reflectable_magic_methods,如果method屬於reflectable_magic_methods,則會額外用rbinary_magic_impl包裝,然後安裝到user_type上。
python
if method in unary_magic_methods:
setattr(user_type, f"__{method}__", unary_magic_impl)
else:
setattr(user_type, f"__{method}__", binary_magic_impl)
if method in reflectable_magic_methods:
setattr(user_type, f"__r{method}__", rbinary_magic_impl)
之後我們就可以透過對SymInt、SymFloat或SymBool物件呼叫__method_attr__或__rmethod_attr__方法來調用。讓我們來驗證一下,在Python命令行裡查看SymBool.__and__的方法,會出現以下輸出:
python
>>> torch.fx.experimental.symbolic_shapes.SymBool.__and__
<function _make_user_magic.<locals>.binary_magic_impl at 0x7fa81db14280>
在Python命令行裡查看SymInt.__add__和SymInt.__radd__的方法,會出現以下輸出:
python
import torch
>>> torch.fx.experimental.symbolic_shapes.SymInt.__add__
<function _make_user_magic.<locals>.binary_magic_impl at 0x7fa81db13700>
>>> torch.fx.experimental.symbolic_shapes.SymInt.__radd__
<function _make_user_magic.<locals>.rbinary_magic_impl at 0x7fa81db13790>
從輸出可以看出來,SymBool.__and__確實跟_make_user_magic和binary_magic_impl有關,SymInt.__add__和SymInt.__radd__也是如此,代表它們確實是由torch/fx/experimental/symbolic_shapes.py安裝的user magic methods。
調用流程
使用者可以透過SymInt / 1.的語法來做除法運算,此處用到了SymInt的/運算子,但從torch.SymInt中卻找不到/運算子的定義。
參考Python truediv() Magic Method:
!
to evaluate the expression x / y, Python attempts to call x.__truediv__(y)
在Python中使用/運算子時,底層會呼叫__truediv__。
而我們知道,此處調用的SymInt.__truediv__方法是在SymInt被定義後才被安裝上去的user magic method。
注意到SymInt的__truediv__方法接受的運算元是SymInt本身和一個數字1.,兩個不同型別的運算元該如何做運算呢?
參考binary_magic_impl,在binary_magic_impl函數中有個前處理,會調用to_node函數將1.包裝成SymNode,如此一來__truediv__的兩個運算元便都是SymNode型別了。
如[user magic methods安裝流程](#user magic methods安裝流程)章節所述,SymInt.__truediv__會呼叫SymNode.truediv,而根據四則運算函數,SymNode.truediv會進一步呼叫SymNode._truediv。
參考_make_node_magic - binary_magic_impl,在SymNode._truediv中,會先取出兩個SymNode的expr成員變數(型別為sympy.Expr),對它們做TrueDiv操作,得到新的sympy.Expr。在SymNode._truediv的最後,會將運算得到的sympy.Expr當作SymNode建構子的expr參數,創建一個SymNode後回傳。注意此處創建出來的SymNode的pytype為float。
SymInt.__truediv__得到SymNode._truediv和SymNode.truediv回傳的SymNode後,會再調用wrap_node函數。
在wrap_node函數中,如果SymNode的pytype為float,則會將SymNode包裝成SymFloat後回傳。
所以SymInt / 1.會得到一個SymFloat。
以下三行分別代表這整個過程中的函數調用鏈,各函數的參數和回傳值,整理如下:
SymInt.__truediv__ → SymNode.truediv → SymNode._truediv → lambda函數 → TrueDiv
SymInt和float → 兩個SymNode → 兩個SymNode → 兩個sympy.Expr → 兩個sympy.Expr
↓
SymFloat ← SymNode ← SymNode ← sympy.Expr ← sympy.Expr
注:其中SymInt.__truediv__即透過_make_user_magic被安裝到SymInt上的binary_magic_impl,SymNode._truediv即透過_make_node_magic被安裝到SymNode上的binary_magic_impl。