為什麼這個 Tensor 算 dense?PyTorch _eval_is_non_overlapping_and_dense 深入解析

前言

PyTorch 2.0 引入了 torch.compile,這個模組透過將 使用者以 Python 撰寫的模型程式碼 擷取成計算圖並進行編譯優化,來加速模型的訓練與推論。而為了在編譯期處理張量形狀未知或可變的情況,PyTorch 建立了一套動態形狀系統(dynamic shapes system):張量的尺寸與步長不再只是 Python 中具體的整數,而可以是 SymPy 符號表達式(如 s0s1);另外還會透過約束求解與邏輯推導來判定各種性質是否成立。

torch/fx/experimental/symbolic_shapes.py 正是這套系統中處理符號形狀、guard 與形狀約束推理的核心模組,當中定義了我們的 _eval_is_non_overlapping_and_dense 函式。

該函數用來判斷張量的記憶體佈局是否「非重疊且稠密」(non-overlapping and dense),它本身只實現判斷邏輯,是一個純 Python 的計算函數。為了將這個判斷邏輯接入動態形狀系統(包括 SymNode, SymInt, SymBool 等自訂義類別),PyTorch 對它做了層層包裝,由內到外分別是:_eval_is_non_overlapping_and_denseeval_is_non_overlapping_and_denseIsNonOverlappingAndDenseIndicatorSymNode._is_non_overlapping_and_dense_indicatorSymNode.is_non_overlapping_and_dense_indicatoris_non_overlapping_and_dense,之後將會一一做介紹。

在正式進入 _eval_is_non_overlapping_and_dense 函數之前,我們得先了解「重疊」和「稠密」這兩個概念。

重疊

有時一個張量中的多個不同的邏輯索引(如:x[1, 0] 中的 [1, 0])會對應到同一個記憶體位置,我們將這種現象稱為重疊(overlapping)。常見的成因有兩種:

  • 維度的步長相對於其尺寸太小,導致相鄰 row 的位置重疊
  • 維度的步長為零,導致相鄰 row 的元素都指向同一塊記憶體

以下的兩個範例分別展示這兩種成因:

例 1:as_strided

首先構建一個一維張量 x

python 复制代码
x = torch.arange(5)
x
x.shape
x.stride()

輸出:

复制代码
tensor([0, 1, 2, 3, 4])
torch.Size([5])
(1,)

PyTorch 中有個 torch.as_strided 函數:

复制代码
torch.as_strided(input, size, stride, storage_offset=None) → Tensor
Create a view of an existing torch.Tensor input with specified size, stride and storage_offset.

Parameters
:
input (Tensor) -- the input tensor.

size (tuple or ints) -- the shape of the output tensor

stride (tuple or ints) -- the stride of the output tensor

storage_offset (int, optional) -- the offset in the underlying storage of the output tensor. If None, the storage_offset of the output tensor will match the input tensor.

as_strided 是 PyTorch 提供的底層 view 操作:它接受一個既有的 input 張量,使用者可指定輸出張量的 size, stride,以及 storage_offset(可選參數),在不複製資料的前提下,將同一塊底層記憶體重新解讀成新的張量視圖。

我們來用 torch.as_strided 函數將張量 x 變成一個二維張量 y

python 复制代码
y = torch.as_strided(x, size=(3, 3), stride=(1, 1))
y

輸出:

复制代码
tensor([[0, 1, 2],
        [1, 2, 3],
        [2, 3, 4]])

從輸出可以看出來,y 張量的元素重複了,也就是所謂的「重疊」。


來看看 y 張量的形狀和步長:

python 复制代码
y.shape
y.stride()

輸出:

复制代码
torch.Size([3, 3])
(1, 1)

其中 stride=(1, 1) 表示無論沿第 0 維還是第 1 維前進一格,底層記憶體位址都只往後移動 1 個元素。因此不同的二維索引可能會對應到同一個記憶體位址,導致了重疊的現象。

例如由 [0, 0] 沿第 0 維前進一格,會得到 [0, 1],沿第 1 維前進一格,會得到 [1, 0],這兩個二維索引指向的底層記憶體位址相同。


我們可以進一步驗證一下,以下程式可以查看每個邏輯索引對應到底層記憶體的哪個 offset:

python 复制代码
for i in range(y.size(0)):
    for j in range(y.size(1)):
        offset = y.storage_offset() + i * y.stride(0) + j * y.stride(1)
        print(f"y[{i},{j}] -> storage offset {offset}")
复制代码
y[0,0] -> storage offset 0
y[0,1] -> storage offset 1
y[0,2] -> storage offset 2
y[1,0] -> storage offset 1
y[1,1] -> storage offset 2
y[1,2] -> storage offset 3
y[2,0] -> storage offset 2
y[2,1] -> storage offset 3
y[2,2] -> storage offset 4

[0,1][1,0] 都對應到 storage offset 1,即:有多個邏輯索引對應到同一個記憶體位置,表示這個張量是「重疊」的。


如果還想更進一步驗證,可以用以下程式,它會改動 x 的第一個元素:

python 复制代码
x[1] = 5
y

輸出:

复制代码
tensor([[0, 5, 2],
        [5, 2, 3],
        [2, 3, 4]])

可以發現二維 view 張量 y[0, 1][1, 0] 都一起被改動了!這再次證實了 y 張量中的這兩個二維索引確實指向同一個記憶體位置!換句話說:y 是重疊的張量。

例 2:expand

構建一個一維張量 x

python 复制代码
x = torch.ones(4)
x

輸出:

复制代码
tensor([1., 1., 1., 1.])

PyTorch 中有個 torch.Tensor.expand 函數:

复制代码
Tensor.expand(*sizes) → Tensor
Returns a new view of the self tensor with singleton dimensions expanded to a larger size.

Passing -1 as the size for a dimension means not changing the size of that dimension.

Tensor can be also expanded to a larger number of dimensions, and the new ones will be appended at the front. For the new dimensions, the size cannot be set to -1.

Expanding a tensor does not allocate new memory, but only creates a new view on the existing tensor where a dimension of size one is expanded to a larger size by setting the stride to 0. Any dimension of size 1 can be expanded to an arbitrary value without allocating new memory.

Parameters
:
*sizes (torch.Size or int...) -- the desired expanded size

expand 接受 *sizes 參數,表示使用者期望張量被擴展後的尺寸;回傳一個新的 view,將原張量中尺寸為 1 的維度擴展到指定的更大尺寸。傳入 -1 表示該維度尺寸不變。張量也可以擴展到更高的維度數,新維度會加在最前面(但新維度不能傳 -1)。擴展張量時不會分配新的記憶體,而是透過將步長設為 0,讓尺寸為 1 的維度在不複製資料的情況下被「看成」更大的尺寸。

expand 函數將張量 x 沿著新維度廣播成一個 4x4 的二維張量 y

python 复制代码
y = x.expand(4, -1)
y

輸出:

复制代码
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])

其實 y 的每個 row 指向的是同一塊記憶體,可以用以下程式來驗證。以下程式會修改 x 的第一個元素:

python 复制代码
x[1] = 5
y

輸出:

复制代码
tensor([[1., 5., 1., 1.],
        [1., 5., 1., 1.],
        [1., 5., 1., 1.],
        [1., 5., 1., 1.]])

y 的步長為:

python 复制代码
y.stride()

輸出:

复制代码
(0, 1)

y 第一個維度的步長為 0,這表示沿著第一個維度往下移動一格時,底層記憶體位址完全不變;也就是說,y[0, :]y[1, :]y[2, :]y[3, :] 其實都對應到同一個 row 的底層資料,只是被視為不同 row,可見 y 這個 view 張量是「重疊」的。

稠密

如果張量各元素之間沒有空洞,底層記憶體被完整填滿,無多餘空間,那麼該張量就可以被稱為「稠密」。

例:非稠密的張量

我們要生成有空洞的張量,先從一維張量開始:

python 复制代码
base = torch.arange(12)

輸出:

复制代码
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

先把它變成兩個 row,再取每個 row 的前 3 個元素:

python 复制代码
x = base.view(2, 6)[:, :3]
x
x.shape
x.stride()

輸出:

复制代码
tensor([[0, 1, 2],
        [6, 7, 8]])
torch.Size([2, 3])
(6, 1)

由以上輸出可知,x 包含的 [[0, 1, 2], [6, 7, 8]] 等六個元素在底層記憶體中並不是連續排在一起的:第 0 維步長為 6 代表每跨一個 row 要跳 6 個元素,但每個 row 只有 3 個元素,因此中間會留下 3 個元素的空洞。


以下程式會把每個邏輯索引對應到的 storage offset 列出來:

python 复制代码
for i in range(x.size(0)):
    for j in range(x.size(1)):
        offset = x.storage_offset() + i * x.stride(0) + j * x.stride(1)
        print(f"x[{i},{j}] -> storage offset {offset}")

輸出:

复制代码
x[0,0] -> storage offset 0
x[0,1] -> storage offset 1
x[0,2] -> storage offset 2
x[1,0] -> storage offset 6
x[1,1] -> storage offset 7
x[1,2] -> storage offset 8

可以看到,第一個 row 用到的是 offset 0,1,2,第二個 row 直接跳到 6,7,8;中間的 3,4,5 完全沒有被 x 這個 view 用到,再次證明了這個張量有空洞,也就是「非稠密」。


正常的 2x3 張量則是這樣:

python 复制代码
y = torch.arange(6).view(2, 3)
y
y.shape
y.stride()
for i in range(y.size(0)):
    for j in range(y.size(1)):
        offset = y.storage_offset() + i * y.stride(0) + j * y.stride(1)
        print(f"y[{i},{j}] -> storage offset {offset}")

輸出:

复制代码
tensor([[0, 1, 2],
        [3, 4, 5]])
torch.Size([2, 3])
(3, 1)
y[0,0] -> storage offset 0
y[0,1] -> storage offset 1
y[0,2] -> storage offset 2
y[1,0] -> storage offset 3
y[1,1] -> storage offset 4
y[1,2] -> storage offset 5

可以看到它剛好連續用掉底層記憶體 offset 的 0~5,中間沒有任何空洞。


剛剛看到的重疊或是非稠密的張量,都是對 base 張量做了一些操作後才得到的。如果我們單純用 torch.onestorch.empty 製造張量,那麼該張量自然會同時滿足「非重疊」和「稠密」兩個條件。可以想像,這樣的張量中每個元素在底層記憶體中都會對應到不同的位址,並且這些位址也是連續分布的。在 PyTorch 中,我們可以用一個專有名詞------contiguous(連續)來形容它。

contiguous

Tensor Implementation Details 中提到:

A tensor is contiguous in memory if its elements are laid out in the Storage in the same order as a standard C-style (row-major) traversal. For a contiguous tensor, the stride typically follows a pattern where the stride for the last dimension is 1, the stride for the second-to-last dimension is the size of the last dimension, and so on.

如果一個張量的元素在儲存空間(Storage)中的排列順序與標準 C-style (row-major) 的遍歷順序相同,則該張量在記憶體中是連續的。對於連續張量,步長通常遵循以下規律:最後一個維度的步長為 1,倒數第二個維度的步長等於最後一個維度的大小,以此類推。

可以想見,如果一個張量是 contiguous 的,其步長應該會是遞減的。

我們可以用 PyTorch 提供的 torch.Tensor.is_contiguous 函數來判斷張量的連續性,官網對函數的介紹如下:

复制代码
torch.Tensor.is_contiguous
Tensor.is_contiguous(memory_format=torch.contiguous_format) → bool
Returns True if self tensor is contiguous in memory in the order specified by memory format.

Parameters
:
memory_format (torch.memory_format, optional) -- Specifies memory allocation order. Default: torch.contiguous_format.

self 張量在記憶體中是依照 memory_format 所指定的順序為連續的,則回傳 True。

可選參數 memory_format 用於指定記憶體配置順序,預設值為 torch.contiguous_format


is_contiguous 當中的 memory_format 參數是什麼意思我們暫且按下不表,先從一個簡單的 demo 開始看起:

python 复制代码
x = torch.arange(12).reshape(3, 4)
x
x.shape
x.stride()
x.is_contiguous()

輸出:

复制代码
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
torch.Size([3, 4])
(4, 1)
True

x 的步長為 (4, 1),最後一個維度的步長為 1,倒數第二個維度的步長等於最後一個維度的大小,這符合 row-major 連續排列的規律,所以 x.is_contiguous() 會返回 True


現在對 x 張量做轉置,改變其維度順序。使用的是 Tensor.T 函數:

复制代码
Tensor.T
Returns a view of this tensor with its dimensions reversed.

If n is the number of dimensions in x, x.T is equivalent to x.permute(n-1, n-2, ..., 0).

Tensor.T 回傳此張量的一個 view,其維度順序被反轉。若 x 的維度數為 n,則 x.T 等同於 x.permute(n-1, n-2, ..., 0)

python 复制代码
y = x.T
y
y.shape
y.stride()
y.is_contiguous()

輸出:

复制代码
tensor([[ 0,  4,  8],
        [ 1,  5,  9],
        [ 2,  6, 10],
        [ 3,  7, 11]])
torch.Size([4, 3])
(1, 4)
False

Tensor.T 只改變了觀察維度的順序,底層資料並沒有被搬動。但張量的步長卻變成 (1, 4),已經不再是遞減排序了,因此也不再滿足預設記憶體格式下的 contiguous 條件,所以 y.is_contiguous() 的結果是 False


剛剛我們略過了 is_contiguous 函數的 memory_format(記憶體格式)參數:其預設值為 torch.contiguous_format(如:NCHW 或 NCTHW),其它可選的參數值包括 torch.channels_last, torch.channels_last_3d 等。

為什麼 is_contiguous 函數需要這個參數呢?這是因為記憶體格式會影響維度順序,而 contiguous 又對維度順序有要求,所以使用者必須指定記憶體格式,函數才能決定輸入張量是否滿足 contiguous 的條件。

舉個例子,對兩個形狀相同的 4 維 NCHW 張量來說,如果傳入不同的 memory_format 參數,函數就會有不同的判定結果:

python 复制代码
x = torch.empty((2, 3, 5, 7))
x.stride()
x.is_contiguous()
x.is_contiguous(memory_format=torch.channels_last)

輸出:

复制代码
(105, 35, 7, 1)
True
False

結果符合預期:使用預設參數時 x 會被判為 contiguous,如果使用。

以下程式碼創建一個 y 張量,使用的是 channels-last 記憶體格式(NHWC):

python 复制代码
y = torch.empty((2, 3, 5, 7), memory_format=torch.channels_last)
y.stride()
y.is_contiguous()
y.is_contiguous(memory_format=torch.channels_last)

輸出:

复制代码
(105, 1, 21, 3)
False
True

可以看到,x 採用預設的 torch.contiguous_format 記憶體格式,所以可被 is_contiguous() 判為 contiguous;而 y 雖然同樣稠密且非重疊,卻是依照 channels_last(NHWC)順序排列,因此對預設的 is_contiguous() 函數來說,它並不是 contiguous。只有在指定記憶體格式參數為 torch.channels_last 的情況下才會返回 True

contiguous 這個性質要求張量的元素在記憶體中依照特定順序緊密排列,因此自然也具有以下兩個性質:

  • 不同索引不會指向同一個位置,也就是「非重疊」
  • 元素之間不會留下空洞,也就是「稠密」

現在我們知道張量的維度順序會影響 contiguous 的判斷結果;那麼,若一個張量雖然不符合 contiguous 的定義,卻仍同時滿足以上兩個條件,有沒有一個詞可以形容這種張量呢?有的,在 PyTorch 中我們把它叫做 non-overlapping and dense(非重疊且稠密)。

非重疊且稠密

non_overlapping_and_dense 與 contiguous 不同,是無方向性的。只要存在某個維度排列,使該張量在該排列下是 contiguous,我們就稱該張量滿足 non_overlapping_and_dense,並不在意維度的實際排列順序。

我們可以用以下程式比較這兩個概念。首先用 torch.randn 中規中矩地創建一個張量:

python 复制代码
x = torch.randn(3, 4)  # shape=(3,4), stride=(4,1)
x.is_contiguous()
torch.ops.aten.is_non_overlapping_and_dense(x)

輸出如下:

复制代码
True
True

如預期般,這個張量既是 contiguous 又是 non overlapping and dense。

接著對 x 動一點手腳,對它轉置得到張量 y

python 复制代码
y = x.T  # shape=(4,3), stride=(1,4)
y.is_contiguous()
torch.ops.aten.is_non_overlapping_and_dense(y)

輸出如下:

复制代码
False
True

.T 會改變觀察維度的順序,導致 y 無法滿足 contiguous 的條件;另外因為 .T 不會搬動張量的底層資料,所以 y 仍然保持著 non overlapping and dense 的性質。可見 non_overlapping_and_dense 是比 contiguous 還寬鬆的條件:所有 contiguous 張量必然是 non_overlapping_and_dense,但反過來則不成立。

如何判斷

上一節使用的 torch.ops.aten.is_non_overlapping_and_dense 判斷函數實際上只是一個 wrapper,它真正的判斷邏輯實作於 _eval_is_non_overlapping_and_dense 函數,也就是本篇的主題。

_eval_is_non_overlapping_and_dense 函數只關心張量在記憶體中是否非重疊及稠密,並不在乎張量各維度的排列順序。函數中的判斷邏輯如下:如果存在某個維度排列,使得輸入張量在該排列下等價於某個 contiguous 張量,即:有效維度(指尺寸大於 1 的維度,尺寸為 1 的維度會被忽略)的步長恰好等於按該排列排序後前面所有維度尺寸的連乘積,那麼該張量就可被視為 non_overlapping_and_dense。

算法之所以只把尺寸大於 1 的維度視為有效維度而忽略尺寸為 1 的維度,原因有二:

  • 尺寸為 1 的維度只有一個元素,無論步長為何都不會存取到新的記憶體位址,既不會與其他元素重疊,也不會造成空洞
  • 這類維度的步長可以是任意值(詳見 PyTorch 張量尺寸為 1 時,步長為何不具語意?),若納入計算反而會干擾判斷

_eval_is_non_overlapping_and_dense

_eval_is_non_overlapping_and_dense 定義在 torch/fx/experimental/symbolic_shapes.py,是一個模組層級的函數,我們可以透過 torch.fx.experimental.symbolic_shapes._eval_is_non_overlapping_and_dense 來調用之。

python 复制代码
def _eval_is_non_overlapping_and_dense(sizes, strides):
    dim = len(sizes)

    # Short-circuits for tensors of rank one, which are
    # non-overlapping and "dense" if their stride is one
    # or it is a 0/1 element tensor
    if dim == 1:
        return strides[0] == 1 or sizes[0] < 2

    # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
    # Sorts (length, stride) pairs by stride
    lengths_and_strides = sorted(
        zip(sizes, strides), key=operator.itemgetter(1)
    )

    # Unlike the C++ code, we don't move the 0/1 size dimensions to the
    # end.  So we have to keep going for this code.
    expected_stride = 1
    for length, stride in lengths_and_strides:

        if length == 1:
            continue

        if stride != expected_stride:
            return False

        expected_stride *= length

    return True

本函數只接收兩個參數:sizesstrides,分別代表輸入張量各維度的尺寸和步長,在理想情況下,兩者皆為 Python 中的 list of int。


值得注意的是,本函數並沒有接受跟「維度順序」和「維度邏輯意義」相關的參數:

_eval_is_non_overlapping_and_dense 則兩者都不需要:它既不在意維度的排列順序,也不在意維度的邏輯意義,只關心張量在記憶體中元素是否重疊、是否稠密------而無論維度的順序為何,這兩個性質都不會改變。


回到函數內容,在輸入張量只有一維時有個短路優化:

python 复制代码
if dim == 1:
    return strides[0] == 1 or sizes[0] < 2
  • 如果步長為 1,則張量必定連續
  • 當一維張量的尺寸小於 2 時,表示張量只有 0 或 1 個元素。此時該張量必為稠密,並且無需在意無重疊問題

如果輸入不只一維,檢查將會繼續:

python 复制代码
lengths_and_strides = sorted(zip(sizes, strides), key=operator.itemgetter(1))

這行代碼將 (size, stride) 按步長由小到大排序,這是在模擬「最可能符合 contiguous 的維度順序」:把各維度由內而外排列------因為在 contiguous 張量中,越內層的維度步長越小。

接著會按照此順序遍歷各維度,計算它們應有的步長,然後看實際的步長值是否符合。

expected_stride 變數正是遍歷過程中各維度應有的步長值,因為檢查順序是由內而外,而最內層維度的步長理應為 1,所以這裡將初始值設為 1:

python 复制代码
expected_stride = 1

接著逐一檢查排序後的(尺寸、步長)對:

python 复制代码
for length, stride in lengths_and_strides:
    if length == 1:
        continue  # size=1 維度不影響連續性,跳過
    
    if stride != expected_stride:
        return False
    
    expected_stride *= length
  • length == 1:因為長度為 1 的維度對記憶體排列和 expected_stride 都沒有影響,所以可以直接忽略
  • 否則檢查當前步長是否等於 expected_stride,不相等則判為非 non_overlapping_and_dense
  • 最後更新下一個外層維度的預期步長:expected_stride *= length

如果某個 (sizes, strides) 對能在每一次的 if stride != expected_stride 判斷中都堅挺住不被 return False,表示該張量中每個有效維度的步長都符合預期,所以最後會:

python 复制代码
    return True

將該張量判為 non_overlapping_and_dense。


在理想情況下,函數的參數為 Python 的 list of int,內部的比較(==!=<)皆為普通 int 運算,因此回傳值會是 Python 的 bool。

但實際上函數並沒有限制輸入的類別,它們也可能是 list of SymInt,在這種情況下,如果輸入張量剛好只有一維,觸發短路優化(strides[0] == 1 or sizes[0] < 2),比較的結果將會是 SymBool 型別。


為了與 PyTorch 的動態形狀系統互動,_eval_is_non_overlapping_and_dense 外面其實還套了好幾層包裝。其中最直接的一層是 eval_is_non_overlapping_and_dense:它在接收到 SymBool 後,會調用 guard_bool 將其轉換為普通 bool,並把該表達式登記到 ShapeEnv.guards 列表。除此之外還有其他層次的包裝,共同負責將本函數串接到動態形狀系統,會在接下來的文章中一一介紹。


函數中有段註解:

python 复制代码
# Unlike the C++ code, we don't move the 0/1 size dimensions to the 
# end. So we have to keep going for this code.

不像 C++ 程式碼,我們不把 size=0/1 維度移到末尾

這是說對於尺寸為一的長度,Python 版是透過排序加跳過來處理,C++ 版本則有另外的處理方式。

C++ 版本 - _compute_non_overlapping_and_dense

C++ 版本位於 c10/core/TensorImpl.cpp_compute_non_overlapping_and_dense

cpp 复制代码
bool _compute_non_overlapping_and_dense(
    ArrayRef<T> sizes,
    ArrayRef<T> strides) {
  auto dim = sizes.size();
  if (dim == 1) {
    return sizes[0] < 2 || strides[0] == 1;
  }
  SmallVector<int64_t, 5> perm;
  perm.resize(dim);
  for (const auto i : c10::irange(dim)) {
    perm[i] = i;
  }
  // Sort by strides, leaving 0 and 1 sized dims at the end of the array
  std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
    if (sizes[a] < 2) {
      return false;
    } else if (sizes[b] < 2) {
      return true;
    }
    return strides[a] < strides[b];
  });
  T require_stride = 1;
  for (const auto i : c10::irange(dim)) {
    const auto& size_perm_i = sizes[perm[i]];
    if (size_perm_i < 2) {
      return true;
    }
    if (strides[perm[i]] != require_stride) {
      return false;
    }
    require_stride *= size_perm_i;
  }
  return true;
}

一開始一樣有個短路優化:

cpp 复制代码
  auto dim = sizes.size();
  if (dim == 1) {
    return sizes[0] < 2 || strides[0] == 1;
  }

若是一維張量,當 size < 2stride == 1 就直接返回 true

若是多維張量,接下來需要對各維度按步長由小到大排序。在 Python 中,是用 sorted(zip(sizes, strides), ...) 直接對 (size, stride) 對本身做排序;C++ 版本則採用不同做法:先建立一個包含各維度索引的陣列 perm = [0, 1, ..., dim-1],對它依步長做升序排列,之後如果想要存取「排序後第 i 個維度」的尺寸或步長,得再透過 sizes[perm[i]]strides[perm[i]] 間接取得。

以下程式碼就是在建立 perm 維度索引陣列:

cpp 复制代码
  SmallVector<int64_t, 5> perm;
  perm.resize(dim);
  for (const auto i : c10::irange(dim)) {
    perm[i] = i;
  }

接著實際用 std::sort 搭配自訂 lambda 比較函式對 perm 進行排序:

cpp 复制代码
  // Sort by strides, leaving 0 and 1 sized dims at the end of the array
  std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
    if (sizes[a] < 2) {
      return false;
    } else if (sizes[b] < 2) {
      return true;
    }
    return strides[a] < strides[b];
  });

lambda 函數的最後一行很好懂,也是通例:

  • 當兩個維度的尺寸都大於等於 2 時,比較 strides[a] < strides[b],依照步長大小做遞增排序

但在這個通例之前有兩個特例,在看這兩個特例之前,我們得先弄明白 std::sort 的比較函式的語意:當 a 應該被排在 b 前面時,比較函數會回傳 true,反之則回傳 false。了解這一點後再來看兩個特例:

  • sizes[a] < 2,回傳 false:這時 a 會被排到 b 之後,也就是把尺寸小於 2 的 a 維度往後推
  • 否則若 sizes[b] < 2,回傳 true:這時 b 會被排在 a 之後,同樣是把尺寸小於 2 的 b 維度往後推

因此整體效果是:將 perm 依步長由小到大排序,並把所有尺寸小於 2 的維度放到最後。這是因為尺寸小於 2 的維度不影響判斷結果,所以後續遍歷遇到它們時,就可以直接短路返回 true

各維度按照步長由小到大排序後,後面的 for 迴圈會檢查每個有效維度的步長是否等於 require_striderequire_stride 的初始值為 1,在每一次迭代後會被乘上當前維度的尺寸,表示下一個維度應有的步長。只要有任何一維不符合,就返回 false;若全部符合,則返回 true

cpp 复制代码
  T require_stride = 1;
  for (const auto i : c10::irange(dim)) {
    const auto& size_perm_i = sizes[perm[i]];
    if (size_perm_i < 2) {
      return true;
    }
    if (strides[perm[i]] != require_stride) {
      return false;
    }
    require_stride *= size_perm_i;
  }
  return true;

因為尺寸小於 2 的維度不影響判斷結果且有可能會造成干擾,所以判斷函數會將這類維度排除。C++ 與 Python 兩個版本的效果等價,差別只在排除的時機:C++ 在排序階段就把這類維度推到最後,遍歷時一遇到便直接短路返回 true;Python 則在遍歷階段直接跳過 size == 1 的維度,寫法更易懂。

案例解析

這樣設計為什麼有效?以下透過幾個具體例子,實際走一遍 Python 版本 _eval_is_non_overlapping_and_dense 的執行流程來驗證。

例 1:NCHW contiguous 張量

創建一個 sizes[2, 3, 4] 的張量:

python 复制代码
x = torch.empty(2,3,4)
x.shape
x.stride()

輸出:

复制代码
torch.Size([2, 3, 4])
(12, 4, 1)

strides(12, 4, 1)

將(尺寸、步長)對按照步長由小到大排序:

复制代码
(4,1), (3,4), (2,12)

檢查流程:

step size stride expected_stride 判斷 新 expected_stride
1 4 1 1 1==1 ✔ 1×4=4
2 3 4 4 4==4 ✔ 4×3=12
3 2 12 12 12==12 ✔ 12×2=24

每個維度都可以通過檢查,函數最後返回 True。

例 2:對 NCHW contiguous 張量做轉置

嘗試轉置上例張量的第 0 和第 1 個維度:

python 复制代码
x = torch.empty(2, 3, 4)
x = torch.transpose(x, 0, 1)
x.shape
x.stride()

輸出:

复制代码
torch.Size([3, 2, 4])
(4, 12, 1)

排序後:

复制代码
(4,1), (3,4), (2,12)

和上例完全一樣!

檢查流程也完全相同:

step size stride expected_stride 判斷 新 expected_stride
1 4 1 1 1==1 ✔ 1×4=4
2 3 4 4 4==4 ✔ 4×3=12
3 2 12 12 12==12 ✔ 12×2=24

函數最終也返回 True。

這就是函數中按步長排序的作用:即使轉置後維度順序改變(張量不再是 contiguous),但經過排序後仍能通過檢查,最後仍可以被判為 non_overlapping_and_dense。

例 3:overlapping 張量

首先創建一個一維張量:

python 复制代码
base = torch.arange(5)
base

輸出:

复制代码
tensor([0, 1, 2, 3, 4])

接著用 as_strided 產生一個二維的 view:

python 复制代码
x = base.as_strided((3, 3), (1, 1))
x
x.shape
x.stride()

輸出:

复制代码
tensor([[0, 1, 2],
        [1, 2, 3],
        [2, 3, 4]])
torch.Size([3, 3])
(1, 1)

從輸出中可以看出來,x 張量的元素重複了。


再來追蹤函數的檢查流程,先對(尺寸、步長)對做排序:

复制代码
(3,1), (3,1)

檢查流程:

step size stride expected_stride 判斷 新 expected_stride
1 3 1 1 1==1 ✔ 1×3=3
2 3 1 3 1!=3 ✗ ---

張量第二個維度的步長為 1,不符合預期的 3,因此函數直接返回 False,確實地判為非 non_overlapping_and_dense。

例 4:有空洞(非稠密)

生成一個有空洞的張量:

python 复制代码
x = torch.arange(12).view(2, 6)[:, :3]
x
x.shape
x.stride()

輸出:

复制代码
tensor([[0, 1, 2],
        [6, 7, 8]])
torch.Size([2, 3])
(6, 1)

函數首先對 x 的(尺寸、步長)對做排序:

复制代码
(3,1), (2,6)

逐維檢查:

step size stride expected_stride 判斷
1 3 1 1 ✔ → expected=3
2 2 6 3 6≠3 ❌

函數在第 2 步檢測到步長為 6,不等於 expected_stride=3(即前一維度佔滿後的緊接位置),正確地判定記憶體中存在間隙,不符合 dense。

例 5:含 size=1 維度

PyTorch 張量尺寸為 1 時,步長為何不具語意? 中提到的方式創建一個 sizes = [2, 1, 3, 4], strides = [12, 240, 4, 1] 的張量:

python 复制代码
x = torch.empty(1, 2, 3, 4)
x.shape
x.stride()
y = x[::10, :, :, :]
y.shape
y.stride()
z = torch.transpose(y, 0, 1)
z.shape
z.stride()

輸出:

复制代码
torch.Size([1, 2, 3, 4])
(24, 12, 4, 1)
torch.Size([1, 2, 3, 4])
(240, 12, 4, 1)
torch.Size([2, 1, 3, 4])
(12, 240, 4, 1)

張量 z 的第二個維度尺寸為 1,步長任意(這裡是 240)。

排序後:

复制代码
(4,1), (3,4), (2,12), (1,240)

檢查流程:

step size stride expected_stride 判斷 新 expected_stride
1 4 1 1 4
2 3 4 4 12
3 2 12 12 24
4 1 240 24 size=1 → skip 24

最後一步的步長跳到了 240,但因為尺寸為 1,函數直接跳過,沒有被它干擾判斷。

demo

理想中 _eval_is_non_overlapping_and_dense 接受的參數應該是 Python list of int 型別,但因為函數並沒有對參數的型別做限制,實際上外部也可能傳入 list of SymInt 作為它的參數。

本節將對 _eval_is_non_overlapping_and_dense 函數傳入不同型別的參數,觀察函數的行為與回傳型別有何差異。

regular tensor with real sizes

一般預期的呼叫方式是傳入 Python list of int。以下程式先創建一個形狀為 2 x 3 x 4 的張量,再判斷它是否 non_overlapping_and_dense:

python 复制代码
from torch.fx.experimental.symbolic_shapes import _eval_is_non_overlapping_and_dense

t = torch.empty(2, 3, 4)
sizes = t.size()
strides = t.stride()
result = _eval_is_non_overlapping_and_dense(list(sizes), list(strides))
print(sizes)
print(strides)
print(result, type(result).__name__)

輸出:

复制代码
torch.Size([2, 3, 4])
(12, 4, 1)
True bool

因為輸入是 Python list of int,所以函數內部的 ==!=< 全都是 int 間的比較運算,函數最終的回傳值也是我們熟悉的 Python bool。

fake tensor with symbolic sizes(一維)

_eval_is_non_overlapping_and_dense 函數也可能被用於判斷具有動態形狀的張量,這時張量的尺寸與步長可能是 SymInt

PyTorch動態形狀系統 簡述了 PyTorch 中建立動態形狀模型的流程,此處模擬當中的第 1、2 步:手動創建 ShapeEnvFakeTensorMode,再用 from_tensor 函數把一個普通張量包裝成具有 symbolic shape 的 FakeTensor,如此一來便能得到尺寸與步長皆為 SymInt 的張量:

python 复制代码
import torch
from torch.fx.experimental.symbolic_shapes import ShapeEnv, _eval_is_non_overlapping_and_dense
from torch._subclasses.fake_tensor import FakeTensorMode

shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)

real = torch.empty(5)
fake = fake_mode.from_tensor(real, static_shapes=False)
print(fake.size(), [type(s).__name__ for s in fake.size()])
print(fake.stride(), [type(s).__name__ for s in fake.stride()])

輸出:

复制代码
torch.Size([s0]) ['SymInt']
(1,) ['SymInt']

fake 是一個一維的 fake tensor,其尺寸為符號 s0,步長為 1,兩者皆為 SymInt


接著把 fakesize()stride() 包裝成 Python list,丟進 _eval_is_non_overlapping_and_dense

python 复制代码
sizes = list(fake.size())
strides = list(fake.stride())
result = _eval_is_non_overlapping_and_dense(sizes, strides)
print(result, type(result).__name__)

輸出:

复制代码
True SymBool

可以看到函數果真回傳 SymBool 了!

這是為什麼呢?我們回頭來看看函數中是怎麼處理 SymInt 輸入的,在短路優化那一行:

python 复制代码
if dim == 1:
    return strides[0] == 1 or sizes[0] < 2

strides[0]sizes[0] 都是 SymInt,它們與 int 做比較的結果會是 SymBool;而兩個 SymBool 經過 Python 的 or 後,得到的還是一個 SymBool,因此函數最後回傳的也是 SymBool

fake tensor with symbolic sizes(多維)

若改為多維張量,情況就不同了。以下範例先創建一個多維張量,再將它包裝成具有 symbolic shape 的 FakeTensor,接著對它呼叫 _eval_is_non_overlapping_and_dense

python 复制代码
real = torch.empty(2, 3, 4)
fake = fake_mode.from_tensor(real, static_shapes=False)
print(fake.size(), [type(s).__name__ for s in fake.size()])
print(fake.stride(), [type(s).__name__ for s in fake.stride()])

result = _eval_is_non_overlapping_and_dense(list(fake.size()), list(fake.stride()))
print(result, type(result).__name__)

輸出:

复制代码
torch.Size([s0, s1, s2]) ['SymInt', 'SymInt', 'SymInt']
(s1*s2, s2, 1) ['SymInt', 'SymInt', 'SymInt']
True bool

跟一維時的情況不同,這次函數返回了Python bool!

回頭查看函數實現,如果輸入張量的各維都能通過迴圈中的判斷,函數最後一行就會 return True。這也是輸入此處輸入全是 SymInt,但函數最後卻能回傳 Python bool 的原因。

總結

從以上demo可以看出:唯一會讓本函數回傳 SymBool 的情況,就是「輸入是 SymInt 且維度為 1」時。三種情況整理如下:

輸入型別 維度 回傳型別
list of int 任意 bool
list of SymInt 1 SymBool
list of SymInt ≥ 2 bool

下一篇將介紹 eval_is_non_overlapping_and_dense,該函數是整個 PyTorch 中唯一會調用 _eval_is_non_overlapping_and_dense 的函數,因為 _eval_is_non_overlapping_and_dense 可能回傳 bool 或 SymBooleval_is_non_overlapping_and_dense 為了確保回傳型別一致,才在 _eval_is_non_overlapping_and_dense 外面多包了一層 guard_bool,用來將它可能收到的 SymBool 轉回 Python bool。

相关推荐
2501_901006477 小时前
Golang map底层实现原理_Golang map哈希表原理教程【收藏】
jvm·数据库·python
寒山独见君~7 小时前
自动化-消息推送Server酱3,APP推送
运维·数据库·python·自动化·通知
IT_陈寒7 小时前
为什么我的Python multiprocessing总是卡在join()?
前端·人工智能·后端
云天AI实战派7 小时前
ChatGPT/AI 智能体功能异常排查指南:账号安全、权限灰度到审批流卡点的全流程解决方案
人工智能·安全·chatgpt
薛定猫AI7 小时前
【深度解析】Open Code Skills 工作流:用知识图谱、Spec 驱动与 UI 设计系统提升 AI Coding Agent 生产力
人工智能·ui·知识图谱
qq_392690667 小时前
Go语言怎么做DNS查询_Go语言DNS域名解析教程【完整】
jvm·数据库·python
m0_738120727 小时前
后渗透维权提权基础——CTF模拟红队进行权限维持(二)
前端·网络·windows·python·安全·php
speop7 小时前
Reasoning kingdom chapter13
android·java·python
袋子(PJ)7 小时前
2026年pytorch基础学习(基于jupyter notebook开发)——从原理到落地:PyTorch神经网络架构与工程优化解析
人工智能·pytorch·深度学习·学习·jupyter