【图像分割】【深度学习】Windows11下MedSAM官方代码Pytorch实现与源码讲解
文章目录
前言
MedSAM是多伦多大学的Jun Ma等人在《Segment Anything in Medical Images and Videos:Benchmark and Deployment【nature communications 2024】》【论文地址】一文中提出的模型,旨在充当通用医学图像分割的基础模型,能够适应各种成像条件、解剖结构和病理条件的变化。
在详细解析MedSAM网络之前,首要任务是搭建MedSAM【Pytorch-demo地址】所需的运行环境,并完成模型训练和测试工作,展开后续工作才有意义。
MedSAM模型运行环境搭建
在win11环境下安装anaconda环境参考,方便搭建专用于GCNet模型的虚拟环境。
-
查看主机支持的cuda版本(最高)
bash# 打开cmd,执行下面的指令查看CUDA版本号 nvidia-smi
-
博主安装环境参考
bash# 创建虚拟环境 conda create -n medsam python=3.10 -y # 查看新环境是否安装成功 conda env list # 激活环境 activate medsam # githup下载medsam源代码到适合目录内,解压文件 git clone https://github.com/bowang-lab/MedSAM # 分别安装pytorch和torchvision pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # 安装其他依赖库包 pip install connected-components-3d pip install PyQt5 # 编译一些额外的自定义包 pip install -e . # 查看所有安装的包 pip list conda list
要求Pytorch大于2.0:
MedSAM模型运行
数据集与模型权重下载
名称 | 下载地址 | 说明 |
---|---|---|
FLARE22Train数据集 | 【官方下载地址】 | MICCAI FLARE 202 挑战赛提供的腹部CT数据集,完成对 13 种腹部器官的分割 |
MedSAM模型权重(medsam_vit_b.pth) | 【官方下载地址】 | MedSAM在FLARE22Train数据集上训练出的模型权重 |
SAM预训练模型权重(sam_vit_b_01ec64.pth) | 【官方下载地址】 | MedSAM是在SAM预训练模型权重基础上再用FLARE22Train数据集继续训练 |
简单说明下FLARE22Train数据集是单通道图像,但是.nii.gz格式是按空间位置(或者说扫描的先后顺序)将多个图像拼接在一起形成多通道的单一数据,读者可以运行以下代码验证:
show_niiImage.py
python
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
# 加载nii.gz文件
# file_path = r'../data/FLARE22Train/labels/FLARE22_Tr_0001_0000.nii.gz' # 替换为您的文件路径
file_path = r'../data/FLARE22Train/images/FLARE22_Tr_0001_0000.nii.gz' # 替换为您的文件路径
img = nib.load(file_path)
# 获取图像数据
data = img.get_fdata()
# 查看图像的基本信息
print(f"Shape of the image data: {data.shape}")
print(f"Shape of the image data: {type(data)}")
print(f"Data type of the image: {data.dtype}")
# 显示中间切片作为示例
num = data.shape[2]
print(np.unique(data))
for i in range(num):
plt.imshow(data[:, :, i], cmap='gray')
plt.title(f'Slice number {i}')
plt.axis('off') # 关闭坐标轴
plt.show()
数据集预处理:将.nii.gz格式多张图像通道拼接成的数据集重新拆分为单通道的单张图像,并用.npy格式保存。
bash
python pre_CT_MR.py
pre_CT_MR.py
python
import numpy as np
# SimpleITK 用于医学图像分析的开源软件包
import SimpleITK as sitk
import os
join = os.path.join
from skimage import transform
from tqdm import tqdm
import cc3d
modality = "CT"
anatomy = "Abd"
img_name_suffix = "_0000.nii.gz"
gt_name_suffix = ".nii.gz"
prefix = modality + "_" + anatomy + "_"
nii_path = "data/FLARE22Train/images" # 图像路径
gt_path = "data/FLARE22Train/labels" # 标签路径
npy_path = r"data/npy/" + prefix[:-1]
os.makedirs(join(npy_path, "gts"), exist_ok=True)
os.makedirs(join(npy_path, "imgs"), exist_ok=True)
image_size = 1024
voxel_num_thre2d = 100
voxel_num_thre3d = 1000
names = sorted(os.listdir(gt_path))
print(f"ori \# files {len(names)=}")
names = [
name
for name in names
if os.path.exists(join(nii_path, name.split(gt_name_suffix)[0] + img_name_suffix))
]
print(f"after sanity check \# files {len(names)=}")
# 设置排除的标签id
remove_label_ids = [
12
] # 去除脱氧核糖核酸,因为它在图像中分散,很难用边界框指定
# 仅在有多个肿瘤时设置;将语义掩码转换为实例掩码
tumor_id = None
# 设置窗口级别和宽度
# https://radiopaedia.org/articles/windowing-ct
WINDOW_LEVEL = 40 # only for CT images
WINDOW_WIDTH = 400 # only for CT images
# 进度条 %% 将预处理后的图像和掩模保存为npz文件
for name in tqdm(names[:40]): # 保留剩余的10个案例进行验证
image_name = name.split(gt_name_suffix)[0] + img_name_suffix
gt_name = name
# 读取掩码
gt_sitk = sitk.ReadImage(join(gt_path, gt_name))
# 将SimpleITK图像对象转换为NumPy数组
gt_data_ori = np.uint8(sitk.GetArrayFromImage(gt_sitk))
# 删除标签id:对mask标签号置0
for remove_label_id in remove_label_ids:
gt_data_ori[gt_data_ori == remove_label_id] = 0
# 将肿瘤标签标记为实例并从gt_data_ori中删除
if tumor_id is not None:
# 找到所有肿瘤标签:1为肿瘤,0是背景
tumor_bw = np.uint8(gt_data_ori == tumor_id)
# 去除肿瘤标签id
gt_data_ori[tumor_bw > 0] = 0
# 将肿瘤标签标记为实例:在三维二值图像(CWH)中识别和标记连通分量
# connectivity:6、18 和 26 连通性分别定义了如何确定体素(voxel)之间的邻接关系。
tumor_inst, tumor_n = cc3d.connected_components(
tumor_bw, connectivity=26, return_N=True
)
# 在gt_data_ori中设置为肿瘤实例分割标签id(id号从语义分割标签总数后开始)
gt_data_ori[tumor_inst > 0] = (
tumor_inst[tumor_inst > 0] + np.max(gt_data_ori) + 1
)
# 用于去除三维二值图像中小于指定阈值1000的连通分量
gt_data_ori = cc3d.dust(
gt_data_ori, threshold=voxel_num_thre3d, connectivity=26, in_place=True
)
# 用于去除二维二值图像中小于指定阈值100的连通分量
for slice_i in range(gt_data_ori.shape[0]):
gt_i = gt_data_ori[slice_i, :, :]
# remove small objects with less than 100 pixels
# reason: fro such small objects, the main challenge is detection rather than segmentation
gt_data_ori[slice_i, :, :] = cc3d.dust(
gt_i, threshold=voxel_num_thre2d, connectivity=8, in_place=True
)
# 有前景分割标签的图像序号(通道号)
z_index, _, _ = np.where(gt_data_ori > 0)
z_index = np.unique(z_index)
if len(z_index) > 0:
# 找到有前景目标的切片
gt_roi = gt_data_ori[z_index, :, :]
# 加载图像转换为NumPy数组
img_sitk = sitk.ReadImage(join(nii_path, image_name))
image_data = sitk.GetArrayFromImage(img_sitk)
#
if modality == "CT":
lower_bound = WINDOW_LEVEL - WINDOW_WIDTH / 2
upper_bound = WINDOW_LEVEL + WINDOW_WIDTH / 2
# 指定最低下限和最高上限
image_data_pre = np.clip(image_data, lower_bound, upper_bound)
# 归一化
image_data_pre = (
(image_data_pre - np.min(image_data_pre))
/ (np.max(image_data_pre) - np.min(image_data_pre))
* 255.0
)
else:
# 计算数组中指定百分位数的值:(max-min)*percent+min
# 下限是0.5% 上限是99.5%
lower_bound, upper_bound = np.percentile(
image_data[image_data > 0], 0.5
), np.percentile(image_data[image_data > 0], 99.5)
# 指定最低下限和最高上限
image_data_pre = np.clip(image_data, lower_bound, upper_bound)
# 归一化
image_data_pre = (
(image_data_pre - np.min(image_data_pre))
/ (np.max(image_data_pre) - np.min(image_data_pre))
* 255.0
)
# 目的只是为了把非零值的最低下限设置为lower_bound,将原本为0的部分从lower_bound恢复到0
image_data_pre[image_data == 0] = 0
image_data_pre = np.uint8(image_data_pre)
img_roi = image_data_pre[z_index, :, :]
np.savez_compressed(join(npy_path, prefix + gt_name.split(gt_name_suffix)[0]+'.npz'), imgs=img_roi, gts=gt_roi, spacing=img_sitk.GetSpacing())
img_roi_sitk = sitk.GetImageFromArray(img_roi)
gt_roi_sitk = sitk.GetImageFromArray(gt_roi)
sitk.WriteImage(
img_roi_sitk,
join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_img.nii.gz"),
)
sitk.WriteImage(
gt_roi_sitk,
join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_gt.nii.gz"),
)
# save the each CT image as npy file
for i in range(img_roi.shape[0]):
img_i = img_roi[i, :, :]
img_3c = np.repeat(img_i[:, :, None], 3, axis=-1)
resize_img_skimg = transform.resize(
img_3c,
(image_size, image_size),
order=3,
preserve_range=True,
mode="constant",
anti_aliasing=True,
)
# 归一化
resize_img_skimg_01 = (resize_img_skimg - resize_img_skimg.min()) / np.clip(
resize_img_skimg.max() - resize_img_skimg.min(), a_min=1e-8, a_max=None
) # normalize to [0, 1], (H, W, 3)
gt_i = gt_roi[i, :, :]
resize_gt_skimg = transform.resize(
gt_i,
(image_size, image_size),
order=0,
preserve_range=True,
mode="constant",
anti_aliasing=False,
)
resize_gt_skimg = np.uint8(resize_gt_skimg)
assert resize_img_skimg_01.shape[:2] == resize_gt_skimg.shape
np.save(
join(
npy_path,
"imgs",
prefix
+ gt_name.split(gt_name_suffix)[0]
+ "-"
+ str(i).zfill(3)
+ ".npy",
),
resize_img_skimg_01,
)
np.save(
join(
npy_path,
"gts",
prefix
+ gt_name.split(gt_name_suffix)[0]
+ "-"
+ str(i).zfill(3)
+ ".npy",
),
resize_gt_skimg,
)
数据集和模型权重的目录结构:
XML
| -- MedSAM
| -- work_dir
| --MedSAM
| --medsam_vit_b.pth
| --SAM
| --sam_vit_b_01ec64.pth
| -- data
| -- FLARE22Train
| --images
| --labels
| --npy
MedSAM训练与测试
训练
train_one_gpu.py
根据个人情况常规修改,不过多解释
checkpoint:表示预训练权重
num_epochs:训练轮数
batch_size:单次的图像数目
resume:继续训练加载的模型权重路径
bash
# 单gpu训练
python train_one_gpu.py
正在训练:
训练保存的权重(最好/最后的权重),放置在work_dir/MedSAM-ViT-B-20241209-0934(数字代表运行时刻的时间)里。
测试:
测试部分博主都是用官方提供的模型权重演示。
MedSAM_Inference.py
根据个人情况常规修改,不过多解释
i:输入图像的位置
o:训练轮数
box:在图像位置的检查区域框左上坐标和右下坐标[x1, y1, x2, y2]
checkpoint:选择加载的模型
bash
```bash
# 测试效果
python MedSAM_Inference.py
完成测试:
MedSAM_CKPT_PATH:指定加载的模型
bash
# 使用图形界面来完成图像指定区域的分割
python gui.py
掩码图的保存路径与加载图像的路径一致:
总结
尽可能简单、详细的介绍了MedSAM的安装流程以及MedSAM的使用方法。后续会根据自己学到的知识结合个人理解讲解MedSAM的原理和代码。