decomposed Relative Positional Embeddings的理解

文章目录

正文

relative positional embedding的一种实现方式是:先计算q和k的相对位置坐标,然后依据相对位置坐标从给定的table中取值。

以q和k都是7×7为例,每个相对位置有两个索引对应x和y两个方向,每个索引值的取值范围是[-6,6]。(第0行相对第6行,x索引相对值为-6;第6行相对第0行,x索引相对值为6;所以索引取值范围是[-6,6])。

这个时候可以构建一个shape为[13,13, head_dim]的table,则当相对位置为(i,j)时,

python 复制代码
position embedding=table[i, j]

(i,j的取值范围都是[0, 12])具体可参考:有关swin transformer相对位置编码的理解

decomposed Relative Positional Embeddings的思想在于,分别计算x和y两个方向上计算相对位置坐标,并分别从两个table中取出对应的位置编码,再将两个方向的编码相加作为最终的编码。

以q为4×4和k是4×4为例,在x和y方向上,每个索引值的取值范围是[-3,3],所以需要构建两个shape为[7, head_dim]的table:

python 复制代码
if use_rel_pos:
    assert (
        input_size is not None
    ), "Input size must be provided if using relative positional encoding."
    # initialize relative positional embeddings
    rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
    rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))

然后依据q和k的shape来计算每个方向上对应的相对位置编码:

python 复制代码
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
    # q_size和k_size分别为当前方向上,q和k的个数, rel_pos为当前方向上定义的table
    q_coords = torch.arange(q_size)[:, None] # shape: [4, 1],给当前方向上每个q编号
    k_coords = torch.arange(k_size)[None, :]  # shape:[1, 4],给当前方向上每个k编号
    relative_coords = (q_coords - k_coords) + (k_size - 1) # q_coords - k_coords就是当前方向上每个q相对于k的位置,加上k_size - 1是为了让相对位置非负
    return rel_pos[relative_coords.long()] # 依据相对位置从预定义好的table中取值

依据q和每个方向上对应的位置编码来计算最终的编码:

python 复制代码
    q_h, q_w = q_size
    k_h, k_w = k_size
    Rh = get_rel_pos(q_h, k_h, rel_pos_h) # 获取h方向的位置编码,shape:[4, 4, head_dim]
    Rw = get_rel_pos(q_w, k_w, rel_pos_w) # 获取w方向的位置编码,shape:[4, 4, head_dim]

    B, _, dim = q.shape
    r_q = q.reshape(B, q_h, q_w, dim)
    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) # r_q与Rh在h方向矩阵乘
    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
    # attn是自注意力机制计算得到的注意力图
    attn = attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
    ).view(B, q_h * q_w, k_h * k_w)

    return attn

文献来源

https://blog.csdn.net/weixin_42364196/article/details/132477924

https://github.com/microsoft/Swin-Transformer

相关推荐
拓朗工控10 分钟前
基于IBOX-602GT工控机在高精度机器视觉检测系统技术方案
人工智能·计算机视觉·视觉检测
七夜zippoe27 分钟前
OpenClaw 上下文管理:Token 优化策略
大数据·人工智能·深度学习·token·openclaw
web守墓人40 分钟前
【深度学习】Pytorch gpu加速原理探究
人工智能·pytorch·深度学习
沪漂阿龙44 分钟前
面试题:循环神经网络(RNN)是什么?词嵌入、时序建模、梯度消失、LSTM/GRU 一文讲透
人工智能·rnn·深度学习·gru·lstm
深度森林1 小时前
医学应用“手术机器人导航”高价值专利案例:基于计算机视觉的临床手术机器人导航规划方法
人工智能·计算机视觉·机器人
ZPC82101 小时前
识别物体 3D 位置 + 自动生成机器人抓取位姿」
数码相机·yolo·计算机视觉
坐望云起1 小时前
机器学习笔记 - 基于C++的深度学习 四、实现梯度下降
笔记·深度学习·机器学习
源码之家1 小时前
计算机毕业设计:Python基于知识图谱的医疗问答系统 Neo4j 机器学习 BERT 深度学习 ECharts(建议收藏)✅
python·深度学习·机器学习·信息可视化·数据分析·知识图谱·课程设计
沪漂阿龙1 小时前
面试题:传统序列模型详解——RNN、LSTM、GRU 原理、区别、优缺点一文讲透
人工智能·rnn·深度学习·gru·lstm
栈溢出了1 小时前
GAT(Graph Attention Network)学习笔记
人工智能·深度学习·算法·机器学习