python
复制代码
from tensorflow import keras
import SimpleITK as sitk
from scipy import ndimage
import numpy as np
import random
import math
import os
class Seg3DDataset(keras.utils.Sequence):
def __init__(self, work_dir, num_classes, batch_size,
hu_min_val, hu_max_val,
mode='train'):
self.work_dir = work_dir
self.num_classes = num_classes
self.batch_size = batch_size
self.mode = mode
self.hu_min_val = hu_min_val
self.hu_max_val = hu_max_val
images_dir = os.path.join(work_dir, "JPEGImages")
labels_dir = os.path.join(work_dir, "Segmentations")
file_names = os.listdir(labels_dir)
random.shuffle(file_names)
self.images_path = []
self.labels_path = []
for filename in file_names:
image_path = os.path.join(images_dir, filename)
label_path = os.path.join(labels_dir, filename)
self.images_path.append(image_path)
self.labels_path.append(label_path)
def __len__(self):
return math.floor(len(self.images_path) / self.batch_size)
def __getitem__(self, idx):
batch_xpaths = self.images_path[idx * self.batch_size:(idx + 1) *
self.batch_size]
batch_ypaths = self.labels_path[idx * self.batch_size:(idx + 1) *
self.batch_size]
x_re = []
y_re = []
for x_path,y_path in zip(batch_xpaths,batch_ypaths):
image = sitk.ReadImage(x_path)
label = sitk.ReadImage(y_path)
if self.mode == 'train' and random.randint(0,1)==0: # 1/2的概率旋转
angle_ = random.randint(20,90) # 训练时数据随机旋转增强
image,label = self.rotate_image(image,label,angle=angle_)
image_array = sitk.GetArrayFromImage(image)
# image_array归一化
image_array = self.normalize_img(image_array)
label_array = sitk.GetArrayFromImage(label)
image_array = np.transpose(image_array,[1,2,0]).astype('float32') # [256,256,16]
label_array = np.transpose(label_array,[1,2,0]) # [256,256,16]
# label要one-hot
onehot_label = np.zeros((label_array.shape[0], label_array.shape[1], label_array.shape[2],self.num_classes), dtype=np.float32)
for i in range(self.num_classes):
onehot_label[:, :, :,i] = (label_array == i).astype(np.float32)
# image要增加channel
image_array = np.expand_dims(image_array,axis=-1)
x_re.append(image_array)
y_re.append(onehot_label)
return np.array(x_re),np.array(y_re)
def rotate_image(self,image,label,angle):
"""
旋转image和label,返回对应的image和label
输入输出均为nii的图像
默认输入的image和label已经对齐
"""
# print(f'原始spacing信息'.center(60,'='))
spacing = image.GetSpacing()
origin = image.GetOrigin()
direction = image.GetDirection()
assert image.GetSpacing()==label.GetSpacing(),f'image: {image.GetSpacing()}; label: {label.GetSpacing()}'
assert image.GetOrigin()==label.GetOrigin(),f'image: {image.GetOrigin()}; label: {label.GetOrigin()}'
assert image.GetDirection()==label.GetDirection(),f'image: {image.GetDirection()}; label: {label.GetDirection()}'
image_array = sitk.GetArrayFromImage(image)
label_array = sitk.GetArrayFromImage(label)
assert image_array.shape==label_array.shape,f'image_array: {image_array.shape}, label_array: {label_array.shape}'
# print('original shape: ',image_array.shape,label_array.shape)
imageArray_rotate = ndimage.rotate(image_array,angle,axes=[1,2],reshape=False,mode='nearest',order=0)
labelArray_rotate = ndimage.rotate(label_array,angle,axes=[1,2],reshape=False,mode='nearest',order=0)
# print('rotate shape: ',imageArray_rotate.shape,labelArray_rotate.shape)
assert image_array.shape==imageArray_rotate.shape,f'org: {image_array.shape}, rotate: {imageArray_rotate.shape}'
assert label_array.shape==labelArray_rotate.shape,f'org: {label_array.shape}, rotate: {labelArray_rotate.shape}'
labelArray_rotate[labelArray_rotate!=0] = 1
image_rotate = sitk.GetImageFromArray(imageArray_rotate)
label_rotate = sitk.GetImageFromArray(labelArray_rotate)
image_rotate.SetSpacing(spacing)
image_rotate.SetOrigin(origin)
image_rotate.SetDirection(direction)
label_rotate.SetSpacing(spacing)
label_rotate.SetOrigin(origin)
label_rotate.SetDirection(direction)
# print(f'rotate之后spacing信息'.center(60,'='))
# print(image_rotate.GetSpacing(),image_rotate.GetOrigin(),image_rotate.GetDirection())
# print(label_rotate.GetSpacing(),label_rotate.GetOrigin(),label_rotate.GetDirection())
assert image_rotate.GetSpacing()==spacing,f'rotate: {image_rotate.GetSpacing()}, org: {spacing}'
assert image_rotate.GetOrigin()==origin,f'rotate: {image_rotate.GetOrigin()}, org: {origin}'
assert image_rotate.GetDirection()==direction,f'rotate: {image_rotate.GetDirection()}, org: {direction}'
return image_rotate,label_rotate
def normalize_img(self,img:np.ndarray)->np.ndarray:
""" 归一化 """
# min_val=-1000
# max_val=600
value_range = self.hu_max_val - self.hu_min_val
norm_0_1 = (img - self.hu_min_val) / value_range
img = np.clip(2*norm_0_1-1,-1,1)
return img
def on_epoch_end(self):
seed = random.randint(1,100)
random.seed(seed)
random.shuffle(self.images_path)
random.seed(seed)
random.shuffle(self.labels_path)