【Block总结】MEEM,多尺度边缘增强模块|即插即用|ACM 2024

论文信息

  • 标题:Multi-Scale and Detail-Enhanced Segment Anything Model for Salient Object Detection
  • 作者:高世轩、张平平、闫天宇、卢湖川(大连理工大学)
  • 发表:ACM International Conference on Multimedia (MM '24)
  • 代码https://github.com/BellyBeauty/MDSAM

创新点

  1. 首创多尺度适配SAM:首次将多尺度适配器应用于SAM迁移到下游任务,以极少参数实现高效迁移
  2. 轻量级多尺度适配器(LMSA):在保持SAM预训练权重的同时,使其学习多尺度信息
  3. 多级融合机制(MLFM):全面利用SAM编码器不同层级的特征,融合浅层细节与深层语义
  4. 细节增强模块(DEM):通过多尺度边缘增强模块(MEEM)解决SAM缺乏细粒度细节的问题
  5. MEEM多尺度边缘增强 :创新地结合平均池化和边缘检测原理,从输入图像中提取多尺度边缘特征,增强显著目标边界感知能力

方法

MDSAM框架包含三个核心模块:LMSA、MLFM和DEM。其中,MEEM (Multi-scale Edge Enhancement Module) 是DEM的关键组件,专门负责细粒度细节增强。

MEEM架构与原理

MEEM采用多分支结构,通过级联平均池化提取多尺度特征,并通过边缘增强器突出目标边缘。核心思想是:边缘 = 原始特征 - 平滑特征,这一简单而有效的边缘检测原理使模型能够捕获高频细节信息。

MEEM完整代码及详解

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F


class LayerNorm2d(nn.Module):
    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x


class MEEM(nn.Module):
    def __init__(self, in_dim, hidden_dim, width, norm=LayerNorm2d, act=nn.GELU):
        super().__init__()
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.width = width
        self.in_conv = nn.Sequential(
            nn.Conv2d(in_dim, hidden_dim, 1, bias=False),
            norm(hidden_dim),
            nn.Sigmoid()
        )

        self.pool = nn.AvgPool2d(3, stride=1, padding=1)

        self.mid_conv = nn.ModuleList()
        self.edge_enhance = nn.ModuleList()
        for i in range(width - 1):
            self.mid_conv.append(nn.Sequential(
                nn.Conv2d(hidden_dim, hidden_dim, 1, bias=False),
                norm(hidden_dim),
                nn.Sigmoid()
            ))
            self.edge_enhance.append(EdgeEnhancer(hidden_dim, norm, act))

        self.out_conv = nn.Sequential(
            nn.Conv2d(hidden_dim * width, in_dim, 1, bias=False),
            norm(in_dim),
            act()
        )

    def forward(self, x):
        mid = self.in_conv(x)

        out = mid
        # print(out.shape)

        for i in range(self.width - 1):
            mid = self.pool(mid)
            mid = self.mid_conv[i](mid)

            out = torch.cat([out, self.edge_enhance[i](mid)], dim=1)

        out = self.out_conv(out)

        return out


class EdgeEnhancer(nn.Module):
    def __init__(self, in_dim, norm, act):
        super().__init__()
        self.out_conv = nn.Sequential(
            nn.Conv2d(in_dim, in_dim, 1, bias=False),
            norm(in_dim),
            nn.Sigmoid()
        )
        self.pool = nn.AvgPool2d(3, stride=1, padding=1)

    def forward(self, x):
        edge = self.pool(x)
        edge = x - edge
        edge = self.out_conv(edge)
        return x + edge


class DetailEnhancement(nn.Module):
    def __init__(self, img_dim, feature_dim, norm, act):
        super().__init__()
        self.img_in_conv = nn.Sequential(
            nn.Conv2d(img_dim, img_dim, 3, padding=1, bias=False),
            norm(img_dim),
            act()
        )
        self.img_er = MEEM(img_dim, img_dim // 2, 4, norm, act)

        self.fusion_conv = nn.Sequential(
            nn.Conv2d(feature_dim + img_dim, img_dim, 3, padding=1, bias=False),
            norm(img_dim),
            act(),
        )


        self.feature_upsample = nn.Sequential(
            nn.Conv2d(feature_dim * 2, feature_dim, 3, padding=1, bias=False),
            norm(feature_dim),
            act(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(feature_dim, feature_dim, 3, padding=1, bias=False),
            norm(feature_dim),
            act(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(feature_dim, feature_dim, 3, padding=1, bias=False),
            norm(feature_dim),
            act(),
        )

    def forward(self, img, feature, b_feature):
        feature = torch.cat([feature, b_feature], dim=1)
        feature = self.feature_upsample(feature)

        img_feature = self.img_in_conv(img)
        img_feature = self.img_er(img_feature) + img_feature

        out_feature = torch.cat([feature, img_feature], dim=1)
        out_feature = self.fusion_conv(out_feature)

        return out


class MLFusion(nn.Module):
    def __init__(self, norm, act):
        super().__init__()
        self.fusi_conv = nn.Sequential(
            nn.Conv2d(1024, 256, 1, bias=False),
            norm(256),
            act(),
        )

        self.attn_conv = nn.ModuleList()
        for i in range(4):
            self.attn_conv.append(nn.Sequential(
                nn.Conv2d(256, 256, 1, bias=False),
                norm(256),
                act(),
            ))

        self.pool = nn.AdaptiveAvgPool2d(1)

        self.sigmoid = nn.Sigmoid()

    def forward(self, feature_list):
        fusi_feature = torch.cat(feature_list, dim=1).contiguous()
        fusi_feature = self.fusi_conv(fusi_feature)

        for i in range(4):
            x = feature_list[i]
            attn = self.attn_conv[i](x)
            attn = self.pool(attn)
            attn = self.sigmoid(attn)

            x = attn * x + x
            feature_list[i] = x

        return feature_list[0] + feature_list[1] + feature_list[2] + feature_list[3]


class ModifyPPM(nn.Module):
    def __init__(self, in_dim, reduction_dim, bins):
        super(ModifyPPM, self).__init__()
        self.features = []
        for bin in bins:
            self.features.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(bin),
                nn.Conv2d(in_dim, reduction_dim, kernel_size=1),
                nn.GELU(),
                nn.Conv2d(reduction_dim, reduction_dim, kernel_size=3, bias=False, groups=reduction_dim),
                nn.GELU()
            ))
        self.features = nn.ModuleList(self.features)
        self.local_conv = nn.Sequential(
            nn.Conv2d(in_dim, in_dim, kernel_size=3, padding=1, bias=False, groups=in_dim),
            nn.GELU(),
        )

    def forward(self, x):
        x_size = x.size()
        out = [self.local_conv(x)]
        for f in self.features:
            out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
        return torch.cat(out, 1)


class LMSA(nn.Module):
    def __init__(self, in_dim, hidden_dim, patch_num):
        super().__init__()
        self.down_project = nn.Linear(in_dim, hidden_dim)
        self.act = nn.GELU()
        self.mppm = ModifyPPM(hidden_dim, hidden_dim // 4, [3, 6, 9, 12])
        self.patch_num = patch_num
        self.up_project = nn.Linear(hidden_dim, in_dim)
        self.down_conv = nn.Sequential(nn.Conv2d(hidden_dim * 2, hidden_dim, 1),
                                       nn.GELU())

    def forward(self, x):
        down_x = self.down_project(x)
        down_x = self.act(down_x)

        down_x = down_x.permute(0, 3, 1, 2).contiguous()
        down_x = self.mppm(down_x).contiguous()
        down_x = self.down_conv(down_x)
        down_x = down_x.permute(0, 2, 3, 1).contiguous()

        up_x = self.up_project(down_x)
        return x + up_x
if __name__ == '__main__':

    model = MEEM(512,512*4,4,LayerNorm2d,nn.GELU)
    x = torch.randn(1, 512, 16, 16)
    y = model(x)

    print('MEEM Input size:',x.shape)
    print('MEEM Output size:', y.shape)

代码详解

  1. LayerNorm2d:通道维度的LayerNorm实现,替代传统的BatchNorm,适应SAM的预训练权重
  2. EdgeEnhancer
    • 核心原理:edge = x - pool(x),通过特征减法提取高频边缘
    • 使用3×3平均池化获取平滑特征
    • 1×1卷积层增强边缘响应,Sigmoid激活提供自适应权重
    • 残差连接保留原始特征信息
  3. MEEM主模块
    • 输入处理:1×1卷积降维,Sigmoid激活约束特征范围
    • 多尺度提取:通过级联平均池化获得不同尺度特征,每个尺度经过特定卷积处理
    • 边缘增强:对每个尺度应用EdgeEnhancer,增强边缘感知
    • 特征融合:通道拼接所有尺度特征,1×1卷积融合为最终输出
    • 参数效率:使用1×1卷积减少计算量,平均池化替代大卷积核

效果

  1. SOD基准测试

    • 在DUTS-OMRON、HKU-IS和ECSSD数据集上达到最佳性能
    • 512×512分辨率下,HKU-IS数据集上F-measure达0.963,MAE低至0.019
    • 384×384分辨率下推理速度达50 FPS,保持高精度
  2. MEEM消融实验

    • 消融研究显示,引入MEEM后边缘精度显著提升
    • 多尺度设置(4个尺度)比单尺度性能提升2.3%
    • 在复杂场景中能更准确地检测不同大小的显著目标
  3. 泛化能力

    • 在伪装目标检测(COD)任务上超越专业COD模型
    • 在医学图像息肉分割任务上MAE降至0.008
    • 零样本分析显示,即使没有提示,也能保持良好性能
  4. 视觉效果

    • 生成的显著图具有精确的边缘和丰富的细节
    • 在目标遮挡、复杂背景等挑战性场景中表现优异
    • 目标形状和轮廓识别能力显著优于现有方法

总结

MDSAM成功解决了SAM应用于显著目标检测的核心挑战,通过创新的MEEM模块有效增强了边缘细节感知能力。MEEM采用多尺度边缘提取和增强机制,通过简单的特征减法原理实现了高效的边缘检测。实验表明,该方法不仅在SOD任务上达到最先进性能,还保持了SAM的强大泛化能力。MDSAM证明了基础模型通过合理适配可以高效应用于特定视觉任务,为视觉基础模型的下游迁移提供了新思路。未来工作可进一步优化MEEM的计算效率,扩展到更多分割场景。

相关推荐
滨HI02 小时前
opencv 计算面积、周长
人工智能·opencv·计算机视觉
OpenBayes3 小时前
OCR 新范式!DeepSeek 以「视觉压缩」替代传统字符识别;Bald Classification数据集助力高精度人像分类
人工智能·深度学习·分类·数据挖掘·ocr·数据集·deepseek
亚马逊云开发者3 小时前
Agentic AI基础设施实践经验系列(四):MCP服务器从本地到云端的部署演进
人工智能
知识搬运工人3 小时前
深入解析U-Net
人工智能
weixin_421133413 小时前
深度强化学习,用神经网络代替 Q-table
人工智能·深度学习·神经网络
lx7416026983 小时前
面试可能的问题(自用)
人工智能·自然语言处理
数字化脑洞实验室3 小时前
智能决策算法的核心原理是什么?
人工智能·算法·机器学习
流烟默3 小时前
机器学习中拟合、欠拟合、过拟合是什么
人工智能·算法·机器学习
说私域3 小时前
社群时代下的商业变革:“开源AI智能名片链动2+1模式S2B2C商城小程序”的应用与影响
人工智能·小程序·开源