基于 U-Net 的医学图像分割

项目概述

该项目实现了一个端到端的医学图像分割流程,包括:

  • 数据预处理与增强

  • U-Net 模型构建与训练

  • 模型验证与可视化

  • 结果保存与分析

数据预处理

项目使用 DSB2018 数据集,通过 preprocess_dsb2018.py 进行数据预处理:

复制代码
def main():
    img_size = 96
    paths = glob('inputs/stage1_train/*')
    
    # 创建输出目录
    os.makedirs('inputs/dsb2018_%d/images' % img_size, exist_ok=True)
    os.makedirs('inputs/dsb2018_%d/masks/0' % img_size, exist_ok=True)
    
    for i in tqdm(range(len(paths))):
        path = paths[i]
        img = cv2.imread(os.path.join(path, 'images', os.path.basename(path) + '.png'))
        mask = np.zeros((img.shape[0], img.shape[1]))
        
        # 合并多个掩码
        for mask_path in glob(os.path.join(path, 'masks', '*')):
            mask_ = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 127
            mask[mask_] = 1
            
        # 调整图像尺寸
        img = cv2.resize(img, (img_size, img_size))
        mask = cv2.resize(mask, (img_size, img_size))
        
        # 保存处理后的图像和掩码
        cv2.imwrite(os.path.join('inputs/dsb2018_%d/images' % img_size, 
                    os.path.basename(path) + '.png'), img)
        cv2.imwrite(os.path.join('inputs/dsb2018_%d/masks/0' % img_size, 
                    os.path.basename(path) + '.png'), (mask * 255).astype('uint8'))

预处理步骤包括:

  1. 统一图像尺寸为 96×96 像素

  2. 合并多个掩码文件为单个二值掩码

  3. 标准化图像格式和通道

数据集类设计

dataset.py 中实现了自定义数据集类,支持数据增强:

复制代码
class Dataset(torch.utils.data.Dataset):
    def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):
        self.img_ids = img_ids
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_ext = img_ext
        self.mask_ext = mask_ext
        self.num_classes = num_classes
        self.transform = transform

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        
        # 读取图像和掩码
        img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))
        mask = []
        for i in range(self.num_classes):
            mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),
                        img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])
        mask = np.dstack(mask)

        # 数据增强
        if self.transform is not None:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
        
        # 标准化和维度调整
        img = img.astype('float32') / 255
        img = img.transpose(2, 0, 1)
        mask = mask.astype('float32') / 255
        mask = mask.transpose(2, 0, 1)
        
        return img, mask, {'img_id': img_id}

数据增强策略

项目使用 Albumentations 库进行数据增强:

训练集增强:

复制代码
train_transform = Compose([
    albu.RandomRotate90(),
    albu.HorizontalFlip(),
    albu.OneOf([
        albu.HueSaturationValue(),
        albu.RandomBrightnessContrast(),
    ], p=1),
    albu.Resize(config['input_h'], config['input_w']),
    albu.Normalize(),
])

验证集增强:

复制代码
val_transform = Compose([
    albu.Resize(config['input_h'], config['input_w']),
    albu.Normalize(),
])

模型训练

train.py 实现了完整的训练流程:

参数配置

复制代码
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', default="dsb2018_96_NestedUNet_woDS")
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--batch_size', default=8, type=int)
    parser.add_argument('--arch', default='NestedUNet')
    parser.add_argument('--deep_supervision', default=False, type=str2bool)
    parser.add_argument('--loss', default='BCEDiceLoss')
    parser.add_argument('--optimizer', default='SGD')
    parser.add_argument('--lr', default=1e-3, type=float)
    # ... 更多参数

训练循环

复制代码
def train(config, train_loader, model, criterion, optimizer):
    model.train()
    
    for input, target, _ in train_loader:
        # 前向传播
        if config['deep_supervision']:
            outputs = model(input)
            loss = 0
            for output in outputs:
                loss += criterion(output, target)
            loss /= len(outputs)
            iou = iou_score(outputs[-1], target)
        else:
            output = model(input)
            loss = criterion(output, target)
            iou = iou_score(output, target)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

模型验证与可视化

val.py 提供了模型验证和结果可视化功能:

复制代码
def plot_examples(datax, datay, model, num_examples=6):
    fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18,4*num_examples))
    m = datax.shape[0]
    
    for row_num in range(num_examples):
        image_indx = np.random.randint(m)
        image_arr = model(datax[image_indx:image_indx+1]).squeeze(0).detach().cpu().numpy()
        
        ax[row_num][0].imshow(np.transpose(datax[image_indx].cpu().numpy(), (1,2,0))[:,:,0])
        ax[row_num][0].set_title("Orignal Image")
        ax[row_num][1].imshow(np.squeeze((image_arr > 0.40)[0,:,:].astype(int)))
        ax[row_num][1].set_title("Segmented Image localization")
        ax[row_num][2].imshow(np.transpose(datay[image_indx].cpu().numpy(), (1,2,0))[:,:,0])
        ax[row_num][2].set_title("Target image")
    
    plt.show()

关键特性

1. 深度监督

支持深度监督训练,通过多个输出层提供中间监督信号。

2. 灵活的损失函数

支持多种损失函数,包括 BCEWithLogitsLoss 和自定义的 BCEDiceLoss。

3. 学习率调度

提供多种学习率调度策略:

  • CosineAnnealingLR

  • ReduceLROnPlateau

  • MultiStepLR

  • ConstantLR

4. 早停机制

通过监控验证集性能实现早停,防止过拟合。

使用方式

训练模型

复制代码
python train.py --dataset dsb2018_96 --arch NestedUNet

验证模型

复制代码
python val.py --name dsb2018_96_NestedUNet_woDS

总结

该项目提供了一个完整的医学图像分割解决方案,具有以下优点:

  1. 模块化设计:各个组件独立,便于修改和扩展

  2. 丰富的数据增强:提高模型泛化能力

  3. 灵活的配置:通过配置文件管理所有超参数

  4. 完整的训练监控:记录训练过程中的各项指标

  5. 结果可视化:直观展示分割效果

这个项目不仅适用于细胞核分割,通过调整配置也可以应用于其他医学图像分割任务,为医学图像分析研究提供了有力的工具。

相关推荐
XINVRY-FPGA18 小时前
XCVU9P-2FLGC2104I Xilinx AMD Virtex UltraScale+ FPGA
嵌入式硬件·机器学习·计算机视觉·fpga开发·硬件工程·dsp开发·fpga
用户23452670098218 小时前
Python实现异步任务队列深度好文
后端·python
夫唯不争,故无尤也19 小时前
PyTorch 的维度变形一站式入门
人工智能·pytorch·python
熊猫钓鱼>_>19 小时前
从零开始构建RPG游戏战斗系统:实战心得与技术要点
开发语言·人工智能·经验分享·python·游戏·ai·qoder
BoBoZz1919 小时前
TriangleStrip连续三角带
python·vtk·图形渲染·图形处理
生信大表哥19 小时前
Python单细胞分析-基于leiden算法的降维聚类
linux·python·算法·生信·数信院生信服务器·生信云服务器
一晌小贪欢20 小时前
【Python办公】用 Selenium 自动化网页批量录入
开发语言·python·selenium·自动化·python3·python学习·网页自动化
wuk99820 小时前
MATLAB双树复小波变换(DTCWT)工具包详解
人工智能·计算机视觉·matlab
诸神缄默不语20 小时前
如何用Python处理文件:Word导出PDF & 如何用Python从Word中提取数据:以处理简历为例
python·pdf·word
vvoennvv20 小时前
【Python TensorFlow】 TCN-LSTM时间序列卷积长短期记忆神经网络时序预测算法(附代码)
python·神经网络·机器学习·tensorflow·lstm·tcn