torch.meshgrid网格代码解析

1. 初始化相对位置偏置嵌入

python 复制代码
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)

假设window_size=7、slef.heads=4,则 2 * window_size - 1 = 13;嵌入层的大小为13*13=169,创建一个大小为169*4的嵌入矩阵。

2. 创建位置索引

python 复制代码
pos = torch.arange(window_size)   # tensor([0, 1, 2, 3, 4, 5, 6])

pos 是一个从 0window_size-1 的一维张量

3. 生成二维网格

python 复制代码
grid = torch.meshgrid(pos, pos, indexing='ij')

torch.meshgrid(pos, pos, indexing='ij') 创建一个二维网格,表示每个位置的(x,y)坐标。pos 是形状为 (N,) 的张量,那么输出将是两个形状为 (N, N) 的张量。第一个张量沿着行变化,第二个张量沿着列变化。

python 复制代码
# torch.meshgrid(pos, pos, indexing='ij')返回两个张量,并作为一个元组返回
(tensor([[0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4, 4, 4],
        [5, 5, 5, 5, 5, 5, 5],
        [6, 6, 6, 6, 6, 6, 6]]), tensor([[0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6]]))

# 如果想要输出的是两个单独的张量,则可以使用下面的代码
# grid_x, grid_y = torch.meshgrid(pos, pos, indexing='ij')

# 打印单独的张量
# print(grid_x)
# print(grid_y)

3.1 torch.stack 将上述两个二维张量沿新的维度堆叠

python 复制代码
grid = torch.stack(torch.meshgrid(pos, pos, indexing='ij'))
python 复制代码
tensor([[[0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3, 3, 3],
         [4, 4, 4, 4, 4, 4, 4],
         [5, 5, 5, 5, 5, 5, 5],
         [6, 6, 6, 6, 6, 6, 6]],

        [[0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6]]])

3.2 Rearrange函数

python 复制代码
# 其中 c=2 代表两个网格(x 和 y 坐标),i=7 和 j=7 代表网格的维度
grid = rearrange(grid, 'c i j -> (i j) c')

# (i j) c 表示将原始张量的维度重新排列成 (i j) 和 c。
# 即将网格的每一个点((i, j))展平为一维,并将每个点的 c 个值(这里是两个值)放在新的维度 c 中

张量 grid 的形状从 (2, window_size, window_size) 变为 (window_size * window_size, 2)指的意思就是将第二个维度i和第三个维度j合并为一个维度。形状为 (49, 2)。

python 复制代码
tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [0, 4],
        [0, 5],
        [0, 6],
        [1, 0],
        [1, 1],
        [1, 2],
        [1, 3],
        [1, 4],
        [1, 5],
        [1, 6],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 3],
        [2, 4],
        [2, 5],
        [2, 6],
        [3, 0],
        [3, 1],
        [3, 2],
        [3, 3],
        [3, 4],
        [3, 5],
        [3, 6],
        [4, 0],
        [4, 1],
        [4, 2],
        [4, 3],
        [4, 4],
        [4, 5],
        [4, 6],
        [5, 0],
        [5, 1],
        [5, 2],
        [5, 3],
        [5, 4],
        [5, 5],
        [5, 6],
        [6, 0],
        [6, 1],
        [6, 2],
        [6, 3],
        [6, 4],
        [6, 5],
        [6, 6]])

在重排后的张量中,每个元素代表窗口内一个位置的 (x, y) 坐标对。通过这种方式,可以方便地处理窗口内位置的相对关系。

4. 计算相对位置

python 复制代码
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
python 复制代码
rearrange(grid, 'i ... -> i 1 ...') 

将这个张量的第一个维度保持不变,同时在第二个维度的位置插入一个新的维度,新的维度大小为1。

python 复制代码
rearrange(grid, 'i ... -> i 1 ...'):
tensor([
  [[0, 0]], [[0, 1]], [[0, 2]], [[0, 3]], [[0, 4]], [[0, 5]], [[0, 6]],
  [[1, 0]], [[1, 1]], [[1, 2]], [[1, 3]], [[1, 4]], [[1, 5]], [[1, 6]],
  [[2, 0]], [[2, 1]], [[2, 2]], [[2, 3]], [[2, 4]], [[2, 5]], [[2, 6]],
  [[3, 0]], [[3, 1]], [[3, 2]], [[3, 3]], [[3, 4]], [[3, 5]], [[3, 6]],
  [[4, 0]], [[4, 1]], [[4, 2]], [[4, 3]], [[4, 4]], [[4, 5]], [[4, 6]],
  [[5, 0]], [[5, 1]], [[5, 2]], [[5, 3]], [[5, 4]], [[5, 5]], [[5, 6]],
  [[6, 0]], [[6, 1]], [[6, 2]], [[6, 3]], [[6, 4]], [[6, 5]], [[6, 6]]
])

rearrange(grid, 'j ... -> 1 j ...'):
tensor([
  [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6],
   [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6],
   [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5], [2, 6],
   [3, 0], [3, 1], [3, 2], [3, 3], [3, 4], [3, 5], [3, 6],
   [4, 0], [4, 1], [4, 2], [4, 3], [4, 4], [4, 5], [4, 6],
   [5, 0], [5, 1], [5, 2], [5, 3], [5, 4], [5, 5], [5, 6],
   [6, 0], [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [6, 6]]
])

grid变形为(49,1,2)和(1,49,2),则相减时, i 维度的 49 会与第二个张量的 49 对齐,第二个维度 11 对齐。... 维度大小为 2,这两个维度自然对齐。

减法操作通过广播机制进行:

对每个 (i, j) 对应的元素,计算 A[i, 0, :] - B[0, j, :]最终形状 : (49, 49, 2)。因为 49(第一个维度)和 49(第二个维度)都被保留下来,并且 2 是原始的最后一个维度。

python 复制代码
rel_pos += window_size - 1
# 相减得到相对位置,并加上 window_size - 1 调整索引为正:
python 复制代码
tensor([
  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]]
])

5. 计算相对位置索引

python 复制代码
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim=-1)

# rel_pose的形状为(49,49,2)
# [2 * window_size - 1, 1]的2 * window_size - 1维度需要被广播到rel_pose 49维度
# 13 会被广播到 49,1 会广播到 49

# torch.tensor([2 * window_size - 1, 1])将变为 (49, 49, 13)

相对位置索引是一个形状为 (height, width, 2) 的张量,其中 2 代表相对位置的 xy 坐标。

相对位置乘以 [13, 1] 并求和得到唯一索引:

python 复制代码
# 给定的张量表示了一个二维网格中每个位置的线性索引。每个数值在张量中指示了在一维数组中的线性位置
tensor([
  [84, 83, 82, 81, 80, 79, 78, 71, 70, 69, 68, 67, 66, 65],
  [72, 71, 70, 69, 68, 67, 66, 59, 58, 57, 56, 55, 54, 53],
  [60, 59, 58, 57, 56, 55, 54, 47, 46, 45, 44, 43, 42, 41],
  [48, 47, 46, 45, 44, 43, 42, 35, 34, 33, 32, 31, 30, 29],
  [36, 35, 34, 33, 32, 31, 30, 23, 22, 21, 20, 19, 18, 17],
  [24, 23, 22, 21, 20, 19, 18, 11, 10, 9,  8,  7,  6,  5],
  [12, 11, 10,  9,  8,  7,  6, -1, -2, -3, -4, -5, -6, -7]
])

rel_pos 是一个张量,通常表示元素之间的相对位置关系。

6. 注册缓冲区

python 复制代码
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent=False)

注册缓冲区(或称缓冲区,Buffer)的主要作用是在内存中预留指定大小的存储空间,用于对输入/输出(I/O)的数据进行临时存储。

rel_pos_indices: 这是一个张量(tensor),它将被注册为缓冲区。这个张量可以是任何需要在前向传播中使用的、但不希望被优化器更新的数据。

persistent=False: 这是一个可选参数。默认情况下,persistentTrue。当 persistent=True 时,这个缓冲区会被保存在模型的 state_dict 中,这样当模型被保存和加载时,这个缓冲区也会被保存和恢复。当 persistent=False 时,这个缓冲区不会被保存在 state_dict 中。这意味着如果你保存模型然后加载它,这个缓冲区的数据将不会被恢复。

**为什么有时我们需要一个 persistent=False 的缓冲区呢?**因为这个缓冲区的数据可以在模型初始化时重新计算或获取,或者这个缓冲区中的数据不是模型状态的重要部分,因此不需要在模型保存和加载时一起保存。

相关推荐
bastgia36 分钟前
Tokenformer: 下一代Transformer架构
人工智能·机器学习·llm
菜狗woc44 分钟前
opencv-python的简单练习
人工智能·python·opencv
15年网络推广青哥1 小时前
国际抖音TikTok矩阵运营的关键要素有哪些?
大数据·人工智能·矩阵
weixin_387545641 小时前
探索 AnythingLLM:借助开源 AI 打造私有化智能知识库
人工智能
engchina2 小时前
如何在 Python 中忽略烦人的警告?
开发语言·人工智能·python
paixiaoxin2 小时前
CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
人工智能·深度学习·机器学习·生成对抗网络·计算机视觉·ocr·.net
OpenCSG3 小时前
CSGHub开源版本v1.2.0更新
人工智能
weixin_515202493 小时前
第R3周:RNN-心脏病预测
人工智能·rnn·深度学习
Altair澳汰尔3 小时前
数据分析和AI丨知识图谱,AI革命中数据集成和模型构建的关键推动者
人工智能·算法·机器学习·数据分析·知识图谱
机器之心3 小时前
图学习新突破:一个统一框架连接空域和频域
人工智能·后端