每日Attention学习10——Scale-Aware Modulation

模块出处

ICCV 23 link code Scale-Aware Modulation Meet Transformer


模块名称

Scale-Aware Modulation (SAM)


模块作用

改进的自注意力


模块结构

模块代码
python3 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class SAM(nn.Module):
    def __init__(self, dim, ca_num_heads=4, sa_num_heads=8, qkv_bias=False, qk_scale=None,
                       attn_drop=0., proj_drop=0., expand_ratio=2):
        super().__init__()
        self.ca_attention = 1
        self.dim = dim
        self.ca_num_heads = ca_num_heads
        self.sa_num_heads = sa_num_heads
        assert dim % ca_num_heads == 0, f"dim {dim} should be divided by num_heads {ca_num_heads}."
        assert dim % sa_num_heads == 0, f"dim {dim} should be divided by num_heads {sa_num_heads}."
        self.act = nn.GELU()
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.split_groups=self.dim//ca_num_heads
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.s = nn.Linear(dim, dim, bias=qkv_bias)
        for i in range(self.ca_num_heads):
            local_conv = nn.Conv2d(dim//self.ca_num_heads, dim//self.ca_num_heads, kernel_size=(3+i*2), padding=(1+i), stride=1, groups=dim//self.ca_num_heads)
            setattr(self, f"local_conv_{i + 1}", local_conv)
        self.proj0 = nn.Conv2d(dim, dim*expand_ratio, kernel_size=1, padding=0, stride=1, groups=self.split_groups)
        self.bn = nn.BatchNorm2d(dim*expand_ratio)
        self.proj1 = nn.Conv2d(dim*expand_ratio, dim, kernel_size=1, padding=0, stride=1)

    def forward(self, x, H, W):
        # In
        B, N, C = x.shape
        v = self.v(x)
        s = self.s(x).reshape(B, H, W, self.ca_num_heads, C//self.ca_num_heads).permute(3, 0, 4, 1, 2)

        # Multi-Head Mixed Convolution
        for i in range(self.ca_num_heads):
            local_conv = getattr(self, f"local_conv_{i + 1}")
            s_i= s[i]
            s_i = local_conv(s_i).reshape(B, self.split_groups, -1, H, W)
            if i == 0:
                s_out = s_i
            else:
                s_out = torch.cat([s_out,s_i],2)
        s_out = s_out.reshape(B, C, H, W)

        # Scale-Aware Aggregation (SAA)
        s_out = self.proj1(self.act(self.bn(self.proj0(s_out))))
        self.modulator = s_out
        s_out = s_out.reshape(B, C, N).permute(0, 2, 1)
        x = s_out * v

        # Out
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

if __name__ == '__main__':
    x = torch.randn([3, 1024, 256])  # B, N, C
    sam = SAM(dim=256)
    out = sam(x, H=32, W=32)  # H=N*W
    print(out.shape)  # 3, 1024, 256

原文表述

我们提出了一种新颖的卷积调制,称为尺度感知调制 (SAM),它包含两个新模块:多头混合卷积 (MHMC) 和尺度感知聚合 (SAA)。MHMC 模块旨在增强感受野并同时捕获多尺度特征。SAA 模块旨在有效地聚合不同头部之间的特征,同时保持轻量级架构。

相关推荐
cqbzcsq17 天前
CellFlow虚拟细胞论文阅读
论文阅读·人工智能·笔记·学习·生物信息
凌晨一点的秃头猪17 天前
论文阅读 GTI(Graph-based Tree Index): 面向高维空间最近邻搜索的动态图-树混合索引结构
论文阅读
有Li17 天前
PTCMIL:基于提示 token 聚类的全切片图像多实例学习分析文献速递/多模态医学影像最新进展
论文阅读·学习·数据挖掘·聚类·文献·医学生
大模型最新论文速读17 天前
06-16 · LLM 最新论文速览
论文阅读·人工智能·深度学习·机器学习·自然语言处理
墨绿色的摆渡人17 天前
论文笔记(一百三十七)Learning Dual-Arm Push and Grasp Synergy in Dense Clutter
arm开发·论文阅读
Chunyyyen18 天前
【第四十九周】论文阅读
论文阅读
Biomamba生信基地18 天前
NC | 单细胞分析揭示头颈部癌早期转移过程中潜在的免疫逃逸机制(R语言版本)
论文阅读·生物信息学·单细胞rna测序
大模型最新论文速读18 天前
06-15 · LLM 最新论文速览
论文阅读·人工智能·深度学习·自然语言处理
小马哥crazymxm18 天前
Arxiv论文周选 (2026-W24)
论文阅读·人工智能·考研
大模型最新论文速读18 天前
TRUST:RL 时保留模型的不确定性,效果提升 8%
论文阅读·人工智能·深度学习·机器学习·自然语言处理