MUSE: PARALLEL MULTI-SCALE ATTENTION FOR SEQUENCE TO SEQUENCE LEARNING 笔记

来源:

MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning

相关工作:

#自注意力机制 #卷积神经网络 #动态卷积 #局部注意力 #并行计算

创新点:

贡献:

  1. 并行多尺度注意力机制

    • MUSE引入了并行多尺度注意力机制,该机制同时捕获序列数据中的长距离和短距离语言结构。这是通过对序列进行不同尺度的并行编码来实现的,利用自注意力和逐点变换。
  2. 结合卷积和自注意力

    • 为了解决自注意力在处理长序列时可能忽略局部信息的问题,MUSE模型将卷积操作和自注意力结合起来,以更有效地学习序列的局部和全局特征。
  3. 共享投影空间

    • 文章强调了在自注意力和卷积操作中使用共享投影空间的重要性。这种设计使得两种操作能够在相同的隐藏空间中进行,从而有助于整合局部和全局特征表示。
  4. 动态卷积核选择

    • MUSE引入了一种门控机制,用于自动选择不同卷积单元的权重,这允许模型动态地选择最适合当前层的卷积核大小。

代码:

python 复制代码
# ---------------------------------------  
# 论文: MUSE: PARALLEL MULTI-SCALE ATTENTION FOR SEQUENCE TO SEQUENCE LEARNING (arxiv 2019)  
# ---------------------------------------  
import numpy as np  
import torch  
from torch import nn  
from torch.nn import init  
  
  
class Depth_Pointwise_Conv1d(nn.Module):  
    def __init__(self, in_ch, out_ch, k):  
        super().__init__()  
        if (k == 1):  
            self.depth_conv = nn.Identity()  
        else:  
            self.depth_conv = nn.Conv1d(  
                in_channels=in_ch,  
                out_channels=in_ch,  
                kernel_size=k,  
                groups=in_ch,  
                padding=k // 2  
            )  
        self.pointwise_conv = nn.Conv1d(  
            in_channels=in_ch,  
            out_channels=out_ch,  
            kernel_size=1,  
            groups=1  
        )  
  
    def forward(self, x):  
        out = self.pointwise_conv(self.depth_conv(x))  
        return out  
  
  
class MUSEAttention(nn.Module):  
  
    def __init__(self, d_model, d_k, d_v, h, dropout=.1):  
  
        super(MUSEAttention, self).__init__()  
        self.fc_q = nn.Linear(d_model, h * d_k)  
        self.fc_k = nn.Linear(d_model, h * d_k)  
        self.fc_v = nn.Linear(d_model, h * d_v)  
        self.fc_o = nn.Linear(h * d_v, d_model)  
        self.dropout = nn.Dropout(dropout)  
  
        self.conv1 = Depth_Pointwise_Conv1d(h * d_v, d_model, 1)  
        self.conv3 = Depth_Pointwise_Conv1d(h * d_v, d_model, 3)  
        self.conv5 = Depth_Pointwise_Conv1d(h * d_v, d_model, 5)  
        self.dy_paras = nn.Parameter(torch.ones(3))  
        self.softmax = nn.Softmax(-1)  
  
        self.d_model = d_model  
        self.d_k = d_k  
        self.d_v = d_v  
        self.h = h  
  
        self.init_weights()  
  
    def init_weights(self):  
        for m in self.modules():  
            if isinstance(m, nn.Conv2d):  
                init.kaiming_normal_(m.weight, mode='fan_out')  
                if m.bias is not None:  
                    init.constant_(m.bias, 0)  
            elif isinstance(m, nn.BatchNorm2d):  
                init.constant_(m.weight, 1)  
                init.constant_(m.bias, 0)  
            elif isinstance(m, nn.Linear):  
                init.normal_(m.weight, std=0.001)  
                if m.bias is not None:  
                    init.constant_(m.bias, 0)  
  
    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):  
  
        # Self Attention  
        b_s, nq = queries.shape[:2]  
        nk = keys.shape[1]  
  
        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)  
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)  
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)  
  
        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)  
        if attention_weights is not None:  
            att = att * attention_weights  
        if attention_mask is not None:  
            att = att.masked_fill(attention_mask, -np.inf)  
        att = torch.softmax(att, -1)  
        att = self.dropout(att)  
  
        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)  
        out = self.fc_o(out)  # (b_s, nq, d_model)  
  
        v2 = v.permute(0, 1, 3, 2).contiguous().view(b_s, -1, nk)  # bs,dim,n  
        self.dy_paras = nn.Parameter(self.softmax(self.dy_paras))  
        out2 = self.dy_paras[0] * self.conv1(v2) + self.dy_paras[1] * self.conv3(v2) + self.dy_paras[2] * self.conv5(v2)  
        out2 = out2.permute(0, 2, 1)  # bs.n.dim  
  
        out = out + out2  
        return out  
  
  
# 输入 B N C,  输出 B N Cif __name__ == '__main__':  
    block = MUSEAttention(d_model=32, d_k=32, d_v=32, h=8).cuda()  
    print("开始运行")  
    input = torch.rand(3, 64, 32).cuda()  
    output = block(input, input, input)  
    print(input.size(), output.size())  
    print('结束运行')
相关推荐
代码小将1 小时前
Leetcode209做题笔记
java·笔记·算法
朗迹 - 张伟1 小时前
UE5 PCG学习笔记
笔记·学习·ue5
寻丶幽风3 小时前
论文阅读笔记——双流网络
论文阅读·笔记·深度学习·视频理解·双流网络
sz66cm8 小时前
Linux基础 -- SSH 流式烧录与压缩传输笔记
linux·笔记·ssh
开发游戏的老王9 小时前
[虚幻官方教程学习笔记]深入理解实时渲染(An In-Depth Look at Real-Time Rendering)
笔记·学习·虚幻
愚润求学10 小时前
【Linux】Ext系列文件系统
linux·运维·服务器·笔记
幸好我会魔法12 小时前
使用githubPage+hexo搭建个人博客
笔记·github
jackson凌12 小时前
【Java学习笔记】finalize方法
java·笔记·学习
能来帮帮蒟蒻吗12 小时前
VUE3 -综合实践(Mock+Axios+ElementPlus)
前端·javascript·vue.js·笔记·学习·ajax·typescript
XQ丶YTY13 小时前
大二java第一面小厂(挂)
java·开发语言·笔记·学习·面试