基于 PyTorch 的 UNet 与 NestedUNet 图像分割

图像分割是计算机视觉领域的重要任务,它旨在将图像中的每个像素分配到特定的类别。本文将详细介绍如何使用 PyTorch 实现经典的 UNet 及其改进版本 NestedUNet,并完整展示从数据预处理到模型训练和评估的全流程。

项目概述

本项目实现了两种主流的图像分割模型:

  • 经典 UNet 模型
  • NestedUNet(也称为 U-Net++)模型

我们使用 DSB2018 数据集作为示例,展示如何构建一个完整的图像分割系统,包括数据预处理、模型定义、训练流程和结果评估。

项目结构

首先,让我们了解项目的文件结构:

plaintext

复制代码
.
├── archs.py          # 模型架构定义(UNet和NestedUNet)
├── train.py          # 训练脚本
├── val.py            # 验证与评估脚本
├── losses.py         # 自定义损失函数
├── metrics.py        # 评估指标
├── dataset.py        # 数据集加载器
├── utils.py          # 工具函数
└── preprocess_dsb2018.py # 数据预处理脚本

数据预处理

在训练模型之前,我们需要对原始数据进行预处理。preprocess_dsb2018.py脚本负责这一工作:

python

运行

复制代码
import os
from glob import glob
import cv2
import numpy as np
from tqdm import tqdm

def main():
    img_size = 96  # 统一图像尺寸为96x96
    
    paths = glob('inputs/stage1_train/*')
    
    # 创建输出目录
    os.makedirs('inputs/dsb2018_%d/images' % img_size, exist_ok=True)
    os.makedirs('inputs/dsb2018_%d/masks/0' % img_size, exist_ok=True)
    
    for i in tqdm(range(len(paths))):
        path = paths[i]
        # 读取图像
        img = cv2.imread(os.path.join(path, 'images',
                         os.path.basename(path) + '.png'))
        # 合并所有掩码
        mask = np.zeros((img.shape[0], img.shape[1]))
        for mask_path in glob(os.path.join(path, 'masks', '*')):
            mask_ = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 127
            mask[mask_] = 1
        
        # 处理不同通道数的图像
        if len(img.shape) == 2:
            img = np.tile(img[..., None], (1, 1, 3))
        if img.shape[2] == 4:
            img = img[..., :3]
            
        # 调整大小
        img = cv2.resize(img, (img_size, img_size))
        mask = cv2.resize(mask, (img_size, img_size))
        
        # 保存处理后的图像和掩码
        cv2.imwrite(os.path.join('inputs/dsb2018_%d/images' % img_size,
                    os.path.basename(path) + '.png'), img)
        cv2.imwrite(os.path.join('inputs/dsb2018_%d/masks/0' % img_size,
                    os.path.basename(path) + '.png'), (mask * 255).astype('uint8'))

if __name__ == '__main__':
    main()

预处理步骤主要做了以下工作:

  1. 将所有图像统一调整为 96x96 大小
  2. 合并多个掩码文件为一个
  3. 处理不同通道数的图像,统一为 3 通道
  4. 组织成标准的数据集目录结构

数据集加载器

dataset.py实现了自定义数据集类,方便加载和预处理图像数据:

python

运行

复制代码
import os
import cv2
import numpy as np
import torch
import torch.utils.data

class Dataset(torch.utils.data.Dataset):
    def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):
        self.img_ids = img_ids
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_ext = img_ext
        self.mask_ext = mask_ext
        self.num_classes = num_classes
        self.transform = transform

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        
        # 读取图像
        img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))
        
        # 读取掩码
        mask = []
        for i in range(self.num_classes):
            mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),
                        img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])
        mask = np.dstack(mask)
        
        # 应用数据增强
        if self.transform is not None:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
        
        # 归一化并调整通道顺序
        img = img.astype('float32') / 255
        img = img.transpose(2, 0, 1)  # 从HWC转为CHW
        mask = mask.astype('float32') / 255
        mask = mask.transpose(2, 0, 1)
        
        return img, mask, {'img_id': img_id}

这个数据集类支持:

  • 加载多类别的掩码
  • 应用数据增强(通过 albumentations 库)
  • 自动进行图像归一化和通道顺序调整

模型架构

archs.py文件定义了 UNet 和 NestedUNet 两种模型架构。

VGGBlock 组件

两种模型都使用了 VGGBlock 作为基本构建块:

python

运行

复制代码
class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out

每个 VGGBlock 包含两个卷积层,每个卷积层后都跟着批归一化和 ReLU 激活函数。

UNet 模型

UNet 模型由编码器、解码器和跳跃连接组成:

python

运行

复制代码
class UNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, **kwargs):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]  # 每个层级的滤波器数量

        self.pool = nn.MaxPool2d(2, 2)  # 下采样
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 上采样

        # 编码器部分
        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        # 解码器部分(带跳跃连接)
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
        self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])

        # 最终卷积层,输出类别数
        self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)

    def forward(self, input):
        # 编码器前向传播
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x2_0 = self.conv2_0(self.pool(x1_0))
        x3_0 = self.conv3_0(self.pool(x2_0))
        x4_0 = self.conv4_0(self.pool(x3_0))

        # 解码器前向传播(带跳跃连接)
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))

        output = self.final(x0_4)
        return output

NestedUNet 模型

NestedUNet(U-Net++)是 UNet 的改进版本,它引入了更多的跳跃连接,增强了特征融合:

python

运行

复制代码
class NestedUNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.deep_supervision = deep_supervision  # 是否启用深度监督

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        # 编码器部分
        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        # 嵌套连接的解码器部分
        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])

        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])

        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])

        self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])

        # 深度监督的输出层
        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)

    def forward(self, input):
        # 编码器和嵌套连接的前向传播
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

        # 根据是否启用深度监督返回不同结果
        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]
        else:
            output = self.final(x0_4)
            return output

NestedUNet 的主要改进是引入了更多的嵌套跳跃连接,使低层级特征能够更直接地传递到高层级,同时支持深度监督(deep supervision),即从多个层级输出结果并联合优化,有助于模型更快收敛。

损失函数

losses.py实现了适用于图像分割的损失函数:

python

运行

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class BCEDiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        # BCE损失
        bce = F.binary_cross_entropy_with_logits(input, target)
        
        # Dice损失
        smooth = 1e-5
        input = torch.sigmoid(input)
        num = target.size(0)
        input = input.view(num, -1)
        target = target.view(num, -1)
        intersection = (input * target)
        dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)
        dice = 1 - dice.sum() / num
        
        # 组合损失
        return 0.5 * bce + dice

class LovaszHingeLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        input = input.squeeze(1)
        target = target.squeeze(1)
        # Lovasz Hinge损失,需要安装对应的库
        loss = lovasz_hinge(input, target, per_image=True)
        return loss

BCEDiceLoss 是 BCE 损失和 Dice 损失的组合,在医学图像分割中表现优异:

  • BCE 损失擅长处理类别不平衡问题
  • Dice 损失更关注前景区域的重叠度

评估指标

metrics.py实现了图像分割常用的评估指标:

python

运行

复制代码
import numpy as np
import torch
import torch.nn.functional as F

def iou_score(output, target):
    """计算交并比(IoU)"""
    smooth = 1e-5

    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
        
    # 二值化输出和目标
    output_ = output > 0.5
    target_ = target > 0.5
    
    # 计算交集和并集
    intersection = (output_ & target_).sum()
    union = (output_ | target_).sum()

    return (intersection + smooth) / (union + smooth)

def dice_coef(output, target):
    """计算Dice系数"""
    smooth = 1e-5

    output = torch.sigmoid(output).view(-1).data.cpu().numpy()
    target = target.view(-1).data.cpu().numpy()
    intersection = (output * target).sum()

    return (2. * intersection + smooth) / \
        (output.sum() + target.sum() + smooth)

IoU(交并比)是语义分割中最常用的指标,计算预测区域与真实区域的交集和并集之比。

训练脚本

train.py实现了完整的模型训练流程:

参数解析

首先定义了可配置的训练参数:

python

运行

复制代码
def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--name', default="dsb2018_96_NestedUNet_woDS",
                        help='model name: (default: arch+timestamp)')
    parser.add_argument('--epochs', default=100, type=int,
                        help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=8, type=int,
                        help='mini-batch size (default: 8)')
    
    # 模型参数
    parser.add_argument('--arch', '-a', metavar='ARCH', default='NestedUNet',
                        choices=ARCH_NAMES, help='model architecture')
    parser.add_argument('--deep_supervision', default=False, type=str2bool)
    parser.add_argument('--input_channels', default=3, type=int,
                        help='input channels')
    parser.add_argument('--num_classes', default=1, type=int,
                        help='number of classes')
    parser.add_argument('--input_w', default=96, type=int,
                        help='image width')
    parser.add_argument('--input_h', default=96, type=int,
                        help='image height')
    
    # 损失函数
    parser.add_argument('--loss', default='BCEDiceLoss',
                        choices=LOSS_NAMES, help='loss function')
    
    # 数据集参数
    parser.add_argument('--dataset', default='dsb2018_96',
                        help='dataset name')
    parser.add_argument('--img_ext', default='.png',
                        help='image file extension')
    parser.add_argument('--mask_ext', default='.png',
                        help='mask file extension')

    # 优化器参数
    parser.add_argument('--optimizer', default='SGD',
                        choices=['Adam', 'SGD'], help='optimizer')
    parser.add_argument('--lr', '--learning_rate', default=1e-3, type=float,
                        help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float,
                        help='momentum')
    parser.add_argument('--weight_decay', default=1e-4, type=float,
                        help='weight decay')
    
    # 学习率调度器
    parser.add_argument('--scheduler', default='CosineAnnealingLR',
                        choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 
                                 'MultiStepLR', 'ConstantLR'])
    # ... 其他参数
    
    return parser.parse_args()

训练和验证函数

python

运行

复制代码
def train(config, train_loader, model, criterion, optimizer):
    avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}
    model.train()  # 设置为训练模式
    
    pbar = tqdm(total=len(train_loader))
    for input, target, _ in train_loader:
        input = input.cuda()
        target = target.cuda()

        # 前向传播
        if config['deep_supervision']:
            outputs = model(input)
            loss = 0
            # 深度监督:对所有输出计算损失并平均
            for output in outputs:
                loss += criterion(output, target)
            loss /= len(outputs)
            iou = iou_score(outputs[-1], target)
        else:
            output = model(input)
            loss = criterion(output, target)
            iou = iou_score(output, target)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 更新指标
        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))

        pbar.set_postfix(loss=avg_meters['loss'].avg, iou=avg_meters['iou'].avg)
        pbar.update(1)
    pbar.close()

    return {'loss': avg_meters['loss'].avg, 'iou': avg_meters['iou'].avg}

def validate(config, val_loader, model, criterion):
    avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}
    model.eval()  # 设置为评估模式

    with torch.no_grad():  # 禁用梯度计算
        pbar = tqdm(total=len(val_loader))
        for input, target, _ in val_loader:
            input = input.cuda()
            target = target.cuda()

            # 前向传播
            if config['deep_supervision']:
                outputs = model(input)
                loss = 0
                for output in outputs:
                    loss += criterion(output, target)
                loss /= len(outputs)
                iou = iou_score(outputs[-1], target)
            else:
                output = model(input)
                loss = criterion(output, target)
                iou = iou_score(output, target)

            # 更新指标
            avg_meters['loss'].update(loss.item(), input.size(0))
            avg_meters['iou'].update(iou, input.size(0))

            pbar.set_postfix(loss=avg_meters['loss'].avg, iou=avg_meters['iou'].avg)
            pbar.update(1)
        pbar.close()

    return {'loss': avg_meters['loss'].avg, 'iou': avg_meters['iou'].avg}

主函数

python

运行

复制代码
def main():
    config = vars(parse_args())
    
    # 创建输出目录
    os.makedirs('models/%s' % config['name'], exist_ok=True)
    
    # 保存配置
    with open('models/%s/config.yml' % config['name'], 'w') as f:
        yaml.dump(config, f)
    
    # 定义损失函数
    if config['loss'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.__dict__[config['loss']]().cuda()
    
    # 启用cudnn加速
    cudnn.benchmark = True
    
    # 创建模型
    print("=> creating model %s" % config['arch'])
    model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])
    model = model.cuda()
    
    # 定义优化器
    params = filter(lambda p: p.requires_grad, model.parameters())
    if config['optimizer'] == 'Adam':
        optimizer = optim.Adam(
            params, lr=config['lr'], weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'SGD':
        optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],
                              nesterov=config['nesterov'], weight_decay=config['weight_decay'])
    
    # 定义学习率调度器
    if config['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], 
                                                  patience=config['patience'],
                                                  verbose=1, min_lr=config['min_lr'])
    # ... 其他调度器
    
    # 数据加载
    img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
    train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)
    
    # 数据增强
    train_transform = Compose([
        albu.RandomRotate90(),
        albu.Flip(),
        OneOf([
            transforms.HueSaturationValue(),
            transforms.RandomBrightness(),
            transforms.RandomContrast(),
        ], p=1),
        albu.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])
    
    val_transform = Compose([
        albu.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])
    
    # 创建数据加载器
    train_dataset = Dataset(...)
    val_dataset = Dataset(...)
    train_loader = torch.utils.data.DataLoader(...)
    val_loader = torch.utils.data.DataLoader(...)
    
    # 训练循环
    log = {'epoch': [], 'lr': [], 'loss': [], 'iou': [], 'val_loss': [], 'val_iou': []}
    best_iou = 0
    trigger = 0
    
    for epoch in range(config['epochs']):
        print('Epoch [%d/%d]' % (epoch, config['epochs']))
        
        # 训练一个epoch
        train_log = train(config, train_loader, model, criterion, optimizer)
        # 验证
        val_log = validate(config, val_loader, model, criterion)
        
        # 更新学习率
        if config['scheduler'] == 'CosineAnnealingLR':
            scheduler.step()
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(val_log['loss'])
        
        # 打印日志
        print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'
              % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))
        
        # 保存日志
        log['epoch'].append(epoch)
        log['lr'].append(config['lr'])
        log['loss'].append(train_log['loss'])
        log['iou'].append(train_log['iou'])
        log['val_loss'].append(val_log['loss'])
        log['val_iou'].append(val_log['iou'])
        pd.DataFrame(log).to_csv('models/%s/log.csv' % config['name'], index=False)
        
        # 保存最佳模型
        if val_log['iou'] > best_iou:
            torch.save(model.state_dict(), 'models/%s/model.pth' % config['name'])
            best_iou = val_log['iou']
            print("=> saved best model")
            trigger = 0
        
        # 早停机制
        if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:
            print("=> early stopping")
            break
        
        torch.cuda.empty_cache()

验证与可视化

val.py用于加载训练好的模型进行验证,并可视化分割结果:

python

运行

复制代码
def main():
    args = parse_args()
    
    # 加载配置
    with open('models/%s/config.yml' % args.name, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    
    # 创建模型
    model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])
    model = model.cuda()
    
    # 加载模型权重
    model.load_state_dict(torch.load('models/%s/model.pth' % config['name']))
    model.eval()
    
    # 准备数据
    img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
    _, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)
    
    # 加载验证集
    val_transform = Compose([
        albu.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])
    val_dataset = Dataset(...)
    val_loader = torch.utils.data.DataLoader(...)
    
    # 评估并保存结果
    avg_meter = AverageMeter()
    for c in range(config['num_classes']):
        os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True)
    
    with torch.no_grad():
        for input, target, meta in tqdm(val_loader, total=len(val_loader)):
            input = input.cuda()
            target = target.cuda()
            
            # 模型预测
            if config['deep_supervision']:
                output = model(input)[-1]
            else:
                output = model(input)
            
            # 计算IoU
            iou = iou_score(output, target)
            avg_meter.update(iou, input.size(0))
            
            # 保存输出结果
            output = torch.sigmoid(output).cpu().numpy()
            for i in range(len(output)):
                for c in range(config['num_classes']):
                    cv2.imwrite(os.path.join('outputs', config['name'], str(c), 
                                meta['img_id'][i] + '.jpg'),
                                (output[i, c] * 255).astype('uint8'))
    
    print('IoU: %.4f' % avg_meter.avg)
    
    # 可视化结果
    plot_examples(input, target, model, num_examples=3)

可视化函数:

python

运行

复制代码
def plot_examples(datax, datay, model, num_examples=6):
    fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18,4*num_examples))
    m = datax.shape[0]
    for row_num in range(num_examples):
        image_indx = np.random.randint(m)
        # 获取模型预测
        image_arr = model(datax[image_indx:image_indx+1]).squeeze(0).detach().cpu().numpy()
        
        # 绘制原图
        ax[row_num][0].imshow(np.transpose(datax[image_indx].cpu().numpy(), (1,2,0))[:,:,0])
        ax[row_num][0].set_title("Original Image")
        
        # 绘制分割结果
        ax[row_num][1].imshow(np.squeeze((image_arr > 0.40)[0,:,:].astype(int)))
        ax[row_num][1].set_title("Segmented Image")
        
        # 绘制目标掩码
        ax[row_num][2].imshow(np.transpose(datay[image_indx].cpu().numpy(), (1,2,0))[:,:,0])
        ax[row_num][2].set_title("Target Mask")
    plt.show()

训练与使用指南

  1. 数据准备

    bash

    复制代码
    python preprocess_dsb2018.py
  2. 模型训练

    bash

    复制代码
    python train.py --dataset dsb2018_96 --arch NestedUNet --epochs 100 --batch_size 8
  3. 模型验证

    bash

    复制代码
    python val.py --name dsb2018_96_NestedUNet_woDS

总结

本文详细介绍了基于 PyTorch 的 UNet 和 NestedUNet 图像分割模型的实现。通过这个项目,我们可以学习到:

  1. 如何构建经典的 UNet 模型及其改进版本 NestedUNet
  2. 如何设计适用于图像分割的损失函数(如 BCEDiceLoss)
  3. 如何实现完整的训练流程,包括数据加载、数据增强、模型训练和验证
  4. 如何评估分割模型的性能(使用 IoU 等指标)

该项目可以作为图像分割任务的基础框架,通过修改数据集加载部分和调整模型参数,可应用于不同的分割任务中。NestedUNet 通过增加嵌套连接和深度监督,通常能比传统 UNet 获得更好的分割性能,但计算成本也更高,实际应用中可根据需求选择合适的模型。

相关推荐
2501_9411474221 小时前
人工智能赋能智慧城市互联网应用:智能交通、能源与公共管理优化实践探索》
人工智能
咚咚王者21 小时前
人工智能之数据分析 numpy:第十五章 项目实践
人工智能·数据分析·numpy
水月wwww1 天前
深度学习——神经网络
人工智能·深度学习·神经网络
司铭鸿1 天前
祖先关系的数学重构:从家谱到算法的思维跃迁
开发语言·数据结构·人工智能·算法·重构·c#·哈希算法
机器之心1 天前
从推荐算法优化到AI4S、Pico和大模型,杨震原长文揭秘字节跳动的技术探索
人工智能·openai
johnny2331 天前
AI加持测试工具汇总:Strix、
人工智能·测试工具
机器之心1 天前
哈工大深圳团队推出Uni-MoE-2.0-Omni:全模态理解、推理及生成新SOTA
人工智能·openai
w***Q3501 天前
人工智能在智能家居中的控制
人工智能·智能家居
青瓷程序设计1 天前
花朵识别系统【最新版】Python+TensorFlow+Vue3+Django+人工智能+深度学习+卷积神经网络算法
人工智能·python·深度学习
阿里云大数据AI技术1 天前
PAI Physical AI Notebook详解4:基于仿真的GR00T-N1.5模型微调
人工智能