SAM 提示框和 Unet的语义分割的融合:自动驾驶车道线分割

1、前言

最近SAM 模型复现的多了,看了不少官方的源码,尝试下SAM和Unet的结合

SAM的提示分割,其实就是在分割的时候,为数据增加一个提示信息,可以是框,点,或者文本等等。这样大模型网络就可以根据提示信息进行分割

SAM模型的复现可以参考:视觉大模型学习笔记_听风吹等浪起的博客-CSDN博客

如下图,就是手动的给出提示框,然后网络分割出框里的内容

需要注意的是,本文不对SAM的原理进行讲解,只是借鉴SAM的提示分割思想,将其和unet结合

其实,思想很简单,一般的图像分割网络是3通道 的,因为图像都是RGB的。

而这里需要传入sam的提示信息 ,所以还要增加一个维度,也就是说最开始输入神经网络的的通道是4维度的!

神经网络输入四维度很好实现,就是更改卷积核的个数罢了。毕竟正常神经网络隐藏层都是上百个,这里一开始和输入和中间隐藏层没啥区别,就是更改个数字罢了

实现了输入,最重要的是怎么把数据传进去!!!这里非常重要!!!

看了 MedSAM 源码实现,发现实现的思路确实很巧妙

大概思路就是,在图像上加上一层canvas,是全黑的。然后对mask的前景分割类别随机挑选一个,变成二值数据,这样就可以求取min和max,就可以得到box,将这个box传给canvas就可以达到提示框的目的

这里后面会根据代码介绍

2、SAM 提示框分割

SAM的提示信息分割有好几种,因为本人是医学图像专业,这里只实现了提示框分割。

其实点提示分割也很简单,明白了提示框分割的思路后,点分割就是怎么对数据处理的问题。或者就是把canvas的提示框换成其他mask模板。

点提示的话,可以用区域生长进行掩膜?

提示框的作用是什么呢?

说白了就是让网络更集中在一块区域, 对指定区域进行分割。

一般的语义分割,就是image和mask 一一对应,然后交叉熵、优化。这样网络同时要预测好几个类别,这也是为啥多类别分割效果不如二值分割好的原因

而提示框的box选定后,不管mask中有几个类别,这样mask都会变成二值的。

就像下面,分割的前景好十来个,但是我们绘制了box圈中小车后,那么对于网络来说,就是二值的分割。也就是前景车1 ,背景是除车的所有区域

代码实现的思路很简单,mask是0 1 2 3 4等等阈值图像,我们随机选中一个值,只取出这个值的数据。例如是3,那么除了3位置(x,y)得其他区域全部是0,3变成了新的阈值前景,这样不就是二值图像了嘛!!

至于二值图像,计算前景的min,max,就可以绘制边界框,这样就是sam提示框分割的精髓。代码实现如下

python 复制代码
        label_ids = np.unique(mask)[1:]
        label_id = random.choice(label_ids.tolist())
        mask = np.uint8(mask == label_id)  # only one label

自动添加边界框:

python 复制代码
        y_indices, x_indices = np.where(mask > 0)
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)

这样输入的数据就不是RGB图像,而是RGB+canvas(矩形边界框)的四通道数据

注意:这里的操作都是在dataset脚本实现的

3、Unet 网络

unet网络分割思想很简单,之前写了很多关于unet分割的介绍,自己参考就行

UNet 网络做图像分割DRIVE数据集-CSDN博客

图像分割_听风吹等浪起的博客-CSDN博客

因为datase已经产生了四通道的数据,所有unet的输入是4维的,而输出是二值图像,所有outout是2维的

这里用1,因为后面用了 DiceCELoss 损失,而非交叉熵

Tips:

虽然Unet网络略显过时,什么DenseUne、TransUnet、SwinUnet等等都很好。

这里为了方便只是将sam和Unet结合,如果需要把sam和其他分割网络结合,把上面的model换成想要的分割模型,定义输入4,输出1就行

其他的代码不需要更改!

4、实验准备

环境如下,pip安装即可

python 复制代码
matplotlib==3.9.2
monai==1.3.1
numpy==2.1.0
opencv_python==4.10.0.82
Pillow==10.4.0
torch==2.0.0
torchvision==0.15.1
tqdm==4.66.4

4.1 数据集

数据集使用的是自动驾驶车道线的分割,标签如下,可视化如下

python 复制代码
0	 background
1	 white_lines
2	 yellow_lines

数据集摆放如下:

复制代码
--data--train---images   训练集的图像
--data--train---masks    训练集的图像标签
--data--val---images     验证集的图像
--data--val---masks      验证集的图像标签

4.2 训练

因为SAM提示框每次都是挑选一个前景分割,对于类别多的话,epoch最好设置的多点

参数如下:

python 复制代码
    parser = argparse.ArgumentParser(description="SamUnet segmentation")
    parser.add_argument("--batch-size", default=8, type=int)
    parser.add_argument("--epochs", default=200, type=int)
    parser.add_argument('--lr', default=0.01, type=float)
    parser.add_argument('--lrf',default=0.0001,type=float)                  # 最终学习率 = lr * lrf

    parser.add_argument("--img_f", default='.jpg', type=str)               # 数据图像的后缀
    parser.add_argument("--mask_f", default='_mask.png', type=str)              # mask图像的后缀

    args = parser.parse_args()

训练集和验证集的比例约为3:1

数据的输入效果:

注意,这里前景应该是白色的,但由于白色不太好区别,这里填充为红色

loss和dice的曲线:

4.3 推理

这里编写了简单的ui推理界面,如下

这里会打印矩形框的坐标:

5、1 项目下载

项目为付费资源,包含数据集、完整代码、训练权重等等

完整测试项目:基于Unet+SAM提示框实现的自动驾驶道路线语义分割、数据集、源码、训练权重资源-CSDN文库

相关推荐
Elastic 中国社区官方博客2 小时前
Elasticsearch:使用 Agent Builder 的 A2A 实现 - 开发者的圣诞颂歌
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
chools2 小时前
【AI超级智能体】快速搞懂工具调用Tool Calling 和 MCP协议
java·人工智能·学习·ai
郝学胜-神的一滴2 小时前
深度学习必学:PyTorch 神经网络参数初始化全攻略(原理 + 代码 + 选择指南)
人工智能·pytorch·python·深度学习·神经网络·机器学习
leobertlan2 小时前
好玩系列:用20元实现快乐保存器
android·人工智能·算法
笨笨饿3 小时前
#58_万能函数的构造方法:ReLU函数
数据结构·人工智能·stm32·单片机·硬件工程·学习方法
jr-create(•̀⌄•́)3 小时前
从零开始:手动实现神经网络识别手写数字(完整代码讲解)
人工智能·深度学习·神经网络
冬奇Lab3 小时前
一天一个开源项目(第78篇):MiroFish - 用群体智能引擎预测未来
人工智能·开源·资讯
冬奇Lab3 小时前
你的 Skill 真的好用吗?来自OpenAI的 Eval 系统化验证 Agent 技能方法论
人工智能·openai
数智工坊3 小时前
Transformer 全套逻辑:公式推导 + 原理解剖 + 逐行精读 - 划时代封神之作!
人工智能·深度学习·transformer
GreenTea4 小时前
AI 时代,工程师的不可替代性在哪里
前端·人工智能·后端