nnUnet 大模型学习笔记(续):3d_fullres 模型的推理、切片推理、计算dice系数

目录

[1. 前言](#1. 前言)

[2. 更改epochs](#2. 更改epochs)

[3. 推理](#3. 推理)

[3.1 nnUNet_predict](#3.1 nnUNet_predict)

[3.2 切成小的nii gz文件推理](#3.2 切成小的nii gz文件推理)

切片代码

融合代码

[3.3 可视化展示](#3.3 可视化展示)

[3.4 评估指标](#3.4 评估指标)

参考


1. 前言

训练了一天半,终于跑完了。。。。

训练的模型在这可以免费下载:

基于nnUnet3d-fullres训练的spineCT训练结果资源-CSDN文库

关于nnUnet的环境搭建、数据集制作、训练网络参考:

第四章:nnUnet大模型之环境配置、数据集制作_nnunet代码详解pytorch-CSDN博客

nnUnet 大模型学习笔记(续):训练网络(3d_fullres)以及数据集标签的处理-CSDN博客

训练过程如下:

生成的结果在 nnUnet_trained_models 目录下:

训练过程的指标可以看曲线图或者训练日志(validation_raw/summary.json):

这里validation_raw/summary.json 没有生成,不知道因为什么原因,程序被kill了**。。。。**

2. 更改epochs

1000个epoch太多了,可以更改官方的参数,如果按照本文的环境搭建,参数在这里

*/environments/nnunet/lib/python3.8/site-packages/nnunet/training/network_training/

3. 推理

nnUnet 是没有推理和测试放在一起的,它会对指定的数据进行推理

如果你有推理的labels的话,那么就可以进行指标计算,这样就可以测试

如果没有labels,那么只有推理

3.1 nnUNet_predict

在最初的数据集里,新建inferTs,用于推理nnUnet推理的结果,把想要推理的数据放在imagesTs下就行了

好像这里的测试数据必须得是0000.nii.gz结尾的,是因为多模态?

运行下面的命令

python 复制代码
nnUNet_predict -i DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/imagesTs/ -o DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/inferTs/ -t 1 -m 3d_fullres -f 0
  • -i 是想要预测的数据目录 ,一般为imagesTs
  • -o 是保存推理后的数据目录,一般为inferTs
  • -t 是任务的训练标号
  • -m 是nnUnet训练好的模型
  • -f 是训练m模型的几折交叉验证

由于是直接使用模型进行推理, inference done. 后出现 WARNING! Cannot run postprocessing because the postprocessing file is missing. 也即表示已经完成推理。

若CT层数太多或层间距小,可能会卡在 inference done. 阶段,此时需要将CT切分成几部分分别进行 推理。

这里有时候nnUnet推理不出来,可能是因为输入层数太多?参考3.2 处理

3.2 切成小的nii gz文件推理

如果3.1可以成功推理,可以不参考这步!!!

切片代码

注意,这里只需要把想要推理的数据进行切片,然后推理完再拼接即可!!

代码如下:

python 复制代码
import SimpleITK as sitk
import numpy as np
import os
import cv2


# 切片函数
def sliceMain(rt):
    img = sitk.ReadImage(rt)
    img_array = sitk.GetArrayFromImage(img)  # nii-->array

    print('input size:', img_array.shape)

    channel = img_array.shape[0]
    y, z = channel // 100, channel % 100  # 2 91

    if z == 0:
        n = y
        print('切出的nii.gz文件个数:', n)
    else:
        n = y + 1
        print('切出的nii.gz文件个数:', n)

    for i in range(n):
        star, end = i * 100, i * 100 + 100 

        if i == n - 1:  # 最后一个切片
            img_select = img_array[star:, :, :]
            shape = img_select.shape
            img_select = sitk.GetImageFromArray(img_select)

            img_save_name = 'data_' + str(i) + '_0000.nii.gz'
            print(img_save_name, 'channel:', shape)
            sitk.WriteImage(img_select, img_save_name)

        else:
            img_select = img_array[star:end, :, :]
            shape = img_select.shape
            img_select = sitk.GetImageFromArray(img_select)

            img_save_name = 'data_' + str(i) + '_0000.nii.gz'
            print(img_save_name, 'channel:', shape)
            sitk.WriteImage(img_select, img_save_name)


if __name__ == '__main__':
    root = 'spine_001.nii.gz'

    # 切片函数
    sliceMain(rt=root)

切片结果:

效果如下:

|----------------------------------------------------------------------------|----------------------------------------------------------------------------|----------------------------------------------------------------------------|
| |||
| | | |

然后推理就行了!

python 复制代码
nnUNet_predict -i DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/imagesTs/ -o DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/inferTs/ -t 1 -m 3d_fullres -f 0

融合代码

推理完成的nii数据下放在data目录下,然后运行下面代码会自动拼接:

python 复制代码
import SimpleITK as sitk
import numpy as np
import os
import cv2


# 切片函数
def sliceMain():
    data = [os.path.join('data',u) for u in os.listdir('data')]

    ret_nii = None
    for index,i in enumerate(data):
        img = sitk.ReadImage(i)
        img_array = sitk.GetArrayFromImage(img)  # nii-->array
        print(i,':',img_array.shape)

        if index ==0:
            ret_nii = img_array
        else:
            ret_nii = np.concatenate((ret_nii,img_array),axis=0)

    print('返回的数组size:',ret_nii.shape)
    sitk.WriteImage(sitk.GetImageFromArray(ret_nii),'ret.nii.gz')


if __name__ == '__main__':
    # 切片函数
    sliceMain()

效果如下:

左上角是拼接后的,其余三个是nnUNet推理生成的

|----------------------------------------------------------------------------|----------------------------------------------------------------------------|
| | |
| | |

3.3 可视化展示

下面是原图加真实gt

下面是原图加nnUNet推理结果:

3.4 评估指标

在labelsTs下放入对应的标签即可

如果评估的话,需要真实的gt图!!!

python 复制代码
import numpy as np
import SimpleITK as sitk
from tqdm import tqdm


def main(pred, gt,n):
    gt = sitk.GetArrayFromImage(sitk.ReadImage(gt))             # [ 0  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
    pred = sitk.GetArrayFromImage(sitk.ReadImage(pred))

    dice = []
    for h in tqdm(range(n)):
        if h == 0:
            continue

        g = np.zeros(gt.shape,dtype=np.uint8)           # 单独提取某个灰度级
        g[gt == h] = 255
        g[g<255] = 0
        g[g==255] = 1

        p = np.zeros(pred.shape,dtype=np.uint8)         # 单独提取某个灰度级
        p[pred == h] = 255
        p[p<255] = 0
        p[p==255] = 1

        if len(np.unique(p)) == 1 or len(np.unique(g)) == 1:
            dice.append('None')

        else:
            dice_score = (2*(p*g).sum() / ((p+g).sum()+1e-8))
            dice.append(round(dice_score,4))

    print(dice)

    for i in range(len(dice) - 1,-1,-1):
        if dice[i] == 'None':
            dice.remove('None')

    print('mean dice',np.array(dice).mean())


if __name__ == "__main__":

    gt_path = 'labels.nii.gz'
    pred_path = 'infer.nii.gz'
    classes = 19

    main(pred=pred_path,gt=gt_path,n=classes)

指标如下:

['None', 0.4523, 0.9311, 0.967, 0.9732, 0.9756, 0.9687, 0.9665, 0.9674, 0.9787, 0.9743, 0.9802, 0.9811, 0.9826, 0.974, 'None', 'None', 'None']
mean dice 0.9337642857142857

代码主要实现思路:

因为推理的时候,不是所有的数据同时包含所有标签,所以这里为了方便评估,将所有的类别全部显示。如果没有某个标签就设定为None,然后计算平均dice的时候,就会去掉相应的空标签。
这是nnUnet某个epoch计算的平均dice指标

参考

参考博文如下:nnUNet使用指南(一):Ubuntu系统下使用nnUNet对自己的多模态MR数据集训练 - 梅雨明夏 - 博客园

nnUNet训练并推理自己的数据集_nnunet训练自己数据集-CSDN博客

相关推荐
泰迪智能科技0128 分钟前
高校深度学习视觉应用平台产品介绍
人工智能·深度学习
盛派网络小助手1 小时前
微信 SDK 更新 Sample,NCF 文档和模板更新,更多更新日志,欢迎解锁
开发语言·人工智能·后端·架构·c#
Eric.Lee20211 小时前
Paddle OCR 中英文检测识别 - python 实现
人工智能·opencv·计算机视觉·ocr检测
cd_farsight1 小时前
nlp初学者怎么入门?需要学习哪些?
人工智能·自然语言处理
AI明说1 小时前
评估大语言模型在药物基因组学问答任务中的表现:PGxQA
人工智能·语言模型·自然语言处理·数智药师·数智药学
Focus_Liu2 小时前
NLP-UIE(Universal Information Extraction)
人工智能·自然语言处理
PowerBI学谦2 小时前
使用copilot轻松将电子邮件转为高效会议
人工智能·copilot
audyxiao0012 小时前
AI一周重要会议和活动概览
人工智能·计算机视觉·数据挖掘·多模态
CCSBRIDGE2 小时前
Magento2项目部署笔记
笔记