前言
PyTorch 2.0 引入了 torch.compile,這個模組透過將 使用者以 Python 撰寫的模型程式碼 擷取成計算圖並進行編譯優化,來加速模型的訓練與推論。而為了在編譯期處理張量形狀未知或可變的情況,PyTorch 建立了一套動態形狀系統(dynamic shapes system):張量的尺寸與步長不再只是 Python 中具體的整數,而可以是 SymPy 符號表達式(如 s0、s1);另外還會透過約束求解與邏輯推導來判定各種性質是否成立。
torch/fx/experimental/symbolic_shapes.py 正是這套系統中處理符號形狀、guard 與形狀約束推理的核心模組,當中定義了我們的 _eval_is_non_overlapping_and_dense 函式。
該函數用來判斷張量的記憶體佈局是否「非重疊且稠密」(non-overlapping and dense),它本身只實現判斷邏輯,是一個純 Python 的計算函數。為了將這個判斷邏輯接入動態形狀系統(包括 SymNode, SymInt, SymBool 等自訂義類別),PyTorch 對它做了層層包裝,由內到外分別是:_eval_is_non_overlapping_and_dense、eval_is_non_overlapping_and_dense、IsNonOverlappingAndDenseIndicator、SymNode._is_non_overlapping_and_dense_indicator、SymNode.is_non_overlapping_and_dense_indicator、is_non_overlapping_and_dense,之後將會一一做介紹。
在正式進入 _eval_is_non_overlapping_and_dense 函數之前,我們得先了解「重疊」和「稠密」這兩個概念。
重疊
有時一個張量中的多個不同的邏輯索引(如:x[1, 0] 中的 [1, 0])會對應到同一個記憶體位置,我們將這種現象稱為重疊(overlapping)。常見的成因有兩種:
- 維度的步長相對於其尺寸太小,導致相鄰 row 的位置重疊
- 維度的步長為零,導致相鄰 row 的元素都指向同一塊記憶體
以下的兩個範例分別展示這兩種成因:
例 1:as_strided
首先構建一個一維張量 x:
python
x = torch.arange(5)
x
x.shape
x.stride()
輸出:
tensor([0, 1, 2, 3, 4])
torch.Size([5])
(1,)
PyTorch 中有個 torch.as_strided 函數:
torch.as_strided(input, size, stride, storage_offset=None) → Tensor
Create a view of an existing torch.Tensor input with specified size, stride and storage_offset.
Parameters
:
input (Tensor) -- the input tensor.
size (tuple or ints) -- the shape of the output tensor
stride (tuple or ints) -- the stride of the output tensor
storage_offset (int, optional) -- the offset in the underlying storage of the output tensor. If None, the storage_offset of the output tensor will match the input tensor.
as_strided 是 PyTorch 提供的底層 view 操作:它接受一個既有的 input 張量,使用者可指定輸出張量的 size, stride,以及 storage_offset(可選參數),在不複製資料的前提下,將同一塊底層記憶體重新解讀成新的張量視圖。
我們來用 torch.as_strided 函數將張量 x 變成一個二維張量 y:
python
y = torch.as_strided(x, size=(3, 3), stride=(1, 1))
y
輸出:
tensor([[0, 1, 2],
[1, 2, 3],
[2, 3, 4]])
從輸出可以看出來,y 張量的元素重複了,也就是所謂的「重疊」。
來看看 y 張量的形狀和步長:
python
y.shape
y.stride()
輸出:
torch.Size([3, 3])
(1, 1)
其中 stride=(1, 1) 表示無論沿第 0 維還是第 1 維前進一格,底層記憶體位址都只往後移動 1 個元素。因此不同的二維索引可能會對應到同一個記憶體位址,導致了重疊的現象。
例如由 [0, 0] 沿第 0 維前進一格,會得到 [0, 1],沿第 1 維前進一格,會得到 [1, 0],這兩個二維索引指向的底層記憶體位址相同。
我們可以進一步驗證一下,以下程式可以查看每個邏輯索引對應到底層記憶體的哪個 offset:
python
for i in range(y.size(0)):
for j in range(y.size(1)):
offset = y.storage_offset() + i * y.stride(0) + j * y.stride(1)
print(f"y[{i},{j}] -> storage offset {offset}")
y[0,0] -> storage offset 0
y[0,1] -> storage offset 1
y[0,2] -> storage offset 2
y[1,0] -> storage offset 1
y[1,1] -> storage offset 2
y[1,2] -> storage offset 3
y[2,0] -> storage offset 2
y[2,1] -> storage offset 3
y[2,2] -> storage offset 4
[0,1] 和 [1,0] 都對應到 storage offset 1,即:有多個邏輯索引對應到同一個記憶體位置,表示這個張量是「重疊」的。
如果還想更進一步驗證,可以用以下程式,它會改動 x 的第一個元素:
python
x[1] = 5
y
輸出:
tensor([[0, 5, 2],
[5, 2, 3],
[2, 3, 4]])
可以發現二維 view 張量 y 的 [0, 1] 和 [1, 0] 都一起被改動了!這再次證實了 y 張量中的這兩個二維索引確實指向同一個記憶體位置!換句話說:y 是重疊的張量。
例 2:expand
構建一個一維張量 x:
python
x = torch.ones(4)
x
輸出:
tensor([1., 1., 1., 1.])
PyTorch 中有個 torch.Tensor.expand 函數:
Tensor.expand(*sizes) → Tensor
Returns a new view of the self tensor with singleton dimensions expanded to a larger size.
Passing -1 as the size for a dimension means not changing the size of that dimension.
Tensor can be also expanded to a larger number of dimensions, and the new ones will be appended at the front. For the new dimensions, the size cannot be set to -1.
Expanding a tensor does not allocate new memory, but only creates a new view on the existing tensor where a dimension of size one is expanded to a larger size by setting the stride to 0. Any dimension of size 1 can be expanded to an arbitrary value without allocating new memory.
Parameters
:
*sizes (torch.Size or int...) -- the desired expanded size
expand 接受 *sizes 參數,表示使用者期望張量被擴展後的尺寸;回傳一個新的 view,將原張量中尺寸為 1 的維度擴展到指定的更大尺寸。傳入 -1 表示該維度尺寸不變。張量也可以擴展到更高的維度數,新維度會加在最前面(但新維度不能傳 -1)。擴展張量時不會分配新的記憶體,而是透過將步長設為 0,讓尺寸為 1 的維度在不複製資料的情況下被「看成」更大的尺寸。
用 expand 函數將張量 x 沿著新維度廣播成一個 4x4 的二維張量 y:
python
y = x.expand(4, -1)
y
輸出:
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
其實 y 的每個 row 指向的是同一塊記憶體,可以用以下程式來驗證。以下程式會修改 x 的第一個元素:
python
x[1] = 5
y
輸出:
tensor([[1., 5., 1., 1.],
[1., 5., 1., 1.],
[1., 5., 1., 1.],
[1., 5., 1., 1.]])
y 的步長為:
python
y.stride()
輸出:
(0, 1)
y 第一個維度的步長為 0,這表示沿著第一個維度往下移動一格時,底層記憶體位址完全不變;也就是說,y[0, :]、y[1, :]、y[2, :]、y[3, :] 其實都對應到同一個 row 的底層資料,只是被視為不同 row,可見 y 這個 view 張量是「重疊」的。
稠密
如果張量各元素之間沒有空洞,底層記憶體被完整填滿,無多餘空間,那麼該張量就可以被稱為「稠密」。
例:非稠密的張量
我們要生成有空洞的張量,先從一維張量開始:
python
base = torch.arange(12)
輸出:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
先把它變成兩個 row,再取每個 row 的前 3 個元素:
python
x = base.view(2, 6)[:, :3]
x
x.shape
x.stride()
輸出:
tensor([[0, 1, 2],
[6, 7, 8]])
torch.Size([2, 3])
(6, 1)
由以上輸出可知,x 包含的 [[0, 1, 2], [6, 7, 8]] 等六個元素在底層記憶體中並不是連續排在一起的:第 0 維步長為 6 代表每跨一個 row 要跳 6 個元素,但每個 row 只有 3 個元素,因此中間會留下 3 個元素的空洞。
以下程式會把每個邏輯索引對應到的 storage offset 列出來:
python
for i in range(x.size(0)):
for j in range(x.size(1)):
offset = x.storage_offset() + i * x.stride(0) + j * x.stride(1)
print(f"x[{i},{j}] -> storage offset {offset}")
輸出:
x[0,0] -> storage offset 0
x[0,1] -> storage offset 1
x[0,2] -> storage offset 2
x[1,0] -> storage offset 6
x[1,1] -> storage offset 7
x[1,2] -> storage offset 8
可以看到,第一個 row 用到的是 offset 0,1,2,第二個 row 直接跳到 6,7,8;中間的 3,4,5 完全沒有被 x 這個 view 用到,再次證明了這個張量有空洞,也就是「非稠密」。
正常的 2x3 張量則是這樣:
python
y = torch.arange(6).view(2, 3)
y
y.shape
y.stride()
for i in range(y.size(0)):
for j in range(y.size(1)):
offset = y.storage_offset() + i * y.stride(0) + j * y.stride(1)
print(f"y[{i},{j}] -> storage offset {offset}")
輸出:
tensor([[0, 1, 2],
[3, 4, 5]])
torch.Size([2, 3])
(3, 1)
y[0,0] -> storage offset 0
y[0,1] -> storage offset 1
y[0,2] -> storage offset 2
y[1,0] -> storage offset 3
y[1,1] -> storage offset 4
y[1,2] -> storage offset 5
可以看到它剛好連續用掉底層記憶體 offset 的 0~5,中間沒有任何空洞。
剛剛看到的重疊或是非稠密的張量,都是對 base 張量做了一些操作後才得到的。如果我們單純用 torch.ones 或 torch.empty 製造張量,那麼該張量自然會同時滿足「非重疊」和「稠密」兩個條件。可以想像,這樣的張量中每個元素在底層記憶體中都會對應到不同的位址,並且這些位址也是連續分布的。在 PyTorch 中,我們可以用一個專有名詞------contiguous(連續)來形容它。
contiguous
Tensor Implementation Details 中提到:
A tensor is contiguous in memory if its elements are laid out in the Storage in the same order as a standard C-style (row-major) traversal. For a contiguous tensor, the stride typically follows a pattern where the stride for the last dimension is 1, the stride for the second-to-last dimension is the size of the last dimension, and so on.
如果一個張量的元素在儲存空間(Storage)中的排列順序與標準 C-style (row-major) 的遍歷順序相同,則該張量在記憶體中是連續的。對於連續張量,步長通常遵循以下規律:最後一個維度的步長為 1,倒數第二個維度的步長等於最後一個維度的大小,以此類推。
可以想見,如果一個張量是 contiguous 的,其步長應該會是遞減的。
我們可以用 PyTorch 提供的 torch.Tensor.is_contiguous 函數來判斷張量的連續性,官網對函數的介紹如下:
torch.Tensor.is_contiguous
Tensor.is_contiguous(memory_format=torch.contiguous_format) → bool
Returns True if self tensor is contiguous in memory in the order specified by memory format.
Parameters
:
memory_format (torch.memory_format, optional) -- Specifies memory allocation order. Default: torch.contiguous_format.
若 self 張量在記憶體中是依照 memory_format 所指定的順序為連續的,則回傳 True。
可選參數 memory_format 用於指定記憶體配置順序,預設值為 torch.contiguous_format。
is_contiguous 當中的 memory_format 參數是什麼意思我們暫且按下不表,先從一個簡單的 demo 開始看起:
python
x = torch.arange(12).reshape(3, 4)
x
x.shape
x.stride()
x.is_contiguous()
輸出:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
torch.Size([3, 4])
(4, 1)
True
x 的步長為 (4, 1),最後一個維度的步長為 1,倒數第二個維度的步長等於最後一個維度的大小,這符合 row-major 連續排列的規律,所以 x.is_contiguous() 會返回 True。
現在對 x 張量做轉置,改變其維度順序。使用的是 Tensor.T 函數:
Tensor.T
Returns a view of this tensor with its dimensions reversed.
If n is the number of dimensions in x, x.T is equivalent to x.permute(n-1, n-2, ..., 0).
Tensor.T 回傳此張量的一個 view,其維度順序被反轉。若 x 的維度數為 n,則 x.T 等同於 x.permute(n-1, n-2, ..., 0)。
python
y = x.T
y
y.shape
y.stride()
y.is_contiguous()
輸出:
tensor([[ 0, 4, 8],
[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11]])
torch.Size([4, 3])
(1, 4)
False
Tensor.T 只改變了觀察維度的順序,底層資料並沒有被搬動。但張量的步長卻變成 (1, 4),已經不再是遞減排序了,因此也不再滿足預設記憶體格式下的 contiguous 條件,所以 y.is_contiguous() 的結果是 False。
剛剛我們略過了 is_contiguous 函數的 memory_format(記憶體格式)參數:其預設值為 torch.contiguous_format(如:NCHW 或 NCTHW),其它可選的參數值包括 torch.channels_last, torch.channels_last_3d 等。
為什麼 is_contiguous 函數需要這個參數呢?這是因為記憶體格式會影響維度順序,而 contiguous 又對維度順序有要求,所以使用者必須指定記憶體格式,函數才能決定輸入張量是否滿足 contiguous 的條件。
舉個例子,對兩個形狀相同的 4 維 NCHW 張量來說,如果傳入不同的 memory_format 參數,函數就會有不同的判定結果:
python
x = torch.empty((2, 3, 5, 7))
x.stride()
x.is_contiguous()
x.is_contiguous(memory_format=torch.channels_last)
輸出:
(105, 35, 7, 1)
True
False
結果符合預期:使用預設參數時 x 會被判為 contiguous,如果使用。
以下程式碼創建一個 y 張量,使用的是 channels-last 記憶體格式(NHWC):
python
y = torch.empty((2, 3, 5, 7), memory_format=torch.channels_last)
y.stride()
y.is_contiguous()
y.is_contiguous(memory_format=torch.channels_last)
輸出:
(105, 1, 21, 3)
False
True
可以看到,x 採用預設的 torch.contiguous_format 記憶體格式,所以可被 is_contiguous() 判為 contiguous;而 y 雖然同樣稠密且非重疊,卻是依照 channels_last(NHWC)順序排列,因此對預設的 is_contiguous() 函數來說,它並不是 contiguous。只有在指定記憶體格式參數為 torch.channels_last 的情況下才會返回 True。
contiguous 這個性質要求張量的元素在記憶體中依照特定順序緊密排列,因此自然也具有以下兩個性質:
- 不同索引不會指向同一個位置,也就是「非重疊」
- 元素之間不會留下空洞,也就是「稠密」
現在我們知道張量的維度順序會影響 contiguous 的判斷結果;那麼,若一個張量雖然不符合 contiguous 的定義,卻仍同時滿足以上兩個條件,有沒有一個詞可以形容這種張量呢?有的,在 PyTorch 中我們把它叫做 non-overlapping and dense(非重疊且稠密)。
非重疊且稠密
non_overlapping_and_dense 與 contiguous 不同,是無方向性的。只要存在某個維度排列,使該張量在該排列下是 contiguous,我們就稱該張量滿足 non_overlapping_and_dense,並不在意維度的實際排列順序。
我們可以用以下程式比較這兩個概念。首先用 torch.randn 中規中矩地創建一個張量:
python
x = torch.randn(3, 4) # shape=(3,4), stride=(4,1)
x.is_contiguous()
torch.ops.aten.is_non_overlapping_and_dense(x)
輸出如下:
True
True
如預期般,這個張量既是 contiguous 又是 non overlapping and dense。
接著對 x 動一點手腳,對它轉置得到張量 y:
python
y = x.T # shape=(4,3), stride=(1,4)
y.is_contiguous()
torch.ops.aten.is_non_overlapping_and_dense(y)
輸出如下:
False
True
.T 會改變觀察維度的順序,導致 y 無法滿足 contiguous 的條件;另外因為 .T 不會搬動張量的底層資料,所以 y 仍然保持著 non overlapping and dense 的性質。可見 non_overlapping_and_dense 是比 contiguous 還寬鬆的條件:所有 contiguous 張量必然是 non_overlapping_and_dense,但反過來則不成立。
如何判斷
上一節使用的 torch.ops.aten.is_non_overlapping_and_dense 判斷函數實際上只是一個 wrapper,它真正的判斷邏輯實作於 _eval_is_non_overlapping_and_dense 函數,也就是本篇的主題。
_eval_is_non_overlapping_and_dense 函數只關心張量在記憶體中是否非重疊及稠密,並不在乎張量各維度的排列順序。函數中的判斷邏輯如下:如果存在某個維度排列,使得輸入張量在該排列下等價於某個 contiguous 張量,即:有效維度(指尺寸大於 1 的維度,尺寸為 1 的維度會被忽略)的步長恰好等於按該排列排序後前面所有維度尺寸的連乘積,那麼該張量就可被視為 non_overlapping_and_dense。
算法之所以只把尺寸大於 1 的維度視為有效維度而忽略尺寸為 1 的維度,原因有二:
- 尺寸為 1 的維度只有一個元素,無論步長為何都不會存取到新的記憶體位址,既不會與其他元素重疊,也不會造成空洞
- 這類維度的步長可以是任意值(詳見 PyTorch 張量尺寸為 1 時,步長為何不具語意?),若納入計算反而會干擾判斷
_eval_is_non_overlapping_and_dense
_eval_is_non_overlapping_and_dense 定義在 torch/fx/experimental/symbolic_shapes.py,是一個模組層級的函數,我們可以透過 torch.fx.experimental.symbolic_shapes._eval_is_non_overlapping_and_dense 來調用之。
python
def _eval_is_non_overlapping_and_dense(sizes, strides):
dim = len(sizes)
# Short-circuits for tensors of rank one, which are
# non-overlapping and "dense" if their stride is one
# or it is a 0/1 element tensor
if dim == 1:
return strides[0] == 1 or sizes[0] < 2
# Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
# Sorts (length, stride) pairs by stride
lengths_and_strides = sorted(
zip(sizes, strides), key=operator.itemgetter(1)
)
# Unlike the C++ code, we don't move the 0/1 size dimensions to the
# end. So we have to keep going for this code.
expected_stride = 1
for length, stride in lengths_and_strides:
if length == 1:
continue
if stride != expected_stride:
return False
expected_stride *= length
return True
本函數只接收兩個參數:sizes 和 strides,分別代表輸入張量各維度的尺寸和步長,在理想情況下,兩者皆為 Python 中的 list of int。
值得注意的是,本函數並沒有接受跟「維度順序」和「維度邏輯意義」相關的參數:
- torch.Tensor.is_contiguous 有
memory_format參數,用來告知記憶體格式------因為判斷連續性必須知道維度順序,而維度順序會受記憶體格式影響 - PyTorch 如何知道張量是 NCHW 還是 Channels Last 的?--sympy_is_channels_last_strides_generic 函數解析 中介紹的
sympy_is_channels_last_strides_generic則有dim_order參數,用來指定各維度在邏輯層面的意義(如 N、C、H、W)
_eval_is_non_overlapping_and_dense 則兩者都不需要:它既不在意維度的排列順序,也不在意維度的邏輯意義,只關心張量在記憶體中元素是否重疊、是否稠密------而無論維度的順序為何,這兩個性質都不會改變。
回到函數內容,在輸入張量只有一維時有個短路優化:
python
if dim == 1:
return strides[0] == 1 or sizes[0] < 2
- 如果步長為 1,則張量必定連續
- 當一維張量的尺寸小於 2 時,表示張量只有 0 或 1 個元素。此時該張量必為稠密,並且無需在意無重疊問題
如果輸入不只一維,檢查將會繼續:
python
lengths_and_strides = sorted(zip(sizes, strides), key=operator.itemgetter(1))
這行代碼將 (size, stride) 按步長由小到大排序,這是在模擬「最可能符合 contiguous 的維度順序」:把各維度由內而外排列------因為在 contiguous 張量中,越內層的維度步長越小。
接著會按照此順序遍歷各維度,計算它們應有的步長,然後看實際的步長值是否符合。
expected_stride 變數正是遍歷過程中各維度應有的步長值,因為檢查順序是由內而外,而最內層維度的步長理應為 1,所以這裡將初始值設為 1:
python
expected_stride = 1
接著逐一檢查排序後的(尺寸、步長)對:
python
for length, stride in lengths_and_strides:
if length == 1:
continue # size=1 維度不影響連續性,跳過
if stride != expected_stride:
return False
expected_stride *= length
length == 1:因為長度為 1 的維度對記憶體排列和expected_stride都沒有影響,所以可以直接忽略- 否則檢查當前步長是否等於
expected_stride,不相等則判為非 non_overlapping_and_dense - 最後更新下一個外層維度的預期步長:
expected_stride *= length
如果某個 (sizes, strides) 對能在每一次的 if stride != expected_stride 判斷中都堅挺住不被 return False,表示該張量中每個有效維度的步長都符合預期,所以最後會:
python
return True
將該張量判為 non_overlapping_and_dense。
在理想情況下,函數的參數為 Python 的 list of int,內部的比較(==、!=、<)皆為普通 int 運算,因此回傳值會是 Python 的 bool。
但實際上函數並沒有限制輸入的類別,它們也可能是 list of SymInt,在這種情況下,如果輸入張量剛好只有一維,觸發短路優化(strides[0] == 1 or sizes[0] < 2),比較的結果將會是 SymBool 型別。
為了與 PyTorch 的動態形狀系統互動,_eval_is_non_overlapping_and_dense 外面其實還套了好幾層包裝。其中最直接的一層是 eval_is_non_overlapping_and_dense:它在接收到 SymBool 後,會調用 guard_bool 將其轉換為普通 bool,並把該表達式登記到 ShapeEnv.guards 列表。除此之外還有其他層次的包裝,共同負責將本函數串接到動態形狀系統,會在接下來的文章中一一介紹。
函數中有段註解:
python
# Unlike the C++ code, we don't move the 0/1 size dimensions to the
# end. So we have to keep going for this code.
不像 C++ 程式碼,我們不把 size=0/1 維度移到末尾
這是說對於尺寸為一的長度,Python 版是透過排序加跳過來處理,C++ 版本則有另外的處理方式。
C++ 版本 - _compute_non_overlapping_and_dense
C++ 版本位於 c10/core/TensorImpl.cpp 的 _compute_non_overlapping_and_dense:
cpp
bool _compute_non_overlapping_and_dense(
ArrayRef<T> sizes,
ArrayRef<T> strides) {
auto dim = sizes.size();
if (dim == 1) {
return sizes[0] < 2 || strides[0] == 1;
}
SmallVector<int64_t, 5> perm;
perm.resize(dim);
for (const auto i : c10::irange(dim)) {
perm[i] = i;
}
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
if (sizes[a] < 2) {
return false;
} else if (sizes[b] < 2) {
return true;
}
return strides[a] < strides[b];
});
T require_stride = 1;
for (const auto i : c10::irange(dim)) {
const auto& size_perm_i = sizes[perm[i]];
if (size_perm_i < 2) {
return true;
}
if (strides[perm[i]] != require_stride) {
return false;
}
require_stride *= size_perm_i;
}
return true;
}
一開始一樣有個短路優化:
cpp
auto dim = sizes.size();
if (dim == 1) {
return sizes[0] < 2 || strides[0] == 1;
}
若是一維張量,當 size < 2 或 stride == 1 就直接返回 true。
若是多維張量,接下來需要對各維度按步長由小到大排序。在 Python 中,是用 sorted(zip(sizes, strides), ...) 直接對 (size, stride) 對本身做排序;C++ 版本則採用不同做法:先建立一個包含各維度索引的陣列 perm = [0, 1, ..., dim-1],對它依步長做升序排列,之後如果想要存取「排序後第 i 個維度」的尺寸或步長,得再透過 sizes[perm[i]]、strides[perm[i]] 間接取得。
以下程式碼就是在建立 perm 維度索引陣列:
cpp
SmallVector<int64_t, 5> perm;
perm.resize(dim);
for (const auto i : c10::irange(dim)) {
perm[i] = i;
}
接著實際用 std::sort 搭配自訂 lambda 比較函式對 perm 進行排序:
cpp
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
if (sizes[a] < 2) {
return false;
} else if (sizes[b] < 2) {
return true;
}
return strides[a] < strides[b];
});
lambda 函數的最後一行很好懂,也是通例:
- 當兩個維度的尺寸都大於等於 2 時,比較
strides[a] < strides[b],依照步長大小做遞增排序
但在這個通例之前有兩個特例,在看這兩個特例之前,我們得先弄明白 std::sort 的比較函式的語意:當 a 應該被排在 b 前面時,比較函數會回傳 true,反之則回傳 false。了解這一點後再來看兩個特例:
- 若
sizes[a] < 2,回傳false:這時a會被排到b之後,也就是把尺寸小於 2 的a維度往後推 - 否則若
sizes[b] < 2,回傳true:這時b會被排在a之後,同樣是把尺寸小於 2 的b維度往後推
因此整體效果是:將 perm 依步長由小到大排序,並把所有尺寸小於 2 的維度放到最後。這是因為尺寸小於 2 的維度不影響判斷結果,所以後續遍歷遇到它們時,就可以直接短路返回 true。
各維度按照步長由小到大排序後,後面的 for 迴圈會檢查每個有效維度的步長是否等於 require_stride。require_stride 的初始值為 1,在每一次迭代後會被乘上當前維度的尺寸,表示下一個維度應有的步長。只要有任何一維不符合,就返回 false;若全部符合,則返回 true。
cpp
T require_stride = 1;
for (const auto i : c10::irange(dim)) {
const auto& size_perm_i = sizes[perm[i]];
if (size_perm_i < 2) {
return true;
}
if (strides[perm[i]] != require_stride) {
return false;
}
require_stride *= size_perm_i;
}
return true;
因為尺寸小於 2 的維度不影響判斷結果且有可能會造成干擾,所以判斷函數會將這類維度排除。C++ 與 Python 兩個版本的效果等價,差別只在排除的時機:C++ 在排序階段就把這類維度推到最後,遍歷時一遇到便直接短路返回 true;Python 則在遍歷階段直接跳過 size == 1 的維度,寫法更易懂。
案例解析
這樣設計為什麼有效?以下透過幾個具體例子,實際走一遍 Python 版本 _eval_is_non_overlapping_and_dense 的執行流程來驗證。
例 1:NCHW contiguous 張量
創建一個 sizes 為 [2, 3, 4] 的張量:
python
x = torch.empty(2,3,4)
x.shape
x.stride()
輸出:
torch.Size([2, 3, 4])
(12, 4, 1)
其 strides 為 (12, 4, 1)。
將(尺寸、步長)對按照步長由小到大排序:
(4,1), (3,4), (2,12)
檢查流程:
| step | size | stride | expected_stride | 判斷 | 新 expected_stride |
|---|---|---|---|---|---|
| 1 | 4 | 1 | 1 | 1==1 ✔ | 1×4=4 |
| 2 | 3 | 4 | 4 | 4==4 ✔ | 4×3=12 |
| 3 | 2 | 12 | 12 | 12==12 ✔ | 12×2=24 |
每個維度都可以通過檢查,函數最後返回 True。
例 2:對 NCHW contiguous 張量做轉置
嘗試轉置上例張量的第 0 和第 1 個維度:
python
x = torch.empty(2, 3, 4)
x = torch.transpose(x, 0, 1)
x.shape
x.stride()
輸出:
torch.Size([3, 2, 4])
(4, 12, 1)
排序後:
(4,1), (3,4), (2,12)
和上例完全一樣!
檢查流程也完全相同:
| step | size | stride | expected_stride | 判斷 | 新 expected_stride |
|---|---|---|---|---|---|
| 1 | 4 | 1 | 1 | 1==1 ✔ | 1×4=4 |
| 2 | 3 | 4 | 4 | 4==4 ✔ | 4×3=12 |
| 3 | 2 | 12 | 12 | 12==12 ✔ | 12×2=24 |
函數最終也返回 True。
這就是函數中按步長排序的作用:即使轉置後維度順序改變(張量不再是 contiguous),但經過排序後仍能通過檢查,最後仍可以被判為 non_overlapping_and_dense。
例 3:overlapping 張量
首先創建一個一維張量:
python
base = torch.arange(5)
base
輸出:
tensor([0, 1, 2, 3, 4])
接著用 as_strided 產生一個二維的 view:
python
x = base.as_strided((3, 3), (1, 1))
x
x.shape
x.stride()
輸出:
tensor([[0, 1, 2],
[1, 2, 3],
[2, 3, 4]])
torch.Size([3, 3])
(1, 1)
從輸出中可以看出來,x 張量的元素重複了。
再來追蹤函數的檢查流程,先對(尺寸、步長)對做排序:
(3,1), (3,1)
檢查流程:
| step | size | stride | expected_stride | 判斷 | 新 expected_stride |
|---|---|---|---|---|---|
| 1 | 3 | 1 | 1 | 1==1 ✔ | 1×3=3 |
| 2 | 3 | 1 | 3 | 1!=3 ✗ | --- |
張量第二個維度的步長為 1,不符合預期的 3,因此函數直接返回 False,確實地判為非 non_overlapping_and_dense。
例 4:有空洞(非稠密)
生成一個有空洞的張量:
python
x = torch.arange(12).view(2, 6)[:, :3]
x
x.shape
x.stride()
輸出:
tensor([[0, 1, 2],
[6, 7, 8]])
torch.Size([2, 3])
(6, 1)
函數首先對 x 的(尺寸、步長)對做排序:
(3,1), (2,6)
逐維檢查:
| step | size | stride | expected_stride | 判斷 |
|---|---|---|---|---|
| 1 | 3 | 1 | 1 | ✔ → expected=3 |
| 2 | 2 | 6 | 3 | 6≠3 ❌ |
函數在第 2 步檢測到步長為 6,不等於 expected_stride=3(即前一維度佔滿後的緊接位置),正確地判定記憶體中存在間隙,不符合 dense。
例 5:含 size=1 維度
用 PyTorch 張量尺寸為 1 時,步長為何不具語意? 中提到的方式創建一個 sizes = [2, 1, 3, 4], strides = [12, 240, 4, 1] 的張量:
python
x = torch.empty(1, 2, 3, 4)
x.shape
x.stride()
y = x[::10, :, :, :]
y.shape
y.stride()
z = torch.transpose(y, 0, 1)
z.shape
z.stride()
輸出:
torch.Size([1, 2, 3, 4])
(24, 12, 4, 1)
torch.Size([1, 2, 3, 4])
(240, 12, 4, 1)
torch.Size([2, 1, 3, 4])
(12, 240, 4, 1)
張量 z 的第二個維度尺寸為 1,步長任意(這裡是 240)。
排序後:
(4,1), (3,4), (2,12), (1,240)
檢查流程:
| step | size | stride | expected_stride | 判斷 | 新 expected_stride |
|---|---|---|---|---|---|
| 1 | 4 | 1 | 1 | ✔ | 4 |
| 2 | 3 | 4 | 4 | ✔ | 12 |
| 3 | 2 | 12 | 12 | ✔ | 24 |
| 4 | 1 | 240 | 24 | size=1 → skip | 24 |
最後一步的步長跳到了 240,但因為尺寸為 1,函數直接跳過,沒有被它干擾判斷。
demo
理想中 _eval_is_non_overlapping_and_dense 接受的參數應該是 Python list of int 型別,但因為函數並沒有對參數的型別做限制,實際上外部也可能傳入 list of SymInt 作為它的參數。
本節將對 _eval_is_non_overlapping_and_dense 函數傳入不同型別的參數,觀察函數的行為與回傳型別有何差異。
regular tensor with real sizes
一般預期的呼叫方式是傳入 Python list of int。以下程式先創建一個形狀為 2 x 3 x 4 的張量,再判斷它是否 non_overlapping_and_dense:
python
from torch.fx.experimental.symbolic_shapes import _eval_is_non_overlapping_and_dense
t = torch.empty(2, 3, 4)
sizes = t.size()
strides = t.stride()
result = _eval_is_non_overlapping_and_dense(list(sizes), list(strides))
print(sizes)
print(strides)
print(result, type(result).__name__)
輸出:
torch.Size([2, 3, 4])
(12, 4, 1)
True bool
因為輸入是 Python list of int,所以函數內部的 ==、!=、< 全都是 int 間的比較運算,函數最終的回傳值也是我們熟悉的 Python bool。
fake tensor with symbolic sizes(一維)
_eval_is_non_overlapping_and_dense 函數也可能被用於判斷具有動態形狀的張量,這時張量的尺寸與步長可能是 SymInt。
PyTorch動態形狀系統 簡述了 PyTorch 中建立動態形狀模型的流程,此處模擬當中的第 1、2 步:手動創建 ShapeEnv 和 FakeTensorMode,再用 from_tensor 函數把一個普通張量包裝成具有 symbolic shape 的 FakeTensor,如此一來便能得到尺寸與步長皆為 SymInt 的張量:
python
import torch
from torch.fx.experimental.symbolic_shapes import ShapeEnv, _eval_is_non_overlapping_and_dense
from torch._subclasses.fake_tensor import FakeTensorMode
shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
real = torch.empty(5)
fake = fake_mode.from_tensor(real, static_shapes=False)
print(fake.size(), [type(s).__name__ for s in fake.size()])
print(fake.stride(), [type(s).__name__ for s in fake.stride()])
輸出:
torch.Size([s0]) ['SymInt']
(1,) ['SymInt']
fake 是一個一維的 fake tensor,其尺寸為符號 s0,步長為 1,兩者皆為 SymInt。
接著把 fake 的 size() 與 stride() 包裝成 Python list,丟進 _eval_is_non_overlapping_and_dense:
python
sizes = list(fake.size())
strides = list(fake.stride())
result = _eval_is_non_overlapping_and_dense(sizes, strides)
print(result, type(result).__name__)
輸出:
True SymBool
可以看到函數果真回傳 SymBool 了!
這是為什麼呢?我們回頭來看看函數中是怎麼處理 SymInt 輸入的,在短路優化那一行:
python
if dim == 1:
return strides[0] == 1 or sizes[0] < 2
strides[0]、sizes[0] 都是 SymInt,它們與 int 做比較的結果會是 SymBool;而兩個 SymBool 經過 Python 的 or 後,得到的還是一個 SymBool,因此函數最後回傳的也是 SymBool。
fake tensor with symbolic sizes(多維)
若改為多維張量,情況就不同了。以下範例先創建一個多維張量,再將它包裝成具有 symbolic shape 的 FakeTensor,接著對它呼叫 _eval_is_non_overlapping_and_dense:
python
real = torch.empty(2, 3, 4)
fake = fake_mode.from_tensor(real, static_shapes=False)
print(fake.size(), [type(s).__name__ for s in fake.size()])
print(fake.stride(), [type(s).__name__ for s in fake.stride()])
result = _eval_is_non_overlapping_and_dense(list(fake.size()), list(fake.stride()))
print(result, type(result).__name__)
輸出:
torch.Size([s0, s1, s2]) ['SymInt', 'SymInt', 'SymInt']
(s1*s2, s2, 1) ['SymInt', 'SymInt', 'SymInt']
True bool
跟一維時的情況不同,這次函數返回了Python bool!
回頭查看函數實現,如果輸入張量的各維都能通過迴圈中的判斷,函數最後一行就會 return True。這也是輸入此處輸入全是 SymInt,但函數最後卻能回傳 Python bool 的原因。
總結
從以上demo可以看出:唯一會讓本函數回傳 SymBool 的情況,就是「輸入是 SymInt 且維度為 1」時。三種情況整理如下:
| 輸入型別 | 維度 | 回傳型別 |
|---|---|---|
| list of int | 任意 | bool |
| list of SymInt | 1 | SymBool |
| list of SymInt | ≥ 2 | bool |
下一篇將介紹 eval_is_non_overlapping_and_dense,該函數是整個 PyTorch 中唯一會調用 _eval_is_non_overlapping_and_dense 的函數,因為 _eval_is_non_overlapping_and_dense 可能回傳 bool 或 SymBool,eval_is_non_overlapping_and_dense 為了確保回傳型別一致,才在 _eval_is_non_overlapping_and_dense 外面多包了一層 guard_bool,用來將它可能收到的 SymBool 轉回 Python bool。