一、项目背景:为什么要做细胞核分割?
细胞核分割是医学影像分析的基础任务之一,在病理诊断、细胞计数、疾病研究中都有重要应用。比如:
- 病理医生通过分析细胞核的形态(大小、形状、分布)判断细胞是否癌变;
- 细胞实验中,需要精确分割单个细胞核以统计数量或观察分裂状态。
传统方法依赖人工标注或阈值分割,效率低且精度差。而深度学习模型(如 U-net 系列)能自动学习细胞核的特征,实现高精度分割,大幅降低人工成本。
我们今天的目标是:用 NestedUNet 模型(U-net++ 的改进版)实现细胞核自动分割,最终在验证集上达到较高的 IoU(交并比,分割任务的核心指标)。
二、环境准备:一行代码搞定依赖
首先确保你的环境安装了以下库,推荐用 Anaconda 创建虚拟环境(避免版本冲突):
bash
# 创建虚拟环境(可选但推荐)
conda create -n seg_env python=3.8
conda activate seg_env
# 安装核心依赖
pip install torch torchvision torchaudio # PyTorch框架(根据CUDA版本选择,详见官网)
pip install albumentations # 数据增强库(比torchvision更强大)
pip install numpy pandas matplotlib # 数据处理与可视化
pip install scikit-image tqdm # 图像处理与进度条
验证环境 :运行python -c "import torch; print(torch.cuda.is_available())",输出True说明 GPU 可用(训练会快 10 倍以上),False则用 CPU 训练(适合入门调试)。
三、数据集解析:2018 Data Science Bowl 细胞核数据
我们使用的数据集是dsb2018_96,源自 2018 年 Data Science Bowl 比赛,已预处理为 96×96 的小尺寸图像,非常适合新手练手。
1. 数据集结构
数据集按 "图像 - 掩码" 对应存储,目录结构如下:
plaintext
inputs/
└── dsb2018_96/ # 数据集名称
├── images/ # 输入图像(细胞核原始图)
│ ├── 0.png
│ ├── 1.png
│ ...
└── masks/ # 掩码(标注的细胞核区域)
├── 0.png
├── 1.png
...
- 图像(images):96×96 像素的灰度图(单通道),显示细胞核的显微镜图像;
- 掩码(masks):与图像同名的二值图,白色区域(像素值 1)表示细胞核,黑色区域(像素值 0)表示背景。
2. 数据特点
- 任务类型:二分类语义分割(仅区分 "细胞核" 和 "背景");
- 难点:细胞核大小不一、形状不规则,且存在重叠(比如两个细胞核粘在一起),对模型的细节捕捉能力要求高;
- 数据量:约 600 张训练图 + 150 张验证图(按 8:2 划分),数量适中,适合中等规模模型训练。
3. 数据获取
如果你没有数据集,可以按以下方式生成类似结构:
-
从Kaggle 官网下载原始 DSB2018 数据;
-
用
scikit-image将图像 Resize 到 96×96:python
运行
from skimage import io, transform img = io.imread("original_image.png") img_resized = transform.resize(img, (96, 96), anti_aliasing=True) io.imsave("inputs/dsb2018_96/images/0.png", img_resized)
四、代码实战:从 0 到 1 训练 NestedUNet
我们的代码分为 5 个核心模块:参数配置、数据加载、模型定义、训练 / 验证循环、主流程。每个模块都有详细注释,确保你能看懂每一行的作用。
1. 完整代码结构
先看整体框架,后面会逐部分解析:
python
运行
import os
import argparse
import yaml
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import matplotlib.pyplot as plt
# 自定义模块(后面会实现)
from dataset import SegDataset # 数据集类
from archs import NestedUNet # 模型类
from loss import BCEDiceLoss # 损失函数
from utils import calculate_iou # 评估指标计算
# 参数解析
def parse_args():
# 省略,后面详细讲
pass
# 数据加载
def get_loaders(args):
# 省略,后面详细讲
pass
# 训练函数
def train_fn(train_loader, model, criterion, optimizer, device):
# 省略,后面详细讲
pass
# 验证函数
def validate_fn(valid_loader, model, criterion, device):
# 省略,后面详细讲
pass
# 主函数
def main():
# 省略,后面详细讲
pass
if __name__ == "__main__":
main()
2. 参数配置(parse_args 函数)
通过命令行参数灵活配置训练细节,核心参数如下(可根据需求调整):
python
运行
def parse_args():
parser = argparse.ArgumentParser()
# 模型参数
parser.add_argument("--arch", default="NestedUNet", help="模型架构(NestedUNet/Unet等)")
parser.add_argument("--deep_supervision", action="store_true", help="是否使用深度监督")
parser.add_argument("--input_channels", default=1, type=int, help="输入通道数(灰度图为1,RGB为3)")
parser.add_argument("--num_classes", default=1, type=int, help="输出类别数(二分类为1)")
# 训练参数
parser.add_argument("--epochs", default=50, type=int, help="训练轮数")
parser.add_argument("--batch_size", default=16, type=int, help="批次大小")
parser.add_argument("--lr", default=1e-4, type=float, help="初始学习率")
parser.add_argument("--loss", default="bce_dice", help="损失函数(bce/bce_dice)")
parser.add_argument("--optimizer", default="adam", help="优化器(adam/sgd)")
parser.add_argument("--scheduler", default="cosine", help="学习率调度器")
# 数据参数
parser.add_argument("--dataset", default="dsb2018_96", help="数据集名称")
parser.add_argument("--img_ext", default=".png", help="图像文件扩展名")
parser.add_argument("--mask_ext", default=".png", help="掩码文件扩展名")
parser.add_argument("--input_w", default=96, type=int, help="图像宽度")
parser.add_argument("--input_h", default=96, type=int, help="图像高度")
# 其他参数
parser.add_argument("--name", default="nested_unet_dsb2018", help="实验名称(用于保存模型)")
parser.add_argument("--early_stopping", default=10, type=int, help="早停轮数(防止过拟合)")
return parser.parse_args()
关键参数说明:
--deep_supervision:NestedUNet 的核心特性,开启后模型会在多个解码阶段输出结果,损失函数对多输出加权,提升小目标分割精度;--loss:推荐用bce_dice(BCE 损失 + Dice 损失),BCE 擅长平衡类别,Dice 擅长处理样本不平衡(细胞核像素少);--early_stopping:若连续 10 轮验证集 IoU 不提升,则停止训练,避免过拟合。
3. 数据加载与增强(get_loaders 函数)
数据增强是提升分割精度的关键,尤其是医学数据量少时,通过增强可以 "伪造" 更多样本,提升模型泛化能力。
(1)自定义数据集类(dataset.py)
python
运行
import os
import numpy as np
from skimage import io
import torch
from torch.utils.data import Dataset
class SegDataset(Dataset):
def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, transform=None):
self.img_ids = img_ids # 图像文件名列表(不含扩展名)
self.img_dir = img_dir # 图像目录
self.mask_dir = mask_dir # 掩码目录
self.img_ext = img_ext # 图像扩展名
self.mask_ext = mask_ext # 掩码扩展名
self.transform = transform # 数据增强器
def __len__(self):
return len(self.img_ids)
def __getitem__(self, idx):
img_id = self.img_ids[idx]
# 读取图像和掩码(转为float32,便于PyTorch处理)
img = io.imread(os.path.join(self.img_dir, img_id + self.img_ext)).astype(np.float32)
mask = io.imread(os.path.join(self.mask_dir, img_id + self.mask_ext)).astype(np.float32)
# 若图像是单通道(灰度图),添加通道维度([H,W]→[H,W,1])
if len(img.shape) == 2:
img = img[..., np.newaxis]
mask = mask[..., np.newaxis]
# 应用数据增强
if self.transform is not None:
augmented = self.transform(image=img, mask=mask)
img = augmented["image"]
mask = augmented["mask"]
# 掩码二值化(确保只有0和1)
mask = (mask > 0.5).astype(np.float32)
return img, mask, img_id
(2)数据增强与加载器
python
运行
def get_loaders(args):
# 数据路径
img_dir = os.path.join("inputs", args.dataset, "images")
mask_dir = os.path.join("inputs", args.dataset, "masks")
img_ids = [os.path.splitext(f)[0] for f in os.listdir(img_dir) if f.endswith(args.img_ext)]
# 划分训练集和验证集(8:2,随机种子41确保可复现)
train_img_ids, valid_img_ids = train_test_split(
img_ids, test_size=0.2, random_state=41
)
# 训练集增强:随机旋转、翻转、色彩抖动(提升模型鲁棒性)
train_transform = A.Compose([
A.RandomRotate90(), # 随机旋转90度
A.Flip(), # 随机水平/垂直翻转
A.OneOf([ # 随机选一种色彩增强
A.RandomBrightnessContrast(),
A.RandomGamma(),
], p=0.5),
A.Resize(args.input_h, args.input_w), # 调整尺寸
A.Normalize(mean=[0.485], std=[0.229]), # 归一化(单通道用一个均值和标准差)
ToTensorV2(), # 转为PyTorch张量([H,W,C]→[C,H,W])
])
# 验证集增强:仅调整尺寸和归一化(不添加噪声,保证评估准确)
valid_transform = A.Compose([
A.Resize(args.input_h, args.input_w),
A.Normalize(mean=[0.485], std=[0.229]),
ToTensorV2(),
])
# 创建数据集和加载器
train_dataset = SegDataset(
train_img_ids, img_dir, mask_dir, args.img_ext, args.mask_ext, train_transform
)
valid_dataset = SegDataset(
valid_img_ids, img_dir, mask_dir, args.img_ext, args.mask_ext, valid_transform
)
train_loader = DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True
)
valid_loader = DataLoader(
valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True
)
return train_loader, valid_loader, train_img_ids, valid_img_ids
增强技巧:
- 训练集用
OneOf随机选一种增强,避免过度增强导致特征失真; - 验证集不做随机变换,确保评估结果稳定;
- 单通道图像的归一化均值 / 标准差可根据数据集统计(这里用 ImageNet 的近似值)。
4. 模型定义:NestedUNet(U-net++)
NestedUNet 是 U-net 的升级版,通过密集特征融合 和深度监督解决 U-net 的 "语义鸿沟" 问题,特别适合分割小目标(如细胞核)。
核心结构(简化版,完整代码见archs.py):
python
运行
import torch
import torch.nn as nn
import torch.nn.functional as F
class NestedUNet(nn.Module):
def __init__(self, input_channels=1, num_classes=1, deep_supervision=True):
super().__init__()
self.deep_supervision = deep_supervision
# 编码端(下采样):提取语义特征
self.down1 = self._down_block(input_channels, 64) # 输出64通道
self.down2 = self._down_block(64, 128) # 输出128通道
self.down3 = self._down_block(128, 256) # 输出256通道
self.down4 = self._down_block(256, 512) # 输出512通道
# 瓶颈层(最深层)
self.center = self._conv_block(512, 1024) # 输出1024通道
# 解码端(上采样):密集特征融合
self.up4 = self._up_block(1024, 512)
self.up3 = self._up_block(512, 256)
self.up2 = self._up_block(256, 128)
self.up1 = self._up_block(128, 64)
# 输出层(深度监督:多个输出分支)
self.out1 = nn.Conv2d(64, num_classes, kernel_size=1)
self.out2 = nn.Conv2d(128, num_classes, kernel_size=1)
self.out3 = nn.Conv2d(256, num_classes, kernel_size=1)
self.out4 = nn.Conv2d(512, num_classes, kernel_size=1)
# 卷积块(2次卷积+ReLU)
def _conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
# 下采样块(卷积块+最大池化)
def _down_block(self, in_channels, out_channels):
return nn.Sequential(
self._conv_block(in_channels, out_channels),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# 上采样块(上采样+特征拼接+卷积块)
def _up_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(in_channels, out_channels, kernel_size=1), # 降维
self._conv_block(out_channels * 2, out_channels) # 拼接编码端特征(×2是因为拼接)
)
def forward(self, x):
# 编码端输出
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
# 瓶颈层
center = self.center(x4)
# 解码端输出(密集融合)
up4 = self.up4(center)
up3 = self.up3(up4)
up2 = self.up2(up3)
up1 = self.up1(up2)
# 深度监督:输出多个分支
out1 = self.out1(up1)
if self.deep_supervision:
out2 = self.out2(up2)
out3 = self.out3(up3)
out4 = self.out4(up4)
return [out1, out2, out3, out4] # 多输出用于深度监督
else:
return out1
NestedUNet 核心优势:
- 解码端每个阶段都融合编码端多个层次的特征,解决 "语义鸿沟";
- 深度监督(多输出)让模型同时学习粗粒度和细粒度特征,小目标分割更准。
5. 损失函数:BCEDiceLoss(平衡类别 + 样本)
细胞核分割中,背景像素远多于细胞核(样本不平衡),且边界难区分,因此需要定制损失函数:
python
运行
import torch
import torch.nn as nn
import torch.nn.functional as F
class BCEDiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super().__init__()
def forward(self, inputs, targets, smooth=1):
# Sigmoid激活(将输出转为0-1概率)
inputs = torch.sigmoid(inputs)
# 展平张量(计算全局损失)
inputs = inputs.view(-1)
targets = targets.view(-1)
# BCE损失(处理类别不平衡)
bce_loss = F.binary_cross_entropy(inputs, targets, reduction='mean')
# Dice损失(衡量重叠度,对边界敏感)
intersection = (inputs * targets).sum()
dice_loss = 1 - (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
# 总损失:BCE + Dice(权重可调整)
return bce_loss + dice_loss
为什么用组合损失:
- BCE 损失:通过交叉熵惩罚错分样本,适合平衡正负类别;
- Dice 损失:直接衡量预测与真实掩码的重叠度,对边界误差更敏感,适合分割任务。
6. 训练与验证循环
训练循环的核心是 "正向传播算损失→反向传播更新参数→验证集评估泛化能力":
(1)训练函数
python
运行
def train_fn(train_loader, model, criterion, optimizer, device):
model.train() # 训练模式(启用Dropout、BN等)
total_loss = 0.0
total_iou = 0.0
# 进度条显示训练过程
loop = tqdm(train_loader, total=len(train_loader))
for imgs, masks, _ in loop:
# 数据移到GPU/CPU
imgs = imgs.to(device)
masks = masks.to(device)
# 梯度清零
optimizer.zero_grad()
# 正向传播
outputs = model(imgs)
# 计算损失(深度监督时,对多个输出加权)
if isinstance(outputs, list):
loss = 0.0
for out in outputs:
loss += criterion(out, masks)
loss /= len(outputs) # 平均多输出损失
else:
loss = criterion(outputs, masks)
# 反向传播+参数更新
loss.backward()
optimizer.step()
# 计算IoU(评估指标)
with torch.no_grad(): # 不计算梯度,节省内存
if isinstance(outputs, list):
pred = torch.sigmoid(outputs[0]) # 用第一个输出(最精细)计算IoU
else:
pred = torch.sigmoid(outputs)
pred = (pred > 0.5).float() # 二值化(0.5为阈值)
iou = calculate_iou(pred, masks)
# 累计损失和IoU
total_loss += loss.item()
total_iou += iou.item()
# 更新进度条
loop.set_postfix(loss=loss.item(), iou=iou.item())
# 计算平均损失和IoU
avg_loss = total_loss / len(train_loader)
avg_iou = total_iou / len(train_loader)
return avg_loss, avg_iou
(2)验证函数
python
运行
def validate_fn(valid_loader, model, criterion, device):
model.eval() # 评估模式(冻结BN、Dropout)
total_loss = 0.0
total_iou = 0.0
with torch.no_grad(): # 验证时不计算梯度
loop = tqdm(valid_loader, total=len(valid_loader))
for imgs, masks, _ in loop:
imgs = imgs.to(device)
masks = masks.to(device)
outputs = model(imgs)
# 计算损失(同训练函数)
if isinstance(outputs, list):
loss = 0.0
for out in outputs:
loss += criterion(out, masks)
loss /= len(outputs)
else:
loss = criterion(outputs, masks)
# 计算IoU
if isinstance(outputs, list):
pred = torch.sigmoid(outputs[0])
else:
pred = torch.sigmoid(outputs)
pred = (pred > 0.5).float()
iou = calculate_iou(pred, masks)
total_loss += loss.item()
total_iou += iou.item()
loop.set_postfix(loss=loss.item(), iou=iou.item())
avg_loss = total_loss / len(valid_loader)
avg_iou = total_iou / len(valid_loader)
return avg_loss, avg_iou
(3)IoU 计算函数(utils.py)
python
运行
import torch
def calculate_iou(pred, target, smooth=1e-6):
# pred和target都是二值张量(0或1)
intersection = (pred & target).sum()
union = (pred | target).sum()
iou = (intersection + smooth) / (union + smooth)
return iou
7. 主流程:整合所有模块
python
运行
def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备:{device}")
# 创建模型保存目录
os.makedirs(f"models/{args.name}", exist_ok=True)
# 保存配置参数(方便复现)
with open(f"models/{args.name}/config.yml", "w") as f:
yaml.dump(vars(args), f)
# 加载数据
train_loader, valid_loader, train_ids, valid_ids = get_loaders(args)
print(f"训练集样本数:{len(train_ids)},验证集样本数:{len(valid_ids)}")
# 初始化模型、损失函数、优化器
model = NestedUNet(
input_channels=args.input_channels,
num_classes=args.num_classes,
deep_supervision=args.deep_supervision
).to(device)
if args.loss == "bce_dice":
criterion = BCEDiceLoss()
else:
criterion = nn.BCEWithLogitsLoss() # 自带Sigmoid
if args.optimizer == "adam":
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
else:
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
# 学习率调度器(cosine退火,自动调整学习率)
if args.scheduler == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs, eta_min=1e-6
)
else:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.5, patience=5
)
# 记录训练日志
log = {
"train_loss": [], "train_iou": [],
"valid_loss": [], "valid_iou": []
}
# 早停相关变量
best_iou = 0.0
early_stopping_counter = 0
# 训练循环
for epoch in range(1, args.epochs + 1):
print(f"\n===== Epoch {epoch}/{args.epochs} =====")
# 训练
train_loss, train_iou = train_fn(train_loader, model, criterion, optimizer, device)
# 验证
valid_loss, valid_iou = validate_fn(valid_loader, model, criterion, device)
# 更新日志
log["train_loss"].append(train_loss)
log["train_iou"].append(train_iou)
log["valid_loss"].append(valid_loss)
log["valid_iou"].append(valid_iou)
print(f"训练集:损失={train_loss:.4f},IoU={train_iou:.4f}")
print(f"验证集:损失={valid_loss:.4f},IoU={valid_iou:.4f}")
# 调整学习率
if args.scheduler == "cosine":
scheduler.step()
else:
scheduler.step(valid_iou) # 基于验证集IoU调整
# 保存最佳模型(验证集IoU最高)
if valid_iou > best_iou:
best_iou = valid_iou
torch.save(model.state_dict(), f"models/{args.name}/best_model.pth")
print(f"保存最佳模型(IoU={best_iou:.4f})")
early_stopping_counter = 0 # 重置早停计数器
else:
early_stopping_counter += 1
print(f"早停计数器:{early_stopping_counter}/{args.early_stopping}")
if early_stopping_counter >= args.early_stopping:
print("早停触发,停止训练")
break
# 绘制训练曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(log["train_loss"], label="Train Loss")
plt.plot(log["valid_loss"], label="Valid Loss")
plt.title("Loss Curve")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(log["train_iou"], label="Train IoU")
plt.plot(log["valid_iou"], label="Valid IoU")
plt.title("IoU Curve")
plt.legend()
plt.savefig(f"models/{args.name}/curves.png")
print(f"训练曲线已保存到 models/{args.name}/curves.png")
if __name__ == "__main__":
main()
五、训练结果与分析
1. 预期效果
在 GTX 1080Ti 上训练 50 轮(约 1 小时),验证集 IoU 可达 0.85 以上(越高越好,1.0 为完美分割)。训练曲线应呈现:
- 损失曲线:训练集和验证集损失均逐渐下降,且差距不大(无过拟合);
- IoU 曲线:训练集和验证集 IoU 均逐渐上升,最终稳定在 0.85 左右。
2. 分割结果可视化
随机选择验证集图像,对比 "原始图像→真实掩码→模型预测":
python
运行
import matplotlib.pyplot as plt
from skimage import io
# 加载模型(略)
model.load_state_dict(torch.load("models/nested_unet_dsb2018/best_model.pth"))
model.eval()
# 取一张验证集图像
img, mask, img_id = valid_dataset[0]
with torch.no_grad():
pred = model(img.unsqueeze(0).to(device)) # 加batch维度
pred = torch.sigmoid(pred[0])[0].cpu().numpy() # 转为numpy
pred = (pred > 0.5).astype(np.float32) # 二值化
# 可视化
plt.figure(figsize=(12, 4))
plt.subplot(131)
plt.imshow(img[0], cmap="gray") # 原始图像(单通道)
plt.title("Original Image")
plt.subplot(132)
plt.imshow(mask[0], cmap="gray") # 真实掩码
plt.title("True Mask")
plt.subplot(133)
plt.imshow(pred[0], cmap="gray") # 预测掩码
plt.title("Predicted Mask")
plt.show()
理想结果:预测掩码与真实掩码高度重合,尤其是细胞核的边缘和重叠区域能被准确分割。
3. 常见问题与调优
-
过拟合:训练集 IoU 高(>0.9),验证集 IoU 低(<0.7)。解决:增加数据增强强度(如添加高斯噪声)、减小模型深度、使用早停。
-
分割边界模糊:预测掩码边缘不清晰。解决:增加 Dice 损失权重(让模型更关注边界)、使用更大的输入尺寸(如 128×128)。
-
小细胞核漏检 :小目标未被分割。解决:开启深度监督(
--deep_supervision)、减小批次大小(让模型更关注小样本)。
六、总结与拓展
通过这个项目,你已经掌握了图像分割的核心流程:
- 数据预处理与增强(提升模型鲁棒性的关键);
- NestedUNet 模型的原理与实现(密集融合 + 深度监督);
- 损失函数与评估指标(BCE+Dice 损失、IoU 计算);
- 训练循环与调优技巧(早停、学习率调度)。
拓展方向
- 尝试更先进的模型:如 U-net+++、SegFormer(结合 Transformer);
- 多模态数据:融合细胞核的染色图像和荧光图像,提升分割精度;
- 后处理优化:用形态学操作(如腐蚀、膨胀)去除预测掩码中的噪声。
希望这篇教程能帮你快速入门图像分割,如果你在实战中遇到问题,欢迎在评论区交流~ 代码已整理到 GitHub,关注我获取完整项目链接!