1. 初始化相对位置偏置嵌入
python
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
假设window_size=7、slef.heads=4,则 2 * window_size - 1 = 13;嵌入层的大小为13*13=169,创建一个大小为169*4的嵌入矩阵。
2. 创建位置索引
python
pos = torch.arange(window_size) # tensor([0, 1, 2, 3, 4, 5, 6])
pos
是一个从 0
到 window_size-1
的一维张量
3. 生成二维网格
python
grid = torch.meshgrid(pos, pos, indexing='ij')
torch.meshgrid(pos, pos, indexing='ij') 创建一个二维网格,表示每个位置的(x,y)坐标。pos
是形状为 (N,)
的张量,那么输出将是两个形状为 (N, N)
的张量。第一个张量沿着行变化,第二个张量沿着列变化。
python
# torch.meshgrid(pos, pos, indexing='ij')返回两个张量,并作为一个元组返回
(tensor([[0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5],
[6, 6, 6, 6, 6, 6, 6]]), tensor([[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6]]))
# 如果想要输出的是两个单独的张量,则可以使用下面的代码
# grid_x, grid_y = torch.meshgrid(pos, pos, indexing='ij')
# 打印单独的张量
# print(grid_x)
# print(grid_y)
3.1 torch.stack
将上述两个二维张量沿新的维度堆叠
python
grid = torch.stack(torch.meshgrid(pos, pos, indexing='ij'))
python
tensor([[[0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5],
[6, 6, 6, 6, 6, 6, 6]],
[[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6]]])
3.2 Rearrange函数
python
# 其中 c=2 代表两个网格(x 和 y 坐标),i=7 和 j=7 代表网格的维度
grid = rearrange(grid, 'c i j -> (i j) c')
# (i j) c 表示将原始张量的维度重新排列成 (i j) 和 c。
# 即将网格的每一个点((i, j))展平为一维,并将每个点的 c 个值(这里是两个值)放在新的维度 c 中
张量 grid
的形状从 (2, window_size, window_size)
变为 (window_size * window_size, 2)指的意思就是将第二个维度i和第三个维度j合并为一个维度。
形状为 (49, 2)。
python
tensor([[0, 0],
[0, 1],
[0, 2],
[0, 3],
[0, 4],
[0, 5],
[0, 6],
[1, 0],
[1, 1],
[1, 2],
[1, 3],
[1, 4],
[1, 5],
[1, 6],
[2, 0],
[2, 1],
[2, 2],
[2, 3],
[2, 4],
[2, 5],
[2, 6],
[3, 0],
[3, 1],
[3, 2],
[3, 3],
[3, 4],
[3, 5],
[3, 6],
[4, 0],
[4, 1],
[4, 2],
[4, 3],
[4, 4],
[4, 5],
[4, 6],
[5, 0],
[5, 1],
[5, 2],
[5, 3],
[5, 4],
[5, 5],
[5, 6],
[6, 0],
[6, 1],
[6, 2],
[6, 3],
[6, 4],
[6, 5],
[6, 6]])
在重排后的张量中,每个元素代表窗口内一个位置的 (x, y) 坐标对。通过这种方式,可以方便地处理窗口内位置的相对关系。
4. 计算相对位置
python
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
python
rearrange(grid, 'i ... -> i 1 ...')
将这个张量的第一个维度保持不变,同时在第二个维度的位置插入一个新的维度,新的维度大小为1。
python
rearrange(grid, 'i ... -> i 1 ...'):
tensor([
[[0, 0]], [[0, 1]], [[0, 2]], [[0, 3]], [[0, 4]], [[0, 5]], [[0, 6]],
[[1, 0]], [[1, 1]], [[1, 2]], [[1, 3]], [[1, 4]], [[1, 5]], [[1, 6]],
[[2, 0]], [[2, 1]], [[2, 2]], [[2, 3]], [[2, 4]], [[2, 5]], [[2, 6]],
[[3, 0]], [[3, 1]], [[3, 2]], [[3, 3]], [[3, 4]], [[3, 5]], [[3, 6]],
[[4, 0]], [[4, 1]], [[4, 2]], [[4, 3]], [[4, 4]], [[4, 5]], [[4, 6]],
[[5, 0]], [[5, 1]], [[5, 2]], [[5, 3]], [[5, 4]], [[5, 5]], [[5, 6]],
[[6, 0]], [[6, 1]], [[6, 2]], [[6, 3]], [[6, 4]], [[6, 5]], [[6, 6]]
])
rearrange(grid, 'j ... -> 1 j ...'):
tensor([
[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6],
[1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6],
[2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5], [2, 6],
[3, 0], [3, 1], [3, 2], [3, 3], [3, 4], [3, 5], [3, 6],
[4, 0], [4, 1], [4, 2], [4, 3], [4, 4], [4, 5], [4, 6],
[5, 0], [5, 1], [5, 2], [5, 3], [5, 4], [5, 5], [5, 6],
[6, 0], [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [6, 6]]
])
grid变形为(49,1,2)和(1,49,2),则相减时, i
维度的 49
会与第二个张量的 49
对齐,第二个维度 1
和 1
对齐。...
维度大小为 2
,这两个维度自然对齐。
减法操作通过广播机制进行:
对每个 (i, j)
对应的元素,计算 A[i, 0, :] - B[0, j, :]
。最终形状 : (49, 49, 2)
。因为 49
(第一个维度)和 49
(第二个维度)都被保留下来,并且 2
是原始的最后一个维度。
python
rel_pos += window_size - 1
# 相减得到相对位置,并加上 window_size - 1 调整索引为正:
python
tensor([
[[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
[5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
[4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
[3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
[2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
[1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
[0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],
[[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
[5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
[4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
[3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
[2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
[1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
[0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],
[[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
[5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
[4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
[3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
[2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
[1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
[0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],
[[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
[5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
[4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
[3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
[2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
[1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
[0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],
[[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
[5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
[4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
[3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
[2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
[1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
[0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],
[[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
[5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
[4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
[3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
[2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
[1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
[0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]]
])
5. 计算相对位置索引
python
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim=-1)
# rel_pose的形状为(49,49,2)
# [2 * window_size - 1, 1]的2 * window_size - 1维度需要被广播到rel_pose 49维度
# 13 会被广播到 49,1 会广播到 49
# torch.tensor([2 * window_size - 1, 1])将变为 (49, 49, 13)
相对位置索引是一个形状为 (height, width, 2)
的张量,其中 2
代表相对位置的 x
和 y
坐标。
相对位置乘以 [13, 1]
并求和得到唯一索引:
python
# 给定的张量表示了一个二维网格中每个位置的线性索引。每个数值在张量中指示了在一维数组中的线性位置
tensor([
[84, 83, 82, 81, 80, 79, 78, 71, 70, 69, 68, 67, 66, 65],
[72, 71, 70, 69, 68, 67, 66, 59, 58, 57, 56, 55, 54, 53],
[60, 59, 58, 57, 56, 55, 54, 47, 46, 45, 44, 43, 42, 41],
[48, 47, 46, 45, 44, 43, 42, 35, 34, 33, 32, 31, 30, 29],
[36, 35, 34, 33, 32, 31, 30, 23, 22, 21, 20, 19, 18, 17],
[24, 23, 22, 21, 20, 19, 18, 11, 10, 9, 8, 7, 6, 5],
[12, 11, 10, 9, 8, 7, 6, -1, -2, -3, -4, -5, -6, -7]
])
rel_pos
是一个张量,通常表示元素之间的相对位置关系。
6. 注册缓冲区
python
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent=False)
注册缓冲区(或称缓冲区,Buffer)的主要作用是在内存中预留指定大小的存储空间,用于对输入/输出(I/O)的数据进行临时存储。
rel_pos_indices
: 这是一个张量(tensor),它将被注册为缓冲区。这个张量可以是任何需要在前向传播中使用的、但不希望被优化器更新的数据。
persistent=False
: 这是一个可选参数。默认情况下,persistent
是 True
。当 persistent=True
时,这个缓冲区会被保存在模型的 state_dict
中,这样当模型被保存和加载时,这个缓冲区也会被保存和恢复。当 persistent=False
时,这个缓冲区不会被保存在 state_dict
中。这意味着如果你保存模型然后加载它,这个缓冲区的数据将不会被恢复。
**为什么有时我们需要一个 persistent=False
的缓冲区呢?**因为这个缓冲区的数据可以在模型初始化时重新计算或获取,或者这个缓冲区中的数据不是模型状态的重要部分,因此不需要在模型保存和加载时一起保存。