【图像分割】【深度学习】Windows11下MedSAM官方代码Pytorch实现与源码讲解

【图像分割】【深度学习】Windows11下MedSAM官方代码Pytorch实现与源码讲解


文章目录


前言

MedSAM是多伦多大学的Jun Ma等人在《Segment Anything in Medical Images and Videos:Benchmark and Deployment【nature communications 2024】》【论文地址】一文中提出的模型,旨在充当通用医学图像分割的基础模型,能够适应各种成像条件、解剖结构和病理条件的变化。

在详细解析MedSAM网络之前,首要任务是搭建MedSAM【Pytorch-demo地址】所需的运行环境,并完成模型训练和测试工作,展开后续工作才有意义。


MedSAM模型运行环境搭建

在win11环境下安装anaconda环境参考,方便搭建专用于GCNet模型的虚拟环境。

  • 查看主机支持的cuda版本(最高)

    bash 复制代码
    # 打开cmd,执行下面的指令查看CUDA版本号
    nvidia-smi
  • 安装GPU版本的torch【官网】。

    其他cuda版本的torch在【以前版本】找对应的安装命令。

  • 博主安装环境参考

    bash 复制代码
    # 创建虚拟环境
    conda create -n medsam python=3.10 -y
    # 查看新环境是否安装成功
    conda env list
    # 激活环境
    activate medsam 
    # githup下载medsam源代码到适合目录内,解压文件
    git clone https://github.com/bowang-lab/MedSAM
    # 分别安装pytorch和torchvision
    pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    # 安装其他依赖库包
    pip install connected-components-3d
    pip install PyQt5
    # 编译一些额外的自定义包
    pip install -e .
    # 查看所有安装的包
    pip list
    conda list

    要求Pytorch大于2.0:


MedSAM模型运行

数据集与模型权重下载

名称 下载地址 说明
FLARE22Train数据集 官方下载地址 MICCAI FLARE 202 挑战赛提供的腹部CT数据集,完成对 13 种腹部器官的分割
MedSAM模型权重(medsam_vit_b.pth) 官方下载地址 MedSAM在FLARE22Train数据集上训练出的模型权重
SAM预训练模型权重(sam_vit_b_01ec64.pth) 官方下载地址 MedSAM是在SAM预训练模型权重基础上再用FLARE22Train数据集继续训练

简单说明下FLARE22Train数据集是单通道图像,但是.nii.gz格式是按空间位置(或者说扫描的先后顺序)将多个图像拼接在一起形成多通道的单一数据,读者可以运行以下代码验证:
show_niiImage.py

python 复制代码
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
# 加载nii.gz文件
# file_path = r'../data/FLARE22Train/labels/FLARE22_Tr_0001_0000.nii.gz'    # 替换为您的文件路径
file_path = r'../data/FLARE22Train/images/FLARE22_Tr_0001_0000.nii.gz'          # 替换为您的文件路径
img = nib.load(file_path)

# 获取图像数据
data = img.get_fdata()

# 查看图像的基本信息
print(f"Shape of the image data: {data.shape}")
print(f"Shape of the image data: {type(data)}")
print(f"Data type of the image: {data.dtype}")

# 显示中间切片作为示例
num = data.shape[2]
print(np.unique(data))
for i in range(num):
    plt.imshow(data[:, :, i], cmap='gray')
    plt.title(f'Slice number {i}')
    plt.axis('off')             # 关闭坐标轴
    plt.show()

数据集预处理:将.nii.gz格式多张图像通道拼接成的数据集重新拆分为单通道的单张图像,并用.npy格式保存。

bash 复制代码
python pre_CT_MR.py

pre_CT_MR.py

python 复制代码
import numpy as np
# SimpleITK 用于医学图像分析的开源软件包
import SimpleITK as sitk
import os

join = os.path.join
from skimage import transform
from tqdm import tqdm
import cc3d

modality = "CT"
anatomy = "Abd" 
img_name_suffix = "_0000.nii.gz"
gt_name_suffix = ".nii.gz"
prefix = modality + "_" + anatomy + "_"

nii_path = "data/FLARE22Train/images"   # 图像路径
gt_path = "data/FLARE22Train/labels"    # 标签路径
npy_path = r"data/npy/" + prefix[:-1]
os.makedirs(join(npy_path, "gts"), exist_ok=True)
os.makedirs(join(npy_path, "imgs"), exist_ok=True)

image_size = 1024
voxel_num_thre2d = 100
voxel_num_thre3d = 1000

names = sorted(os.listdir(gt_path))
print(f"ori \# files {len(names)=}")
names = [
    name
    for name in names
    if os.path.exists(join(nii_path, name.split(gt_name_suffix)[0] + img_name_suffix))
]
print(f"after sanity check \# files {len(names)=}")

# 设置排除的标签id
remove_label_ids = [
    12
]  # 去除脱氧核糖核酸,因为它在图像中分散,很难用边界框指定

# 仅在有多个肿瘤时设置;将语义掩码转换为实例掩码
tumor_id = None

# 设置窗口级别和宽度
# https://radiopaedia.org/articles/windowing-ct
WINDOW_LEVEL = 40  # only for CT images
WINDOW_WIDTH = 400  # only for CT images

# 进度条 %% 将预处理后的图像和掩模保存为npz文件
for name in tqdm(names[:40]):  # 保留剩余的10个案例进行验证
    image_name = name.split(gt_name_suffix)[0] + img_name_suffix
    gt_name = name
    # 读取掩码
    gt_sitk = sitk.ReadImage(join(gt_path, gt_name))
    # 将SimpleITK图像对象转换为NumPy数组
    gt_data_ori = np.uint8(sitk.GetArrayFromImage(gt_sitk))
    # 删除标签id:对mask标签号置0
    for remove_label_id in remove_label_ids:
        gt_data_ori[gt_data_ori == remove_label_id] = 0
    # 将肿瘤标签标记为实例并从gt_data_ori中删除
    if tumor_id is not None:
        # 找到所有肿瘤标签:1为肿瘤,0是背景
        tumor_bw = np.uint8(gt_data_ori == tumor_id)
        # 去除肿瘤标签id
        gt_data_ori[tumor_bw > 0] = 0
        # 将肿瘤标签标记为实例:在三维二值图像(CWH)中识别和标记连通分量
        # connectivity:6、18 和 26 连通性分别定义了如何确定体素(voxel)之间的邻接关系。
        tumor_inst, tumor_n = cc3d.connected_components(
            tumor_bw, connectivity=26, return_N=True
        )
        # 在gt_data_ori中设置为肿瘤实例分割标签id(id号从语义分割标签总数后开始)
        gt_data_ori[tumor_inst > 0] = (
            tumor_inst[tumor_inst > 0] + np.max(gt_data_ori) + 1
        )

    # 用于去除三维二值图像中小于指定阈值1000的连通分量
    gt_data_ori = cc3d.dust(
        gt_data_ori, threshold=voxel_num_thre3d, connectivity=26, in_place=True
    )
    # 用于去除二维二值图像中小于指定阈值100的连通分量
    for slice_i in range(gt_data_ori.shape[0]):
        gt_i = gt_data_ori[slice_i, :, :]
        # remove small objects with less than 100 pixels
        # reason: fro such small objects, the main challenge is detection rather than segmentation
        gt_data_ori[slice_i, :, :] = cc3d.dust(
            gt_i, threshold=voxel_num_thre2d, connectivity=8, in_place=True
        )
    # 有前景分割标签的图像序号(通道号)
    z_index, _, _ = np.where(gt_data_ori > 0)
    z_index = np.unique(z_index)
    if len(z_index) > 0:
        # 找到有前景目标的切片
        gt_roi = gt_data_ori[z_index, :, :]
        # 加载图像转换为NumPy数组
        img_sitk = sitk.ReadImage(join(nii_path, image_name))
        image_data = sitk.GetArrayFromImage(img_sitk)
        #
        if modality == "CT":
            lower_bound = WINDOW_LEVEL - WINDOW_WIDTH / 2
            upper_bound = WINDOW_LEVEL + WINDOW_WIDTH / 2
            # 指定最低下限和最高上限
            image_data_pre = np.clip(image_data, lower_bound, upper_bound)
            # 归一化
            image_data_pre = (
                (image_data_pre - np.min(image_data_pre))
                / (np.max(image_data_pre) - np.min(image_data_pre))
                * 255.0
            )
        else:
            # 计算数组中指定百分位数的值:(max-min)*percent+min
            # 下限是0.5% 上限是99.5%
            lower_bound, upper_bound = np.percentile(
                image_data[image_data > 0], 0.5
            ), np.percentile(image_data[image_data > 0], 99.5)
            # 指定最低下限和最高上限
            image_data_pre = np.clip(image_data, lower_bound, upper_bound)
            # 归一化
            image_data_pre = (
                (image_data_pre - np.min(image_data_pre))
                / (np.max(image_data_pre) - np.min(image_data_pre))
                * 255.0
            )
            # 目的只是为了把非零值的最低下限设置为lower_bound,将原本为0的部分从lower_bound恢复到0
            image_data_pre[image_data == 0] = 0

        image_data_pre = np.uint8(image_data_pre)
        img_roi = image_data_pre[z_index, :, :]
        np.savez_compressed(join(npy_path, prefix + gt_name.split(gt_name_suffix)[0]+'.npz'), imgs=img_roi, gts=gt_roi, spacing=img_sitk.GetSpacing())
        img_roi_sitk = sitk.GetImageFromArray(img_roi)
        gt_roi_sitk = sitk.GetImageFromArray(gt_roi)
        sitk.WriteImage(
            img_roi_sitk,
            join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_img.nii.gz"),
        )
        sitk.WriteImage(
            gt_roi_sitk,
            join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_gt.nii.gz"),
        )
        # save the each CT image as npy file
        for i in range(img_roi.shape[0]):
            img_i = img_roi[i, :, :]
            img_3c = np.repeat(img_i[:, :, None], 3, axis=-1)
            resize_img_skimg = transform.resize(
                img_3c,
                (image_size, image_size),
                order=3,
                preserve_range=True,
                mode="constant",
                anti_aliasing=True,
            )
            # 归一化
            resize_img_skimg_01 = (resize_img_skimg - resize_img_skimg.min()) / np.clip(
                resize_img_skimg.max() - resize_img_skimg.min(), a_min=1e-8, a_max=None
            )  # normalize to [0, 1], (H, W, 3)
            gt_i = gt_roi[i, :, :]
            resize_gt_skimg = transform.resize(
                gt_i,
                (image_size, image_size),
                order=0,
                preserve_range=True,
                mode="constant",
                anti_aliasing=False,
            )
            resize_gt_skimg = np.uint8(resize_gt_skimg)
            assert resize_img_skimg_01.shape[:2] == resize_gt_skimg.shape
            np.save(
                join(
                    npy_path,
                    "imgs",
                    prefix
                    + gt_name.split(gt_name_suffix)[0]
                    + "-"
                    + str(i).zfill(3)
                    + ".npy",
                ),
                resize_img_skimg_01,
            )
            np.save(
                join(
                    npy_path,
                    "gts",
                    prefix
                    + gt_name.split(gt_name_suffix)[0]
                    + "-"
                    + str(i).zfill(3)
                    + ".npy",
                ),
                resize_gt_skimg,
            )

数据集和模型权重的目录结构:

XML 复制代码
| -- MedSAM
    | -- work_dir
        | --MedSAM
        	| --medsam_vit_b.pth
        | --SAM
        	| --sam_vit_b_01ec64.pth
    | -- data
        | -- FLARE22Train
        	| --images
        	| --labels
        	| --npy

MedSAM训练与测试

训练

train_one_gpu.py

根据个人情况常规修改,不过多解释

checkpoint:表示预训练权重

num_epochs:训练轮数

batch_size:单次的图像数目

resume:继续训练加载的模型权重路径

bash 复制代码
# 单gpu训练	
python train_one_gpu.py

正在训练:

训练保存的权重(最好/最后的权重),放置在work_dir/MedSAM-ViT-B-20241209-0934(数字代表运行时刻的时间)里。

测试:

测试部分博主都是用官方提供的模型权重演示。
MedSAM_Inference.py

根据个人情况常规修改,不过多解释

i:输入图像的位置

o:训练轮数

box:在图像位置的检查区域框左上坐标和右下坐标[x1, y1, x2, y2]

checkpoint:选择加载的模型

bash 复制代码
```bash
# 测试效果
python MedSAM_Inference.py

完成测试:

gui.py

MedSAM_CKPT_PATH:指定加载的模型

bash 复制代码
# 使用图形界面来完成图像指定区域的分割
python gui.py

掩码图的保存路径与加载图像的路径一致:


总结

尽可能简单、详细的介绍了MedSAM的安装流程以及MedSAM的使用方法。后续会根据自己学到的知识结合个人理解讲解MedSAM的原理和代码。

相关推荐
WBingJ3 分钟前
李宏毅机器学习-批次 (batch)和动量(momentum)
人工智能·机器学习·batch
DevUI团队11 分钟前
开源开发者获奖,MateChat即将发布,快来看DevUI在华为云开源开发者论坛的精彩回顾吧
前端·人工智能
过去式的马马马24 分钟前
aippt:AI 智能生成 PPT 的开源项目
人工智能·开源·powerpoint
池央43 分钟前
AlexNet:开启深度学习图像识别新纪元
人工智能·深度学习
剑盾云安全专家1 小时前
让PPT不再“难搞”:智能工具如何改变办公体验
人工智能·aigc·powerpoint·软件
封步宇AIGC1 小时前
量化交易系统开发-实时行情自动化交易-8.25.真格(澎博财经旗下)平台
人工智能·python·机器学习·数据挖掘
黑客呀3 小时前
密码学——密码学基础、散列函数与数字签名
网络·数据库·人工智能
weixin_4769582710 小时前
自动驾驶技术——HSL
人工智能·机器学习·自动驾驶
代码小狗Codog10 小时前
WIDER FACE数据集转YOLO格式
人工智能·yolo·目标跟踪
物联高科10 小时前
智能电网技术如何助力能源转型?
人工智能·单片机·嵌入式硬件·能源·创业创新