MUSE: PARALLEL MULTI-SCALE ATTENTION FOR SEQUENCE TO SEQUENCE LEARNING 笔记

来源:

MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning

相关工作:

#自注意力机制 #卷积神经网络 #动态卷积 #局部注意力 #并行计算

创新点:

贡献:

  1. 并行多尺度注意力机制

    • MUSE引入了并行多尺度注意力机制,该机制同时捕获序列数据中的长距离和短距离语言结构。这是通过对序列进行不同尺度的并行编码来实现的,利用自注意力和逐点变换。
  2. 结合卷积和自注意力

    • 为了解决自注意力在处理长序列时可能忽略局部信息的问题,MUSE模型将卷积操作和自注意力结合起来,以更有效地学习序列的局部和全局特征。
  3. 共享投影空间

    • 文章强调了在自注意力和卷积操作中使用共享投影空间的重要性。这种设计使得两种操作能够在相同的隐藏空间中进行,从而有助于整合局部和全局特征表示。
  4. 动态卷积核选择

    • MUSE引入了一种门控机制,用于自动选择不同卷积单元的权重,这允许模型动态地选择最适合当前层的卷积核大小。

代码:

python 复制代码
# ---------------------------------------  
# 论文: MUSE: PARALLEL MULTI-SCALE ATTENTION FOR SEQUENCE TO SEQUENCE LEARNING (arxiv 2019)  
# ---------------------------------------  
import numpy as np  
import torch  
from torch import nn  
from torch.nn import init  
  
  
class Depth_Pointwise_Conv1d(nn.Module):  
    def __init__(self, in_ch, out_ch, k):  
        super().__init__()  
        if (k == 1):  
            self.depth_conv = nn.Identity()  
        else:  
            self.depth_conv = nn.Conv1d(  
                in_channels=in_ch,  
                out_channels=in_ch,  
                kernel_size=k,  
                groups=in_ch,  
                padding=k // 2  
            )  
        self.pointwise_conv = nn.Conv1d(  
            in_channels=in_ch,  
            out_channels=out_ch,  
            kernel_size=1,  
            groups=1  
        )  
  
    def forward(self, x):  
        out = self.pointwise_conv(self.depth_conv(x))  
        return out  
  
  
class MUSEAttention(nn.Module):  
  
    def __init__(self, d_model, d_k, d_v, h, dropout=.1):  
  
        super(MUSEAttention, self).__init__()  
        self.fc_q = nn.Linear(d_model, h * d_k)  
        self.fc_k = nn.Linear(d_model, h * d_k)  
        self.fc_v = nn.Linear(d_model, h * d_v)  
        self.fc_o = nn.Linear(h * d_v, d_model)  
        self.dropout = nn.Dropout(dropout)  
  
        self.conv1 = Depth_Pointwise_Conv1d(h * d_v, d_model, 1)  
        self.conv3 = Depth_Pointwise_Conv1d(h * d_v, d_model, 3)  
        self.conv5 = Depth_Pointwise_Conv1d(h * d_v, d_model, 5)  
        self.dy_paras = nn.Parameter(torch.ones(3))  
        self.softmax = nn.Softmax(-1)  
  
        self.d_model = d_model  
        self.d_k = d_k  
        self.d_v = d_v  
        self.h = h  
  
        self.init_weights()  
  
    def init_weights(self):  
        for m in self.modules():  
            if isinstance(m, nn.Conv2d):  
                init.kaiming_normal_(m.weight, mode='fan_out')  
                if m.bias is not None:  
                    init.constant_(m.bias, 0)  
            elif isinstance(m, nn.BatchNorm2d):  
                init.constant_(m.weight, 1)  
                init.constant_(m.bias, 0)  
            elif isinstance(m, nn.Linear):  
                init.normal_(m.weight, std=0.001)  
                if m.bias is not None:  
                    init.constant_(m.bias, 0)  
  
    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):  
  
        # Self Attention  
        b_s, nq = queries.shape[:2]  
        nk = keys.shape[1]  
  
        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)  
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)  
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)  
  
        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)  
        if attention_weights is not None:  
            att = att * attention_weights  
        if attention_mask is not None:  
            att = att.masked_fill(attention_mask, -np.inf)  
        att = torch.softmax(att, -1)  
        att = self.dropout(att)  
  
        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)  
        out = self.fc_o(out)  # (b_s, nq, d_model)  
  
        v2 = v.permute(0, 1, 3, 2).contiguous().view(b_s, -1, nk)  # bs,dim,n  
        self.dy_paras = nn.Parameter(self.softmax(self.dy_paras))  
        out2 = self.dy_paras[0] * self.conv1(v2) + self.dy_paras[1] * self.conv3(v2) + self.dy_paras[2] * self.conv5(v2)  
        out2 = out2.permute(0, 2, 1)  # bs.n.dim  
  
        out = out + out2  
        return out  
  
  
# 输入 B N C,  输出 B N Cif __name__ == '__main__':  
    block = MUSEAttention(d_model=32, d_k=32, d_v=32, h=8).cuda()  
    print("开始运行")  
    input = torch.rand(3, 64, 32).cuda()  
    output = block(input, input, input)  
    print(input.size(), output.size())  
    print('结束运行')
相关推荐
DKPT2 小时前
ZGC和G1收集器相比哪个更好?
java·jvm·笔记·学习·spring
QT 小鲜肉3 小时前
【孙子兵法之上篇】001. 孙子兵法·计篇
笔记·读书·孙子兵法
星轨初途4 小时前
数据结构排序算法详解(5)——非比较函数:计数排序(鸽巢原理)及排序算法复杂度和稳定性分析
c语言·开发语言·数据结构·经验分享·笔记·算法·排序算法
QT 小鲜肉4 小时前
【孙子兵法之上篇】001. 孙子兵法·计篇深度解析与现代应用
笔记·读书·孙子兵法
love530love7 小时前
【笔记】ComfUI RIFEInterpolation 节点缺失问题(cupy CUDA 安装)解决方案
人工智能·windows·笔记·python·插件·comfyui
愚戏师7 小时前
MySQL 数据导出
数据库·笔记·mysql
摇滚侠7 小时前
2025最新 SpringCloud 教程,教程简介,笔记01
笔记·spring cloud
RickyWasYoung9 小时前
【笔记】智能汽车、电动汽车政策文件
笔记·汽车
love530love12 小时前
【保姆级教程】Windows + Podman 从零部署 Duix-Avatar 数字人项目
人工智能·windows·笔记·python·数字人·podman·duix-avatar
草莓熊Lotso13 小时前
《算法闯关指南:动态规划算法--斐波拉契数列模型》--01.第N个泰波拉契数,02.三步问题
开发语言·c++·经验分享·笔记·其他·算法·动态规划