1、前言
最近SAM 模型复现的多了,看了不少官方的源码,尝试下SAM和Unet的结合
SAM的提示分割,其实就是在分割的时候,为数据增加一个提示信息,可以是框,点,或者文本等等。这样大模型网络就可以根据提示信息进行分割
SAM模型的复现可以参考:视觉大模型学习笔记_听风吹等浪起的博客-CSDN博客
如下图,就是手动的给出提示框,然后网络分割出框里的内容
需要注意的是,本文不对SAM的原理进行讲解,只是借鉴SAM的提示分割思想,将其和unet结合
其实,思想很简单,一般的图像分割网络是3通道 的,因为图像都是RGB的。
而这里需要传入sam的提示信息 ,所以还要增加一个维度,也就是说最开始输入神经网络的的通道是4维度的!
神经网络输入四维度很好实现,就是更改卷积核的个数罢了。毕竟正常神经网络隐藏层都是上百个,这里一开始和输入和中间隐藏层没啥区别,就是更改个数字罢了
实现了输入,最重要的是怎么把数据传进去!!!这里非常重要!!!
看了 MedSAM 源码实现,发现实现的思路确实很巧妙
大概思路就是,在图像上加上一层canvas,是全黑的。然后对mask的前景分割类别随机挑选一个,变成二值数据,这样就可以求取min和max,就可以得到box,将这个box传给canvas就可以达到提示框的目的
这里后面会根据代码介绍
2、SAM 提示框分割
SAM的提示信息分割有好几种,因为本人是医学图像专业,这里只实现了提示框分割。
其实点提示分割也很简单,明白了提示框分割的思路后,点分割就是怎么对数据处理的问题。或者就是把canvas的提示框换成其他mask模板。
点提示的话,可以用区域生长进行掩膜?
提示框的作用是什么呢?
说白了就是让网络更集中在一块区域, 对指定区域进行分割。
一般的语义分割,就是image和mask 一一对应,然后交叉熵、优化。这样网络同时要预测好几个类别,这也是为啥多类别分割效果不如二值分割好的原因
而提示框的box选定后,不管mask中有几个类别,这样mask都会变成二值的。
就像下面,分割的前景好十来个,但是我们绘制了box圈中小车后,那么对于网络来说,就是二值的分割。也就是前景车1 ,背景是除车的所有区域
代码实现的思路很简单,mask是0 1 2 3 4等等阈值图像,我们随机选中一个值,只取出这个值的数据。例如是3,那么除了3位置(x,y)得其他区域全部是0,3变成了新的阈值前景,这样不就是二值图像了嘛!!
至于二值图像,计算前景的min,max,就可以绘制边界框,这样就是sam提示框分割的精髓。代码实现如下
python
label_ids = np.unique(mask)[1:]
label_id = random.choice(label_ids.tolist())
mask = np.uint8(mask == label_id) # only one label
自动添加边界框:
python
y_indices, x_indices = np.where(mask > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
这样输入的数据就不是RGB图像,而是RGB+canvas(矩形边界框)的四通道数据
注意:这里的操作都是在dataset脚本实现的
3、Unet 网络
unet网络分割思想很简单,之前写了很多关于unet分割的介绍,自己参考就行
因为datase已经产生了四通道的数据,所有unet的输入是4维的,而输出是二值图像,所有outout是2维的
这里用1,因为后面用了 DiceCELoss 损失,而非交叉熵
Tips:
虽然Unet网络略显过时,什么DenseUne、TransUnet、SwinUnet等等都很好。
这里为了方便只是将sam和Unet结合,如果需要把sam和其他分割网络结合,把上面的model换成想要的分割模型,定义输入4,输出1就行
其他的代码不需要更改!
4、实验准备
环境如下,pip安装即可
python
matplotlib==3.9.2
monai==1.3.1
numpy==2.1.0
opencv_python==4.10.0.82
Pillow==10.4.0
torch==2.0.0
torchvision==0.15.1
tqdm==4.66.4
4.1 数据集
数据集使用的是自动驾驶车道线的分割,标签如下,可视化如下
python
0 background
1 white_lines
2 yellow_lines
数据集摆放如下:
--data--train---images 训练集的图像
--data--train---masks 训练集的图像标签
--data--val---images 验证集的图像
--data--val---masks 验证集的图像标签
4.2 训练
因为SAM提示框每次都是挑选一个前景分割,对于类别多的话,epoch最好设置的多点
参数如下:
python
parser = argparse.ArgumentParser(description="SamUnet segmentation")
parser.add_argument("--batch-size", default=8, type=int)
parser.add_argument("--epochs", default=200, type=int)
parser.add_argument('--lr', default=0.01, type=float)
parser.add_argument('--lrf',default=0.0001,type=float) # 最终学习率 = lr * lrf
parser.add_argument("--img_f", default='.jpg', type=str) # 数据图像的后缀
parser.add_argument("--mask_f", default='_mask.png', type=str) # mask图像的后缀
args = parser.parse_args()
训练集和验证集的比例约为3:1
数据的输入效果:
注意,这里前景应该是白色的,但由于白色不太好区别,这里填充为红色
loss和dice的曲线:
4.3 推理
这里编写了简单的ui推理界面,如下
这里会打印矩形框的坐标:
5、1 项目下载
项目为付费资源,包含数据集、完整代码、训练权重等等