项目概述
该项目实现了一个端到端的医学图像分割流程,包括:
-
数据预处理与增强
-
U-Net 模型构建与训练
-
模型验证与可视化
-
结果保存与分析
数据预处理
项目使用 DSB2018 数据集,通过 preprocess_dsb2018.py 进行数据预处理:
def main():
img_size = 96
paths = glob('inputs/stage1_train/*')
# 创建输出目录
os.makedirs('inputs/dsb2018_%d/images' % img_size, exist_ok=True)
os.makedirs('inputs/dsb2018_%d/masks/0' % img_size, exist_ok=True)
for i in tqdm(range(len(paths))):
path = paths[i]
img = cv2.imread(os.path.join(path, 'images', os.path.basename(path) + '.png'))
mask = np.zeros((img.shape[0], img.shape[1]))
# 合并多个掩码
for mask_path in glob(os.path.join(path, 'masks', '*')):
mask_ = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 127
mask[mask_] = 1
# 调整图像尺寸
img = cv2.resize(img, (img_size, img_size))
mask = cv2.resize(mask, (img_size, img_size))
# 保存处理后的图像和掩码
cv2.imwrite(os.path.join('inputs/dsb2018_%d/images' % img_size,
os.path.basename(path) + '.png'), img)
cv2.imwrite(os.path.join('inputs/dsb2018_%d/masks/0' % img_size,
os.path.basename(path) + '.png'), (mask * 255).astype('uint8'))
预处理步骤包括:
-
统一图像尺寸为 96×96 像素
-
合并多个掩码文件为单个二值掩码
-
标准化图像格式和通道
数据集类设计
dataset.py 中实现了自定义数据集类,支持数据增强:
class Dataset(torch.utils.data.Dataset):
def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):
self.img_ids = img_ids
self.img_dir = img_dir
self.mask_dir = mask_dir
self.img_ext = img_ext
self.mask_ext = mask_ext
self.num_classes = num_classes
self.transform = transform
def __getitem__(self, idx):
img_id = self.img_ids[idx]
# 读取图像和掩码
img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))
mask = []
for i in range(self.num_classes):
mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),
img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])
mask = np.dstack(mask)
# 数据增强
if self.transform is not None:
augmented = self.transform(image=img, mask=mask)
img = augmented['image']
mask = augmented['mask']
# 标准化和维度调整
img = img.astype('float32') / 255
img = img.transpose(2, 0, 1)
mask = mask.astype('float32') / 255
mask = mask.transpose(2, 0, 1)
return img, mask, {'img_id': img_id}
数据增强策略
项目使用 Albumentations 库进行数据增强:
训练集增强:
train_transform = Compose([
albu.RandomRotate90(),
albu.HorizontalFlip(),
albu.OneOf([
albu.HueSaturationValue(),
albu.RandomBrightnessContrast(),
], p=1),
albu.Resize(config['input_h'], config['input_w']),
albu.Normalize(),
])
验证集增强:
val_transform = Compose([
albu.Resize(config['input_h'], config['input_w']),
albu.Normalize(),
])
模型训练
train.py 实现了完整的训练流程:
参数配置
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--name', default="dsb2018_96_NestedUNet_woDS")
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--batch_size', default=8, type=int)
parser.add_argument('--arch', default='NestedUNet')
parser.add_argument('--deep_supervision', default=False, type=str2bool)
parser.add_argument('--loss', default='BCEDiceLoss')
parser.add_argument('--optimizer', default='SGD')
parser.add_argument('--lr', default=1e-3, type=float)
# ... 更多参数
训练循环
def train(config, train_loader, model, criterion, optimizer):
model.train()
for input, target, _ in train_loader:
# 前向传播
if config['deep_supervision']:
outputs = model(input)
loss = 0
for output in outputs:
loss += criterion(output, target)
loss /= len(outputs)
iou = iou_score(outputs[-1], target)
else:
output = model(input)
loss = criterion(output, target)
iou = iou_score(output, target)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
模型验证与可视化
val.py 提供了模型验证和结果可视化功能:
def plot_examples(datax, datay, model, num_examples=6):
fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18,4*num_examples))
m = datax.shape[0]
for row_num in range(num_examples):
image_indx = np.random.randint(m)
image_arr = model(datax[image_indx:image_indx+1]).squeeze(0).detach().cpu().numpy()
ax[row_num][0].imshow(np.transpose(datax[image_indx].cpu().numpy(), (1,2,0))[:,:,0])
ax[row_num][0].set_title("Orignal Image")
ax[row_num][1].imshow(np.squeeze((image_arr > 0.40)[0,:,:].astype(int)))
ax[row_num][1].set_title("Segmented Image localization")
ax[row_num][2].imshow(np.transpose(datay[image_indx].cpu().numpy(), (1,2,0))[:,:,0])
ax[row_num][2].set_title("Target image")
plt.show()
关键特性
1. 深度监督
支持深度监督训练,通过多个输出层提供中间监督信号。
2. 灵活的损失函数
支持多种损失函数,包括 BCEWithLogitsLoss 和自定义的 BCEDiceLoss。
3. 学习率调度
提供多种学习率调度策略:
-
CosineAnnealingLR
-
ReduceLROnPlateau
-
MultiStepLR
-
ConstantLR
4. 早停机制
通过监控验证集性能实现早停,防止过拟合。
使用方式
训练模型
python train.py --dataset dsb2018_96 --arch NestedUNet
验证模型
python val.py --name dsb2018_96_NestedUNet_woDS
总结
该项目提供了一个完整的医学图像分割解决方案,具有以下优点:
-
模块化设计:各个组件独立,便于修改和扩展
-
丰富的数据增强:提高模型泛化能力
-
灵活的配置:通过配置文件管理所有超参数
-
完整的训练监控:记录训练过程中的各项指标
-
结果可视化:直观展示分割效果
这个项目不仅适用于细胞核分割,通过调整配置也可以应用于其他医学图像分割任务,为医学图像分析研究提供了有力的工具。