Attention Free Transformer (AFT)-2020论文笔记


名称:

Attention Free Transformer (AFT)

来源:

[2105.14103] An Attention Free Transformer

相关工作:

#Approximatingthedotproduct #Sparselocalattention #Contextcompression #Eliminatingdotproductattention #MLPsforvision

创新点:

贡献:

  • 提出了一种全新的注意力机制替代方案,完全摒弃了点积注意力。

  • AFT的计算复杂度与输入长度和特征维度呈线性关系,适用于大规模数据。

  • AFT-local和AFT-conv变体通过引入局部性和空间权重共享,进一步提高了模型的效率和性能。

代码:

python 复制代码
# ---------------------------------------  
# 论文:An Attention Free Transformer (arxiv2021)  
# ---------------------------------------  
import torch  
from torch import nn  
from torch.nn import init  
  
  
class AFT_FULL(nn.Module):  
  
    def __init__(self, d_model, n=49, simple=False):  
  
        super(AFT_FULL, self).__init__()  
        self.fc_q = nn.Linear(d_model, d_model)  
        self.fc_k = nn.Linear(d_model, d_model)  
        self.fc_v = nn.Linear(d_model, d_model)  
        if (simple):  
            self.position_biases = torch.zeros((n, n))  
        else:  
            self.position_biases = nn.Parameter(torch.ones((n, n)))  
        self.d_model = d_model  
        self.n = n  
        self.sigmoid = nn.Sigmoid()  
  
        self.init_weights()  
  
    def init_weights(self):  
        for m in self.modules():  
            if isinstance(m, nn.Conv2d):  
                init.kaiming_normal_(m.weight, mode='fan_out')  
                if m.bias is not None:  
                    init.constant_(m.bias, 0)  
            elif isinstance(m, nn.BatchNorm2d):  
                init.constant_(m.weight, 1)  
                init.constant_(m.bias, 0)  
            elif isinstance(m, nn.Linear):  
                init.normal_(m.weight, std=0.001)  
                if m.bias is not None:  
                    init.constant_(m.bias, 0)  
  
    def forward(self, input):  
  
        bs, n, dim = input.shape  
  
        q = self.fc_q(input)  # bs,n,dim  
        k = self.fc_k(input).view(1, bs, n, dim)  # 1,bs,n,dim  
        v = self.fc_v(input).view(1, bs, n, dim)  # 1,bs,n,dim  
  
        numerator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)) * v, dim=2)  # n,bs,dim  
        denominator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)), dim=2)  # n,bs,dim  
  
        out = (numerator / denominator)  # n,bs,dim  
        out = self.sigmoid(q) * (out.permute(1, 0, 2))  # bs,n,dim  
  
        return out  
  
  
# 输入 B C N,  输出 B C Nif __name__ == '__main__':  
    block = AFT_FULL(d_model=512, n=64).cuda()  
    input = torch.rand(64, 64, 512).cuda()  
    output = block( input)  
    print(input.size(), output.size())
    
相关推荐
CoovallyAIHub1 小时前
SBP-YOLO:面向嵌入式悬架的轻量实时模型,实现减速带与坑洼高精度检测
深度学习·算法·计算机视觉
CoovallyAIHub1 小时前
医药、零件、饮料瓶盖……SuperSimpleNet让质检“即插即用”
深度学习·算法·计算机视觉
跳跳糖炒酸奶1 小时前
第六章、从transformer到nlp大模型:编码器-解码器模型 (Encoder-Decoder)
深度学习·自然语言处理·transformer
大千AI助手2 小时前
VeRL:强化学习与大模型训练的高效融合框架
人工智能·深度学习·神经网络·llm·强化学习·verl·字节跳动seed
初级炼丹师(爱说实话版)3 小时前
2025算法八股——深度学习——优化器小结
人工智能·深度学习·算法
AI算法工程师Moxi3 小时前
人工智能在医学图像中的应用:从机器学习到深度学习
人工智能·深度学习·机器学习
张较瘦_3 小时前
[论文阅读] 人工智能 + 软件工程 | 首个仓库级多任务调试数据集!RepoDebug揭秘LLM真实调试水平
论文阅读·人工智能
CV-杨帆4 小时前
论文阅读:ACL 2023 MEETINGQA: Extractive Question-Answering on Meeting Transcripts
论文阅读
胡耀超5 小时前
大模型架构演进全景:从Transformer到下一代智能系统的技术路径(MoE、Mamba/SSM、混合架构)
人工智能·深度学习·ai·架构·大模型·transformer·技术趋势分析
Gyoku Mint12 小时前
提示词工程(Prompt Engineering)的崛起——为什么“会写Prompt”成了新技能?
人工智能·pytorch·深度学习·神经网络·语言模型·自然语言处理·nlp