在计算机视觉任务中,语义分割 (Semantic Segmentation)是核心方向之一,目标是给图像中每个像素分配对应类别标签。从零搭建分割模型繁琐且耗时,而 segmentation-models-pytorch(简称 smp)是 PyTorch 生态中最实用、开箱即用的分割模型库,封装了 UNet、FPN、DeepLabV3+ 等经典算法,支持多种骨干网络,无需复杂配置即可快速训练高精度分割模型。
本文带你从零开始,完成环境安装 → 数据集构建 → 模型定义 → 训练/验证 → 推理预测全流程,新手也能直接复刻运行。
一、库介绍与核心优势
segmentation-models-pytorch 是基于 PyTorch 的语义分割工具库,核心特点:
-
支持主流模型:UNet、UNet++、FPN、PSPNet、DeepLabV3、DeepLabV3+ 等;
-
丰富骨干网络:ResNet、MobileNet、EfficientNet 等,可自由搭配,兼顾精度与速度;
-
开箱即用:预训练权重、损失函数(DiceLoss、JaccardLoss)、评价指标(IoU、F1)已封装;
-
简洁 API:一行代码定义模型,大幅降低开发成本。
适用场景:医学图像分割、遥感图像分割、自动驾驶场景分割、工业缺陷检测等。
二、环境安装
首先安装核心依赖库,smp 依赖 PyTorch,建议提前配置好 CUDA 加速训练:
# 安装 PyTorch(根据你的 CUDA 版本选择,官网复制对应命令)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # 安装 segmentation-models-pytorch
pip install segmentation-models-pytorch
# 安装辅助库
pip install numpy opencv-python pillow matplotlib tqdm
三、数据集准备
语义分割数据集标准格式:原图 + 对应掩码图
-
原图:RGB 图像(.jpg/.png)
-
掩码图:单通道灰度图,像素值为类别编号(如 0=背景,1=目标)
推荐目录结构(自定义数据集通用):

自定义数据集类(PyTorch Dataset)
我们封装一个通用的分割数据集加载类,支持任意自定义数据集:
import os
import cv2
import numpy as np
from torch.utils.data import Dataset
class SegmentationDataset(Dataset):
def __init__(self, image_dir, mask_dir, transforms=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transforms = transforms
# 获取文件名(保证原图与掩码一一对应)
self.filenames = sorted(os.listdir(image_dir))
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
# 读取图像
img_path = os.path.join(self.image_dir, self.filenames[idx])
mask_path = os.path.join(self.mask_dir, self.filenames[idx])
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR转RGB
mask = cv2.imread(mask_path, 0) # 读取单通道灰度掩码
# 数据增强
if self.transforms:
augmented = self.transforms(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
return image, mask
四、核心配置与数据增强
定义训练超参数、数据增强策略,使用 albumentations 库做分割专用增强(保证图像与掩码同步变换):
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
# ===================== 超参数配置 =====================
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CLASSES = 1 # 分割类别数(单类别=1,多类别修改为对应数量)
EPOCHS = 20
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
# ===================== 数据增强 =====================
# 训练集增强(随机裁剪、翻转、归一化)
train_transform = A.Compose([
A.Resize(height=256, width=256),
A.RandomCrop(height=224, width=224),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
# 验证集增强(仅 resize + 归一化,无随机增强)
val_transform = A.Compose([
A.Resize(height=224, width=224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
五、一行代码定义分割模型
smp 最大优势:模型+骨干网络一键组合,无需手动搭建网络结构。
常用模型示例
"""
segmentation_models_pytorch 多模型对比训练 (DDP 多卡)
支持 8 卡分布式训练, 自动对比各模型精度与速度, 输出排行榜
启动方式 (8卡):
torchrun --nproc_per_node=8 pytorch_demo/unet_train.py
启动方式 (单卡):
python pytorch_demo/unet_train.py
"""
import os
import time
import csv
from glob import glob
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import segmentation_models_pytorch as smp
# ===================== 分布式工具 =====================
def get_world_info():
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
return rank, local_rank, world_size
def is_main_process():
return int(os.environ.get("RANK", 0)) == 0
# ===================== 配置 =====================
class Config:
# 数据路径
train_img_dir = "data/road_B_4/train/images"
train_mask_dir = "data/road_B_4/train/masks"
val_img_dir = "data/road_B_4/val/images"
val_mask_dir = "data/road_B_4/val/masks"
# 训练参数
image_size = (512, 512)
batch_size = 8 # 每卡 batch (8卡总batch=64)
epochs = 100
lr = 1e-3
num_workers = 4
# 损失权重
dice_weight = 0.5
bce_weight = 0.5
# EarlyStopping
early_stop_patience = 15
early_stop_min_delta = 1e-4
# 输出
save_dir = "checkpoints"
log_dir = "logs"
# 待训练模型列表: (显示名, 模型类名, backbone)
model_list = [
("Unet", "Unet", "resnet34"),
("Unet++", "UnetPlusPlus", "resnet34"),
("DeepLabV3+", "DeepLabV3Plus", "resnet34"),
("FPN", "FPN", "resnet34"),
("PSPNet", "PSPNet", "resnet34"),
("PAN", "PAN", "resnet34"),
("MAnet", "MAnet", "resnet34"),
]
config = Config()
# ===================== 数据集 =====================
class SegmentationDataset(Dataset):
def __init__(self, img_dir, mask_dir, image_size=(512, 512)):
self.img_paths = sorted(glob(os.path.join(img_dir, "*")))
self.mask_paths = sorted(glob(os.path.join(mask_dir, "*")))
self.image_size = image_size
assert len(self.img_paths) == len(self.mask_paths), \
f"图片({len(self.img_paths)})和掩码({len(self.mask_paths)})数量不一致"
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
image = cv2.imread(self.img_paths[idx])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, self.image_size)
mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, self.image_size, interpolation=cv2.INTER_NEAREST)
image = image.astype(np.float32) / 255.0
mask = mask.astype(np.float32) / 255.0
image = torch.from_numpy(image).permute(2, 0, 1).float()
mask = torch.from_numpy(mask).unsqueeze(0).float()
return image, mask
# ===================== 损失函数 =====================
class DiceBCELoss(nn.Module):
def __init__(self, dice_weight=0.5, bce_weight=0.5):
super().__init__()
self.dice_weight = dice_weight
self.bce_weight = bce_weight
self.bce = nn.BCEWithLogitsLoss()
def forward(self, pred, target):
pred_probs = torch.sigmoid(pred)
smooth = 1.0
pred_flat = pred_probs.contiguous().view(-1)
target_flat = target.contiguous().view(-1)
intersection = (pred_flat * target_flat).sum()
dice_loss = 1 - (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
bce_loss = self.bce(pred, target)
return self.dice_weight * dice_loss + self.bce_weight * bce_loss
# ===================== 评估指标 =====================
def compute_iou(pred, target, threshold=0.5):
pred = (torch.sigmoid(pred) > threshold).float()
intersection = (pred * target).sum()
union = pred.sum() + target.sum() - intersection
if union == 0:
return torch.tensor(1.0, device=pred.device)
return intersection / union
def compute_dice(pred, target, threshold=0.5):
pred = (torch.sigmoid(pred) > threshold).float()
smooth = 1.0
pred_flat = pred.contiguous().view(-1)
target_flat = target.contiguous().view(-1)
intersection = (pred_flat * target_flat).sum()
return (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
def reduce_tensor(tensor, world_size):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
return rt / world_size
# ===================== 训练器 =====================
class Trainer:
def __init__(self, model, config, device, local_rank, rank, world_size, model_name):
self.model = model.to(device)
self.config = config
self.device = device
self.local_rank = local_rank
self.rank = rank
self.world_size = world_size
self.model_name = model_name
if world_size > 1:
self.model = nn.parallel.DistributedDataParallel(
self.model, device_ids=[local_rank], output_device=local_rank
)
self.ddp_model = self.model if world_size > 1 else None
self.criterion = DiceBCELoss(dice_weight=config.dice_weight, bce_weight=config.bce_weight)
self.optimizer = optim.Adam(self.model.parameters(), lr=config.lr * world_size)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode="min", factor=0.5, patience=5
)
# TensorBoard (rank 0)
self.writer = None
if rank == 0:
model_log_dir = os.path.join(config.log_dir, model_name)
os.makedirs(model_log_dir, exist_ok=True)
self.writer = SummaryWriter(model_log_dir)
# Checkpoint dir per model
self.model_ckpt_dir = os.path.join(config.save_dir, model_name)
if rank == 0:
os.makedirs(self.model_ckpt_dir, exist_ok=True)
self.total_params = 0
self.trainable_params = 0
def _unwrap(self):
return self.ddp_model.module if self.ddp_model else self.model
def train_epoch(self, loader):
self.model.train()
total_loss = 0.0
total_iou = 0.0
pbar = tqdm(loader, desc=f"Train [R{self.rank}]", disable=self.rank != 0)
for images, masks in pbar:
images, masks = images.to(self.device), masks.to(self.device)
self.optimizer.zero_grad()
preds = self.model(images)
loss = self.criterion(preds, masks)
loss.backward()
self.optimizer.step()
if self.world_size > 1:
reduced_loss = reduce_tensor(loss.detach(), self.world_size)
reduced_iou = reduce_tensor(compute_iou(preds, masks).detach(), self.world_size)
batch_loss = reduced_loss.item()
batch_iou = reduced_iou.item()
else:
batch_loss = loss.item()
batch_iou = compute_iou(preds, masks).item()
total_loss += batch_loss
total_iou += batch_iou
if self.rank == 0:
pbar.set_postfix(loss=batch_loss)
n = len(loader)
return total_loss / n, total_iou / n
@torch.no_grad()
def validate(self, loader):
self.model.eval()
total_loss = 0.0
total_iou = 0.0
total_dice = 0.0
for images, masks in tqdm(loader, desc=f"Val [R{self.rank}]", disable=self.rank != 0):
images, masks = images.to(self.device), masks.to(self.device)
preds = self.model(images)
loss = self.criterion(preds, masks)
if self.world_size > 1:
reduced_loss = reduce_tensor(loss.detach(), self.world_size)
reduced_iou = reduce_tensor(compute_iou(preds, masks).detach(), self.world_size)
reduced_dice = reduce_tensor(compute_dice(preds, masks).detach(), self.world_size)
total_loss += reduced_loss.item()
total_iou += reduced_iou.item()
total_dice += reduced_dice.item()
else:
total_loss += loss.item()
total_iou += compute_iou(preds, masks).item()
total_dice += compute_dice(preds, masks).item()
n = len(loader)
return total_loss / n, total_iou / n, total_dice / n
def fit(self, train_loader, val_loader, train_sampler=None):
best_val_iou = 0.0
best_val_dice = 0.0
best_val_loss = float("inf")
patience_counter = 0
epoch_times = []
for epoch in range(self.config.epochs):
if train_sampler is not None:
train_sampler.set_epoch(epoch)
if self.rank == 0:
print(f"\n{'='*50}")
print(f"[{self.model_name}] Epoch {epoch+1}/{self.config.epochs} | world_size={self.world_size}")
t0 = time.perf_counter()
train_loss, train_iou = self.train_epoch(train_loader)
val_loss, val_iou, val_dice = self.validate(val_loader)
epoch_time = time.perf_counter() - t0
epoch_times.append(epoch_time)
self.scheduler.step(val_loss)
if self.rank == 0:
lr = self.optimizer.param_groups[0]["lr"]
print(f"Train Loss: {train_loss:.4f} IoU: {train_iou:.4f} | "
f"Val Loss: {val_loss:.4f} IoU: {val_iou:.4f} Dice: {val_dice:.4f} | "
f"LR: {lr:.2e} Time: {epoch_time:.1f}s")
self.writer.add_scalar("Loss/train", train_loss, epoch)
self.writer.add_scalar("Loss/val", val_loss, epoch)
self.writer.add_scalar("IoU/train", train_iou, epoch)
self.writer.add_scalar("IoU/val", val_iou, epoch)
self.writer.add_scalar("Dice/val", val_dice, epoch)
self.writer.add_scalar("Time/epoch", epoch_time, epoch)
# 保存最佳
if val_iou > best_val_iou:
best_val_iou = val_iou
best_val_dice = val_dice
best_val_loss = val_loss
best_path = os.path.join(self.model_ckpt_dir, "best_model.pth")
self._save_checkpoint(best_path, epoch, val_loss, val_iou, val_dice)
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= self.config.early_stop_patience:
print(f"Early stopping triggered at epoch {epoch+1}")
break
if self.rank == 0:
self.writer.close()
avg_time = np.mean(epoch_times) if epoch_times else 0
print(f"\n[{self.model_name}] 完成 | "
f"best IoU={best_val_iou:.4f} Dice={best_val_dice:.4f} "
f"Loss={best_val_loss:.4f} | avg_time={avg_time:.1f}s/epoch")
return best_val_iou, best_val_dice, best_val_loss, np.mean(epoch_times) if epoch_times else 0
def _save_checkpoint(self, path, epoch, val_loss, val_iou, val_dice):
raw_model = self._unwrap()
torch.save({
"epoch": epoch,
"model_state_dict": raw_model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"val_loss": val_loss,
"val_iou": val_iou,
"val_dice": val_dice,
"config": {
"model_name": self.model_name,
"in_channels": self.config.in_channels if hasattr(self.config, 'in_channels') else 3,
"classes": self.config.classes if hasattr(self.config, 'classes') else 1,
"image_size": self.config.image_size,
},
}, path)
# ===================== 排行榜 =====================
def print_ranking(results, rank):
"""打印对比排行榜 (仅 rank 0)"""
if rank != 0:
return
print("\n\n" + "=" * 90)
print("模型对比排行榜 (按 IoU 降序)")
print("=" * 90)
ranked = sorted(results, key=lambda r: r["best_iou"], reverse=True)
header = (f"{'排名':>4} | {'模型':<14} | {'参数量':>8} | "
f"{'IoU':>8} | {'Dice':>8} | {'Loss':>8} | {'时间/epoch':>10}")
sep = "-" * len(header)
print(header)
print(sep)
for i, r in enumerate(ranked, 1):
print(f"{i:>4} | {r['name']:<14} | {r['params_m']:>5.2f}M | "
f"{r['best_iou']:>8.4f} | {r['best_dice']:>8.4f} | "
f"{r['best_loss']:>8.4f} | {r['avg_time']:>7.1f}s")
print(sep)
best_acc = ranked[0]
fastest = min(ranked, key=lambda r: r['avg_time'])
print(f"\n最佳精度: {best_acc['name']} (IoU={best_acc['best_iou']:.4f})")
print(f"最快速: {fastest['name']} ({fastest['avg_time']:.1f}s/epoch)")
# 精度-速度综合: 取前3名中最快的
top3 = ranked[:3]
balanced = min(top3, key=lambda r: r['avg_time'])
print(f"综合推荐: {balanced['name']} (IoU={balanced['best_iou']:.4f}, "
f"{balanced['avg_time']:.1f}s/epoch) --- 精度前3中最快")
print("=" * 90)
def save_csv(results, path, rank):
if rank != 0:
return
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
ranked = sorted(results, key=lambda r: r["best_iou"], reverse=True)
with open(path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=["排名", "模型", "参数量", "IoU", "Dice", "Loss", "时间/epoch"])
w.writeheader()
for i, r in enumerate(ranked, 1):
w.writerow({
"排名": i, "模型": r["name"], "参数量": f"{r['params_m']}M",
"IoU": r["best_iou"], "Dice": r["best_dice"],
"Loss": r["best_loss"], "时间/epoch": f"{r['avg_time']}s",
})
print(f"对比结果已保存: {path}")
# ===================== 合成数据 =====================
def generate_synthetic_data(num_samples=100, image_size=(128, 128)):
os.makedirs("synthetic_data/train/images", exist_ok=True)
os.makedirs("synthetic_data/train/masks", exist_ok=True)
os.makedirs("synthetic_data/val/images", exist_ok=True)
os.makedirs("synthetic_data/val/masks", exist_ok=True)
rng = np.random.RandomState(42)
for split in ["train", "val"]:
n = num_samples if split == "train" else num_samples // 5
for i in range(n):
H, W = image_size
img = rng.randint(0, 256, (H, W, 3), dtype=np.uint8)
mask = np.zeros((H, W), dtype=np.uint8)
cx, cy = rng.randint(20, W - 20), rng.randint(20, H - 20)
r = rng.randint(10, 30)
cv2.circle(mask, (cx, cy), r, 255, -1)
if rng.rand() > 0.5:
x1, y1 = rng.randint(10, W - 10), rng.randint(10, H - 10)
x2, y2 = x1 + rng.randint(10, 30), y1 + rng.randint(10, 30)
cv2.rectangle(mask, (x1, y1), (x2, y2), 255, -1)
cv2.imwrite(f"synthetic_data/{split}/images/{i:04d}.png", img)
cv2.imwrite(f"synthetic_data/{split}/masks/{i:04d}.png", mask)
if is_main_process():
print(f"合成数据已生成: synthetic_data/ (train={num_samples}, val={num_samples//5})")
return "synthetic_data"
# ===================== 主函数 =====================
def main():
rank, local_rank, world_size = get_world_info()
# 初始化分布式
if world_size > 1:
dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo")
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if rank == 0:
print(f"设备: {device} | 总卡数: {world_size}")
print(f"每卡Batch: {config.batch_size} | 总Batch: {config.batch_size * max(world_size, 1)}")
print(f"Epochs: {config.epochs} | 图片尺寸: {config.image_size}")
print(f"参与对比模型: {len(config.model_list)} 个")
for name, cls_name, backbone in config.model_list:
print(f" - {name:<12} class={cls_name:<16} backbone={backbone}")
# 数据
if not os.path.exists(config.train_img_dir):
if rank == 0:
print("未找到数据, 使用合成数据演示...")
data_root = generate_synthetic_data()
config.train_img_dir = f"{data_root}/train/images"
config.train_mask_dir = f"{data_root}/train/masks"
config.val_img_dir = f"{data_root}/val/images"
config.val_mask_dir = f"{data_root}/val/masks"
train_dataset = SegmentationDataset(config.train_img_dir, config.train_mask_dir, config.image_size)
val_dataset = SegmentationDataset(config.val_img_dir, config.val_mask_dir, config.image_size)
if rank == 0:
print(f"\n训练样本: {len(train_dataset)} | 验证样本: {len(val_dataset)}")
# DataLoader (数据集不变, 各模型共享)
train_sampler = DistributedSampler(train_dataset, shuffle=True) if world_size > 1 else None
val_sampler = DistributedSampler(val_dataset, shuffle=False) if world_size > 1 else None
train_loader = DataLoader(
train_dataset, batch_size=config.batch_size,
sampler=train_sampler, shuffle=(train_sampler is None),
num_workers=config.num_workers, pin_memory=True, drop_last=True,
)
val_loader = DataLoader(
val_dataset, batch_size=config.batch_size,
sampler=val_sampler, shuffle=False,
num_workers=config.num_workers, pin_memory=True,
)
# 逐个训练模型
all_results = []
for display_name, class_name, backbone in config.model_list:
if rank == 0:
print(f"\n\n{'#'*70}")
print(f"# 开始训练: {display_name} ({class_name}, backbone={backbone})")
print(f"{'#'*70}")
model_class = getattr(smp, class_name)
model = model_class(
encoder_name=backbone,
encoder_weights="imagenet",
in_channels=3,
classes=1,
activation=None,
)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
if rank == 0:
print(f"参数量: {total_params/1e6:.2f}M (可训练: {trainable_params/1e6:.2f}M)")
# DDP 包装
if world_size > 1:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank
)
trainer = Trainer(model, config, device, local_rank, rank, world_size, display_name)
trainer.total_params = total_params
trainer.trainable_params = trainable_params
best_iou, best_dice, best_loss, avg_time = trainer.fit(train_loader, val_loader, train_sampler)
all_results.append({
"name": display_name,
"class_name": class_name,
"backbone": backbone,
"params_m": round(total_params / 1e6, 2),
"trainable_m": round(trainable_params / 1e6, 2),
"best_iou": round(best_iou, 4),
"best_dice": round(best_dice, 4),
"best_loss": round(best_loss, 4),
"avg_time": round(avg_time, 1),
})
torch.cuda.empty_cache()
# 排行榜
print_ranking(all_results, rank)
save_csv(all_results, os.path.join(config.save_dir, "compare_results.csv"), rank)
if rank == 0:
print(f"\n所有模型训练完成! 结果对比: {config.save_dir}/compare_results.csv")
if __name__ == "__main__":
main()
损失函数与优化器
smp 封装了分割任务专用损失函数,比单纯交叉熵效果更好:
# 损失函数:DiceLoss + BCEWithLogitsLoss(单类别分割组合)
loss = smp.losses.DiceLoss(smp.losses.BINARY_MODE)
# 多类别用:loss = smp.losses.DiceLoss(smp.losses.MULTICLASS_MODE)
# 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 评价指标:IoU(交并比,分割核心指标)
metrics = [
smp.metrics.IoU(threshold=0.5) # 二分类阈值 0.5
]
六、训练与验证 pipeline
smp 提供了 TrainEpoch、ValidEpoch 封装好的训练循环,代码极简且稳定:
from torch.utils.data import DataLoader
from segmentation_models_pytorch.utils.train import TrainEpoch, ValidEpoch
# ===================== 加载数据 =====================
# 替换为你的数据集路径
TRAIN_IMAGE_DIR = "dataset/images/train"
TRAIN_MASK_DIR = "dataset/masks/train"
VAL_IMAGE_DIR = "dataset/images/val"
VAL_MASK_DIR = "dataset/masks/val"
# 构建数据集
train_dataset = SegmentationDataset(TRAIN_IMAGE_DIR, TRAIN_MASK_DIR, train_transform)
val_dataset = SegmentationDataset(VAL_IMAGE_DIR, VAL_MASK_DIR, val_transform)
# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
# ===================== 初始化训练器 =====================
train_epoch = TrainEpoch(
model,
loss=loss,
optimizer=optimizer,
metrics=metrics,
device=DEVICE,
verbose=True,
)
valid_epoch = ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=DEVICE,
verbose=True,
)
# ===================== 开始训练 =====================
max_iou = 0 # 保存最优模型
for epoch in range(1, EPOCHS+1):
print(f"\nEpoch: {epoch}/{EPOCHS}")
# 训练
train_logs = train_epoch.run(train_loader)
# 验证
valid_logs = valid_epoch.run(val_loader)
# 保存最优模型(根据验证集 IoU)
if max_iou < valid_logs['iou_score']:
max_iou = valid_logs['iou_score']
torch.save(model, 'best_segmentation_model.pth')
print("最优模型已保存!")
训练过程会实时打印:损失值、IoU 分数,直观观察模型收敛情况。
七、模型推理预测
训练完成后,加载最优模型,对单张图像进行分割预测:
"""
UNet 批量推理脚本
读取图像目录, 用训练好的 best_model.pth 进行推理, 保存预测掩码
使用方法:
python pytorch_demo/unet_infer.py
"""
import os
import sys
from glob import glob
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import segmentation_models_pytorch as smp
# ===================== 推理配置 (按需修改) =====================
class InferConfig:
# 输入图片路径 (单张图片 或 目录)
input = r"D:\AI+X\AI+XCode\Xx\Code\Unet_0911_Demo\paishe_test"
# 输出目录 (保存结果)
output = "inference_results"
# 模型权重路径 (由 train_compare.py 保存在 checkpoints/<模型名>/ 下)
checkpoint = r"D:\AI+X\AI+XCode\Xx\dataset\train_sgement_model_pytorch\best_model.pth"
# 推理尺寸 (W, H)
image_size = (512, 512)
# 推理 batch size
batch_size = 4
# 二值化阈值
threshold = 0.7
# 推理设备 (cuda / cpu)
device = "cuda"
# 预测掩码文件名后缀 (单独保存二值掩码, 设为 None 则不保存)
mask_suffix = "_mask.png"
# 分割结果画在原图上的叠加图后缀 (设为 None 则不保存)
overlay_suffix = "_overlay.png"
# 叠加图参数
overlay_color = (0, 255, 0) # BGR: 绿色
overlay_alpha = 0.4 # 透明度 (0~1)
config = InferConfig()
# ===================== 数据集 =====================
class InferenceDataset(Dataset):
"""推理用数据集: 读取图片, resize, 返回文件名和信息"""
def __init__(self, img_paths, image_size=(512, 512)):
self.img_paths = img_paths
self.image_size = image_size
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
path = self.img_paths[idx]
image = cv2.imread(path)
if image is None:
raise FileNotFoundError(f"无法读取图片: {path}")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_h, original_w = image.shape[:2]
# resize 到模型输入尺寸
image_resized = cv2.resize(image_rgb, self.image_size)
# 归一化 + 转 tensor
img_tensor = torch.from_numpy(image_resized.astype(np.float32) / 255.0)
img_tensor = img_tensor.permute(2, 0, 1).float()
return img_tensor, path, original_w, original_h
# ===================== 模型加载 =====================
def load_model(checkpoint_path, device):
"""加载 checkpoint 并构建模型(自动识别架构)"""
ckpt = torch.load(checkpoint_path, map_location="cpu")
# 从 checkpoint 解析模型配置
model_name = ckpt.get("model_name", "Unet")
backbone = ckpt.get("backbone", "resnet34")
ckpt_config = ckpt.get("config", {})
in_channels = ckpt_config.get("in_channels", 3)
classes = ckpt_config.get("classes", 1)
# 动态构建模型(支持 Unet, UnetPlusPlus, DeepLabV3Plus, FPN, PSPNet, PAN, MAnet ...)
model_class = getattr(smp, model_name, smp.Unet)
model = model_class(
encoder_name=backbone,
encoder_weights=None,
in_channels=in_channels,
classes=classes,
activation="sigmoid",
)
model.load_state_dict(ckpt["model_state_dict"], strict=False)
model.to(device)
model.eval()
print(f"模型加载完成 | arch={model_name} backbone={backbone} "
f"in_channels={in_channels} classes={classes}")
print(f" best IoU={ckpt.get('best_iou', 'N/A')} best Dice={ckpt.get('best_dice', 'N/A')}")
return model
# ===================== 推理 =====================
@torch.no_grad()
def infer_batch(model, loader, device, threshold):
"""批量推理, 返回结果列表"""
model.eval()
results = []
for images, paths, orig_ws, orig_hs in tqdm(loader, desc="Infer"):
images = images.to(device)
preds = model(images)
masks = (preds > threshold).float()
for i in range(len(paths)):
mask = masks[i].squeeze(0).cpu().numpy()
mask_uint8 = (mask * 255).astype(np.uint8)
results.append({
"mask": mask_uint8,
"path": paths[i],
"orig_w": orig_ws[i].item(),
"orig_h": orig_hs[i].item(),
})
return results
# ===================== 保存结果 =====================
def save_results(results, output_dir, config):
"""保存预测掩码和叠加图到输出目录"""
os.makedirs(output_dir, exist_ok=True)
for r in results:
stem = os.path.splitext(os.path.basename(r["path"]))[0]
mask = r["mask"]
orig_w, orig_h = r["orig_w"], r["orig_h"]
# 恢复原图尺寸
if mask.shape[1] != orig_w or mask.shape[0] != orig_h:
mask_full = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
else:
mask_full = mask
# 1. 保存二值掩码
if config.mask_suffix:
cv2.imwrite(os.path.join(output_dir, stem + config.mask_suffix), mask_full)
# 2. 在原图上画分割结果
if config.overlay_suffix:
img = cv2.imread(r["path"])
if img is not None:
overlay = draw_overlay(img, mask_full, config.overlay_color, config.overlay_alpha)
cv2.imwrite(os.path.join(output_dir, stem + config.overlay_suffix), overlay)
saved = []
if config.mask_suffix:
saved.append("掩码")
if config.overlay_suffix:
saved.append("叠加图")
print(f"共保存 {len(results)} 张{' + '.join(saved)}到: {output_dir}")
def draw_overlay(image, mask, color=(0, 255, 0), alpha=0.4):
"""将分割掩码以半透明颜色叠加到原图上"""
overlay = image.copy()
mask_bool = mask > 0
# 在掩码区域叠加颜色
for c in range(3):
overlay[..., c] = np.where(mask_bool, overlay[..., c] * (1 - alpha) + color[c] * alpha, overlay[..., c])
# 画掩码轮廓 (更清晰地显示边界)
mask_uint8 = mask.astype(np.uint8)
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(overlay, contours, -1, color, 2)
return overlay.astype(np.uint8)
# ===================== 收集图片 =====================
def collect_images(input_path):
"""收集输入路径下的所有图片文件"""
if os.path.isfile(input_path):
return [input_path]
if os.path.isdir(input_path):
exts = ["*.png", "*.jpg", "*.jpeg", "*.tif", "*.tiff", "*.bmp"]
paths = []
for ext in exts:
paths.extend(sorted(glob(os.path.join(input_path, ext))))
return sorted(set(paths))
raise FileNotFoundError(f"输入路径不存在: {input_path}")
# ===================== 主函数 =====================
def main():
print(f"输入: {config.input}")
print(f"输出: {config.output}")
print(f"Checkpoint: {config.checkpoint}")
# 1. 设备
if config.device == "cuda" and not torch.cuda.is_available():
print("CUDA 不可用, 回退到 CPU")
config.device = "cpu"
device = torch.device(config.device)
print(f"设备: {device}")
# 2. 收集图片
img_paths = collect_images(config.input)
if len(img_paths) == 0:
print("未找到任何图片!")
return
print(f"找到 {len(img_paths)} 张图片")
# 3. 加载模型
if not os.path.exists(config.checkpoint):
print(f"错误: checkpoint 不存在: {config.checkpoint}")
sys.exit(1)
model = load_model(config.checkpoint, device)
# 4. DataLoader
dataset = InferenceDataset(img_paths, config.image_size)
loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False, num_workers=0)
# 5. 推理
results = infer_batch(model, loader, device, config.threshold)
# 6. 保存
save_results(results, config.output, config)
print(f"推理完成! 输出目录: {config.output}")
if __name__ == "__main__":
main()
运行后会直接示原图 和分割预测结果,快速验证模型效果。
八、进阶技巧
-
多类别分割
-
修改
CLASSES为类别总数; -
损失函数改为
smp.losses.DiceLoss(smp.losses.MULTICLASS_MODE); -
激活函数改为
activation="softmax"。
-
-
骨干网络替换
支持上百种骨干网络,轻量级用 mobilenet_v2、efficientnet-b0,高精度用 resnet50、resnet101。
-
训练优化
-
加入学习率调度器
torch.optim.lr_scheduler.ReduceLROnPlateau; -
增大数据集增强力度;
-
微调输入图像分辨率(512×512 精度更高)。
-
九、总结
segmentation-models-pytorch 是语义分割的效率神器,完美解决了「模型搭建难、训练繁琐」的痛点:
-
一行代码定义任意分割模型,无需手动构建网络;
-
预训练权重 + 专用损失函数,快速收敛到高精度;
-
全流程代码简洁,新手可直接用于比赛、项目、毕业设计
本文的代码可以直接适配医学分割、 遥感 分割、工业检测等几乎所有语义分割任务,替换数据集路径即可快速训练自己的模型。