🔥 AI 即插即用 | 你的CV涨点模块"军火库"已开源!🔥
为了方便大家在CV科研和项目中高效涨点,我创建并维护了一个即插即用模块的GitHub代码仓库。
仓库里不仅有:
- 核心模块即插即用代码
- 论文精读总结
- 架构图深度解析
更有海量SOTA模型的创新模块汇总,致力于打造一个"AI即插即用"的百宝箱,方便大家快速实验、组合创新!
🚀 GitHub 仓库链接 :https://github.com/AITricks/AITricks
觉得有帮助的话,欢迎大家 Star, Fork, PR 一键三连,共同维护!
即插即用涨点系列(十四)2025 SOTA | Efficient ViM:基于"隐状态混合SSD"与"多阶段融合"的轻量级视觉 Mamba 新标杆 (含原理+代码)
论文原文 (Paper) :https://arxiv.org/pdf/2411.15241
官方代码 (Code) :https://github.com/mlvlab/EfficientViM
Efficient ViM: 基于隐状态混合器状态空间对偶的高效视觉 Mamba
1. 核心思想
本文针对现有轻量级视觉模型中 Attention 机制计算成本高以及标准 Mamba 模型在视觉任务中仍存在计算冗余的问题,提出了一种名为 Efficient ViM 的新型架构。其核心创新在于提出了 基于隐状态混合器的状态空间对偶(HSM-SSD) 算法,该算法通过在压缩的隐状态空间而非原始高维序列空间中执行通道混合(Channel Mixing)和门控操作,显著降低了线性投影的计算复杂度。此外,论文还引入了单头设计以减少内存访问瓶颈,并提出了多阶段隐状态融合(MSF)策略以增强模型的特征表示能力。最终,Efficient ViM 在 ImageNet-1k 上实现了新的速度与精度的 SOTA 权衡,优于 MobileNetV3 和最新的 SHVIT 等模型。
2. 背景与动机
-
文本角度总结 :
在资源受限的边缘设备上部署神经网络需要极高的效率。早期的轻量级模型(如 MobileNet)主要依赖卷积(CNN)来提取局部特征,而随后的混合架构引入了 Vision Transformer (ViT) 来捕捉全局依赖,但 Attention 的二次方复杂度 O ( L 2 ) O(L^2) O(L2) 限制了其效率。尽管状态空间模型(SSM,如 Mamba)提供了线性的全局建模能力 O ( L ) O(L) O(L),但现有的视觉 Mamba(如 Vim, VMamba)在推理速度上仍不如高度优化的轻量级 CNN。作者观察到,Mamba2 中引入的状态空间对偶(SSD)层,其运行时间瓶颈主要来自于对输入序列进行的线性投影操作(Linear Projections),这导致了不必要的计算开销。因此,本文旨在设计一种改进的 SSD 层,在保持全局感受野的同时,大幅削减这些线性投影的成本。
-
动机图解分析:
-
图 1(Figure 1):速度-精度权衡图
- 现象:该图展示了 ImageNet-1k 上各轻量级模型的 Top-1 准确率与吞吐量(Throughput)的关系。
- 分析:Efficient ViM(图中的红色和蓝色五角星)位于帕累托前沿的最左上方,这意味着在相同的速度下它精度最高,或在相同精度下速度最快。例如,相比于经典的 MobileNetV3,Efficient ViM 实现了 80% 的速度提升且精度更高;相比于最新的 SOTA 模型 SHVIT,也有显著优势。这直观地证明了现有方法在效率上仍有提升空间,而 Efficient ViM 成功突破了这一瓶颈。
-
图 2(Figure 2):NC-SSD 与 HSM-SSD 的复杂度对比

- 对比 :左图 (a) 是标准的非因果 SSD (NC-SSD),其中的红色块代表线性层,其复杂度与序列长度 L L L 和通道数 D D D 相关,即 O ( L D 2 ) O(LD^2) O(LD2)。右图 (b) 是本文提出的 HSM-SSD。
- 核心问题 :作者指出 NC-SSD 中主要的计算量浪费在对全长序列 x x x 进行投影。
- 解决方案 :HSM-SSD 将原本在 L L L 维度进行的昂贵操作(红色块),转移到了压缩后的隐状态 N N N 维度(橙色块),即 O ( N D 2 ) O(ND^2) O(ND2)。由于状态数 N N N 远小于序列长度 L L L,这直接解决了计算效率瓶颈。
-
图 3(Figure 3):运行时分解图
- 现象:该图分析了多头(Multi-Head)设计带来的内存开销。左饼图显示多头设计中,"Copy & Reshape"(橙色部分)占据了 25.2% 的时间。
- 分析:这揭示了理论 FLOPs 低并不代表实际推理快。多头机制引入了大量的内存读写操作(Memory-bound)。因此,本文改为单头设计(右饼图),将内存操作占比降至 5.1%,进一步提升了实际推理速度。
-
3. 主要贡献点
-
[贡献点 1]:提出了基于隐状态混合器的 SSD (HSM-SSD)
为了解决标准 SSD 层中线性投影成本过高的问题,作者设计了 HSM-SSD。它利用隐状态作为输入的压缩潜在表示,在隐状态空间内执行门控和线性投影操作。这一设计将核心计算复杂度从与序列长度相关 O ( L D 2 ) O(LD^2) O(LD2) 降低为与状态数相关 O ( N D 2 ) O(ND^2) O(ND2),在处理高分辨率图像时优势尤为明显。
-
[贡献点 2]:针对硬件效率优化的单头设计与网络架构
作者分析发现,传统的多头设计在轻量级模型中会带来沉重的内存访问负担(Memory-bound)。因此,Efficient ViM 采用了单头 HSM-SSD 设计,通过消除张量重塑和拷贝操作来最大化实际吞吐量,同时通过引入状态级的重要性权重来弥补去除多头带来的表征能力损失。
-
[贡献点 3]:多阶段隐状态融合 (Multi-stage Hidden State Fusion, MSF)
为了进一步提升模型性能,作者提出了一种融合机制,利用网络不同阶段(Stage)的隐状态来辅助最终的预测。通过对各阶段隐状态进行加权求和并参与 Logit 计算,该机制增强了隐状态的表征能力,并丰富了多尺度特征的利用,在不显著增加推理延迟的情况下提升了准确率。
4. 方法细节
-
整体网络架构:

- 宏观流程(Figure 4 左) :
Efficient ViM 采用了标准的分层金字塔结构。- Stem Layer :输入图像( H × W × 3 H \times W \times 3 H×W×3)首先经过 Stem 层(由四个 3 × 3 3 \times 3 3×3 卷积组成,步长为 2),将分辨率下采样为 H / 16 × W / 16 H/16 \times W/16 H/16×W/16。
- Stages 1-3 :随后进入三个主要的 Stage。每个 Stage 包含堆叠的 Efficient ViM Block。
- Patch Merging:在 Stage 之间,使用下采样层(Patch Merging)降低分辨率并增加通道数,构建分层特征。
- Head (MSF) :在最后输出阶段,引入了 多阶段隐状态融合 (MSF) 模块,将各阶段的隐状态聚合,与最终的特征图一起用于分类预测。
- 宏观流程(Figure 4 左) :
-
核心创新模块详解(HSM-SSD Layer - Figure 2b & Algorithm 1):
-
模块 A:隐状态混合器 (Hidden State Mixer, HSM)
- 内部结构 :这是本论文最核心的算子。不同于标准 SSD 先计算输出 y y y 再进行线性投影,HSM-SSD 先计算隐状态 h h h。
- 数据流 :
- 输入投影与离散化 :输入 x i n x_{in} xin 经过轻量级线性层生成参数 B , C , Δ B, C, \Delta B,C,Δ。
- 隐状态生成 :计算初始隐状态 h i n = ( A ⊙ B ) T x i n h_{in} = (A \odot B)^T x_{in} hin=(A⊙B)Txin。这一步将维度从 L L L(序列长度)压缩到了 N N N(状态数,通常 N ≪ L N \ll L N≪L)。
- 隐状态混合 (HSM) :在 h i n h_{in} hin 上进行核心的通道混合操作。具体为: h = Linear ( h i n ⊙ σ ( Linear ( h i n ) ) ) h = \text{Linear}(h_{in} \odot \sigma(\text{Linear}(h_{in}))) h=Linear(hin⊙σ(Linear(hin)))。这里包含了一个门控机制( σ \sigma σ)和线性投影。
- 输出生成 :最后通过 x o u t = C h x_{out} = C h xout=Ch 将更新后的隐状态投影回原始序列空间。
- 设计目的 :通过在 N N N 维度而非 L L L 维度进行密集的矩阵乘法,极大地减少了 FLOPs,同时利用隐状态的全局压缩特性捕捉上下文信息。
-
模块 B:Efficient ViM Block (Figure 4 右)
- 内部结构:类似于 Transformer Block,由两个子模块组成:HSM-SSD 模块和前馈网络(FFN)。
- 数据流 :
- 局部特征提取 :输入首先经过一个 3 × 3 3 \times 3 3×3 深度卷积(DWConv),用于捕获局部空间信息(这也是轻量级模型的标准操作)。
- 全局特征提取 :经过 LayerNorm 后,数据进入 HSM-SSD 层,负责捕捉全局依赖。
- 通道交互 :最后经过另一个 3 × 3 3 \times 3 3×3 DWConv 和 FFN(由两个 1 × 1 1 \times 1 1×1 卷积组成),进行通道间的信息交互。
- 残差连接:每个子模块都配有残差连接。
- 设计理念:结合卷积的局部归纳偏置和 SSM 的线性全局建模能力,同时保证推理速度。
-
-
理念与机制总结:
- 核心理念 :"在压缩空间做昂贵运算" 。标准 Attention 或 SSD 在 Token 数量巨大的序列空间做混合,成本高昂。HSM-SSD 认为隐状态 h h h 本身就是一种对输入的紧凑总结,因此在 h h h 上做混合既能捕捉全局信息,又能大幅降低计算量。
- 公式解读 :
原 SSD 输出: x o u t = ( C h ⊙ σ ( z ) ) W o u t x_{out} = (Ch \odot \sigma(z)) W_{out} xout=(Ch⊙σ(z))Wout (在 L L L 空间运算)。
HSM-SSD 近似: x o u t ≈ C ( ( h ⊙ σ ( h W z ) ) W o u t ) x_{out} \approx C ((h \odot \sigma(h W_z)) W_{out}) xout≈C((h⊙σ(hWz))Wout) (在 N N N 空间运算)。
这在数学上利用了线性变换的结合律,将计算转移到了维度更小的中间变量上。
-
图解总结 :
结合 Figure 2 和 Figure 4,Efficient ViM 通过在 Block 内部替换掉昂贵的 NC-SSD 为 HSM-SSD,并在网络末端通过虚线连接(MSF)利用中间层的隐状态。这种设计使得数据流在主干网络中保持高效流动(由 HSM 加速),同时在最后汇聚多层级信息以保证精度,完美解决了"全局建模成本高"和"轻量级模型表征弱"的矛盾。
5. 即插即用模块的作用
本文提出的技术具有很好的通用性,可以作为即插即用的模块应用于其他轻量级架构设计中:
-
HSM-SSD 模块 (Hidden State Mixer-based SSD)
- 适用场景 :任何需要线性复杂度全局建模的轻量级视觉任务(分类、检测、分割)。
- 具体应用 :
- 移动端主干网络替换 :可以直接替换 MobileNet 或 EdgeViT 中的 Attention 模块或大核卷积模块,显著降低 FLOPs 和推理延迟,特别是在处理高分辨率输入(如 51 2 2 512^2 5122 或更大)时优势巨大。
- 实时语义分割:在编码器阶段使用 HSM-SSD 替代 Self-Attention,能够以极低的计算成本提供全局感受野,这对于分割任务中的上下文理解至关重要。
-
多阶段隐状态融合 (MSF) 策略
- 适用场景:基于 SSM 或 RNN 的层级式网络架构。
- 具体应用 :
- 辅助监督/特征增强 :在任何基于 Mamba 或 LSTM 的视觉模型中,提取各 Stage 的隐状态 h h h,通过简单的加权平均和线性层生成 Logits 并融合到最终输出中。这是一种几乎零成本(仅增加少量参数,几乎不增加推理计算量)的涨点技巧,可增强模型的泛化能力。
-
单头 SSD 设计 (Single-Head Design)
- 适用场景 :受限于内存带宽的边缘设备部署。
- 具体应用 :
- FPGA/移动端加速:如果发现模型在特定硬件上的推理瓶颈在于内存读写(Memory-bound)而非计算(Compute-bound),可以将多头注意力或多头 SSM 替换为这种单头加状态级权重(State-wise Importance)的设计,以减少 Tensor 的 Reshape 和 Copy 操作。
6.即插即用模块
python
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import trunc_normal_
from timm.models import register_model
from fvcore.nn import flop_count
from .utils import LayerNorm1D, LayerNorm2D, ConvLayer1D, ConvLayer2D, FFN, Stem, PatchMerging
class HSMSSD(nn.Module):
def __init__(self, d_model, ssd_expand=1, A_init_range=(1, 16), state_dim = 64):
super().__init__()
self.ssd_expand = ssd_expand
self.d_inner = int(self.ssd_expand * d_model)
self.state_dim = state_dim
self.BCdt_proj = ConvLayer1D(d_model, 3*state_dim, 1, norm=None, act_layer=None)
conv_dim = self.state_dim*3
self.dw = ConvLayer2D(conv_dim, conv_dim, 3,1,1, groups=conv_dim, norm=None, act_layer=None, bn_weight_init=0)
self.hz_proj = ConvLayer1D(d_model, 2*self.d_inner, 1, norm=None, act_layer=None)
self.out_proj = ConvLayer1D(self.d_inner, d_model, 1, norm=None, act_layer=None, bn_weight_init=0)
A = torch.empty(self.state_dim, dtype=torch.float32).uniform_(*A_init_range)
self.A = torch.nn.Parameter(A)
self.act = nn.SiLU()
self.D = nn.Parameter(torch.ones(1))
self.D._no_weight_decay = True
def forward(self, x):
batch, _, L= x.shape
H = int(math.sqrt(L))
BCdt = self.dw(self.BCdt_proj(x).view(batch,-1, H, H)).flatten(2)
B,C,dt = torch.split(BCdt, [self.state_dim, self.state_dim, self.state_dim], dim=1)
A = (dt + self.A.view(1,-1,1)).softmax(-1)
AB = (A * B)
h = x @ AB.transpose(-2,-1)
h, z = torch.split(self.hz_proj(h), [self.d_inner, self.d_inner], dim=1)
h = self.out_proj(h * self.act(z)+ h * self.D)
y = h @ C # B C N, B C L -> B C L
y = y.view(batch,-1,H,H).contiguous()# + x * self.D # B C H W
return y, h
class EfficientViMBlock(nn.Module):
def __init__(self, dim, mlp_ratio=4., ssd_expand=1, state_dim=64):
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.mixer = HSMSSD(d_model=dim, ssd_expand=ssd_expand,state_dim=state_dim)
self.norm = LayerNorm1D(dim)
self.dwconv1 = ConvLayer2D(dim, dim, 3, padding=1, groups=dim, bn_weight_init=0, act_layer = None)
self.dwconv2 = ConvLayer2D(dim, dim, 3, padding=1, groups=dim, bn_weight_init=0, act_layer = None)
self.ffn = FFN(in_dim=dim, dim=int(dim * mlp_ratio))
#LayerScale
self.alpha = nn.Parameter(1e-4 * torch.ones(4,dim), requires_grad=True)
def forward(self, x):
alpha = torch.sigmoid(self.alpha).view(4,-1,1,1)
# DWconv1
x = (1-alpha[0]) * x + alpha[0] * self.dwconv1(x)
# HSM-SSD
x_prev = x
x, h = self.mixer(self.norm(x.flatten(2)))
x = (1-alpha[1]) * x_prev + alpha[1] * x
# DWConv2
x = (1-alpha[2]) * x + alpha[2] * self.dwconv2(x)
# FFN
x = (1-alpha[3]) * x + alpha[3] * self.ffn(x)
return x, h
class EfficientViMStage(nn.Module):
def __init__(self, in_dim, out_dim, depth, mlp_ratio=4.,downsample=None, ssd_expand=1, state_dim=64):
super().__init__()
self.depth = depth
self.blocks = nn.ModuleList([
EfficientViMBlock(dim=in_dim, mlp_ratio=mlp_ratio, ssd_expand=ssd_expand, state_dim=state_dim) for _ in range(depth)])
self.downsample = downsample(in_dim=in_dim, out_dim =out_dim) if downsample is not None else None
def forward(self, x):
for blk in self.blocks:
x, h = blk(x)
x_out = x
if self.downsample is not None:
x = self.downsample(x)
return x, x_out, h
class EfficientViM(nn.Module):
def __init__(self, in_dim=3, num_classes=1000, embed_dim=[128,256,512], depths=[2, 2, 2], mlp_ratio=4., ssd_expand=1, state_dim=[49,25,9], distillation=False, **kwargs):
super().__init__()
self.num_layers = len(depths)
self.num_classes = num_classes
self.distillation =distillation
self.patch_embed = Stem(in_dim=in_dim, dim=embed_dim[0])
PatchMergingBlock = PatchMerging
# build stages
self.stages = nn.ModuleList()
for i_layer in range(self.num_layers):
stage = EfficientViMStage(in_dim=int(embed_dim[i_layer]),
out_dim=int(embed_dim[i_layer+1]) if (i_layer < self.num_layers - 1) else None,
depth=depths[i_layer],
mlp_ratio=mlp_ratio,
downsample=PatchMergingBlock if (i_layer < self.num_layers - 1) else None,
ssd_expand=ssd_expand,
state_dim = state_dim[i_layer])
self.stages.append(stage)
# Weights for multi-stage hidden-state Fusion
self.weights = nn.Parameter(torch.ones(4))
self.norm = nn.ModuleList([
LayerNorm1D(embed_dim[0]),
LayerNorm1D(embed_dim[1]),
LayerNorm1D(embed_dim[2]),
LayerNorm2D(embed_dim[2]),
])
self.heads = nn.ModuleList([
nn.Linear(embed_dim[0], num_classes) if num_classes > 0 else nn.Identity(),
nn.Linear(embed_dim[1], num_classes) if num_classes > 0 else nn.Identity(),
nn.Linear(embed_dim[2], num_classes) if num_classes > 0 else nn.Identity(),
nn.Linear(embed_dim[2], num_classes) if num_classes > 0 else nn.Identity()
])
if distillation:
self.weights_dist = nn.Parameter(torch.ones(4))
self.heads_dist = nn.ModuleList([
nn.Linear(embed_dim[0], num_classes) if num_classes > 0 else nn.Identity(),
nn.Linear(embed_dim[1], num_classes) if num_classes > 0 else nn.Identity(),
nn.Linear(embed_dim[2], num_classes) if num_classes > 0 else nn.Identity(),
nn.Linear(embed_dim[2], num_classes) if num_classes > 0 else nn.Identity()
])
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, LayerNorm2D):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, LayerNorm1D):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.no_grad()
def flops(self, shape=(3, 224, 224)):
supported_ops = {
"aten::silu": None,
"aten::neg": None,
"aten::exp": None,
"aten::flip": None,
"aten::softmax": None,
"aten::sigmoid": None,
"aten::mul": None,
"aten::add": None,
"aten::mean": None,
"aten::var": None,
"aten::sub": None,
"aten::sqrt": None,
"aten::div": None,
"aten::rsub": None,
"aten::adaptive_avg_pool1d": None,
}
import copy
model = copy.deepcopy(self)
model.cuda().eval()
input = torch.randn((1, *shape), device=next(model.parameters()).device)
Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops)
del model, input
return sum(Gflops.values()) * 1e9
def forward(self, x):
x = self.patch_embed(x)
weights = self.weights.softmax(-1)
z = torch.zeros((x.shape[0], self.num_classes), device=x.device)
if self.distillation:
weights_dist = self.weights_dist.softmax(-1)
z_dist = torch.zeros((x.shape[0], self.num_classes), device=x.device)
for i, stage in enumerate(self.stages):
x, x_out, h = stage(x)
h = self.norm[i](h)
h = torch.nn.functional.adaptive_avg_pool1d(h, 1).flatten(1)
z = z + weights[i] * self.heads[i](h)
if self.distillation:
z_dist = z_dist + weights_dist[i] * self.heads_dist[i](h)
x = self.norm[3](x)
x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
z = z + weights[3] * self.heads[3](x)
if self.distillation:
z_dist = z_dist + weights_dist[3] * self.heads_dist[3](x)
z= z, z_dist
if not self.training:
z = (z[0] + z[1]) / 2
return z
@register_model
def EfficientViM_M1(pretrained=False, **kwargs):
model = EfficientViM(
in_dim=3,
embed_dim=[128,192,320],
depths=[2,2,2],
mlp_ratio=4.,
ssd_expand=1.,
state_dim=[49,25,9],
**kwargs)
return model
@register_model
def EfficientViM_M2(pretrained=False, **kwargs):
model = EfficientViM(
in_dim=3,
embed_dim=[128,256,512],
depths=[2,2,2],
mlp_ratio=4.,
ssd_expand=1.,
state_dim=[49,25,9],
**kwargs)
return model
@register_model
def EfficientViM_M3(pretrained=False, **kwargs):
model = EfficientViM(
in_dim=3,
embed_dim=[224,320,512],
depths=[2,2,2],
mlp_ratio=4.,
ssd_expand=1.,
state_dim=[49,25,9],
**kwargs)
return model
@register_model
def EfficientViM_M4(pretrained=False, **kwargs):
model = EfficientViM(
in_dim=3,
embed_dim=[224,320,512],
depths=[3,4,2],
mlp_ratio=4.,
ssd_expand=1.,
state_dim=[64,32,16],
**kwargs)
return model