SwinTransformer的相对位置索引的原理以及源码分析

文章目录

  • [1. 理论分析](#1. 理论分析)
  • [2. 完整代码](#2. 完整代码)

引用:参考博客链接


1. 理论分析

根据论文中提供的公式可知是在 Q Q Q和 K K K进行匹配并除以 d \sqrt d d 后加上了相对位置偏执 B B B。

A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d + B ) V \begin{aligned} &Attention(Q,K,V) = Softmax(\frac{QK^T}{\sqrt d}+B)V \end{aligned} Attention(Q,K,V)=Softmax(d QKT+B)V

如下图,假设输入的feature map高宽都为2,那么首先我们可以构建出每个像素的绝对位置(左下方的矩阵),对于每个像素的绝对位置是使用行号和列号表示的。比如蓝色的像素对应的是第0行第0列所以绝对位置索引是(0,0),接下来再看看相对位置索引。首先看下蓝色的像素,在蓝色像素使用q与所有像素k进行匹配过程中,是以蓝色像素为参考点。然后用蓝色像素的绝对位置索引与其他位置索引进行相减,就得到其他位置相对蓝色像素的相对位置索引。例如黄色像素的绝对位置索引是(0,1),则它相对蓝色像素的相对位置索引为 ( 0 , 0 ) − ( 0 , 1 ) = ( 0 , − 1 ) (0,0)−(0,1)=(0,−1) (0,0)−(0,1)=(0,−1) 。那么同理可以得到其他位置相对蓝色像素的相对位置索引矩阵。同样,也能得到相对黄色,红色以及绿色像素的相对位置索引矩阵。接下来将每个相对位置索引矩阵按行展平,并拼接在一起可以得到下面的4x4矩阵 。

对应源代码为:

python 复制代码
import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_

window_size = [2,2]
num_heads = 3

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 绝对位置索引

coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2

请注意,这里描述的一直是相对位置索引 ,并不是相对位置偏执参数。因为后面我们会根据相对位置索引去取对应的参数。比如说黄色像素是在蓝色像素的右边,所以相对蓝色像素的相对位置索引为(0,−1)。绿色像素是在红色像素的右边,所以相对红色像素的相对位置索引为(0,−1)。可以发现这两者的相对位置索引都是(0,−1),所以他们使用的相对位置偏执参数都是一样的。

其实讲到这基本已经讲完了,但在源码中作者为了方便把二维索引给转成了一维索引。具体这么转的呢,有人肯定想到,简单啊直接把行、列索引相加不就变一维了吗?比如上面的相对位置索引中有(0,−1)和(−1,0)在二维的相对位置索引中明显是代表不同的位置,但如果简单相加都等于-1那不就出问题了吗?接下来我们看看源码中是怎么做的。首先在原始的相对位置索引上加上 ( M − 1 ) (M-1) (M−1) ( M M M为窗口的大小,在本示例中 M M M=2),加上之后索引中就不会有负数了。

对应源代码为:

python 复制代码
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

接着将所有的行标都乘上2M-1。

对应源代码为:

python 复制代码
relative_coords[:, :, 0] *= 2 * window_size[1] - 1

最后将行标和列标进行相加。这样即保证了相对位置关系,而且不会出现上述0 + ( − 1 ) = ( − 1 ) + 0 0+(-1)=(-1)+00+(−1)=(−1)+0的问题了,是不是很神奇。

python 复制代码
relative_position_index = relative_coords.sum(-1)

刚刚上面也说了,之前计算的是相对位置索引 ,并不是相对位置偏执参数 。真正使用到的可训练参数 B ^ \hat{B} B^ 是保存在relative position bias table表里的,这个表的长度是等于 ( 2 M − 1 ) × ( 2 M − 1 ) (2M−1)×(2M−1) (2M−1)×(2M−1)的。那么上述公式中的相对位置偏执参数 B B B是根据上面的相对位置索引表根据查relative position bias table表得到的,如下图所示。

对应源代码为:

python 复制代码
'''
relative_position_bias_table:
	其shape=((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
'''
relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)

'''
index:
	虽然shape=(window_size[0] * window_size[1], window_size[0] * window_size[1]),
	但是只有(2 * window_size[0] - 1) * (2 * window_size[1] - 1)个不同的元素。
	作为索引,正好能一一对应relative_position_bias_table中的元素
'''
index = relative_position_index.view(-1)
relative_position_bias = relative_position_bias_table[index] # index的每一个不同的元素对应relative_position_bias_table中一个值

relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

2. 完整代码

python 复制代码
import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_

window_size = [2,2]
num_heads = 3

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # # 绝对位置索引 # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords_temp = relative_coords.numpy()
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

relative_coords[:, :, 0] *= 2 * window_size[1] - 1

relative_position_index = relative_coords.sum(-1)

relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)

print('relative_position_bias_table:',relative_position_bias_table.shape)
print(relative_position_index.shape)

index = relative_position_index.view(-1)
print('index:',index.shape)
relative_position_bias = relative_position_bias_table[index]
print('relative_position_bias_Noreshape:',relative_position_bias.shape)

relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)

relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

print('relative_position_bias:',relative_position_bias.shape)
相关推荐
桃源学社(接毕设)2 分钟前
基于MATLAB的运动模糊图像修复方法研究(LW+源码+讲解+部署)
图像处理·深度学习·计算机视觉·matlab·毕业设计·图像滤波去噪
钢铁男儿6 分钟前
PyTorch 机器学习基础(机器学习一般流程)
人工智能·pytorch·机器学习
老鱼说AI7 分钟前
当自回归模型遇上扩散模型:下一代序列预测模型详解与Pytorch实现
人工智能·pytorch·深度学习·神经网络·语言模型·自然语言处理·stable diffusion
兰亭妙微7 分钟前
用户体验设计 | 从UX到AX:人工智能如何重构交互范式?
人工智能·重构·ux
2501_9247311111 分钟前
智慧城市交通场景误检率↓78%!陌讯多模态融合算法实战解析
人工智能·算法·目标检测·视觉检测·智慧城市
掘金安东尼1 小时前
机器在看“断言”:AI 消费时代的内容策略升级
人工智能
木头左1 小时前
利用机器学习优化Backtrader策略原理与实践
人工智能·机器学习
2501_924534894 小时前
智慧零售商品识别误报率↓74%!陌讯多模态融合算法在自助结算场景的落地优化
大数据·人工智能·算法·计算机视觉·目标跟踪·视觉检测·零售
盖雅工场4 小时前
连锁零售排班难?自动排班系统来解决
大数据·人工智能·物联网·算法·零售
bryant_meng6 小时前
【Apache MXNet】
人工智能·apache·mxnet