SwinTransformer的相对位置索引的原理以及源码分析

文章目录

  • [1. 理论分析](#1. 理论分析)
  • [2. 完整代码](#2. 完整代码)

引用:参考博客链接


1. 理论分析

根据论文中提供的公式可知是在 Q Q Q和 K K K进行匹配并除以 d \sqrt d d 后加上了相对位置偏执 B B B。

A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d + B ) V \begin{aligned} &Attention(Q,K,V) = Softmax(\frac{QK^T}{\sqrt d}+B)V \end{aligned} Attention(Q,K,V)=Softmax(d QKT+B)V

如下图,假设输入的feature map高宽都为2,那么首先我们可以构建出每个像素的绝对位置(左下方的矩阵),对于每个像素的绝对位置是使用行号和列号表示的。比如蓝色的像素对应的是第0行第0列所以绝对位置索引是(0,0),接下来再看看相对位置索引。首先看下蓝色的像素,在蓝色像素使用q与所有像素k进行匹配过程中,是以蓝色像素为参考点。然后用蓝色像素的绝对位置索引与其他位置索引进行相减,就得到其他位置相对蓝色像素的相对位置索引。例如黄色像素的绝对位置索引是(0,1),则它相对蓝色像素的相对位置索引为 ( 0 , 0 ) − ( 0 , 1 ) = ( 0 , − 1 ) (0,0)−(0,1)=(0,−1) (0,0)−(0,1)=(0,−1) 。那么同理可以得到其他位置相对蓝色像素的相对位置索引矩阵。同样,也能得到相对黄色,红色以及绿色像素的相对位置索引矩阵。接下来将每个相对位置索引矩阵按行展平,并拼接在一起可以得到下面的4x4矩阵 。

对应源代码为:

python 复制代码
import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_

window_size = [2,2]
num_heads = 3

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 绝对位置索引

coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2

请注意,这里描述的一直是相对位置索引 ,并不是相对位置偏执参数。因为后面我们会根据相对位置索引去取对应的参数。比如说黄色像素是在蓝色像素的右边,所以相对蓝色像素的相对位置索引为(0,−1)。绿色像素是在红色像素的右边,所以相对红色像素的相对位置索引为(0,−1)。可以发现这两者的相对位置索引都是(0,−1),所以他们使用的相对位置偏执参数都是一样的。

其实讲到这基本已经讲完了,但在源码中作者为了方便把二维索引给转成了一维索引。具体这么转的呢,有人肯定想到,简单啊直接把行、列索引相加不就变一维了吗?比如上面的相对位置索引中有(0,−1)和(−1,0)在二维的相对位置索引中明显是代表不同的位置,但如果简单相加都等于-1那不就出问题了吗?接下来我们看看源码中是怎么做的。首先在原始的相对位置索引上加上 ( M − 1 ) (M-1) (M−1) ( M M M为窗口的大小,在本示例中 M M M=2),加上之后索引中就不会有负数了。

对应源代码为:

python 复制代码
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

接着将所有的行标都乘上2M-1。

对应源代码为:

python 复制代码
relative_coords[:, :, 0] *= 2 * window_size[1] - 1

最后将行标和列标进行相加。这样即保证了相对位置关系,而且不会出现上述0 + ( − 1 ) = ( − 1 ) + 0 0+(-1)=(-1)+00+(−1)=(−1)+0的问题了,是不是很神奇。

python 复制代码
relative_position_index = relative_coords.sum(-1)

刚刚上面也说了,之前计算的是相对位置索引 ,并不是相对位置偏执参数 。真正使用到的可训练参数 B ^ \hat{B} B^ 是保存在relative position bias table表里的,这个表的长度是等于 ( 2 M − 1 ) × ( 2 M − 1 ) (2M−1)×(2M−1) (2M−1)×(2M−1)的。那么上述公式中的相对位置偏执参数 B B B是根据上面的相对位置索引表根据查relative position bias table表得到的,如下图所示。

对应源代码为:

python 复制代码
'''
relative_position_bias_table:
	其shape=((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
'''
relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)

'''
index:
	虽然shape=(window_size[0] * window_size[1], window_size[0] * window_size[1]),
	但是只有(2 * window_size[0] - 1) * (2 * window_size[1] - 1)个不同的元素。
	作为索引,正好能一一对应relative_position_bias_table中的元素
'''
index = relative_position_index.view(-1)
relative_position_bias = relative_position_bias_table[index] # index的每一个不同的元素对应relative_position_bias_table中一个值

relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

2. 完整代码

python 复制代码
import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_

window_size = [2,2]
num_heads = 3

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # # 绝对位置索引 # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords_temp = relative_coords.numpy()
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

relative_coords[:, :, 0] *= 2 * window_size[1] - 1

relative_position_index = relative_coords.sum(-1)

relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)

print('relative_position_bias_table:',relative_position_bias_table.shape)
print(relative_position_index.shape)

index = relative_position_index.view(-1)
print('index:',index.shape)
relative_position_bias = relative_position_bias_table[index]
print('relative_position_bias_Noreshape:',relative_position_bias.shape)

relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)

relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

print('relative_position_bias:',relative_position_bias.shape)
相关推荐
AI算法-图哥1 分钟前
pytorch量化训练
人工智能·pytorch·深度学习·文生图·模型压缩·量化
大山同学4 分钟前
DPGO:异步和并行分布式位姿图优化 2020 RA-L best paper
人工智能·分布式·语言模型·去中心化·slam·感知定位
机器学习之心4 分钟前
时序预测 | 改进图卷积+informer时间序列预测,pytorch架构
人工智能·pytorch·python·时间序列预测·informer·改进图卷积
天飓31 分钟前
基于OpenCV的自制Python访客识别程序
人工智能·python·opencv
檀越剑指大厂33 分钟前
开源AI大模型工作流神器Flowise本地部署与远程访问
人工智能·开源
声网36 分钟前
「人眼视觉不再是视频消费的唯一形式」丨智能编解码和 AI 视频生成专场回顾@RTE2024
人工智能·音视频
newxtc1 小时前
【AiPPT-注册/登录安全分析报告-无验证方式导致安全隐患】
人工智能·安全·ai写作·极验·行为验证
技术仔QAQ1 小时前
【tokenization分词】WordPiece, Byte-Pair Encoding(BPE), Byte-level BPE(BBPE)的原理和代码
人工智能·python·gpt·语言模型·自然语言处理·开源·nlp
陌上阳光1 小时前
动手学深度学习70 BERT微调
人工智能·深度学习·bert
正义的彬彬侠2 小时前
sklearn.datasets中make_classification函数
人工智能·python·机器学习·分类·sklearn