论文阅读与源码解析:MogaNet: Multi-order Gated Aggregation Network
论文地址:https://arxiv.org/pdf/2211.03295
GitHub项目地址:https://github.com/Westlake-AI/MogaNet
源码:https://github.com/Westlake-AI/MogaNet/blob/main/models/moganet.py
Motivation
现代 ConvNets 的表示能力尚未得到充分利用,需要对现代 ConvNet 架构设计进行重新设计。低阶交互倾向于对相对简单和常见的局部视觉概念进行建模,这些概念表达能力较差,无法捕获高级语义模式。相比之下,高阶表示绝对全局范围的复杂概念,但容易受到攻击和泛化性差。因此可以将低阶和高阶结合起来,利用各自的优势形成互补。
Method
模型分为4个layer,里面有许多block,block里面分为两个module,一个是进行空间聚合,另一个对通道进行聚合。整体架构类似于Swin Transformer的设计,在每个layer处理之后对特征进行下采样,同时每个block里面空间聚合模块类似于Swin Transformer里面的Attention模块,通道聚合模块则类似于Swin Transformer里面的MLP模块。
源码解读
Spatial Aggregation
- 细粒度的局部纹理(低阶)和复杂的全局形状(中阶),分别由 Conv1×1(·) 和 GAP(·) 实例化。为了迫使网络与其隐式倾斜的交互强度,我们设计了 FD(·) 来自适应地排除琐碎的(被忽视)交互。
- 利用不同kernel大小的卷积模块来提取多阶上下文。
python
class MultiOrderGatedAggregation(nn.Module):
"""Spatial Block with Multi-order Gated Aggregation.
Args:
embed_dims (int): Number of input channels.
attn_dw_dilation (list): Dilations of three DWConv layers.
attn_channel_split (list): The raletive ratio of splited channels.
attn_act_type (str): The activation type for Spatial Block.
Defaults to 'SiLU'.
"""
def __init__(self,
embed_dims,
attn_dw_dilation=[1, 2, 3],
attn_channel_split=[1, 3, 4],
attn_act_type='SiLU',
attn_force_fp32=False,
):
super(MultiOrderGatedAggregation, self).__init__()
self.embed_dims = embed_dims
self.attn_force_fp32 = attn_force_fp32
self.proj_1 = nn.Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
self.gate = nn.Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
self.value = MultiOrderDWConv(
embed_dims=embed_dims,
dw_dilation=attn_dw_dilation,
channel_split=attn_channel_split,
)
self.proj_2 = nn.Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
# activation for gating and value
self.act_value = build_act_layer(attn_act_type)
self.act_gate = build_act_layer(attn_act_type)
# decompose
self.sigma = ElementScale(
embed_dims, init_value=1e-5, requires_grad=True)
def feat_decompose(self, x):
x = self.proj_1(x)
# 对特征平均池化,用原来的特征减去池化后的特征,再将差值乘以一个系数加回到原来的特征中去。
# x_d: [B, C, H, W] -> [B, C, 1, 1]
x_d = F.adaptive_avg_pool2d(x, output_size=1)
x = x + self.sigma(x - x_d)
x = self.act_value(x)
return x
def forward_gating(self, g, v):
with torch.autocast(device_type='cuda', enabled=False):
g = g.to(torch.float32)
v = v.to(torch.float32)
return self.proj_2(self.act_gate(g) * self.act_gate(v))
def forward(self, x):
shortcut = x.clone()
# proj 1x1
# 首先对特征进行空间特征的多样性处理
x = self.feat_decompose(x)
# gating and value branch
# 然后设计两条路径进行特征处理,一条路径gate只进行非线性激活,另一条路径value对特征提取全局和局部的特征,最后将两个特征进行相乘。
g = self.gate(x)
v = self.value(x)
# aggregation
if not self.attn_force_fp32:
x = self.proj_2(self.act_gate(g) * self.act_gate(v))
else:
x = self.forward_gating(self.act_gate(g), self.act_gate(v))
x = x + shortcut
return x
sigma模块
python
class ElementScale(nn.Module):
"""A learnable element-wise scaler."""
def __init__(self, embed_dims, init_value=0., requires_grad=True):
super(ElementScale, self).__init__()
self.scale = nn.Parameter(
init_value * torch.ones((1, embed_dims, 1, 1)),
requires_grad=requires_grad
)
def forward(self, x):
return x * self.scale
多阶(kernel)卷积模块:把特征沿通道维度按比例分成三部分,每部分的特征经过不同大小的kernel卷积处理,最后拼接起来。
python
class MultiOrderDWConv(nn.Module):
"""Multi-order Features with Dilated DWConv Kernel.
Args:
embed_dims (int): Number of input channels.
dw_dilation (list): Dilations of three DWConv layers.
channel_split (list): The raletive ratio of three splited channels.
"""
def __init__(self,
embed_dims,
dw_dilation=[1, 2, 3,],
channel_split=[1, 3, 4,],
):
super(MultiOrderDWConv, self).__init__()
self.split_ratio = [i / sum(channel_split) for i in channel_split]
self.embed_dims_1 = int(self.split_ratio[1] * embed_dims)
self.embed_dims_2 = int(self.split_ratio[2] * embed_dims)
self.embed_dims_0 = embed_dims - self.embed_dims_1 - self.embed_dims_2
self.embed_dims = embed_dims
assert len(dw_dilation) == len(channel_split) == 3
assert 1 <= min(dw_dilation) and max(dw_dilation) <= 3
assert embed_dims % sum(channel_split) == 0
# basic DW conv
self.DW_conv0 = nn.Conv2d(
in_channels=self.embed_dims,
out_channels=self.embed_dims,
kernel_size=5,
padding=(1 + 4 * dw_dilation[0]) // 2,
groups=self.embed_dims,
stride=1, dilation=dw_dilation[0],
)
# DW conv 1
self.DW_conv1 = nn.Conv2d(
in_channels=self.embed_dims_1,
out_channels=self.embed_dims_1,
kernel_size=5,
padding=(1 + 4 * dw_dilation[1]) // 2,
groups=self.embed_dims_1,
stride=1, dilation=dw_dilation[1],
)
# DW conv 2
self.DW_conv2 = nn.Conv2d(
in_channels=self.embed_dims_2,
out_channels=self.embed_dims_2,
kernel_size=7,
padding=(1 + 6 * dw_dilation[2]) // 2,
groups=self.embed_dims_2,
stride=1, dilation=dw_dilation[2],
)
# a channel convolution
self.PW_conv = nn.Conv2d( # point-wise convolution
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=1)
def forward(self, x):
x_0 = self.DW_conv0(x)
x_1 = self.DW_conv1(
x_0[:, self.embed_dims_0: self.embed_dims_0+self.embed_dims_1, ...])
x_2 = self.DW_conv2(
x_0[:, self.embed_dims-self.embed_dims_2:, ...])
x = torch.cat([
x_0[:, :self.embed_dims_0, ...], x_1, x_2], dim=1)
x = self.PW_conv(x)
return x
轻量级通道聚合模块 CA(·) 来自适应地在高维隐藏空间中重新分配通道级特征,并进一步将其扩展到通道聚合 (CA) 块。
python
class ChannelAggregationFFN(nn.Module):
"""An implementation of FFN with Channel Aggregation.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
feedforward_channels (int): The hidden dimension of FFNs.
kernel_size (int): The depth-wise conv kernel size as the
depth-wise convolution. Defaults to 3.
act_type (str): The type of activation. Defaults to 'GELU'.
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
"""
def __init__(self,
embed_dims,
feedforward_channels,
kernel_size=3,
act_type='GELU',
ffn_drop=0.):
super(ChannelAggregationFFN, self).__init__()
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.fc1 = nn.Conv2d(
in_channels=embed_dims,
out_channels=self.feedforward_channels,
kernel_size=1)
self.dwconv = nn.Conv2d(
in_channels=self.feedforward_channels,
out_channels=self.feedforward_channels,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
bias=True,
groups=self.feedforward_channels)
self.act = build_act_layer(act_type)
self.fc2 = nn.Conv2d(
in_channels=feedforward_channels,
out_channels=embed_dims,
kernel_size=1)
self.drop = nn.Dropout(ffn_drop)
self.decompose = nn.Conv2d(
in_channels=self.feedforward_channels, # C -> 1
out_channels=1, kernel_size=1,
)
self.sigma = ElementScale(
self.feedforward_channels, init_value=1e-5, requires_grad=True)
self.decompose_act = build_act_layer(act_type)
def feat_decompose(self, x):
# x_d: [B, C, H, W] -> [B, 1, H, W]
# 将特征的通道维度压缩为1,然后进行非线性激活,用原来的特征减去这个特征,形成互补交互,最后把差值乘以一个系数,加到原来的特征里面。
x = x + self.sigma(x - self.decompose_act(self.decompose(x)))
return x
def forward(self, x):
# proj 1
# 线性层扩展到较高维度,利用深度卷积提取特征
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
# proj 2
# 互补交互重新分配通道特征,最后投影到原来的维度
x = self.feat_decompose(x)
x = self.fc2(x)
x = self.drop(x)
return x