nnUnet 大模型学习笔记(续):3d_fullres 模型的推理、切片推理、计算dice系数

目录

[1. 前言](#1. 前言)

[2. 更改epochs](#2. 更改epochs)

[3. 推理](#3. 推理)

[3.1 nnUNet_predict](#3.1 nnUNet_predict)

[3.2 切成小的nii gz文件推理](#3.2 切成小的nii gz文件推理)

切片代码

融合代码

[3.3 可视化展示](#3.3 可视化展示)

[3.4 评估指标](#3.4 评估指标)

参考


1. 前言

训练了一天半,终于跑完了。。。。

训练的模型在这可以免费下载:

基于nnUnet3d-fullres训练的spineCT训练结果资源-CSDN文库

关于nnUnet的环境搭建、数据集制作、训练网络参考:

第四章:nnUnet大模型之环境配置、数据集制作_nnunet代码详解pytorch-CSDN博客

nnUnet 大模型学习笔记(续):训练网络(3d_fullres)以及数据集标签的处理-CSDN博客

训练过程如下:

生成的结果在 nnUnet_trained_models 目录下:

训练过程的指标可以看曲线图或者训练日志(validation_raw/summary.json):

这里validation_raw/summary.json 没有生成,不知道因为什么原因,程序被kill了**。。。。**

2. 更改epochs

1000个epoch太多了,可以更改官方的参数,如果按照本文的环境搭建,参数在这里

复制代码
*/environments/nnunet/lib/python3.8/site-packages/nnunet/training/network_training/

3. 推理

nnUnet 是没有推理和测试放在一起的,它会对指定的数据进行推理

如果你有推理的labels的话,那么就可以进行指标计算,这样就可以测试

如果没有labels,那么只有推理

3.1 nnUNet_predict

在最初的数据集里,新建inferTs,用于推理nnUnet推理的结果,把想要推理的数据放在imagesTs下就行了

好像这里的测试数据必须得是0000.nii.gz结尾的,是因为多模态?

运行下面的命令

python 复制代码
nnUNet_predict -i DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/imagesTs/ -o DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/inferTs/ -t 1 -m 3d_fullres -f 0
  • -i 是想要预测的数据目录 ,一般为imagesTs
  • -o 是保存推理后的数据目录,一般为inferTs
  • -t 是任务的训练标号
  • -m 是nnUnet训练好的模型
  • -f 是训练m模型的几折交叉验证

由于是直接使用模型进行推理, inference done. 后出现 WARNING! Cannot run postprocessing because the postprocessing file is missing. 也即表示已经完成推理。

若CT层数太多或层间距小,可能会卡在 inference done. 阶段,此时需要将CT切分成几部分分别进行 推理。

这里有时候nnUnet推理不出来,可能是因为输入层数太多?参考3.2 处理

3.2 切成小的nii gz文件推理

如果3.1可以成功推理,可以不参考这步!!!

切片代码

注意,这里只需要把想要推理的数据进行切片,然后推理完再拼接即可!!

代码如下:

python 复制代码
import SimpleITK as sitk
import numpy as np
import os
import cv2


# 切片函数
def sliceMain(rt):
    img = sitk.ReadImage(rt)
    img_array = sitk.GetArrayFromImage(img)  # nii-->array

    print('input size:', img_array.shape)

    channel = img_array.shape[0]
    y, z = channel // 100, channel % 100  # 2 91

    if z == 0:
        n = y
        print('切出的nii.gz文件个数:', n)
    else:
        n = y + 1
        print('切出的nii.gz文件个数:', n)

    for i in range(n):
        star, end = i * 100, i * 100 + 100 

        if i == n - 1:  # 最后一个切片
            img_select = img_array[star:, :, :]
            shape = img_select.shape
            img_select = sitk.GetImageFromArray(img_select)

            img_save_name = 'data_' + str(i) + '_0000.nii.gz'
            print(img_save_name, 'channel:', shape)
            sitk.WriteImage(img_select, img_save_name)

        else:
            img_select = img_array[star:end, :, :]
            shape = img_select.shape
            img_select = sitk.GetImageFromArray(img_select)

            img_save_name = 'data_' + str(i) + '_0000.nii.gz'
            print(img_save_name, 'channel:', shape)
            sitk.WriteImage(img_select, img_save_name)


if __name__ == '__main__':
    root = 'spine_001.nii.gz'

    # 切片函数
    sliceMain(rt=root)

切片结果:

效果如下:

|----------------------------------------------------------------------------|----------------------------------------------------------------------------|----------------------------------------------------------------------------|
| |||
| | | |

然后推理就行了!

python 复制代码
nnUNet_predict -i DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/imagesTs/ -o DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/inferTs/ -t 1 -m 3d_fullres -f 0

融合代码

推理完成的nii数据下放在data目录下,然后运行下面代码会自动拼接:

python 复制代码
import SimpleITK as sitk
import numpy as np
import os
import cv2


# 切片函数
def sliceMain():
    data = [os.path.join('data',u) for u in os.listdir('data')]

    ret_nii = None
    for index,i in enumerate(data):
        img = sitk.ReadImage(i)
        img_array = sitk.GetArrayFromImage(img)  # nii-->array
        print(i,':',img_array.shape)

        if index ==0:
            ret_nii = img_array
        else:
            ret_nii = np.concatenate((ret_nii,img_array),axis=0)

    print('返回的数组size:',ret_nii.shape)
    sitk.WriteImage(sitk.GetImageFromArray(ret_nii),'ret.nii.gz')


if __name__ == '__main__':
    # 切片函数
    sliceMain()

效果如下:

左上角是拼接后的,其余三个是nnUNet推理生成的

|----------------------------------------------------------------------------|----------------------------------------------------------------------------|
| | |
| | |

3.3 可视化展示

下面是原图加真实gt

下面是原图加nnUNet推理结果:

3.4 评估指标

在labelsTs下放入对应的标签即可

如果评估的话,需要真实的gt图!!!

python 复制代码
import numpy as np
import SimpleITK as sitk
from tqdm import tqdm


def main(pred, gt,n):
    gt = sitk.GetArrayFromImage(sitk.ReadImage(gt))             # [ 0  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
    pred = sitk.GetArrayFromImage(sitk.ReadImage(pred))

    dice = []
    for h in tqdm(range(n)):
        if h == 0:
            continue

        g = np.zeros(gt.shape,dtype=np.uint8)           # 单独提取某个灰度级
        g[gt == h] = 255
        g[g<255] = 0
        g[g==255] = 1

        p = np.zeros(pred.shape,dtype=np.uint8)         # 单独提取某个灰度级
        p[pred == h] = 255
        p[p<255] = 0
        p[p==255] = 1

        if len(np.unique(p)) == 1 or len(np.unique(g)) == 1:
            dice.append('None')

        else:
            dice_score = (2*(p*g).sum() / ((p+g).sum()+1e-8))
            dice.append(round(dice_score,4))

    print(dice)

    for i in range(len(dice) - 1,-1,-1):
        if dice[i] == 'None':
            dice.remove('None')

    print('mean dice',np.array(dice).mean())


if __name__ == "__main__":

    gt_path = 'labels.nii.gz'
    pred_path = 'infer.nii.gz'
    classes = 19

    main(pred=pred_path,gt=gt_path,n=classes)

指标如下:

'None', 0.4523, 0.9311, 0.967, 0.9732, 0.9756, 0.9687, 0.9665, 0.9674, 0.9787, 0.9743, 0.9802, 0.9811, 0.9826, 0.974, 'None', 'None', 'None'

mean dice 0.9337642857142857

代码主要实现思路:

因为推理的时候,不是所有的数据同时包含所有标签,所以这里为了方便评估,将所有的类别全部显示。如果没有某个标签就设定为None,然后计算平均dice的时候,就会去掉相应的空标签。
这是nnUnet某个epoch计算的平均dice指标

参考

参考博文如下:nnUNet使用指南(一):Ubuntu系统下使用nnUNet对自己的多模态MR数据集训练 - 梅雨明夏 - 博客园

nnUNet训练并推理自己的数据集_nnunet训练自己数据集-CSDN博客

相关推荐
好奇龙猫18 分钟前
【AI学习-comfyUI学习-第三十节-第三十一节-FLUX-SD放大工作流+FLUX图生图工作流-各个部分学习】
人工智能·学习
沈浩(种子思维作者)25 分钟前
真的能精准医疗吗?癌症能提前发现吗?
人工智能·python·网络安全·健康医疗·量子计算
saoys26 分钟前
Opencv 学习笔记:图像掩膜操作(精准提取指定区域像素)
笔记·opencv·学习
minhuan27 分钟前
大模型应用:大模型越大越好?模型参数量与效果的边际效益分析.51
人工智能·大模型参数评估·边际效益分析·大模型参数选择
Cherry的跨界思维33 分钟前
28、AI测试环境搭建与全栈工具实战:从本地到云平台的完整指南
java·人工智能·vue3·ai测试·ai全栈·测试全栈·ai测试全栈
MM_MS36 分钟前
Halcon变量控制类型、数据类型转换、字符串格式化、元组操作
开发语言·人工智能·深度学习·算法·目标检测·计算机视觉·视觉检测
ASF1231415sd1 小时前
【基于YOLOv10n-CSP-PTB的大豆花朵检测与识别系统详解】
人工智能·yolo·目标跟踪
水如烟2 小时前
孤能子视角:“意识“的阶段性回顾,“感质“假说
人工智能
电子小白1232 小时前
第13期PCB layout工程师初级培训-1-EDA软件的通用设置
笔记·嵌入式硬件·学习·pcb·layout
Carl_奕然2 小时前
【数据挖掘】数据挖掘必会技能之:A/B测试
人工智能·python·数据挖掘·数据分析