MUSE: PARALLEL MULTI-SCALE ATTENTION FOR SEQUENCE TO SEQUENCE LEARNING 笔记

来源:

MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning

相关工作:

#自注意力机制 #卷积神经网络 #动态卷积 #局部注意力 #并行计算

创新点:

贡献:

  1. 并行多尺度注意力机制

    • MUSE引入了并行多尺度注意力机制,该机制同时捕获序列数据中的长距离和短距离语言结构。这是通过对序列进行不同尺度的并行编码来实现的,利用自注意力和逐点变换。
  2. 结合卷积和自注意力

    • 为了解决自注意力在处理长序列时可能忽略局部信息的问题,MUSE模型将卷积操作和自注意力结合起来,以更有效地学习序列的局部和全局特征。
  3. 共享投影空间

    • 文章强调了在自注意力和卷积操作中使用共享投影空间的重要性。这种设计使得两种操作能够在相同的隐藏空间中进行,从而有助于整合局部和全局特征表示。
  4. 动态卷积核选择

    • MUSE引入了一种门控机制,用于自动选择不同卷积单元的权重,这允许模型动态地选择最适合当前层的卷积核大小。

代码:

python 复制代码
# ---------------------------------------  
# 论文: MUSE: PARALLEL MULTI-SCALE ATTENTION FOR SEQUENCE TO SEQUENCE LEARNING (arxiv 2019)  
# ---------------------------------------  
import numpy as np  
import torch  
from torch import nn  
from torch.nn import init  
  
  
class Depth_Pointwise_Conv1d(nn.Module):  
    def __init__(self, in_ch, out_ch, k):  
        super().__init__()  
        if (k == 1):  
            self.depth_conv = nn.Identity()  
        else:  
            self.depth_conv = nn.Conv1d(  
                in_channels=in_ch,  
                out_channels=in_ch,  
                kernel_size=k,  
                groups=in_ch,  
                padding=k // 2  
            )  
        self.pointwise_conv = nn.Conv1d(  
            in_channels=in_ch,  
            out_channels=out_ch,  
            kernel_size=1,  
            groups=1  
        )  
  
    def forward(self, x):  
        out = self.pointwise_conv(self.depth_conv(x))  
        return out  
  
  
class MUSEAttention(nn.Module):  
  
    def __init__(self, d_model, d_k, d_v, h, dropout=.1):  
  
        super(MUSEAttention, self).__init__()  
        self.fc_q = nn.Linear(d_model, h * d_k)  
        self.fc_k = nn.Linear(d_model, h * d_k)  
        self.fc_v = nn.Linear(d_model, h * d_v)  
        self.fc_o = nn.Linear(h * d_v, d_model)  
        self.dropout = nn.Dropout(dropout)  
  
        self.conv1 = Depth_Pointwise_Conv1d(h * d_v, d_model, 1)  
        self.conv3 = Depth_Pointwise_Conv1d(h * d_v, d_model, 3)  
        self.conv5 = Depth_Pointwise_Conv1d(h * d_v, d_model, 5)  
        self.dy_paras = nn.Parameter(torch.ones(3))  
        self.softmax = nn.Softmax(-1)  
  
        self.d_model = d_model  
        self.d_k = d_k  
        self.d_v = d_v  
        self.h = h  
  
        self.init_weights()  
  
    def init_weights(self):  
        for m in self.modules():  
            if isinstance(m, nn.Conv2d):  
                init.kaiming_normal_(m.weight, mode='fan_out')  
                if m.bias is not None:  
                    init.constant_(m.bias, 0)  
            elif isinstance(m, nn.BatchNorm2d):  
                init.constant_(m.weight, 1)  
                init.constant_(m.bias, 0)  
            elif isinstance(m, nn.Linear):  
                init.normal_(m.weight, std=0.001)  
                if m.bias is not None:  
                    init.constant_(m.bias, 0)  
  
    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):  
  
        # Self Attention  
        b_s, nq = queries.shape[:2]  
        nk = keys.shape[1]  
  
        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)  
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)  
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)  
  
        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)  
        if attention_weights is not None:  
            att = att * attention_weights  
        if attention_mask is not None:  
            att = att.masked_fill(attention_mask, -np.inf)  
        att = torch.softmax(att, -1)  
        att = self.dropout(att)  
  
        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)  
        out = self.fc_o(out)  # (b_s, nq, d_model)  
  
        v2 = v.permute(0, 1, 3, 2).contiguous().view(b_s, -1, nk)  # bs,dim,n  
        self.dy_paras = nn.Parameter(self.softmax(self.dy_paras))  
        out2 = self.dy_paras[0] * self.conv1(v2) + self.dy_paras[1] * self.conv3(v2) + self.dy_paras[2] * self.conv5(v2)  
        out2 = out2.permute(0, 2, 1)  # bs.n.dim  
  
        out = out + out2  
        return out  
  
  
# 输入 B N C,  输出 B N Cif __name__ == '__main__':  
    block = MUSEAttention(d_model=32, d_k=32, d_v=32, h=8).cuda()  
    print("开始运行")  
    input = torch.rand(3, 64, 32).cuda()  
    output = block(input, input, input)  
    print(input.size(), output.size())  
    print('结束运行')
相关推荐
使一颗心免于哀伤9 小时前
《设计模式之禅》笔记摘录 - 21.状态模式
笔记·设计模式
_落纸2 天前
三大基础无源电子元件——电阻(R)、电感(L)、电容(C)
笔记
Alice-YUE2 天前
【CSS学习笔记3】css特性
前端·css·笔记·html
2303_Alpha2 天前
SpringBoot
笔记·学习
Hello_Embed3 天前
STM32HAL 快速入门(二十):UART 中断改进 —— 环形缓冲区解决数据丢失
笔记·stm32·单片机·学习·嵌入式软件
咸甜适中3 天前
rust语言 (1.88) 学习笔记:客户端和服务器端同在一个项目中
笔记·学习·rust
Grassto3 天前
RAG 从入门到放弃?丐版 demo 实战笔记(go+python)
笔记
Magnetic_h3 天前
【iOS】设计模式复习
笔记·学习·ios·设计模式·objective-c·cocoa
周周记笔记3 天前
学习笔记:第一个Python程序
笔记·学习
丑小鸭是白天鹅3 天前
Kotlin协程详细笔记之切线程和挂起函数
开发语言·笔记·kotlin