各种注意力机制,Attention、MLP、ReP等系列的PyTorch实现,含核心代码

不知道CV方向的同学在读论文的时候有没有发现这样一个问题:论文的核心思想很简单,但当你找这篇论文的核心代码时发现,作者提供的源码模块会嵌入到分类、检测、分割等任务框架中,这时候如果你对某一特定框架不熟悉,尽管核心代码只有十几行,依然会发现很难找出。

今天我就帮大家解决一部分这个问题,还记得上次分享的attention论文合集吗?没印象的同学点这里。

这次总结了这30篇attention论文中的核心代码分享,还有一部分其他系列的论文,比如ReP、卷积级数等,核心代码与原文都整理了。

由于篇幅和时间原因,暂时只分享了一部分,需要全部论文以及完整核心代码的同学看文末

Attention论文

1、Axial Attention in Multidimensional Transformers
核心代码
from model.attention.Axial_attention import AxialImageTransformer
import torch

if __name__ == '__main__':
    input=torch.randn(3, 128, 7, 7)
    model = AxialImageTransformer(
        dim = 128,
        depth = 12,
        reversible = True
    )
    outputs = model(input)
    print(outputs.shape)
2、CCNet: Criss-Cross Attention for Semantic Segmentation
核心代码
from model.attention.CrissCrossAttention import CrissCrossAttention
import torch

if __name__ == '__main__':
    input=torch.randn(3, 64, 7, 7)
    model = CrissCrossAttention(64)
    outputs = model(input)
    print(outputs.shape)
3、Aggregating Global Features into Local Vision Transformer
核心代码
from model.attention.MOATransformer import MOATransformer
import torch

if __name__ == '__main__':
    input=torch.randn(1,3,224,224)
    model = MOATransformer(
        img_size=224,
        patch_size=4,
        in_chans=3,
        num_classes=1000,
        embed_dim=96,
        depths=[2, 2, 6],
        num_heads=[3, 6, 12],
        window_size=14,
        mlp_ratio=4.,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        drop_path_rate=0.1,
        ape=False,
        patch_norm=True,
        use_checkpoint=False
    )
    output=model(input)
    print(output.shape)
4、CROSSFORMER: A VERSATILE VISION TRANSFORMER HINGING ON CROSS-SCALE ATTENTION
核心代码
from model.attention.Crossformer import CrossFormer
import torch

if __name__ == '__main__':
    input=torch.randn(1,3,224,224)
    model = CrossFormer(img_size=224,
        patch_size=[4, 8, 16, 32],
        in_chans= 3,
        num_classes=1000,
        embed_dim=48,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        group_size=[7, 7, 7, 7],
        mlp_ratio=4.,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        drop_path_rate=0.1,
        ape=False,
        patch_norm=True,
        use_checkpoint=False,
        merge_size=[[2, 4], [2,4], [2, 4]]
    )
    output=model(input)
    print(output.shape)
5、Vision Transformer with Deformable Attention
核心代码
from model.attention.DAT import DAT
import torch

if __name__ == '__main__':
    input=torch.randn(1,3,224,224)
    model = DAT(
        img_size=224,
        patch_size=4,
        num_classes=1000,
        expansion=4,
        dim_stem=96,
        dims=[96, 192, 384, 768],
        depths=[2, 2, 6, 2],
        stage_spec=[['L', 'S'], ['L', 'S'], ['L', 'D', 'L', 'D', 'L', 'D'], ['L', 'D']],
        heads=[3, 6, 12, 24],
        window_sizes=[7, 7, 7, 7] ,
        groups=[-1, -1, 3, 6],
        use_pes=[False, False, True, True],
        dwc_pes=[False, False, False, False],
        strides=[-1, -1, 1, 1],
        sr_ratios=[-1, -1, -1, -1],
        offset_range_factor=[-1, -1, 2, 2],
        no_offs=[False, False, False, False],
        fixed_pes=[False, False, False, False],
        use_dwc_mlps=[False, False, False, False],
        use_conv_patches=False,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.2,
    )
    output=model(input)
    print(output[0].shape)
6、Separable Self-attention for Mobile Vision Transformers
核心代码
from model.attention.MobileViTv2Attention import MobileViTv2Attention
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
    input=torch.randn(50,49,512)
    sa = MobileViTv2Attention(d_model=512)
    output=sa(input)
    print(output.shape)
7、On the Integration of Self-Attention and Convolution
核心代码
from model.attention.ACmix import ACmix
import torch

if __name__ == '__main__':
    input=torch.randn(50,256,7,7)
    acmix = ACmix(in_planes=256, out_planes=256)
    output=acmix(input)
    print(output.shape)
8、Non-deep Networks
核心代码
from model.attention.ParNetAttention import *
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    pna = ParNetAttention(channel=512)
    output=pna(input)
    print(output.shape) #50,512,7,7
9、UFO-ViT: High Performance Linear Vision Transformer without Softmax
核心代码
from model.attention.UFOAttention import *
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
    input=torch.randn(50,49,512)
    ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)
    output=ufo(input,input,input)
    print(output.shape) #[50, 49, 512]
10、Coordinate Attention for Efficient Mobile Network Design
核心代码
from model.attention.CoordAttention import CoordAtt
import torch
from torch import nn
from torch.nn import functional as F

inp=torch.rand([2, 96, 56, 56])
inp_dim, oup_dim = 96, 96
reduction=32

coord_attention = CoordAtt(inp_dim, oup_dim, reduction=reduction)
output=coord_attention(inp)
print(output.shape)

ReP论文

1、RepVGG: Making VGG-style ConvNets Great Again
核心代码
from model.rep.repvgg import RepBlock
import torch


input=torch.randn(50,512,49,49)
repblock=RepBlock(512,512)
repblock.eval()
out=repblock(input)
repblock._switch_to_deploy()
out2=repblock(input)
print('difference between vgg and repvgg')
print(((out2-out)**2).sum())
2、ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks
核心代码
from model.rep.acnet import ACNet
import torch
from torch import nn

input=torch.randn(50,512,49,49)
acnet=ACNet(512,512)
acnet.eval()
out=acnet(input)
acnet._switch_to_deploy()
out2=acnet(input)
print('difference:')
print(((out2-out)**2).sum())

卷积级数论文

1、CondConv: Conditionally Parameterized Convolutions for Efficient Inference
核心代码
from model.conv.CondConv import *
import torch
from torch import nn
from torch.nn import functional as F





if __name__ == '__main__':
    input=torch.randn(2,32,64,64)
    m=CondConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)
    out=m(input)
    print(out.shape)
2、Dynamic Convolution: Attention over Convolution Kernels
核心代码
from model.conv.DynamicConv import *
import torch
from torch import nn
from torch.nn import functional as F

if __name__ == '__main__':
    input=torch.randn(2,32,64,64)
    m=DynamicConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)
    out=m(input)
    print(out.shape) # 2,32,64,64
3、Involution: Inverting the Inherence of Convolution for Visual Recognition
核心代码
from model.conv.Involution import Involution
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(1,4,64,64)
involution=Involution(kernel_size=3,in_channel=4,stride=2)
out=involution(input)
print(out.shape)

关注下方《学姐带你玩AI》🚀🚀🚀

回复"核心代码"获取全部论文+代码合集

码字不易,欢迎大家点赞评论收藏!

相关推荐
盼小辉丶6 分钟前
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
深度学习·神经网络·tensorflow
起名字什么的好难12 分钟前
conda虚拟环境安装pytorch gpu版
人工智能·pytorch·conda
18号房客19 分钟前
计算机视觉-人工智能(AI)入门教程一
人工智能·深度学习·opencv·机器学习·计算机视觉·数据挖掘·语音识别
百家方案21 分钟前
「下载」智慧产业园区-数字孪生建设解决方案:重构产业全景图,打造虚实结合的园区数字化底座
大数据·人工智能·智慧园区·数智化园区
云起无垠27 分钟前
“AI+Security”系列第4期(一)之“洞” 见未来:AI 驱动的漏洞挖掘新范式
人工智能
QQ_7781329741 小时前
基于深度学习的图像超分辨率重建
人工智能·机器学习·超分辨率重建
清 晨1 小时前
Web3 生态全景:创新与发展之路
人工智能·web3·去中心化·智能合约
公众号Codewar原创作者1 小时前
R数据分析:工具变量回归的做法和解释,实例解析
开发语言·人工智能·python
IT古董2 小时前
【漫话机器学习系列】020.正则化强度的倒数C(Inverse of regularization strength)
人工智能·机器学习
进击的小小学生2 小时前
机器学习连载
人工智能·机器学习