為什麼要有 eval_is_non_overlapping_and_dense?PyTorch 包裝層與調用端解析

前言

前篇 為什麼這個 Tensor 算 dense?PyTorch _eval_is_non_overlapping_and_dense 深入解析 介紹了 _eval_is_non_overlapping_and_dense 函數,該函數位於 torch/fx/experimental/symbolic_shapes.py,實作了判斷張量的記憶體佈局是否「非重疊且稠密」(non-overlapping and dense)的邏輯,是個純 Python 函數。

本篇介紹的 eval_is_non_overlapping_and_dense 函數則是對 _eval_is_non_overlapping_and_dense 函數的包裝,與它同樣是一個模組層級的函數,並開始跟 PyTorch 的編譯系統產生關聯。

注:從本節開始,會牽涉到 PyTorch 中的動態形狀系統,建議先閱讀 PyTorch 的官方文檔:Dynamic Shapes Core Concepts

eval_is_non_overlapping_and_dense

eval_is_non_overlapping_and_dense 位於 torch/fx/experimental/symbolic_shapes.py,實作如下:

python 复制代码
# TODO: Deduplicate this with torch/_prims_common/__init__.py
def eval_is_non_overlapping_and_dense(sizes, strides):
    return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides)))

可以看到它的參數跟 _eval_is_non_overlapping_and_dense 一樣是 sizesstrides,呼叫 _eval_is_non_overlapping_and_dense 後對其回傳值作了兩層包裝,第一層是 guard_bool,第二層則是 int

為何要做這兩層包裝呢?回顧前篇 為什麼這個 Tensor 算 dense?PyTorch _eval_is_non_overlapping_and_dense 深入解析,當中提到了 _eval_is_non_overlapping_and_dense 函數回傳的型別在一般情況下是 bool,但也有回傳 SymBool 的情況,所以此處才在它回傳後多包一層 guard_bool 做保險,即使 _eval_is_non_overlapping_and_dense 真的回傳了 SymBoolguard_bool 也能將它轉成 bool 型別。

第二層包裝則是將 bool 轉為 int 給後續函數做使用。

調用端

IsNonOverlappingAndDenseIndicator

torch/fx/experimental/symbolic_shapes.py

eval_is_non_overlapping_and_dense 會在 IsNonOverlappingAndDenseIndicator.eval 方法中被調用:

python 复制代码
class IsNonOverlappingAndDenseIndicator(sympy.Function):
    # ...
    @classmethod
    def eval(cls, *args):
        # ...
            return eval_is_non_overlapping_and_dense(
                [int(a) for a in size_args],
                [int(a) for a in stride_args]
            )
        return None

IsNonOverlappingAndDenseIndicator 正是我們下一篇的主題。

SYMPY_INTERP

torch/fx/experimental/symbolic_shapes.py

python 复制代码
SYMPY_INTERP = {
    'Eq': operator.eq,
    'Ne': operator.ne,
    'Gt': operator.gt,
    'Lt': operator.lt,
    'Le': operator.le,
    'Ge': operator.ge,
    'Min': min,
    'Max': max,
    'Mod': operator.mod,
    'FloorDiv': operator.floordiv,
    'TrueDiv': operator.truediv,
    'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
    'floor': math.floor,
    'ceiling': math.ceil,
}

SYMPY_INTERP 字典中,作為 'IsNonOverlappingAndDenseIndicator' 鍵的值,在執行期驗證時會用到。

sizes_strides_user

_make_node_sizes_strides 函數同樣位於 torch/fx/experimental/symbolic_shapes.py,是一個模組層級的函數,當中調用了 eval_is_non_overlapping_and_dense

python 复制代码
def _make_node_sizes_strides(method, func):
    # ...
    # TODO: This is technically hotpath, but in the ideal end state
    # guards on this will resolve at a higher level so you never
    # spend time in this code
    def sizes_strides_user(sizes, strides):
        for a in itertools.chain(sizes, strides):
            if isinstance(a, SymInt):
                return wrap_node(getattr(a.node, method)(
                    [to_node(a.node, b) for b in sizes],
                    [to_node(a.node, b) for b in strides],
                ))
        if method == "is_non_overlapping_and_dense_indicator":
            return eval_is_non_overlapping_and_dense(sizes, strides)
        else:
            # TODO: this is an awful implementation
            return bool(func(
                [sympy.sympify(a) for a in sizes],
                [sympy.sympify(a) for a in strides],
            ))

PyTorch SymNode 為何找不到方法實作?──sizes_strides_methods 動態安裝機制解析 中有對 sizes_strides_methods 的詳細介紹。

相关推荐
九酒5 小时前
AI Agent 开发踩坑记:口播功能非得用 APP 原生实现吗?
前端·人工智能·agent
蝎子莱莱爱打怪5 小时前
DSpark 讲透:DeepSeek 不换模型,硬把 V4 提速 85%,是怎么做到的?
人工智能·面试·程序员
巫山老妖7 小时前
置身AI内
人工智能
IT_陈寒8 小时前
JavaScript项目实战经验分享
前端·人工智能·后端
vanuan9 小时前
两个AI智能体第一次对话-A2A双Agent协作实战
人工智能
Warson_L10 小时前
Python `Annotated` 与 LangGraph Reducer 学习笔记
python
韩师傅10 小时前
海天线算法的前世今生
python·计算机视觉
韩师傅10 小时前
当你的甲方设备过烂,要如何快速出效果?
python·计算机视觉
Warson_L10 小时前
LangGraph的MessageState and HumanMessage
python
韩师傅11 小时前
当你的甲方吐槽天空不够蓝,你应该如何应对
python·计算机视觉