SwinTransformer的相对位置索引的原理以及源码分析

文章目录

  • [1. 理论分析](#1. 理论分析)
  • [2. 完整代码](#2. 完整代码)

引用:参考博客链接


1. 理论分析

根据论文中提供的公式可知是在 Q Q Q和 K K K进行匹配并除以 d \sqrt d d 后加上了相对位置偏执 B B B。

A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d + B ) V \begin{aligned} &Attention(Q,K,V) = Softmax(\frac{QK^T}{\sqrt d}+B)V \end{aligned} Attention(Q,K,V)=Softmax(d QKT+B)V

如下图,假设输入的feature map高宽都为2,那么首先我们可以构建出每个像素的绝对位置(左下方的矩阵),对于每个像素的绝对位置是使用行号和列号表示的。比如蓝色的像素对应的是第0行第0列所以绝对位置索引是(0,0),接下来再看看相对位置索引。首先看下蓝色的像素,在蓝色像素使用q与所有像素k进行匹配过程中,是以蓝色像素为参考点。然后用蓝色像素的绝对位置索引与其他位置索引进行相减,就得到其他位置相对蓝色像素的相对位置索引。例如黄色像素的绝对位置索引是(0,1),则它相对蓝色像素的相对位置索引为 ( 0 , 0 ) − ( 0 , 1 ) = ( 0 , − 1 ) (0,0)−(0,1)=(0,−1) (0,0)−(0,1)=(0,−1) 。那么同理可以得到其他位置相对蓝色像素的相对位置索引矩阵。同样,也能得到相对黄色,红色以及绿色像素的相对位置索引矩阵。接下来将每个相对位置索引矩阵按行展平,并拼接在一起可以得到下面的4x4矩阵 。

对应源代码为:

python 复制代码
import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_

window_size = [2,2]
num_heads = 3

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 绝对位置索引

coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2

请注意,这里描述的一直是相对位置索引 ,并不是相对位置偏执参数。因为后面我们会根据相对位置索引去取对应的参数。比如说黄色像素是在蓝色像素的右边,所以相对蓝色像素的相对位置索引为(0,−1)。绿色像素是在红色像素的右边,所以相对红色像素的相对位置索引为(0,−1)。可以发现这两者的相对位置索引都是(0,−1),所以他们使用的相对位置偏执参数都是一样的。

其实讲到这基本已经讲完了,但在源码中作者为了方便把二维索引给转成了一维索引。具体这么转的呢,有人肯定想到,简单啊直接把行、列索引相加不就变一维了吗?比如上面的相对位置索引中有(0,−1)和(−1,0)在二维的相对位置索引中明显是代表不同的位置,但如果简单相加都等于-1那不就出问题了吗?接下来我们看看源码中是怎么做的。首先在原始的相对位置索引上加上 ( M − 1 ) (M-1) (M−1) ( M M M为窗口的大小,在本示例中 M M M=2),加上之后索引中就不会有负数了。

对应源代码为:

python 复制代码
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

接着将所有的行标都乘上2M-1。

对应源代码为:

python 复制代码
relative_coords[:, :, 0] *= 2 * window_size[1] - 1

最后将行标和列标进行相加。这样即保证了相对位置关系,而且不会出现上述0 + ( − 1 ) = ( − 1 ) + 0 0+(-1)=(-1)+00+(−1)=(−1)+0的问题了,是不是很神奇。

python 复制代码
relative_position_index = relative_coords.sum(-1)

刚刚上面也说了,之前计算的是相对位置索引 ,并不是相对位置偏执参数 。真正使用到的可训练参数 B ^ \hat{B} B^ 是保存在relative position bias table表里的,这个表的长度是等于 ( 2 M − 1 ) × ( 2 M − 1 ) (2M−1)×(2M−1) (2M−1)×(2M−1)的。那么上述公式中的相对位置偏执参数 B B B是根据上面的相对位置索引表根据查relative position bias table表得到的,如下图所示。

对应源代码为:

python 复制代码
'''
relative_position_bias_table:
	其shape=((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
'''
relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)

'''
index:
	虽然shape=(window_size[0] * window_size[1], window_size[0] * window_size[1]),
	但是只有(2 * window_size[0] - 1) * (2 * window_size[1] - 1)个不同的元素。
	作为索引,正好能一一对应relative_position_bias_table中的元素
'''
index = relative_position_index.view(-1)
relative_position_bias = relative_position_bias_table[index] # index的每一个不同的元素对应relative_position_bias_table中一个值

relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

2. 完整代码

python 复制代码
import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_

window_size = [2,2]
num_heads = 3

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # # 绝对位置索引 # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords_temp = relative_coords.numpy()
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

relative_coords[:, :, 0] *= 2 * window_size[1] - 1

relative_position_index = relative_coords.sum(-1)

relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)

print('relative_position_bias_table:',relative_position_bias_table.shape)
print(relative_position_index.shape)

index = relative_position_index.view(-1)
print('index:',index.shape)
relative_position_bias = relative_position_bias_table[index]
print('relative_position_bias_Noreshape:',relative_position_bias.shape)

relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)

relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

print('relative_position_bias:',relative_position_bias.shape)
相关推荐
Elastic 中国社区官方博客13 分钟前
使用 Vertex AI Gemini 模型和 Elasticsearch Playground 快速创建 RAG 应用程序
大数据·人工智能·elasticsearch·搜索引擎·全文检索
说私域38 分钟前
地理定位营销与开源AI智能名片O2O商城小程序的融合与发展
人工智能·小程序
Q_w77421 小时前
计算机视觉小目标检测模型
人工智能·目标检测·计算机视觉
创意锦囊1 小时前
ChatGPT推出Canvas功能
人工智能·chatgpt
知来者逆1 小时前
V3D——从单一图像生成 3D 物体
人工智能·计算机视觉·3d·图像生成
碳苯2 小时前
【rCore OS 开源操作系统】Rust 枚举与模式匹配
开发语言·人工智能·后端·rust·操作系统·os
whaosoft-1432 小时前
51c视觉~CV~合集3
人工智能
网络研究院4 小时前
如何安全地大规模部署 GenAI 应用程序
网络·人工智能·安全·ai·部署·观点
凭栏落花侧4 小时前
决策树:简单易懂的预测模型
人工智能·算法·决策树·机器学习·信息可视化·数据挖掘·数据分析
xiandong207 小时前
240929-CGAN条件生成对抗网络
图像处理·人工智能·深度学习·神经网络·生成对抗网络·计算机视觉