torch.meshgrid网格代码解析

1. 初始化相对位置偏置嵌入

python 复制代码
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)

假设window_size=7、slef.heads=4,则 2 * window_size - 1 = 13;嵌入层的大小为13*13=169,创建一个大小为169*4的嵌入矩阵。

2. 创建位置索引

python 复制代码
pos = torch.arange(window_size)   # tensor([0, 1, 2, 3, 4, 5, 6])

pos 是一个从 0window_size-1 的一维张量

3. 生成二维网格

python 复制代码
grid = torch.meshgrid(pos, pos, indexing='ij')

torch.meshgrid(pos, pos, indexing='ij') 创建一个二维网格,表示每个位置的(x,y)坐标。pos 是形状为 (N,) 的张量,那么输出将是两个形状为 (N, N) 的张量。第一个张量沿着行变化,第二个张量沿着列变化。

python 复制代码
# torch.meshgrid(pos, pos, indexing='ij')返回两个张量,并作为一个元组返回
(tensor([[0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4, 4, 4],
        [5, 5, 5, 5, 5, 5, 5],
        [6, 6, 6, 6, 6, 6, 6]]), tensor([[0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6]]))

# 如果想要输出的是两个单独的张量,则可以使用下面的代码
# grid_x, grid_y = torch.meshgrid(pos, pos, indexing='ij')

# 打印单独的张量
# print(grid_x)
# print(grid_y)

3.1 torch.stack 将上述两个二维张量沿新的维度堆叠

python 复制代码
grid = torch.stack(torch.meshgrid(pos, pos, indexing='ij'))
python 复制代码
tensor([[[0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3, 3, 3],
         [4, 4, 4, 4, 4, 4, 4],
         [5, 5, 5, 5, 5, 5, 5],
         [6, 6, 6, 6, 6, 6, 6]],

        [[0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6]]])

3.2 Rearrange函数

python 复制代码
# 其中 c=2 代表两个网格(x 和 y 坐标),i=7 和 j=7 代表网格的维度
grid = rearrange(grid, 'c i j -> (i j) c')

# (i j) c 表示将原始张量的维度重新排列成 (i j) 和 c。
# 即将网格的每一个点((i, j))展平为一维,并将每个点的 c 个值(这里是两个值)放在新的维度 c 中

张量 grid 的形状从 (2, window_size, window_size) 变为 (window_size * window_size, 2)指的意思就是将第二个维度i和第三个维度j合并为一个维度。形状为 (49, 2)。

python 复制代码
tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [0, 4],
        [0, 5],
        [0, 6],
        [1, 0],
        [1, 1],
        [1, 2],
        [1, 3],
        [1, 4],
        [1, 5],
        [1, 6],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 3],
        [2, 4],
        [2, 5],
        [2, 6],
        [3, 0],
        [3, 1],
        [3, 2],
        [3, 3],
        [3, 4],
        [3, 5],
        [3, 6],
        [4, 0],
        [4, 1],
        [4, 2],
        [4, 3],
        [4, 4],
        [4, 5],
        [4, 6],
        [5, 0],
        [5, 1],
        [5, 2],
        [5, 3],
        [5, 4],
        [5, 5],
        [5, 6],
        [6, 0],
        [6, 1],
        [6, 2],
        [6, 3],
        [6, 4],
        [6, 5],
        [6, 6]])

在重排后的张量中,每个元素代表窗口内一个位置的 (x, y) 坐标对。通过这种方式,可以方便地处理窗口内位置的相对关系。

4. 计算相对位置

python 复制代码
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
python 复制代码
rearrange(grid, 'i ... -> i 1 ...') 

将这个张量的第一个维度保持不变,同时在第二个维度的位置插入一个新的维度,新的维度大小为1。

python 复制代码
rearrange(grid, 'i ... -> i 1 ...'):
tensor([
  [[0, 0]], [[0, 1]], [[0, 2]], [[0, 3]], [[0, 4]], [[0, 5]], [[0, 6]],
  [[1, 0]], [[1, 1]], [[1, 2]], [[1, 3]], [[1, 4]], [[1, 5]], [[1, 6]],
  [[2, 0]], [[2, 1]], [[2, 2]], [[2, 3]], [[2, 4]], [[2, 5]], [[2, 6]],
  [[3, 0]], [[3, 1]], [[3, 2]], [[3, 3]], [[3, 4]], [[3, 5]], [[3, 6]],
  [[4, 0]], [[4, 1]], [[4, 2]], [[4, 3]], [[4, 4]], [[4, 5]], [[4, 6]],
  [[5, 0]], [[5, 1]], [[5, 2]], [[5, 3]], [[5, 4]], [[5, 5]], [[5, 6]],
  [[6, 0]], [[6, 1]], [[6, 2]], [[6, 3]], [[6, 4]], [[6, 5]], [[6, 6]]
])

rearrange(grid, 'j ... -> 1 j ...'):
tensor([
  [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6],
   [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6],
   [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5], [2, 6],
   [3, 0], [3, 1], [3, 2], [3, 3], [3, 4], [3, 5], [3, 6],
   [4, 0], [4, 1], [4, 2], [4, 3], [4, 4], [4, 5], [4, 6],
   [5, 0], [5, 1], [5, 2], [5, 3], [5, 4], [5, 5], [5, 6],
   [6, 0], [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [6, 6]]
])

grid变形为(49,1,2)和(1,49,2),则相减时, i 维度的 49 会与第二个张量的 49 对齐,第二个维度 11 对齐。... 维度大小为 2,这两个维度自然对齐。

减法操作通过广播机制进行:

对每个 (i, j) 对应的元素,计算 A[i, 0, :] - B[0, j, :]最终形状 : (49, 49, 2)。因为 49(第一个维度)和 49(第二个维度)都被保留下来,并且 2 是原始的最后一个维度。

python 复制代码
rel_pos += window_size - 1
# 相减得到相对位置,并加上 window_size - 1 调整索引为正:
python 复制代码
tensor([
  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]]
])

5. 计算相对位置索引

python 复制代码
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim=-1)

# rel_pose的形状为(49,49,2)
# [2 * window_size - 1, 1]的2 * window_size - 1维度需要被广播到rel_pose 49维度
# 13 会被广播到 49,1 会广播到 49

# torch.tensor([2 * window_size - 1, 1])将变为 (49, 49, 13)

相对位置索引是一个形状为 (height, width, 2) 的张量,其中 2 代表相对位置的 xy 坐标。

相对位置乘以 [13, 1] 并求和得到唯一索引:

python 复制代码
# 给定的张量表示了一个二维网格中每个位置的线性索引。每个数值在张量中指示了在一维数组中的线性位置
tensor([
  [84, 83, 82, 81, 80, 79, 78, 71, 70, 69, 68, 67, 66, 65],
  [72, 71, 70, 69, 68, 67, 66, 59, 58, 57, 56, 55, 54, 53],
  [60, 59, 58, 57, 56, 55, 54, 47, 46, 45, 44, 43, 42, 41],
  [48, 47, 46, 45, 44, 43, 42, 35, 34, 33, 32, 31, 30, 29],
  [36, 35, 34, 33, 32, 31, 30, 23, 22, 21, 20, 19, 18, 17],
  [24, 23, 22, 21, 20, 19, 18, 11, 10, 9,  8,  7,  6,  5],
  [12, 11, 10,  9,  8,  7,  6, -1, -2, -3, -4, -5, -6, -7]
])

rel_pos 是一个张量,通常表示元素之间的相对位置关系。

6. 注册缓冲区

python 复制代码
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent=False)

注册缓冲区(或称缓冲区,Buffer)的主要作用是在内存中预留指定大小的存储空间,用于对输入/输出(I/O)的数据进行临时存储。

rel_pos_indices: 这是一个张量(tensor),它将被注册为缓冲区。这个张量可以是任何需要在前向传播中使用的、但不希望被优化器更新的数据。

persistent=False: 这是一个可选参数。默认情况下,persistentTrue。当 persistent=True 时,这个缓冲区会被保存在模型的 state_dict 中,这样当模型被保存和加载时,这个缓冲区也会被保存和恢复。当 persistent=False 时,这个缓冲区不会被保存在 state_dict 中。这意味着如果你保存模型然后加载它,这个缓冲区的数据将不会被恢复。

**为什么有时我们需要一个 persistent=False 的缓冲区呢?**因为这个缓冲区的数据可以在模型初始化时重新计算或获取,或者这个缓冲区中的数据不是模型状态的重要部分,因此不需要在模型保存和加载时一起保存。

相关推荐
Mintopia21 分钟前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮1 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬1 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia2 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区2 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两4 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪5 小时前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain
strayCat232555 小时前
Clawdbot 源码解读 7: 扩展机制
人工智能·开源
王鑫星5 小时前
SWE-bench 首次突破 80%:Claude Opus 4.5 发布,Anthropic 的野心不止于写代码
人工智能