比SAM小60倍的轻量级分割一切模型:MobileSAM

1 MobileSAM

SAM就是一类处理图像分割任务的通用模型。与以往只能处理某种特定类型图片的图像分割模型不同,SAM可以处理所有类型的图像。

在SAM出现前,基本上所有的图像分割模型都是专有模型。比如,在医学领域,有专门分割核磁图像的人工智能模型,也有专门分割CT影像的人工智能模型。但这些模型往往只在分割专有领域内的图像时,才具有良好性能,而在分割其他领域的图像时往往性能不佳。

1.1 模型介绍

SAM是一种prompt-guided的视觉基础模型,用于从其背景中剪切出感兴趣的对象。自Meta研究团队发布SA项目以来,SAM因其令人印象深刻的零样本传输性能和与其他模型兼容的高度通用性而备受关注,用于高级视觉应用,如具有细粒度控制的图像编辑。

许多这样的用例需要在资源受限的边缘设备上运行,比如移动应用程序。今天分享中,我们的目标是通过用轻量级图像编码器取代重量级图像编码器,使SAM对移动友好。原始SAM文件中训练这种新SAM的方式会导致性能不令人满意,尤其是当可用的训练来源有限时。

我们发现,这主要是由图像编码器和掩模解码器的耦合优化引起的,因此提出了解耦蒸馏。具体地说,将原始SAM中的图像编码器ViT-H的知识提取到一个轻量级的图像编码器中,该编码器可以自动与原始SAM中的掩码解码器兼容。

训练可以在不到一天的时间内在单个GPU上完成,由此产生的轻量级SAM被称为MobileSAM,它比原始SAM小60多倍,但性能与原始SAM相当。就推理速度而言,MobileSAM每幅图像运行约10ms:图像编码器运行8ms,掩码解码器运行2ms。凭借卓越的性能和更高的通用性,我们的MobileSAM比并发的FastSAM小7倍,快4倍,更适合移动应用。

论文地址:https://arxiv.org/pdf/2306.14289.pdf

代码地址:https://github.com/ChaoningZhang/MobileSAM

1.2 新框架

  • Background on SAM

在这里,我们首先总结SAM的结构及其工作原理。SAM由一个基于ViT的图像编码器和一个提示引导掩码解码器组成。图像编码器将图像作为输入并生成嵌入,然后将嵌入提供给掩码解码器。掩码解码器生成一个掩码,根据点(或框)等提示从背景中剪切出任何对象。此外,SAM允许为同一提示生成多个掩码,以解决模糊性问题,这提供了宝贵的灵活性。考虑到这一点,这项工作保持了SAM的流水线,首先采用基于ViT的编码器来生成图像嵌入,然后采用提示引导解码器来生成所需的掩码。这条管道是为"分段任何东西"而优化设计的,可用于"分段所有东西"的下游任务。

SAM的耦合知识蒸馏。左图表示完全耦合蒸馏,右图表示半耦合蒸馏。

  • Project goal

该项目的目标是生成一个移动友好型SAM(MobileSAM),以轻量级的方式实现令人满意的性能,并且比原始SAM快得多。原始SAM中的提示引导掩码解码器的参数小于4M,因此被认为是轻量级的。给定编码器处理的图像嵌入,如他们的公开演示中所示,SAM可以在资源受限的设备中工作,因为掩码解码器是轻量级的。然而,原始SAM中的默认图像编码器是基于ViT-H的,具有超过600M的参数,这是非常重量级的,并使整个SAM管道与移动设备不兼容。因此,获得移动友好SAM的关键在于用轻量级的图像编码器取代重量级的图像编码器,这也自动保持了原始SAM的所有功能和特性。

以ViT-B为图像编码器的SAM的耦合蒸馏和解耦蒸馏的比较。与耦合蒸馏相比,解耦蒸馏性能更好,所需计算资源少于1%。

1.3 实验

下图给出了point与bbox提示词下MobileSAM与原生SAM的结果对比,可以看到:MobileSAM可以取得令人满意的Mask预测结果。

下图从Segment everything角度对比了SAM、FastSAM以及MobileSAM三个模型,可以看到:

  • MobileSAM与原生SAM结果对齐惊人的好,而FastSAM会生成一些无法满意的结果
  • FastSAM通常生成非平滑的边缘,而SAM与MobileSAM并没有该问题

MobileSAM在所有方面都优于FastSAM

SAM原始论文的标题是"Segment anything",而不是"segment everything"。如SAM中所强调的,SAM执行可prompt分割的任务,该任务"在给定任何分割prompt的情况下返回有效的分割Mask"。

prompt的作用是指定要在图像中分割的内容。理论上,只要正确设置prompt,任何目标都可以被分割,因此,它被称为"Segment anything"。相比之下,"segment everything"本质上是目标建议生成,对此不需要prompt。在SAM中,选择"segment everything"(目标建议生成)作为下游任务之一,以演示其零样本传输性能。

总之,"Segment anything"解决了任何目标的可prompt分割的基础任务,而"segment everything"解决了为所有目标生成Mask建议的下游任务不一定需要prompt,FastSAM以无prompt的方式直接用YOLO v8生成Mask建议。为了实现可prompt分割,设计了一种映射算法来从提议Mask集中选择Mask。

2 运行环境与实战

2.1 conda环境准备

conda环境准备详见:annoconda

2.2 运行环境安装

git clone https://github.com/ChaoningZhang/MobileSAM
cd MobileSAM

conda create -n mobilesam python=3.9
conda activate mobilesam

pip install -e .
pip install gradio

pip install torchvision==0.15.1
pip install timm
pip install opencv-python

2.3 模型下载

下载地址:https://huggingface.co/spaces/dhkim2810/MobileSAM/tree/main

2.4 运行

cd app

修改app.py中的代码

demo.launch(server_name='192.168.1.160')  #地址为自己的内网IP

python app.py

3 通过代码分割图片

3.1 基于prompt分割

from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

model_type = "vit_t"
sam_checkpoint = "./weights/mobile_sam.pt"

device = "cuda" if torch.cuda.is_available() else "cpu"

mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
mobile_sam.to(device=device)
mobile_sam.eval()

predictor = SamPredictor(mobile_sam)
predictor.set_image('/opt/data/cata_dog.jpg')
masks, _, _ = predictor.predict('cat')

3.2 分割整张图片

from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

model_type = "vit_t"
sam_checkpoint = "./weights/mobile_sam.pt"

device = "cuda" if torch.cuda.is_available() else "cpu"

mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
mobile_sam.to(device=device)
mobile_sam.eval()

mask_generator = SamAutomaticMaskGenerator(mobile_sam)
masks = mask_generator.generate('/opt/data/cat_dog.jpg')
相关推荐
深度学习实战训练营1 小时前
基于keras的停车场车位识别
人工智能·深度学习·keras
菜就多练_08282 小时前
《深度学习》OpenCV 摄像头OCR 过程及案例解析
人工智能·深度学习·opencv·ocr
没有余地 EliasJie2 小时前
Windows Ubuntu下搭建深度学习Pytorch训练框架与转换环境TensorRT
pytorch·windows·深度学习·ubuntu·pycharm·conda·tensorflow
技术无疆3 小时前
【Python】Streamlit:为数据科学与机器学习打造的简易应用框架
开发语言·人工智能·python·深度学习·神经网络·机器学习·数据挖掘
浊酒南街3 小时前
吴恩达深度学习笔记:卷积神经网络(Foundations of Convolutional Neural Networks)2.7-2.8
人工智能·深度学习·神经网络
被制作时长两年半的个人练习生4 小时前
【pytorch】权重为0的情况
人工智能·pytorch·深度学习
xiandong2011 小时前
240929-CGAN条件生成对抗网络
图像处理·人工智能·深度学习·神经网络·生成对抗网络·计算机视觉
innutritious12 小时前
车辆重识别(2020NIPS去噪扩散概率模型)论文阅读2024/9/27
人工智能·深度学习·计算机视觉
醒了就刷牙13 小时前
56 门控循环单元(GRU)_by《李沐:动手学深度学习v2》pytorch版
pytorch·深度学习·gru
橙子小哥的代码世界13 小时前
【深度学习】05-RNN循环神经网络-02- RNN循环神经网络的发展历史与演化趋势/LSTM/GRU/Transformer
人工智能·pytorch·rnn·深度学习·神经网络·lstm·transformer