论文阅读与源码解析:MogaNet

论文阅读与源码解析:MogaNet: Multi-order Gated Aggregation Network

论文地址:https://arxiv.org/pdf/2211.03295

GitHub项目地址:https://github.com/Westlake-AI/MogaNet

源码:https://github.com/Westlake-AI/MogaNet/blob/main/models/moganet.py

Motivation

现代 ConvNets 的表示能力尚未得到充分利用,需要对现代 ConvNet 架构设计进行重新设计。低阶交互倾向于对相对简单和常见的局部视觉概念进行建模,这些概念表达能力较差,无法捕获高级语义模式。相比之下,高阶表示绝对全局范围的复杂概念,但容易受到攻击和泛化性差。因此可以将低阶和高阶结合起来,利用各自的优势形成互补。

Method

模型分为4个layer,里面有许多block,block里面分为两个module,一个是进行空间聚合,另一个对通道进行聚合。整体架构类似于Swin Transformer的设计,在每个layer处理之后对特征进行下采样,同时每个block里面空间聚合模块类似于Swin Transformer里面的Attention模块,通道聚合模块则类似于Swin Transformer里面的MLP模块。

源码解读

Spatial Aggregation

  1. 细粒度的局部纹理(低阶)和复杂的全局形状(中阶),分别由 Conv1×1(·) 和 GAP(·) 实例化。为了迫使网络与其隐式倾斜的交互强度,我们设计了 FD(·) 来自适应地排除琐碎的(被忽视)交互。
  2. 利用不同kernel大小的卷积模块来提取多阶上下文。
python 复制代码
class MultiOrderGatedAggregation(nn.Module):
    """Spatial Block with Multi-order Gated Aggregation.

    Args:
        embed_dims (int): Number of input channels.
        attn_dw_dilation (list): Dilations of three DWConv layers.
        attn_channel_split (list): The raletive ratio of splited channels.
        attn_act_type (str): The activation type for Spatial Block.
            Defaults to 'SiLU'.
    """

    def __init__(self,
                 embed_dims,
                 attn_dw_dilation=[1, 2, 3],
                 attn_channel_split=[1, 3, 4],
                 attn_act_type='SiLU',
                 attn_force_fp32=False,
                ):
        super(MultiOrderGatedAggregation, self).__init__()

        self.embed_dims = embed_dims
        self.attn_force_fp32 = attn_force_fp32
        self.proj_1 = nn.Conv2d(
            in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
        self.gate = nn.Conv2d(
            in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
        self.value = MultiOrderDWConv(
            embed_dims=embed_dims,
            dw_dilation=attn_dw_dilation,
            channel_split=attn_channel_split,
        )
        self.proj_2 = nn.Conv2d(
            in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)

        # activation for gating and value
        self.act_value = build_act_layer(attn_act_type)
        self.act_gate = build_act_layer(attn_act_type)

        # decompose
        self.sigma = ElementScale(
            embed_dims, init_value=1e-5, requires_grad=True)

    def feat_decompose(self, x):
        x = self.proj_1(x)
        # 对特征平均池化,用原来的特征减去池化后的特征,再将差值乘以一个系数加回到原来的特征中去。
        # x_d: [B, C, H, W] -> [B, C, 1, 1]
        x_d = F.adaptive_avg_pool2d(x, output_size=1)
        x = x + self.sigma(x - x_d)
        x = self.act_value(x)
        return x

    def forward_gating(self, g, v):
        with torch.autocast(device_type='cuda', enabled=False):
            g = g.to(torch.float32)
            v = v.to(torch.float32)
            return self.proj_2(self.act_gate(g) * self.act_gate(v))

    def forward(self, x):
        shortcut = x.clone()
        # proj 1x1
        # 首先对特征进行空间特征的多样性处理
        x = self.feat_decompose(x)
        # gating and value branch
        # 然后设计两条路径进行特征处理,一条路径gate只进行非线性激活,另一条路径value对特征提取全局和局部的特征,最后将两个特征进行相乘。
        g = self.gate(x)
        v = self.value(x)
        # aggregation
        if not self.attn_force_fp32:
            x = self.proj_2(self.act_gate(g) * self.act_gate(v))
        else:
            x = self.forward_gating(self.act_gate(g), self.act_gate(v))
        x = x + shortcut
        return x

sigma模块

python 复制代码
class ElementScale(nn.Module):
    """A learnable element-wise scaler."""

    def __init__(self, embed_dims, init_value=0., requires_grad=True):
        super(ElementScale, self).__init__()
        self.scale = nn.Parameter(
            init_value * torch.ones((1, embed_dims, 1, 1)),
            requires_grad=requires_grad
        )

    def forward(self, x):
        return x * self.scale

多阶(kernel)卷积模块:把特征沿通道维度按比例分成三部分,每部分的特征经过不同大小的kernel卷积处理,最后拼接起来。

python 复制代码
class MultiOrderDWConv(nn.Module):
    """Multi-order Features with Dilated DWConv Kernel.

    Args:
        embed_dims (int): Number of input channels.
        dw_dilation (list): Dilations of three DWConv layers.
        channel_split (list): The raletive ratio of three splited channels.
    """

    def __init__(self,
                 embed_dims,
                 dw_dilation=[1, 2, 3,],
                 channel_split=[1, 3, 4,],
                ):
        super(MultiOrderDWConv, self).__init__()

        self.split_ratio = [i / sum(channel_split) for i in channel_split]
        self.embed_dims_1 = int(self.split_ratio[1] * embed_dims)
        self.embed_dims_2 = int(self.split_ratio[2] * embed_dims)
        self.embed_dims_0 = embed_dims - self.embed_dims_1 - self.embed_dims_2
        self.embed_dims = embed_dims
        assert len(dw_dilation) == len(channel_split) == 3
        assert 1 <= min(dw_dilation) and max(dw_dilation) <= 3
        assert embed_dims % sum(channel_split) == 0

        # basic DW conv
        self.DW_conv0 = nn.Conv2d(
            in_channels=self.embed_dims,
            out_channels=self.embed_dims,
            kernel_size=5,
            padding=(1 + 4 * dw_dilation[0]) // 2,
            groups=self.embed_dims,
            stride=1, dilation=dw_dilation[0],
        )
        # DW conv 1
        self.DW_conv1 = nn.Conv2d(
            in_channels=self.embed_dims_1,
            out_channels=self.embed_dims_1,
            kernel_size=5,
            padding=(1 + 4 * dw_dilation[1]) // 2,
            groups=self.embed_dims_1,
            stride=1, dilation=dw_dilation[1],
        )
        # DW conv 2
        self.DW_conv2 = nn.Conv2d(
            in_channels=self.embed_dims_2,
            out_channels=self.embed_dims_2,
            kernel_size=7,
            padding=(1 + 6 * dw_dilation[2]) // 2,
            groups=self.embed_dims_2,
            stride=1, dilation=dw_dilation[2],
        )
        # a channel convolution
        self.PW_conv = nn.Conv2d(  # point-wise convolution
            in_channels=embed_dims,
            out_channels=embed_dims,
            kernel_size=1)

    def forward(self, x):
        x_0 = self.DW_conv0(x)
        x_1 = self.DW_conv1(
            x_0[:, self.embed_dims_0: self.embed_dims_0+self.embed_dims_1, ...])
        x_2 = self.DW_conv2(
            x_0[:, self.embed_dims-self.embed_dims_2:, ...])
        x = torch.cat([
            x_0[:, :self.embed_dims_0, ...], x_1, x_2], dim=1)
        x = self.PW_conv(x)
        return x

轻量级通道聚合模块 CA(·) 来自适应地在高维隐藏空间中重新分配通道级特征,并进一步将其扩展到通道聚合 (CA) 块。

python 复制代码
class ChannelAggregationFFN(nn.Module):
    """An implementation of FFN with Channel Aggregation.

    Args:
        embed_dims (int): The feature dimension. Same as
            `MultiheadAttention`.
        feedforward_channels (int): The hidden dimension of FFNs.
        kernel_size (int): The depth-wise conv kernel size as the
            depth-wise convolution. Defaults to 3.
        act_type (str): The type of activation. Defaults to 'GELU'.
        ffn_drop (float, optional): Probability of an element to be
            zeroed in FFN. Default 0.0.
    """

    def __init__(self,
                 embed_dims,
                 feedforward_channels,
                 kernel_size=3,
                 act_type='GELU',
                 ffn_drop=0.):
        super(ChannelAggregationFFN, self).__init__()

        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels

        self.fc1 = nn.Conv2d(
            in_channels=embed_dims,
            out_channels=self.feedforward_channels,
            kernel_size=1)
        self.dwconv = nn.Conv2d(
            in_channels=self.feedforward_channels,
            out_channels=self.feedforward_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=kernel_size // 2,
            bias=True,
            groups=self.feedforward_channels)
        self.act = build_act_layer(act_type)
        self.fc2 = nn.Conv2d(
            in_channels=feedforward_channels,
            out_channels=embed_dims,
            kernel_size=1)
        self.drop = nn.Dropout(ffn_drop)

        self.decompose = nn.Conv2d(
            in_channels=self.feedforward_channels,  # C -> 1
            out_channels=1, kernel_size=1,
        )
        self.sigma = ElementScale(
            self.feedforward_channels, init_value=1e-5, requires_grad=True)
        self.decompose_act = build_act_layer(act_type)

    def feat_decompose(self, x):
        # x_d: [B, C, H, W] -> [B, 1, H, W]
        # 将特征的通道维度压缩为1,然后进行非线性激活,用原来的特征减去这个特征,形成互补交互,最后把差值乘以一个系数,加到原来的特征里面。
        x = x + self.sigma(x - self.decompose_act(self.decompose(x)))
        return x

    def forward(self, x):
        # proj 1
        # 线性层扩展到较高维度,利用深度卷积提取特征
        x = self.fc1(x)
        x = self.dwconv(x)
        x = self.act(x)
        x = self.drop(x)
        # proj 2
        # 互补交互重新分配通道特征,最后投影到原来的维度
        x = self.feat_decompose(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
相关推荐
计算机软件程序设计5 分钟前
酒店/电影推荐系统里面如何应用深度学习如CNN?
人工智能·深度学习·cnn·textcnn
正在走向自律1 小时前
AI 写作(一):开启创作新纪元(1/10)
人工智能·aigc·ai写作
杨琴12 小时前
基于NVIDIA NIM 平台打造智能AI知识问答系统
人工智能
盼小辉丶2 小时前
内容安全与系统构建加速,助力解决生成式AI时代的双重挑战
人工智能·深度学习·aigc
再不会python就不礼貌了2 小时前
震撼!最强开源模型通义千问2.5 72B竟在4GB老显卡上成功运行!
人工智能·算法·机器学习·chatgpt·产品经理
途途途途3 小时前
Python 给 Excel 写入数据的四种方法
windows·python·excel
走在考研路上4 小时前
Python错误处理
开发语言·python
数据小爬虫@4 小时前
Python爬虫:如何优雅地“偷窥”商品详情
开发语言·爬虫·python
霍格沃兹测试开发学社测试人社区4 小时前
OpenAI Chatgpt 大语言模型
软件测试·人工智能·测试开发·语言模型·chatgpt
闰土_RUNTU4 小时前
Pytorch分布式训练print()使用技巧
人工智能·pytorch·python·分布式训练·训练技巧