SAM 提示框和 Unet的语义分割的融合:自动驾驶车道线分割

1、前言

最近SAM 模型复现的多了,看了不少官方的源码,尝试下SAM和Unet的结合

SAM的提示分割,其实就是在分割的时候,为数据增加一个提示信息,可以是框,点,或者文本等等。这样大模型网络就可以根据提示信息进行分割

SAM模型的复现可以参考:视觉大模型学习笔记_听风吹等浪起的博客-CSDN博客

如下图,就是手动的给出提示框,然后网络分割出框里的内容

需要注意的是,本文不对SAM的原理进行讲解,只是借鉴SAM的提示分割思想,将其和unet结合

其实,思想很简单,一般的图像分割网络是3通道 的,因为图像都是RGB的。

而这里需要传入sam的提示信息 ,所以还要增加一个维度,也就是说最开始输入神经网络的的通道是4维度的!

神经网络输入四维度很好实现,就是更改卷积核的个数罢了。毕竟正常神经网络隐藏层都是上百个,这里一开始和输入和中间隐藏层没啥区别,就是更改个数字罢了

实现了输入,最重要的是怎么把数据传进去!!!这里非常重要!!!

看了 MedSAM 源码实现,发现实现的思路确实很巧妙

大概思路就是,在图像上加上一层canvas,是全黑的。然后对mask的前景分割类别随机挑选一个,变成二值数据,这样就可以求取min和max,就可以得到box,将这个box传给canvas就可以达到提示框的目的

这里后面会根据代码介绍

2、SAM 提示框分割

SAM的提示信息分割有好几种,因为本人是医学图像专业,这里只实现了提示框分割。

其实点提示分割也很简单,明白了提示框分割的思路后,点分割就是怎么对数据处理的问题。或者就是把canvas的提示框换成其他mask模板。

点提示的话,可以用区域生长进行掩膜?

提示框的作用是什么呢?

说白了就是让网络更集中在一块区域, 对指定区域进行分割。

一般的语义分割,就是image和mask 一一对应,然后交叉熵、优化。这样网络同时要预测好几个类别,这也是为啥多类别分割效果不如二值分割好的原因

而提示框的box选定后,不管mask中有几个类别,这样mask都会变成二值的。

就像下面,分割的前景好十来个,但是我们绘制了box圈中小车后,那么对于网络来说,就是二值的分割。也就是前景车1 ,背景是除车的所有区域

代码实现的思路很简单,mask是0 1 2 3 4等等阈值图像,我们随机选中一个值,只取出这个值的数据。例如是3,那么除了3位置(x,y)得其他区域全部是0,3变成了新的阈值前景,这样不就是二值图像了嘛!!

至于二值图像,计算前景的min,max,就可以绘制边界框,这样就是sam提示框分割的精髓。代码实现如下

python 复制代码
        label_ids = np.unique(mask)[1:]
        label_id = random.choice(label_ids.tolist())
        mask = np.uint8(mask == label_id)  # only one label

自动添加边界框:

python 复制代码
        y_indices, x_indices = np.where(mask > 0)
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)

这样输入的数据就不是RGB图像,而是RGB+canvas(矩形边界框)的四通道数据

注意:这里的操作都是在dataset脚本实现的

3、Unet 网络

unet网络分割思想很简单,之前写了很多关于unet分割的介绍,自己参考就行

UNet 网络做图像分割DRIVE数据集-CSDN博客

图像分割_听风吹等浪起的博客-CSDN博客

因为datase已经产生了四通道的数据,所有unet的输入是4维的,而输出是二值图像,所有outout是2维的

这里用1,因为后面用了 DiceCELoss 损失,而非交叉熵

Tips:

虽然Unet网络略显过时,什么DenseUne、TransUnet、SwinUnet等等都很好。

这里为了方便只是将sam和Unet结合,如果需要把sam和其他分割网络结合,把上面的model换成想要的分割模型,定义输入4,输出1就行

其他的代码不需要更改!

4、实验准备

环境如下,pip安装即可

python 复制代码
matplotlib==3.9.2
monai==1.3.1
numpy==2.1.0
opencv_python==4.10.0.82
Pillow==10.4.0
torch==2.0.0
torchvision==0.15.1
tqdm==4.66.4

4.1 数据集

数据集使用的是自动驾驶车道线的分割,标签如下,可视化如下

python 复制代码
0	 background
1	 white_lines
2	 yellow_lines

数据集摆放如下:

复制代码
--data--train---images   训练集的图像
--data--train---masks    训练集的图像标签
--data--val---images     验证集的图像
--data--val---masks      验证集的图像标签

4.2 训练

因为SAM提示框每次都是挑选一个前景分割,对于类别多的话,epoch最好设置的多点

参数如下:

python 复制代码
    parser = argparse.ArgumentParser(description="SamUnet segmentation")
    parser.add_argument("--batch-size", default=8, type=int)
    parser.add_argument("--epochs", default=200, type=int)
    parser.add_argument('--lr', default=0.01, type=float)
    parser.add_argument('--lrf',default=0.0001,type=float)                  # 最终学习率 = lr * lrf

    parser.add_argument("--img_f", default='.jpg', type=str)               # 数据图像的后缀
    parser.add_argument("--mask_f", default='_mask.png', type=str)              # mask图像的后缀

    args = parser.parse_args()

训练集和验证集的比例约为3:1

数据的输入效果:

注意,这里前景应该是白色的,但由于白色不太好区别,这里填充为红色

loss和dice的曲线:

4.3 推理

这里编写了简单的ui推理界面,如下

这里会打印矩形框的坐标:

5、1 项目下载

项目为付费资源,包含数据集、完整代码、训练权重等等

完整测试项目:基于Unet+SAM提示框实现的自动驾驶道路线语义分割、数据集、源码、训练权重资源-CSDN文库

相关推荐
mahuifa1 分钟前
混合开发环境---使用编程AI辅助开发Qt
人工智能·vscode·qt·qtcreator·编程ai
四口鲸鱼爱吃盐3 分钟前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类
蓝天星空16 分钟前
Python调用open ai接口
人工智能·python
睡觉狂魔er17 分钟前
自动驾驶控制与规划——Project 3: LQR车辆横向控制
人工智能·机器学习·自动驾驶
scan72440 分钟前
LILAC采样算法
人工智能·算法·机器学习
leaf_leaves_leaf42 分钟前
win11用一条命令给anaconda环境安装GPU版本pytorch,并检查是否为GPU版本
人工智能·pytorch·python
夜雨飘零11 小时前
基于Pytorch实现的说话人日志(说话人分离)
人工智能·pytorch·python·声纹识别·说话人分离·说话人日志
菌菌的快乐生活1 小时前
理解支持向量机
算法·机器学习·支持向量机
爱喝热水的呀哈喽1 小时前
《机器学习》支持向量机
人工智能·决策树·机器学习
大山同学1 小时前
第三章线性判别函数(二)
线性代数·算法·机器学习