论文阅读与源码解析:MogaNet

论文阅读与源码解析:MogaNet: Multi-order Gated Aggregation Network

论文地址:https://arxiv.org/pdf/2211.03295

GitHub项目地址:https://github.com/Westlake-AI/MogaNet

源码:https://github.com/Westlake-AI/MogaNet/blob/main/models/moganet.py

Motivation

现代 ConvNets 的表示能力尚未得到充分利用,需要对现代 ConvNet 架构设计进行重新设计。低阶交互倾向于对相对简单和常见的局部视觉概念进行建模,这些概念表达能力较差,无法捕获高级语义模式。相比之下,高阶表示绝对全局范围的复杂概念,但容易受到攻击和泛化性差。因此可以将低阶和高阶结合起来,利用各自的优势形成互补。

Method

模型分为4个layer,里面有许多block,block里面分为两个module,一个是进行空间聚合,另一个对通道进行聚合。整体架构类似于Swin Transformer的设计,在每个layer处理之后对特征进行下采样,同时每个block里面空间聚合模块类似于Swin Transformer里面的Attention模块,通道聚合模块则类似于Swin Transformer里面的MLP模块。

源码解读

Spatial Aggregation

  1. 细粒度的局部纹理(低阶)和复杂的全局形状(中阶),分别由 Conv1×1(·) 和 GAP(·) 实例化。为了迫使网络与其隐式倾斜的交互强度,我们设计了 FD(·) 来自适应地排除琐碎的(被忽视)交互。
  2. 利用不同kernel大小的卷积模块来提取多阶上下文。
python 复制代码
class MultiOrderGatedAggregation(nn.Module):
    """Spatial Block with Multi-order Gated Aggregation.

    Args:
        embed_dims (int): Number of input channels.
        attn_dw_dilation (list): Dilations of three DWConv layers.
        attn_channel_split (list): The raletive ratio of splited channels.
        attn_act_type (str): The activation type for Spatial Block.
            Defaults to 'SiLU'.
    """

    def __init__(self,
                 embed_dims,
                 attn_dw_dilation=[1, 2, 3],
                 attn_channel_split=[1, 3, 4],
                 attn_act_type='SiLU',
                 attn_force_fp32=False,
                ):
        super(MultiOrderGatedAggregation, self).__init__()

        self.embed_dims = embed_dims
        self.attn_force_fp32 = attn_force_fp32
        self.proj_1 = nn.Conv2d(
            in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
        self.gate = nn.Conv2d(
            in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
        self.value = MultiOrderDWConv(
            embed_dims=embed_dims,
            dw_dilation=attn_dw_dilation,
            channel_split=attn_channel_split,
        )
        self.proj_2 = nn.Conv2d(
            in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)

        # activation for gating and value
        self.act_value = build_act_layer(attn_act_type)
        self.act_gate = build_act_layer(attn_act_type)

        # decompose
        self.sigma = ElementScale(
            embed_dims, init_value=1e-5, requires_grad=True)

    def feat_decompose(self, x):
        x = self.proj_1(x)
        # 对特征平均池化,用原来的特征减去池化后的特征,再将差值乘以一个系数加回到原来的特征中去。
        # x_d: [B, C, H, W] -> [B, C, 1, 1]
        x_d = F.adaptive_avg_pool2d(x, output_size=1)
        x = x + self.sigma(x - x_d)
        x = self.act_value(x)
        return x

    def forward_gating(self, g, v):
        with torch.autocast(device_type='cuda', enabled=False):
            g = g.to(torch.float32)
            v = v.to(torch.float32)
            return self.proj_2(self.act_gate(g) * self.act_gate(v))

    def forward(self, x):
        shortcut = x.clone()
        # proj 1x1
        # 首先对特征进行空间特征的多样性处理
        x = self.feat_decompose(x)
        # gating and value branch
        # 然后设计两条路径进行特征处理,一条路径gate只进行非线性激活,另一条路径value对特征提取全局和局部的特征,最后将两个特征进行相乘。
        g = self.gate(x)
        v = self.value(x)
        # aggregation
        if not self.attn_force_fp32:
            x = self.proj_2(self.act_gate(g) * self.act_gate(v))
        else:
            x = self.forward_gating(self.act_gate(g), self.act_gate(v))
        x = x + shortcut
        return x

sigma模块

python 复制代码
class ElementScale(nn.Module):
    """A learnable element-wise scaler."""

    def __init__(self, embed_dims, init_value=0., requires_grad=True):
        super(ElementScale, self).__init__()
        self.scale = nn.Parameter(
            init_value * torch.ones((1, embed_dims, 1, 1)),
            requires_grad=requires_grad
        )

    def forward(self, x):
        return x * self.scale

多阶(kernel)卷积模块:把特征沿通道维度按比例分成三部分,每部分的特征经过不同大小的kernel卷积处理,最后拼接起来。

python 复制代码
class MultiOrderDWConv(nn.Module):
    """Multi-order Features with Dilated DWConv Kernel.

    Args:
        embed_dims (int): Number of input channels.
        dw_dilation (list): Dilations of three DWConv layers.
        channel_split (list): The raletive ratio of three splited channels.
    """

    def __init__(self,
                 embed_dims,
                 dw_dilation=[1, 2, 3,],
                 channel_split=[1, 3, 4,],
                ):
        super(MultiOrderDWConv, self).__init__()

        self.split_ratio = [i / sum(channel_split) for i in channel_split]
        self.embed_dims_1 = int(self.split_ratio[1] * embed_dims)
        self.embed_dims_2 = int(self.split_ratio[2] * embed_dims)
        self.embed_dims_0 = embed_dims - self.embed_dims_1 - self.embed_dims_2
        self.embed_dims = embed_dims
        assert len(dw_dilation) == len(channel_split) == 3
        assert 1 <= min(dw_dilation) and max(dw_dilation) <= 3
        assert embed_dims % sum(channel_split) == 0

        # basic DW conv
        self.DW_conv0 = nn.Conv2d(
            in_channels=self.embed_dims,
            out_channels=self.embed_dims,
            kernel_size=5,
            padding=(1 + 4 * dw_dilation[0]) // 2,
            groups=self.embed_dims,
            stride=1, dilation=dw_dilation[0],
        )
        # DW conv 1
        self.DW_conv1 = nn.Conv2d(
            in_channels=self.embed_dims_1,
            out_channels=self.embed_dims_1,
            kernel_size=5,
            padding=(1 + 4 * dw_dilation[1]) // 2,
            groups=self.embed_dims_1,
            stride=1, dilation=dw_dilation[1],
        )
        # DW conv 2
        self.DW_conv2 = nn.Conv2d(
            in_channels=self.embed_dims_2,
            out_channels=self.embed_dims_2,
            kernel_size=7,
            padding=(1 + 6 * dw_dilation[2]) // 2,
            groups=self.embed_dims_2,
            stride=1, dilation=dw_dilation[2],
        )
        # a channel convolution
        self.PW_conv = nn.Conv2d(  # point-wise convolution
            in_channels=embed_dims,
            out_channels=embed_dims,
            kernel_size=1)

    def forward(self, x):
        x_0 = self.DW_conv0(x)
        x_1 = self.DW_conv1(
            x_0[:, self.embed_dims_0: self.embed_dims_0+self.embed_dims_1, ...])
        x_2 = self.DW_conv2(
            x_0[:, self.embed_dims-self.embed_dims_2:, ...])
        x = torch.cat([
            x_0[:, :self.embed_dims_0, ...], x_1, x_2], dim=1)
        x = self.PW_conv(x)
        return x

轻量级通道聚合模块 CA(·) 来自适应地在高维隐藏空间中重新分配通道级特征,并进一步将其扩展到通道聚合 (CA) 块。

python 复制代码
class ChannelAggregationFFN(nn.Module):
    """An implementation of FFN with Channel Aggregation.

    Args:
        embed_dims (int): The feature dimension. Same as
            `MultiheadAttention`.
        feedforward_channels (int): The hidden dimension of FFNs.
        kernel_size (int): The depth-wise conv kernel size as the
            depth-wise convolution. Defaults to 3.
        act_type (str): The type of activation. Defaults to 'GELU'.
        ffn_drop (float, optional): Probability of an element to be
            zeroed in FFN. Default 0.0.
    """

    def __init__(self,
                 embed_dims,
                 feedforward_channels,
                 kernel_size=3,
                 act_type='GELU',
                 ffn_drop=0.):
        super(ChannelAggregationFFN, self).__init__()

        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels

        self.fc1 = nn.Conv2d(
            in_channels=embed_dims,
            out_channels=self.feedforward_channels,
            kernel_size=1)
        self.dwconv = nn.Conv2d(
            in_channels=self.feedforward_channels,
            out_channels=self.feedforward_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=kernel_size // 2,
            bias=True,
            groups=self.feedforward_channels)
        self.act = build_act_layer(act_type)
        self.fc2 = nn.Conv2d(
            in_channels=feedforward_channels,
            out_channels=embed_dims,
            kernel_size=1)
        self.drop = nn.Dropout(ffn_drop)

        self.decompose = nn.Conv2d(
            in_channels=self.feedforward_channels,  # C -> 1
            out_channels=1, kernel_size=1,
        )
        self.sigma = ElementScale(
            self.feedforward_channels, init_value=1e-5, requires_grad=True)
        self.decompose_act = build_act_layer(act_type)

    def feat_decompose(self, x):
        # x_d: [B, C, H, W] -> [B, 1, H, W]
        # 将特征的通道维度压缩为1,然后进行非线性激活,用原来的特征减去这个特征,形成互补交互,最后把差值乘以一个系数,加到原来的特征里面。
        x = x + self.sigma(x - self.decompose_act(self.decompose(x)))
        return x

    def forward(self, x):
        # proj 1
        # 线性层扩展到较高维度,利用深度卷积提取特征
        x = self.fc1(x)
        x = self.dwconv(x)
        x = self.act(x)
        x = self.drop(x)
        # proj 2
        # 互补交互重新分配通道特征,最后投影到原来的维度
        x = self.feat_decompose(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
相关推荐
我登哥MVP1 分钟前
VS Code 安装 Claude Code 并接入 DeepSeek V4 Model
人工智能·python·node.js·agent·codex·deepseek·claude code
unique2 分钟前
AI Native 调研报告
人工智能
云烟成雨TD3 分钟前
Spring AI Alibaba 1.x 系列【73】两步 RAG
java·人工智能·spring
ai产品老杨4 分钟前
解耦视频高并发与边缘计算AI布控:基于Docker的高性能安防平台,破局GB28181/RTSP协议兼容与源码交付痛点
人工智能·音视频·边缘计算
CHrisFC5 分钟前
LIMS 系统 AI 建设路径:从自动化到智能化的演进之路
运维·人工智能·自动化
饼干哥哥6 分钟前
一口气搭了300个AI Agents并发处理跨境运营的dirty work
人工智能
AI行业学习6 分钟前
CC‑Switch v3.16.1-下载、配置、安装(2026‑06‑01 最新官方版)
开发语言·人工智能·windows·python
小糖学代码7 分钟前
机器学习:5.深度学习
人工智能·深度学习·机器学习
unity工具人7 分钟前
python+yolov8 图像识别-测试案例
python·opencv·yolo
lipku9 分钟前
LiveTalking 更新:集成 vLLM-Omni TTS服务
python·开源·数字人·vllm·实时数字人