nnUnet 大模型学习笔记(续):3d_fullres 模型的推理、切片推理、计算dice系数

目录

[1. 前言](#1. 前言)

[2. 更改epochs](#2. 更改epochs)

[3. 推理](#3. 推理)

[3.1 nnUNet_predict](#3.1 nnUNet_predict)

[3.2 切成小的nii gz文件推理](#3.2 切成小的nii gz文件推理)

切片代码

融合代码

[3.3 可视化展示](#3.3 可视化展示)

[3.4 评估指标](#3.4 评估指标)

参考


1. 前言

训练了一天半,终于跑完了。。。。

训练的模型在这可以免费下载:

基于nnUnet3d-fullres训练的spineCT训练结果资源-CSDN文库

关于nnUnet的环境搭建、数据集制作、训练网络参考:

第四章:nnUnet大模型之环境配置、数据集制作_nnunet代码详解pytorch-CSDN博客

nnUnet 大模型学习笔记(续):训练网络(3d_fullres)以及数据集标签的处理-CSDN博客

训练过程如下:

生成的结果在 nnUnet_trained_models 目录下:

训练过程的指标可以看曲线图或者训练日志(validation_raw/summary.json):

这里validation_raw/summary.json 没有生成,不知道因为什么原因,程序被kill了**。。。。**

2. 更改epochs

1000个epoch太多了,可以更改官方的参数,如果按照本文的环境搭建,参数在这里

*/environments/nnunet/lib/python3.8/site-packages/nnunet/training/network_training/

3. 推理

nnUnet 是没有推理和测试放在一起的,它会对指定的数据进行推理

如果你有推理的labels的话,那么就可以进行指标计算,这样就可以测试

如果没有labels,那么只有推理

3.1 nnUNet_predict

在最初的数据集里,新建inferTs,用于推理nnUnet推理的结果,把想要推理的数据放在imagesTs下就行了

好像这里的测试数据必须得是0000.nii.gz结尾的,是因为多模态?

运行下面的命令

python 复制代码
nnUNet_predict -i DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/imagesTs/ -o DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/inferTs/ -t 1 -m 3d_fullres -f 0
  • -i 是想要预测的数据目录 ,一般为imagesTs
  • -o 是保存推理后的数据目录,一般为inferTs
  • -t 是任务的训练标号
  • -m 是nnUnet训练好的模型
  • -f 是训练m模型的几折交叉验证

由于是直接使用模型进行推理, inference done. 后出现 WARNING! Cannot run postprocessing because the postprocessing file is missing. 也即表示已经完成推理。

若CT层数太多或层间距小,可能会卡在 inference done. 阶段,此时需要将CT切分成几部分分别进行 推理。

这里有时候nnUnet推理不出来,可能是因为输入层数太多?参考3.2 处理

3.2 切成小的nii gz文件推理

如果3.1可以成功推理,可以不参考这步!!!

切片代码

注意,这里只需要把想要推理的数据进行切片,然后推理完再拼接即可!!

代码如下:

python 复制代码
import SimpleITK as sitk
import numpy as np
import os
import cv2


# 切片函数
def sliceMain(rt):
    img = sitk.ReadImage(rt)
    img_array = sitk.GetArrayFromImage(img)  # nii-->array

    print('input size:', img_array.shape)

    channel = img_array.shape[0]
    y, z = channel // 100, channel % 100  # 2 91

    if z == 0:
        n = y
        print('切出的nii.gz文件个数:', n)
    else:
        n = y + 1
        print('切出的nii.gz文件个数:', n)

    for i in range(n):
        star, end = i * 100, i * 100 + 100 

        if i == n - 1:  # 最后一个切片
            img_select = img_array[star:, :, :]
            shape = img_select.shape
            img_select = sitk.GetImageFromArray(img_select)

            img_save_name = 'data_' + str(i) + '_0000.nii.gz'
            print(img_save_name, 'channel:', shape)
            sitk.WriteImage(img_select, img_save_name)

        else:
            img_select = img_array[star:end, :, :]
            shape = img_select.shape
            img_select = sitk.GetImageFromArray(img_select)

            img_save_name = 'data_' + str(i) + '_0000.nii.gz'
            print(img_save_name, 'channel:', shape)
            sitk.WriteImage(img_select, img_save_name)


if __name__ == '__main__':
    root = 'spine_001.nii.gz'

    # 切片函数
    sliceMain(rt=root)

切片结果:

效果如下:

|----------------------------------------------------------------------------|----------------------------------------------------------------------------|----------------------------------------------------------------------------|
| |||
| | | |

然后推理就行了!

python 复制代码
nnUNet_predict -i DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/imagesTs/ -o DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/inferTs/ -t 1 -m 3d_fullres -f 0

融合代码

推理完成的nii数据下放在data目录下,然后运行下面代码会自动拼接:

python 复制代码
import SimpleITK as sitk
import numpy as np
import os
import cv2


# 切片函数
def sliceMain():
    data = [os.path.join('data',u) for u in os.listdir('data')]

    ret_nii = None
    for index,i in enumerate(data):
        img = sitk.ReadImage(i)
        img_array = sitk.GetArrayFromImage(img)  # nii-->array
        print(i,':',img_array.shape)

        if index ==0:
            ret_nii = img_array
        else:
            ret_nii = np.concatenate((ret_nii,img_array),axis=0)

    print('返回的数组size:',ret_nii.shape)
    sitk.WriteImage(sitk.GetImageFromArray(ret_nii),'ret.nii.gz')


if __name__ == '__main__':
    # 切片函数
    sliceMain()

效果如下:

左上角是拼接后的,其余三个是nnUNet推理生成的

|----------------------------------------------------------------------------|----------------------------------------------------------------------------|
| | |
| | |

3.3 可视化展示

下面是原图加真实gt

下面是原图加nnUNet推理结果:

3.4 评估指标

在labelsTs下放入对应的标签即可

如果评估的话,需要真实的gt图!!!

python 复制代码
import numpy as np
import SimpleITK as sitk
from tqdm import tqdm


def main(pred, gt,n):
    gt = sitk.GetArrayFromImage(sitk.ReadImage(gt))             # [ 0  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
    pred = sitk.GetArrayFromImage(sitk.ReadImage(pred))

    dice = []
    for h in tqdm(range(n)):
        if h == 0:
            continue

        g = np.zeros(gt.shape,dtype=np.uint8)           # 单独提取某个灰度级
        g[gt == h] = 255
        g[g<255] = 0
        g[g==255] = 1

        p = np.zeros(pred.shape,dtype=np.uint8)         # 单独提取某个灰度级
        p[pred == h] = 255
        p[p<255] = 0
        p[p==255] = 1

        if len(np.unique(p)) == 1 or len(np.unique(g)) == 1:
            dice.append('None')

        else:
            dice_score = (2*(p*g).sum() / ((p+g).sum()+1e-8))
            dice.append(round(dice_score,4))

    print(dice)

    for i in range(len(dice) - 1,-1,-1):
        if dice[i] == 'None':
            dice.remove('None')

    print('mean dice',np.array(dice).mean())


if __name__ == "__main__":

    gt_path = 'labels.nii.gz'
    pred_path = 'infer.nii.gz'
    classes = 19

    main(pred=pred_path,gt=gt_path,n=classes)

指标如下:

['None', 0.4523, 0.9311, 0.967, 0.9732, 0.9756, 0.9687, 0.9665, 0.9674, 0.9787, 0.9743, 0.9802, 0.9811, 0.9826, 0.974, 'None', 'None', 'None']
mean dice 0.9337642857142857

代码主要实现思路:

因为推理的时候,不是所有的数据同时包含所有标签,所以这里为了方便评估,将所有的类别全部显示。如果没有某个标签就设定为None,然后计算平均dice的时候,就会去掉相应的空标签。
这是nnUnet某个epoch计算的平均dice指标

参考

参考博文如下:nnUNet使用指南(一):Ubuntu系统下使用nnUNet对自己的多模态MR数据集训练 - 梅雨明夏 - 博客园

nnUNet训练并推理自己的数据集_nnunet训练自己数据集-CSDN博客

相关推荐
余生H7 分钟前
transformer.js(三):底层架构及性能优化指南
javascript·深度学习·架构·transformer
一棵开花的树,枝芽无限靠近你10 分钟前
【PPTist】添加PPT模版
前端·学习·编辑器·html
果冻人工智能26 分钟前
2025 年将颠覆商业的 8 大 AI 应用场景
人工智能·ai员工
代码不行的搬运工27 分钟前
神经网络12-Time-Series Transformer (TST)模型
人工智能·神经网络·transformer
VertexGeek29 分钟前
Rust学习(八):异常处理和宏编程:
学习·算法·rust
石小石Orz29 分钟前
Three.js + AI:AI 算法生成 3D 萤火虫飞舞效果~
javascript·人工智能·算法
罗小罗同学35 分钟前
医工交叉入门书籍分享:Transformer模型在机器学习领域的应用|个人观点·24-11-22
深度学习·机器学习·transformer
孤独且没人爱的纸鹤39 分钟前
【深度学习】:从人工神经网络的基础原理到循环神经网络的先进技术,跨越智能算法的关键发展阶段及其未来趋势,探索技术进步与应用挑战
人工智能·python·深度学习·机器学习·ai
阿_旭41 分钟前
TensorFlow构建CNN卷积神经网络模型的基本步骤:数据处理、模型构建、模型训练
人工智能·深度学习·cnn·tensorflow
羊小猪~~42 分钟前
tensorflow案例7--数据增强与测试集, 训练集, 验证集的构建
人工智能·python·深度学习·机器学习·cnn·tensorflow·neo4j