图像分割是计算机视觉领域的重要任务,它旨在将图像中的每个像素分配到特定的类别。本文将详细介绍如何使用 PyTorch 实现经典的 UNet 及其改进版本 NestedUNet,并完整展示从数据预处理到模型训练和评估的全流程。
项目概述
本项目实现了两种主流的图像分割模型:
- 经典 UNet 模型
- NestedUNet(也称为 U-Net++)模型
我们使用 DSB2018 数据集作为示例,展示如何构建一个完整的图像分割系统,包括数据预处理、模型定义、训练流程和结果评估。
项目结构
首先,让我们了解项目的文件结构:
plaintext
.
├── archs.py # 模型架构定义(UNet和NestedUNet)
├── train.py # 训练脚本
├── val.py # 验证与评估脚本
├── losses.py # 自定义损失函数
├── metrics.py # 评估指标
├── dataset.py # 数据集加载器
├── utils.py # 工具函数
└── preprocess_dsb2018.py # 数据预处理脚本
数据预处理
在训练模型之前,我们需要对原始数据进行预处理。preprocess_dsb2018.py脚本负责这一工作:
python
运行
import os
from glob import glob
import cv2
import numpy as np
from tqdm import tqdm
def main():
img_size = 96 # 统一图像尺寸为96x96
paths = glob('inputs/stage1_train/*')
# 创建输出目录
os.makedirs('inputs/dsb2018_%d/images' % img_size, exist_ok=True)
os.makedirs('inputs/dsb2018_%d/masks/0' % img_size, exist_ok=True)
for i in tqdm(range(len(paths))):
path = paths[i]
# 读取图像
img = cv2.imread(os.path.join(path, 'images',
os.path.basename(path) + '.png'))
# 合并所有掩码
mask = np.zeros((img.shape[0], img.shape[1]))
for mask_path in glob(os.path.join(path, 'masks', '*')):
mask_ = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 127
mask[mask_] = 1
# 处理不同通道数的图像
if len(img.shape) == 2:
img = np.tile(img[..., None], (1, 1, 3))
if img.shape[2] == 4:
img = img[..., :3]
# 调整大小
img = cv2.resize(img, (img_size, img_size))
mask = cv2.resize(mask, (img_size, img_size))
# 保存处理后的图像和掩码
cv2.imwrite(os.path.join('inputs/dsb2018_%d/images' % img_size,
os.path.basename(path) + '.png'), img)
cv2.imwrite(os.path.join('inputs/dsb2018_%d/masks/0' % img_size,
os.path.basename(path) + '.png'), (mask * 255).astype('uint8'))
if __name__ == '__main__':
main()
预处理步骤主要做了以下工作:
- 将所有图像统一调整为 96x96 大小
- 合并多个掩码文件为一个
- 处理不同通道数的图像,统一为 3 通道
- 组织成标准的数据集目录结构
数据集加载器
dataset.py实现了自定义数据集类,方便加载和预处理图像数据:
python
运行
import os
import cv2
import numpy as np
import torch
import torch.utils.data
class Dataset(torch.utils.data.Dataset):
def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):
self.img_ids = img_ids
self.img_dir = img_dir
self.mask_dir = mask_dir
self.img_ext = img_ext
self.mask_ext = mask_ext
self.num_classes = num_classes
self.transform = transform
def __len__(self):
return len(self.img_ids)
def __getitem__(self, idx):
img_id = self.img_ids[idx]
# 读取图像
img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))
# 读取掩码
mask = []
for i in range(self.num_classes):
mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),
img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])
mask = np.dstack(mask)
# 应用数据增强
if self.transform is not None:
augmented = self.transform(image=img, mask=mask)
img = augmented['image']
mask = augmented['mask']
# 归一化并调整通道顺序
img = img.astype('float32') / 255
img = img.transpose(2, 0, 1) # 从HWC转为CHW
mask = mask.astype('float32') / 255
mask = mask.transpose(2, 0, 1)
return img, mask, {'img_id': img_id}
这个数据集类支持:
- 加载多类别的掩码
- 应用数据增强(通过 albumentations 库)
- 自动进行图像归一化和通道顺序调整
模型架构
archs.py文件定义了 UNet 和 NestedUNet 两种模型架构。
VGGBlock 组件
两种模型都使用了 VGGBlock 作为基本构建块:
python
运行
class VGGBlock(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(middle_channels)
self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
return out
每个 VGGBlock 包含两个卷积层,每个卷积层后都跟着批归一化和 ReLU 激活函数。
UNet 模型
UNet 模型由编码器、解码器和跳跃连接组成:
python
运行
class UNet(nn.Module):
def __init__(self, num_classes, input_channels=3, **kwargs):
super().__init__()
nb_filter = [32, 64, 128, 256, 512] # 每个层级的滤波器数量
self.pool = nn.MaxPool2d(2, 2) # 下采样
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 上采样
# 编码器部分
self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
# 解码器部分(带跳跃连接)
self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
# 最终卷积层,输出类别数
self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
def forward(self, input):
# 编码器前向传播
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x2_0 = self.conv2_0(self.pool(x1_0))
x3_0 = self.conv3_0(self.pool(x2_0))
x4_0 = self.conv4_0(self.pool(x3_0))
# 解码器前向传播(带跳跃连接)
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
output = self.final(x0_4)
return output
NestedUNet 模型
NestedUNet(U-Net++)是 UNet 的改进版本,它引入了更多的跳跃连接,增强了特征融合:
python
运行
class NestedUNet(nn.Module):
def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
super().__init__()
nb_filter = [32, 64, 128, 256, 512]
self.deep_supervision = deep_supervision # 是否启用深度监督
self.pool = nn.MaxPool2d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# 编码器部分
self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
# 嵌套连接的解码器部分
self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
# 深度监督的输出层
if self.deep_supervision:
self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
else:
self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
def forward(self, input):
# 编码器和嵌套连接的前向传播
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
x2_0 = self.conv2_0(self.pool(x1_0))
x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
x3_0 = self.conv3_0(self.pool(x2_0))
x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
x4_0 = self.conv4_0(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
# 根据是否启用深度监督返回不同结果
if self.deep_supervision:
output1 = self.final1(x0_1)
output2 = self.final2(x0_2)
output3 = self.final3(x0_3)
output4 = self.final4(x0_4)
return [output1, output2, output3, output4]
else:
output = self.final(x0_4)
return output
NestedUNet 的主要改进是引入了更多的嵌套跳跃连接,使低层级特征能够更直接地传递到高层级,同时支持深度监督(deep supervision),即从多个层级输出结果并联合优化,有助于模型更快收敛。
损失函数
losses.py实现了适用于图像分割的损失函数:
python
运行
import torch
import torch.nn as nn
import torch.nn.functional as F
class BCEDiceLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, target):
# BCE损失
bce = F.binary_cross_entropy_with_logits(input, target)
# Dice损失
smooth = 1e-5
input = torch.sigmoid(input)
num = target.size(0)
input = input.view(num, -1)
target = target.view(num, -1)
intersection = (input * target)
dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)
dice = 1 - dice.sum() / num
# 组合损失
return 0.5 * bce + dice
class LovaszHingeLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, target):
input = input.squeeze(1)
target = target.squeeze(1)
# Lovasz Hinge损失,需要安装对应的库
loss = lovasz_hinge(input, target, per_image=True)
return loss
BCEDiceLoss 是 BCE 损失和 Dice 损失的组合,在医学图像分割中表现优异:
- BCE 损失擅长处理类别不平衡问题
- Dice 损失更关注前景区域的重叠度
评估指标
metrics.py实现了图像分割常用的评估指标:
python
运行
import numpy as np
import torch
import torch.nn.functional as F
def iou_score(output, target):
"""计算交并比(IoU)"""
smooth = 1e-5
if torch.is_tensor(output):
output = torch.sigmoid(output).data.cpu().numpy()
if torch.is_tensor(target):
target = target.data.cpu().numpy()
# 二值化输出和目标
output_ = output > 0.5
target_ = target > 0.5
# 计算交集和并集
intersection = (output_ & target_).sum()
union = (output_ | target_).sum()
return (intersection + smooth) / (union + smooth)
def dice_coef(output, target):
"""计算Dice系数"""
smooth = 1e-5
output = torch.sigmoid(output).view(-1).data.cpu().numpy()
target = target.view(-1).data.cpu().numpy()
intersection = (output * target).sum()
return (2. * intersection + smooth) / \
(output.sum() + target.sum() + smooth)
IoU(交并比)是语义分割中最常用的指标,计算预测区域与真实区域的交集和并集之比。
训练脚本
train.py实现了完整的模型训练流程:
参数解析
首先定义了可配置的训练参数:
python
运行
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--name', default="dsb2018_96_NestedUNet_woDS",
help='model name: (default: arch+timestamp)')
parser.add_argument('--epochs', default=100, type=int,
help='number of total epochs to run')
parser.add_argument('-b', '--batch_size', default=8, type=int,
help='mini-batch size (default: 8)')
# 模型参数
parser.add_argument('--arch', '-a', metavar='ARCH', default='NestedUNet',
choices=ARCH_NAMES, help='model architecture')
parser.add_argument('--deep_supervision', default=False, type=str2bool)
parser.add_argument('--input_channels', default=3, type=int,
help='input channels')
parser.add_argument('--num_classes', default=1, type=int,
help='number of classes')
parser.add_argument('--input_w', default=96, type=int,
help='image width')
parser.add_argument('--input_h', default=96, type=int,
help='image height')
# 损失函数
parser.add_argument('--loss', default='BCEDiceLoss',
choices=LOSS_NAMES, help='loss function')
# 数据集参数
parser.add_argument('--dataset', default='dsb2018_96',
help='dataset name')
parser.add_argument('--img_ext', default='.png',
help='image file extension')
parser.add_argument('--mask_ext', default='.png',
help='mask file extension')
# 优化器参数
parser.add_argument('--optimizer', default='SGD',
choices=['Adam', 'SGD'], help='optimizer')
parser.add_argument('--lr', '--learning_rate', default=1e-3, type=float,
help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float,
help='momentum')
parser.add_argument('--weight_decay', default=1e-4, type=float,
help='weight decay')
# 学习率调度器
parser.add_argument('--scheduler', default='CosineAnnealingLR',
choices=['CosineAnnealingLR', 'ReduceLROnPlateau',
'MultiStepLR', 'ConstantLR'])
# ... 其他参数
return parser.parse_args()
训练和验证函数
python
运行
def train(config, train_loader, model, criterion, optimizer):
avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}
model.train() # 设置为训练模式
pbar = tqdm(total=len(train_loader))
for input, target, _ in train_loader:
input = input.cuda()
target = target.cuda()
# 前向传播
if config['deep_supervision']:
outputs = model(input)
loss = 0
# 深度监督:对所有输出计算损失并平均
for output in outputs:
loss += criterion(output, target)
loss /= len(outputs)
iou = iou_score(outputs[-1], target)
else:
output = model(input)
loss = criterion(output, target)
iou = iou_score(output, target)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 更新指标
avg_meters['loss'].update(loss.item(), input.size(0))
avg_meters['iou'].update(iou, input.size(0))
pbar.set_postfix(loss=avg_meters['loss'].avg, iou=avg_meters['iou'].avg)
pbar.update(1)
pbar.close()
return {'loss': avg_meters['loss'].avg, 'iou': avg_meters['iou'].avg}
def validate(config, val_loader, model, criterion):
avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}
model.eval() # 设置为评估模式
with torch.no_grad(): # 禁用梯度计算
pbar = tqdm(total=len(val_loader))
for input, target, _ in val_loader:
input = input.cuda()
target = target.cuda()
# 前向传播
if config['deep_supervision']:
outputs = model(input)
loss = 0
for output in outputs:
loss += criterion(output, target)
loss /= len(outputs)
iou = iou_score(outputs[-1], target)
else:
output = model(input)
loss = criterion(output, target)
iou = iou_score(output, target)
# 更新指标
avg_meters['loss'].update(loss.item(), input.size(0))
avg_meters['iou'].update(iou, input.size(0))
pbar.set_postfix(loss=avg_meters['loss'].avg, iou=avg_meters['iou'].avg)
pbar.update(1)
pbar.close()
return {'loss': avg_meters['loss'].avg, 'iou': avg_meters['iou'].avg}
主函数
python
运行
def main():
config = vars(parse_args())
# 创建输出目录
os.makedirs('models/%s' % config['name'], exist_ok=True)
# 保存配置
with open('models/%s/config.yml' % config['name'], 'w') as f:
yaml.dump(config, f)
# 定义损失函数
if config['loss'] == 'BCEWithLogitsLoss':
criterion = nn.BCEWithLogitsLoss().cuda()
else:
criterion = losses.__dict__[config['loss']]().cuda()
# 启用cudnn加速
cudnn.benchmark = True
# 创建模型
print("=> creating model %s" % config['arch'])
model = archs.__dict__[config['arch']](config['num_classes'],
config['input_channels'],
config['deep_supervision'])
model = model.cuda()
# 定义优化器
params = filter(lambda p: p.requires_grad, model.parameters())
if config['optimizer'] == 'Adam':
optimizer = optim.Adam(
params, lr=config['lr'], weight_decay=config['weight_decay'])
elif config['optimizer'] == 'SGD':
optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],
nesterov=config['nesterov'], weight_decay=config['weight_decay'])
# 定义学习率调度器
if config['scheduler'] == 'CosineAnnealingLR':
scheduler = lr_scheduler.CosineAnnealingLR(
optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
elif config['scheduler'] == 'ReduceLROnPlateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'],
patience=config['patience'],
verbose=1, min_lr=config['min_lr'])
# ... 其他调度器
# 数据加载
img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)
# 数据增强
train_transform = Compose([
albu.RandomRotate90(),
albu.Flip(),
OneOf([
transforms.HueSaturationValue(),
transforms.RandomBrightness(),
transforms.RandomContrast(),
], p=1),
albu.Resize(config['input_h'], config['input_w']),
transforms.Normalize(),
])
val_transform = Compose([
albu.Resize(config['input_h'], config['input_w']),
transforms.Normalize(),
])
# 创建数据加载器
train_dataset = Dataset(...)
val_dataset = Dataset(...)
train_loader = torch.utils.data.DataLoader(...)
val_loader = torch.utils.data.DataLoader(...)
# 训练循环
log = {'epoch': [], 'lr': [], 'loss': [], 'iou': [], 'val_loss': [], 'val_iou': []}
best_iou = 0
trigger = 0
for epoch in range(config['epochs']):
print('Epoch [%d/%d]' % (epoch, config['epochs']))
# 训练一个epoch
train_log = train(config, train_loader, model, criterion, optimizer)
# 验证
val_log = validate(config, val_loader, model, criterion)
# 更新学习率
if config['scheduler'] == 'CosineAnnealingLR':
scheduler.step()
elif config['scheduler'] == 'ReduceLROnPlateau':
scheduler.step(val_log['loss'])
# 打印日志
print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'
% (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))
# 保存日志
log['epoch'].append(epoch)
log['lr'].append(config['lr'])
log['loss'].append(train_log['loss'])
log['iou'].append(train_log['iou'])
log['val_loss'].append(val_log['loss'])
log['val_iou'].append(val_log['iou'])
pd.DataFrame(log).to_csv('models/%s/log.csv' % config['name'], index=False)
# 保存最佳模型
if val_log['iou'] > best_iou:
torch.save(model.state_dict(), 'models/%s/model.pth' % config['name'])
best_iou = val_log['iou']
print("=> saved best model")
trigger = 0
# 早停机制
if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:
print("=> early stopping")
break
torch.cuda.empty_cache()
验证与可视化
val.py用于加载训练好的模型进行验证,并可视化分割结果:
python
运行
def main():
args = parse_args()
# 加载配置
with open('models/%s/config.yml' % args.name, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
# 创建模型
model = archs.__dict__[config['arch']](config['num_classes'],
config['input_channels'],
config['deep_supervision'])
model = model.cuda()
# 加载模型权重
model.load_state_dict(torch.load('models/%s/model.pth' % config['name']))
model.eval()
# 准备数据
img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
_, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)
# 加载验证集
val_transform = Compose([
albu.Resize(config['input_h'], config['input_w']),
transforms.Normalize(),
])
val_dataset = Dataset(...)
val_loader = torch.utils.data.DataLoader(...)
# 评估并保存结果
avg_meter = AverageMeter()
for c in range(config['num_classes']):
os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True)
with torch.no_grad():
for input, target, meta in tqdm(val_loader, total=len(val_loader)):
input = input.cuda()
target = target.cuda()
# 模型预测
if config['deep_supervision']:
output = model(input)[-1]
else:
output = model(input)
# 计算IoU
iou = iou_score(output, target)
avg_meter.update(iou, input.size(0))
# 保存输出结果
output = torch.sigmoid(output).cpu().numpy()
for i in range(len(output)):
for c in range(config['num_classes']):
cv2.imwrite(os.path.join('outputs', config['name'], str(c),
meta['img_id'][i] + '.jpg'),
(output[i, c] * 255).astype('uint8'))
print('IoU: %.4f' % avg_meter.avg)
# 可视化结果
plot_examples(input, target, model, num_examples=3)
可视化函数:
python
运行
def plot_examples(datax, datay, model, num_examples=6):
fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18,4*num_examples))
m = datax.shape[0]
for row_num in range(num_examples):
image_indx = np.random.randint(m)
# 获取模型预测
image_arr = model(datax[image_indx:image_indx+1]).squeeze(0).detach().cpu().numpy()
# 绘制原图
ax[row_num][0].imshow(np.transpose(datax[image_indx].cpu().numpy(), (1,2,0))[:,:,0])
ax[row_num][0].set_title("Original Image")
# 绘制分割结果
ax[row_num][1].imshow(np.squeeze((image_arr > 0.40)[0,:,:].astype(int)))
ax[row_num][1].set_title("Segmented Image")
# 绘制目标掩码
ax[row_num][2].imshow(np.transpose(datay[image_indx].cpu().numpy(), (1,2,0))[:,:,0])
ax[row_num][2].set_title("Target Mask")
plt.show()
训练与使用指南
-
数据准备:
bash
python preprocess_dsb2018.py -
模型训练:
bash
python train.py --dataset dsb2018_96 --arch NestedUNet --epochs 100 --batch_size 8 -
模型验证:
bash
python val.py --name dsb2018_96_NestedUNet_woDS
总结
本文详细介绍了基于 PyTorch 的 UNet 和 NestedUNet 图像分割模型的实现。通过这个项目,我们可以学习到:
- 如何构建经典的 UNet 模型及其改进版本 NestedUNet
- 如何设计适用于图像分割的损失函数(如 BCEDiceLoss)
- 如何实现完整的训练流程,包括数据加载、数据增强、模型训练和验证
- 如何评估分割模型的性能(使用 IoU 等指标)
该项目可以作为图像分割任务的基础框架,通过修改数据集加载部分和调整模型参数,可应用于不同的分割任务中。NestedUNet 通过增加嵌套连接和深度监督,通常能比传统 UNet 获得更好的分割性能,但计算成本也更高,实际应用中可根据需求选择合适的模型。