比SAM小60倍的轻量级分割一切模型:MobileSAM

1 MobileSAM

SAM就是一类处理图像分割任务的通用模型。与以往只能处理某种特定类型图片的图像分割模型不同,SAM可以处理所有类型的图像。

在SAM出现前,基本上所有的图像分割模型都是专有模型。比如,在医学领域,有专门分割核磁图像的人工智能模型,也有专门分割CT影像的人工智能模型。但这些模型往往只在分割专有领域内的图像时,才具有良好性能,而在分割其他领域的图像时往往性能不佳。

1.1 模型介绍

SAM是一种prompt-guided的视觉基础模型,用于从其背景中剪切出感兴趣的对象。自Meta研究团队发布SA项目以来,SAM因其令人印象深刻的零样本传输性能和与其他模型兼容的高度通用性而备受关注,用于高级视觉应用,如具有细粒度控制的图像编辑。

许多这样的用例需要在资源受限的边缘设备上运行,比如移动应用程序。今天分享中,我们的目标是通过用轻量级图像编码器取代重量级图像编码器,使SAM对移动友好。原始SAM文件中训练这种新SAM的方式会导致性能不令人满意,尤其是当可用的训练来源有限时。

我们发现,这主要是由图像编码器和掩模解码器的耦合优化引起的,因此提出了解耦蒸馏。具体地说,将原始SAM中的图像编码器ViT-H的知识提取到一个轻量级的图像编码器中,该编码器可以自动与原始SAM中的掩码解码器兼容。

训练可以在不到一天的时间内在单个GPU上完成,由此产生的轻量级SAM被称为MobileSAM,它比原始SAM小60多倍,但性能与原始SAM相当。就推理速度而言,MobileSAM每幅图像运行约10ms:图像编码器运行8ms,掩码解码器运行2ms。凭借卓越的性能和更高的通用性,我们的MobileSAM比并发的FastSAM小7倍,快4倍,更适合移动应用。

论文地址:https://arxiv.org/pdf/2306.14289.pdf

代码地址:https://github.com/ChaoningZhang/MobileSAM

1.2 新框架

  • Background on SAM

在这里,我们首先总结SAM的结构及其工作原理。SAM由一个基于ViT的图像编码器和一个提示引导掩码解码器组成。图像编码器将图像作为输入并生成嵌入,然后将嵌入提供给掩码解码器。掩码解码器生成一个掩码,根据点(或框)等提示从背景中剪切出任何对象。此外,SAM允许为同一提示生成多个掩码,以解决模糊性问题,这提供了宝贵的灵活性。考虑到这一点,这项工作保持了SAM的流水线,首先采用基于ViT的编码器来生成图像嵌入,然后采用提示引导解码器来生成所需的掩码。这条管道是为"分段任何东西"而优化设计的,可用于"分段所有东西"的下游任务。

SAM的耦合知识蒸馏。左图表示完全耦合蒸馏,右图表示半耦合蒸馏。

  • Project goal

该项目的目标是生成一个移动友好型SAM(MobileSAM),以轻量级的方式实现令人满意的性能,并且比原始SAM快得多。原始SAM中的提示引导掩码解码器的参数小于4M,因此被认为是轻量级的。给定编码器处理的图像嵌入,如他们的公开演示中所示,SAM可以在资源受限的设备中工作,因为掩码解码器是轻量级的。然而,原始SAM中的默认图像编码器是基于ViT-H的,具有超过600M的参数,这是非常重量级的,并使整个SAM管道与移动设备不兼容。因此,获得移动友好SAM的关键在于用轻量级的图像编码器取代重量级的图像编码器,这也自动保持了原始SAM的所有功能和特性。

以ViT-B为图像编码器的SAM的耦合蒸馏和解耦蒸馏的比较。与耦合蒸馏相比,解耦蒸馏性能更好,所需计算资源少于1%。

1.3 实验

下图给出了point与bbox提示词下MobileSAM与原生SAM的结果对比,可以看到:MobileSAM可以取得令人满意的Mask预测结果。

下图从Segment everything角度对比了SAM、FastSAM以及MobileSAM三个模型,可以看到:

  • MobileSAM与原生SAM结果对齐惊人的好,而FastSAM会生成一些无法满意的结果
  • FastSAM通常生成非平滑的边缘,而SAM与MobileSAM并没有该问题

MobileSAM在所有方面都优于FastSAM

SAM原始论文的标题是"Segment anything",而不是"segment everything"。如SAM中所强调的,SAM执行可prompt分割的任务,该任务"在给定任何分割prompt的情况下返回有效的分割Mask"。

prompt的作用是指定要在图像中分割的内容。理论上,只要正确设置prompt,任何目标都可以被分割,因此,它被称为"Segment anything"。相比之下,"segment everything"本质上是目标建议生成,对此不需要prompt。在SAM中,选择"segment everything"(目标建议生成)作为下游任务之一,以演示其零样本传输性能。

总之,"Segment anything"解决了任何目标的可prompt分割的基础任务,而"segment everything"解决了为所有目标生成Mask建议的下游任务不一定需要prompt,FastSAM以无prompt的方式直接用YOLO v8生成Mask建议。为了实现可prompt分割,设计了一种映射算法来从提议Mask集中选择Mask。

2 运行环境与实战

2.1 conda环境准备

conda环境准备详见:annoconda

2.2 运行环境安装

git clone https://github.com/ChaoningZhang/MobileSAM
cd MobileSAM

conda create -n mobilesam python=3.9
conda activate mobilesam

pip install -e .
pip install gradio

pip install torchvision==0.15.1
pip install timm
pip install opencv-python

2.3 模型下载

下载地址:https://huggingface.co/spaces/dhkim2810/MobileSAM/tree/main

2.4 运行

cd app

修改app.py中的代码

demo.launch(server_name='192.168.1.160')  #地址为自己的内网IP

python app.py

3 通过代码分割图片

3.1 基于prompt分割

from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

model_type = "vit_t"
sam_checkpoint = "./weights/mobile_sam.pt"

device = "cuda" if torch.cuda.is_available() else "cpu"

mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
mobile_sam.to(device=device)
mobile_sam.eval()

predictor = SamPredictor(mobile_sam)
predictor.set_image('/opt/data/cata_dog.jpg')
masks, _, _ = predictor.predict('cat')

3.2 分割整张图片

from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

model_type = "vit_t"
sam_checkpoint = "./weights/mobile_sam.pt"

device = "cuda" if torch.cuda.is_available() else "cpu"

mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
mobile_sam.to(device=device)
mobile_sam.eval()

mask_generator = SamAutomaticMaskGenerator(mobile_sam)
masks = mask_generator.generate('/opt/data/cat_dog.jpg')
相关推荐
糖豆豆今天也要努力鸭40 分钟前
torch.__version__的torch版本和conda list的torch版本不一致
linux·pytorch·python·深度学习·conda·torch
何大春1 小时前
【弱监督语义分割】Self-supervised Image-specific Prototype Exploration for WSSS 论文阅读
论文阅读·人工智能·python·深度学习·论文笔记·原型模式
Suyuoa1 小时前
附录2-pytorch yolov5目标检测
python·深度学习·yolo
余生H2 小时前
transformer.js(三):底层架构及性能优化指南
javascript·深度学习·架构·transformer
罗小罗同学3 小时前
医工交叉入门书籍分享:Transformer模型在机器学习领域的应用|个人观点·24-11-22
深度学习·机器学习·transformer
孤独且没人爱的纸鹤3 小时前
【深度学习】:从人工神经网络的基础原理到循环神经网络的先进技术,跨越智能算法的关键发展阶段及其未来趋势,探索技术进步与应用挑战
人工智能·python·深度学习·机器学习·ai
阿_旭3 小时前
TensorFlow构建CNN卷积神经网络模型的基本步骤:数据处理、模型构建、模型训练
人工智能·深度学习·cnn·tensorflow
羊小猪~~3 小时前
tensorflow案例7--数据增强与测试集, 训练集, 验证集的构建
人工智能·python·深度学习·机器学习·cnn·tensorflow·neo4j
极客代码3 小时前
【Python TensorFlow】进阶指南(续篇三)
开发语言·人工智能·python·深度学习·tensorflow
Seeklike3 小时前
11.22 深度学习-pytorch自动微分
人工智能·pytorch·深度学习