nnUnet 大模型学习笔记(续):3d_fullres 模型的推理、切片推理、计算dice系数

目录

[1. 前言](#1. 前言)

[2. 更改epochs](#2. 更改epochs)

[3. 推理](#3. 推理)

[3.1 nnUNet_predict](#3.1 nnUNet_predict)

[3.2 切成小的nii gz文件推理](#3.2 切成小的nii gz文件推理)

切片代码

融合代码

[3.3 可视化展示](#3.3 可视化展示)

[3.4 评估指标](#3.4 评估指标)

参考


1. 前言

训练了一天半,终于跑完了。。。。

训练的模型在这可以免费下载:

基于nnUnet3d-fullres训练的spineCT训练结果资源-CSDN文库

关于nnUnet的环境搭建、数据集制作、训练网络参考:

第四章:nnUnet大模型之环境配置、数据集制作_nnunet代码详解pytorch-CSDN博客

nnUnet 大模型学习笔记(续):训练网络(3d_fullres)以及数据集标签的处理-CSDN博客

训练过程如下:

生成的结果在 nnUnet_trained_models 目录下:

训练过程的指标可以看曲线图或者训练日志(validation_raw/summary.json):

这里validation_raw/summary.json 没有生成,不知道因为什么原因,程序被kill了**。。。。**

2. 更改epochs

1000个epoch太多了,可以更改官方的参数,如果按照本文的环境搭建,参数在这里

*/environments/nnunet/lib/python3.8/site-packages/nnunet/training/network_training/

3. 推理

nnUnet 是没有推理和测试放在一起的,它会对指定的数据进行推理

如果你有推理的labels的话,那么就可以进行指标计算,这样就可以测试

如果没有labels,那么只有推理

3.1 nnUNet_predict

在最初的数据集里,新建inferTs,用于推理nnUnet推理的结果,把想要推理的数据放在imagesTs下就行了

好像这里的测试数据必须得是0000.nii.gz结尾的,是因为多模态?

运行下面的命令

python 复制代码
nnUNet_predict -i DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/imagesTs/ -o DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/inferTs/ -t 1 -m 3d_fullres -f 0
  • -i 是想要预测的数据目录 ,一般为imagesTs
  • -o 是保存推理后的数据目录,一般为inferTs
  • -t 是任务的训练标号
  • -m 是nnUnet训练好的模型
  • -f 是训练m模型的几折交叉验证

由于是直接使用模型进行推理, inference done. 后出现 WARNING! Cannot run postprocessing because the postprocessing file is missing. 也即表示已经完成推理。

若CT层数太多或层间距小,可能会卡在 inference done. 阶段,此时需要将CT切分成几部分分别进行 推理。

这里有时候nnUnet推理不出来,可能是因为输入层数太多?参考3.2 处理

3.2 切成小的nii gz文件推理

如果3.1可以成功推理,可以不参考这步!!!

切片代码

注意,这里只需要把想要推理的数据进行切片,然后推理完再拼接即可!!

代码如下:

python 复制代码
import SimpleITK as sitk
import numpy as np
import os
import cv2


# 切片函数
def sliceMain(rt):
    img = sitk.ReadImage(rt)
    img_array = sitk.GetArrayFromImage(img)  # nii-->array

    print('input size:', img_array.shape)

    channel = img_array.shape[0]
    y, z = channel // 100, channel % 100  # 2 91

    if z == 0:
        n = y
        print('切出的nii.gz文件个数:', n)
    else:
        n = y + 1
        print('切出的nii.gz文件个数:', n)

    for i in range(n):
        star, end = i * 100, i * 100 + 100 

        if i == n - 1:  # 最后一个切片
            img_select = img_array[star:, :, :]
            shape = img_select.shape
            img_select = sitk.GetImageFromArray(img_select)

            img_save_name = 'data_' + str(i) + '_0000.nii.gz'
            print(img_save_name, 'channel:', shape)
            sitk.WriteImage(img_select, img_save_name)

        else:
            img_select = img_array[star:end, :, :]
            shape = img_select.shape
            img_select = sitk.GetImageFromArray(img_select)

            img_save_name = 'data_' + str(i) + '_0000.nii.gz'
            print(img_save_name, 'channel:', shape)
            sitk.WriteImage(img_select, img_save_name)


if __name__ == '__main__':
    root = 'spine_001.nii.gz'

    # 切片函数
    sliceMain(rt=root)

切片结果:

效果如下:

|----------------------------------------------------------------------------|----------------------------------------------------------------------------|----------------------------------------------------------------------------|
| |||
| | | |

然后推理就行了!

python 复制代码
nnUNet_predict -i DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/imagesTs/ -o DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/inferTs/ -t 1 -m 3d_fullres -f 0

融合代码

推理完成的nii数据下放在data目录下,然后运行下面代码会自动拼接:

python 复制代码
import SimpleITK as sitk
import numpy as np
import os
import cv2


# 切片函数
def sliceMain():
    data = [os.path.join('data',u) for u in os.listdir('data')]

    ret_nii = None
    for index,i in enumerate(data):
        img = sitk.ReadImage(i)
        img_array = sitk.GetArrayFromImage(img)  # nii-->array
        print(i,':',img_array.shape)

        if index ==0:
            ret_nii = img_array
        else:
            ret_nii = np.concatenate((ret_nii,img_array),axis=0)

    print('返回的数组size:',ret_nii.shape)
    sitk.WriteImage(sitk.GetImageFromArray(ret_nii),'ret.nii.gz')


if __name__ == '__main__':
    # 切片函数
    sliceMain()

效果如下:

左上角是拼接后的,其余三个是nnUNet推理生成的

|----------------------------------------------------------------------------|----------------------------------------------------------------------------|
| | |
| | |

3.3 可视化展示

下面是原图加真实gt

下面是原图加nnUNet推理结果:

3.4 评估指标

在labelsTs下放入对应的标签即可

如果评估的话,需要真实的gt图!!!

python 复制代码
import numpy as np
import SimpleITK as sitk
from tqdm import tqdm


def main(pred, gt,n):
    gt = sitk.GetArrayFromImage(sitk.ReadImage(gt))             # [ 0  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
    pred = sitk.GetArrayFromImage(sitk.ReadImage(pred))

    dice = []
    for h in tqdm(range(n)):
        if h == 0:
            continue

        g = np.zeros(gt.shape,dtype=np.uint8)           # 单独提取某个灰度级
        g[gt == h] = 255
        g[g<255] = 0
        g[g==255] = 1

        p = np.zeros(pred.shape,dtype=np.uint8)         # 单独提取某个灰度级
        p[pred == h] = 255
        p[p<255] = 0
        p[p==255] = 1

        if len(np.unique(p)) == 1 or len(np.unique(g)) == 1:
            dice.append('None')

        else:
            dice_score = (2*(p*g).sum() / ((p+g).sum()+1e-8))
            dice.append(round(dice_score,4))

    print(dice)

    for i in range(len(dice) - 1,-1,-1):
        if dice[i] == 'None':
            dice.remove('None')

    print('mean dice',np.array(dice).mean())


if __name__ == "__main__":

    gt_path = 'labels.nii.gz'
    pred_path = 'infer.nii.gz'
    classes = 19

    main(pred=pred_path,gt=gt_path,n=classes)

指标如下:

['None', 0.4523, 0.9311, 0.967, 0.9732, 0.9756, 0.9687, 0.9665, 0.9674, 0.9787, 0.9743, 0.9802, 0.9811, 0.9826, 0.974, 'None', 'None', 'None']
mean dice 0.9337642857142857

代码主要实现思路:

因为推理的时候,不是所有的数据同时包含所有标签,所以这里为了方便评估,将所有的类别全部显示。如果没有某个标签就设定为None,然后计算平均dice的时候,就会去掉相应的空标签。
这是nnUnet某个epoch计算的平均dice指标

参考

参考博文如下:nnUNet使用指南(一):Ubuntu系统下使用nnUNet对自己的多模态MR数据集训练 - 梅雨明夏 - 博客园

nnUNet训练并推理自己的数据集_nnunet训练自己数据集-CSDN博客

相关推荐
UMS攸信技术1 小时前
汽车电子行业数字化转型的实践与探索——以盈趣汽车电子为例
人工智能·汽车
ws2019071 小时前
聚焦汽车智能化与电动化︱AUTO TECH 2025 华南展,以展带会,已全面启动,与您相约11月广州!
大数据·人工智能·汽车
堇舟2 小时前
斯皮尔曼相关(Spearman correlation)系数
人工智能·算法·机器学习
Y.O.U..2 小时前
STL学习-容器适配器
开发语言·c++·学习·stl·1024程序员节
爱写代码的小朋友2 小时前
使用 OpenCV 进行人脸检测
人工智能·opencv·计算机视觉
TT哇3 小时前
【Java】数组的定义与使用
java·开发语言·笔记
Cici_ovo3 小时前
摄像头点击器常见问题——摄像头视窗打开慢
人工智能·单片机·嵌入式硬件·物联网·计算机视觉·硬件工程
QQ39575332373 小时前
中阳智能交易系统:创新金融科技赋能投资新时代
人工智能·金融
yyfhq3 小时前
dcgan
深度学习·机器学习·生成对抗网络
这个男人是小帅3 小时前
【图神经网络】 AM-GCN论文精讲(全网最细致篇)
人工智能·pytorch·深度学习·神经网络·分类