segmentation-models-pytorch 极简实战:快速搭建与训练高精度语义分割模型

在计算机视觉任务中,语义分割 (Semantic Segmentation)是核心方向之一,目标是给图像中每个像素分配对应类别标签。从零搭建分割模型繁琐且耗时,而 segmentation-models-pytorch(简称 smp)是 PyTorch 生态中最实用、开箱即用的分割模型库,封装了 UNet、FPN、DeepLabV3+ 等经典算法,支持多种骨干网络,无需复杂配置即可快速训练高精度分割模型。

本文带你从零开始,完成环境安装 → 数据集构建 → 模型定义 → 训练/验证 → 推理预测全流程,新手也能直接复刻运行。

一、库介绍与核心优势

segmentation-models-pytorch 是基于 PyTorch 的语义分割工具库,核心特点:

  1. 支持主流模型:UNet、UNet++、FPN、PSPNet、DeepLabV3、DeepLabV3+ 等;

  2. 丰富骨干网络:ResNet、MobileNet、EfficientNet 等,可自由搭配,兼顾精度与速度;

  3. 开箱即用:预训练权重、损失函数(DiceLoss、JaccardLoss)、评价指标(IoU、F1)已封装;

  4. 简洁 API:一行代码定义模型,大幅降低开发成本。

适用场景:医学图像分割、遥感图像分割、自动驾驶场景分割、工业缺陷检测等。

二、环境安装

首先安装核心依赖库,smp 依赖 PyTorch,建议提前配置好 CUDA 加速训练:

复制代码
# 安装 PyTorch(根据你的 CUDA 版本选择,官网复制对应命令) 
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # 安装 segmentation-models-pytorch 
pip install segmentation-models-pytorch
# 安装辅助库 
pip install numpy opencv-python pillow matplotlib tqdm

三、数据集准备

语义分割数据集标准格式:原图 + 对应掩码图

  • 原图:RGB 图像(.jpg/.png)

  • 掩码图:单通道灰度图,像素值为类别编号(如 0=背景,1=目标)

推荐目录结构(自定义数据集通用):

自定义数据集类(PyTorch Dataset)

我们封装一个通用的分割数据集加载类,支持任意自定义数据集:

复制代码
import os
import cv2
import numpy as np
from torch.utils.data import Dataset

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transforms=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transforms = transforms
        # 获取文件名(保证原图与掩码一一对应)
        self.filenames = sorted(os.listdir(image_dir))

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        # 读取图像
        img_path = os.path.join(self.image_dir, self.filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.filenames[idx])
        
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # BGR转RGB
        mask = cv2.imread(mask_path, 0)  # 读取单通道灰度掩码
        
        # 数据增强
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

四、核心配置与数据增强

定义训练超参数、数据增强策略,使用 albumentations 库做分割专用增强(保证图像与掩码同步变换):

复制代码
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

# ===================== 超参数配置 =====================
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CLASSES = 1  # 分割类别数(单类别=1,多类别修改为对应数量)
EPOCHS = 20
BATCH_SIZE = 8
LEARNING_RATE = 1e-4

# ===================== 数据增强 =====================
# 训练集增强(随机裁剪、翻转、归一化)
train_transform = A.Compose([
    A.Resize(height=256, width=256),
    A.RandomCrop(height=224, width=224),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

# 验证集增强(仅 resize + 归一化,无随机增强)
val_transform = A.Compose([
    A.Resize(height=224, width=224),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

五、一行代码定义分割模型

smp 最大优势:模型+骨干网络一键组合,无需手动搭建网络结构。

常用模型示例

复制代码
"""
segmentation_models_pytorch 多模型对比训练 (DDP 多卡)
支持 8 卡分布式训练, 自动对比各模型精度与速度, 输出排行榜

启动方式 (8卡):
    torchrun --nproc_per_node=8 pytorch_demo/unet_train.py

启动方式 (单卡):
    python pytorch_demo/unet_train.py
"""

import os
import time
import csv
from glob import glob

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

import segmentation_models_pytorch as smp


# ===================== 分布式工具 =====================
def get_world_info():
    rank = int(os.environ.get("RANK", 0))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    return rank, local_rank, world_size


def is_main_process():
    return int(os.environ.get("RANK", 0)) == 0


# ===================== 配置 =====================
class Config:
    # 数据路径
    train_img_dir = "data/road_B_4/train/images"
    train_mask_dir = "data/road_B_4/train/masks"
    val_img_dir = "data/road_B_4/val/images"
    val_mask_dir = "data/road_B_4/val/masks"

    # 训练参数
    image_size = (512, 512)
    batch_size = 8                  # 每卡 batch (8卡总batch=64)
    epochs = 100
    lr = 1e-3
    num_workers = 4

    # 损失权重
    dice_weight = 0.5
    bce_weight = 0.5

    # EarlyStopping
    early_stop_patience = 15
    early_stop_min_delta = 1e-4

    # 输出
    save_dir = "checkpoints"
    log_dir = "logs"

    # 待训练模型列表: (显示名, 模型类名, backbone)
    model_list = [
        ("Unet",          "Unet",          "resnet34"),
        ("Unet++",        "UnetPlusPlus",  "resnet34"),
        ("DeepLabV3+",    "DeepLabV3Plus", "resnet34"),
        ("FPN",           "FPN",           "resnet34"),
        ("PSPNet",        "PSPNet",        "resnet34"),
        ("PAN",           "PAN",           "resnet34"),
        ("MAnet",         "MAnet",         "resnet34"),
    ]


config = Config()


# ===================== 数据集 =====================
class SegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, image_size=(512, 512)):
        self.img_paths = sorted(glob(os.path.join(img_dir, "*")))
        self.mask_paths = sorted(glob(os.path.join(mask_dir, "*")))
        self.image_size = image_size
        assert len(self.img_paths) == len(self.mask_paths), \
            f"图片({len(self.img_paths)})和掩码({len(self.mask_paths)})数量不一致"

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.img_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, self.image_size)

        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, self.image_size, interpolation=cv2.INTER_NEAREST)

        image = image.astype(np.float32) / 255.0
        mask = mask.astype(np.float32) / 255.0

        image = torch.from_numpy(image).permute(2, 0, 1).float()
        mask = torch.from_numpy(mask).unsqueeze(0).float()
        return image, mask


# ===================== 损失函数 =====================
class DiceBCELoss(nn.Module):
    def __init__(self, dice_weight=0.5, bce_weight=0.5):
        super().__init__()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, target):
        pred_probs = torch.sigmoid(pred)
        smooth = 1.0
        pred_flat = pred_probs.contiguous().view(-1)
        target_flat = target.contiguous().view(-1)
        intersection = (pred_flat * target_flat).sum()
        dice_loss = 1 - (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
        bce_loss = self.bce(pred, target)
        return self.dice_weight * dice_loss + self.bce_weight * bce_loss


# ===================== 评估指标 =====================
def compute_iou(pred, target, threshold=0.5):
    pred = (torch.sigmoid(pred) > threshold).float()
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    if union == 0:
        return torch.tensor(1.0, device=pred.device)
    return intersection / union


def compute_dice(pred, target, threshold=0.5):
    pred = (torch.sigmoid(pred) > threshold).float()
    smooth = 1.0
    pred_flat = pred.contiguous().view(-1)
    target_flat = target.contiguous().view(-1)
    intersection = (pred_flat * target_flat).sum()
    return (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)


def reduce_tensor(tensor, world_size):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    return rt / world_size


# ===================== 训练器 =====================
class Trainer:
    def __init__(self, model, config, device, local_rank, rank, world_size, model_name):
        self.model = model.to(device)
        self.config = config
        self.device = device
        self.local_rank = local_rank
        self.rank = rank
        self.world_size = world_size
        self.model_name = model_name

        if world_size > 1:
            self.model = nn.parallel.DistributedDataParallel(
                self.model, device_ids=[local_rank], output_device=local_rank
            )
        self.ddp_model = self.model if world_size > 1 else None

        self.criterion = DiceBCELoss(dice_weight=config.dice_weight, bce_weight=config.bce_weight)
        self.optimizer = optim.Adam(self.model.parameters(), lr=config.lr * world_size)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode="min", factor=0.5, patience=5
        )

        # TensorBoard (rank 0)
        self.writer = None
        if rank == 0:
            model_log_dir = os.path.join(config.log_dir, model_name)
            os.makedirs(model_log_dir, exist_ok=True)
            self.writer = SummaryWriter(model_log_dir)

        # Checkpoint dir per model
        self.model_ckpt_dir = os.path.join(config.save_dir, model_name)
        if rank == 0:
            os.makedirs(self.model_ckpt_dir, exist_ok=True)

        self.total_params = 0
        self.trainable_params = 0

    def _unwrap(self):
        return self.ddp_model.module if self.ddp_model else self.model

    def train_epoch(self, loader):
        self.model.train()
        total_loss = 0.0
        total_iou = 0.0

        pbar = tqdm(loader, desc=f"Train [R{self.rank}]", disable=self.rank != 0)
        for images, masks in pbar:
            images, masks = images.to(self.device), masks.to(self.device)

            self.optimizer.zero_grad()
            preds = self.model(images)
            loss = self.criterion(preds, masks)
            loss.backward()
            self.optimizer.step()

            if self.world_size > 1:
                reduced_loss = reduce_tensor(loss.detach(), self.world_size)
                reduced_iou = reduce_tensor(compute_iou(preds, masks).detach(), self.world_size)
                batch_loss = reduced_loss.item()
                batch_iou = reduced_iou.item()
            else:
                batch_loss = loss.item()
                batch_iou = compute_iou(preds, masks).item()

            total_loss += batch_loss
            total_iou += batch_iou

            if self.rank == 0:
                pbar.set_postfix(loss=batch_loss)

        n = len(loader)
        return total_loss / n, total_iou / n

    @torch.no_grad()
    def validate(self, loader):
        self.model.eval()
        total_loss = 0.0
        total_iou = 0.0
        total_dice = 0.0

        for images, masks in tqdm(loader, desc=f"Val [R{self.rank}]", disable=self.rank != 0):
            images, masks = images.to(self.device), masks.to(self.device)
            preds = self.model(images)
            loss = self.criterion(preds, masks)

            if self.world_size > 1:
                reduced_loss = reduce_tensor(loss.detach(), self.world_size)
                reduced_iou = reduce_tensor(compute_iou(preds, masks).detach(), self.world_size)
                reduced_dice = reduce_tensor(compute_dice(preds, masks).detach(), self.world_size)
                total_loss += reduced_loss.item()
                total_iou += reduced_iou.item()
                total_dice += reduced_dice.item()
            else:
                total_loss += loss.item()
                total_iou += compute_iou(preds, masks).item()
                total_dice += compute_dice(preds, masks).item()

        n = len(loader)
        return total_loss / n, total_iou / n, total_dice / n

    def fit(self, train_loader, val_loader, train_sampler=None):
        best_val_iou = 0.0
        best_val_dice = 0.0
        best_val_loss = float("inf")
        patience_counter = 0
        epoch_times = []

        for epoch in range(self.config.epochs):
            if train_sampler is not None:
                train_sampler.set_epoch(epoch)

            if self.rank == 0:
                print(f"\n{'='*50}")
                print(f"[{self.model_name}] Epoch {epoch+1}/{self.config.epochs} | world_size={self.world_size}")

            t0 = time.perf_counter()
            train_loss, train_iou = self.train_epoch(train_loader)
            val_loss, val_iou, val_dice = self.validate(val_loader)
            epoch_time = time.perf_counter() - t0
            epoch_times.append(epoch_time)

            self.scheduler.step(val_loss)

            if self.rank == 0:
                lr = self.optimizer.param_groups[0]["lr"]
                print(f"Train Loss: {train_loss:.4f}  IoU: {train_iou:.4f}  |  "
                      f"Val Loss: {val_loss:.4f}  IoU: {val_iou:.4f}  Dice: {val_dice:.4f}  |  "
                      f"LR: {lr:.2e}  Time: {epoch_time:.1f}s")

                self.writer.add_scalar("Loss/train", train_loss, epoch)
                self.writer.add_scalar("Loss/val", val_loss, epoch)
                self.writer.add_scalar("IoU/train", train_iou, epoch)
                self.writer.add_scalar("IoU/val", val_iou, epoch)
                self.writer.add_scalar("Dice/val", val_dice, epoch)
                self.writer.add_scalar("Time/epoch", epoch_time, epoch)

                # 保存最佳
                if val_iou > best_val_iou:
                    best_val_iou = val_iou
                    best_val_dice = val_dice
                    best_val_loss = val_loss
                    best_path = os.path.join(self.model_ckpt_dir, "best_model.pth")
                    self._save_checkpoint(best_path, epoch, val_loss, val_iou, val_dice)
                    patience_counter = 0
                else:
                    patience_counter += 1

                if patience_counter >= self.config.early_stop_patience:
                    print(f"Early stopping triggered at epoch {epoch+1}")
                    break

        if self.rank == 0:
            self.writer.close()
            avg_time = np.mean(epoch_times) if epoch_times else 0
            print(f"\n[{self.model_name}] 完成 | "
                  f"best IoU={best_val_iou:.4f} Dice={best_val_dice:.4f} "
                  f"Loss={best_val_loss:.4f} | avg_time={avg_time:.1f}s/epoch")

        return best_val_iou, best_val_dice, best_val_loss, np.mean(epoch_times) if epoch_times else 0

    def _save_checkpoint(self, path, epoch, val_loss, val_iou, val_dice):
        raw_model = self._unwrap()
        torch.save({
            "epoch": epoch,
            "model_state_dict": raw_model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "val_loss": val_loss,
            "val_iou": val_iou,
            "val_dice": val_dice,
            "config": {
                "model_name": self.model_name,
                "in_channels": self.config.in_channels if hasattr(self.config, 'in_channels') else 3,
                "classes": self.config.classes if hasattr(self.config, 'classes') else 1,
                "image_size": self.config.image_size,
            },
        }, path)


# ===================== 排行榜 =====================
def print_ranking(results, rank):
    """打印对比排行榜 (仅 rank 0)"""
    if rank != 0:
        return

    print("\n\n" + "=" * 90)
    print("模型对比排行榜 (按 IoU 降序)")
    print("=" * 90)

    ranked = sorted(results, key=lambda r: r["best_iou"], reverse=True)

    header = (f"{'排名':>4} | {'模型':<14} | {'参数量':>8} | "
              f"{'IoU':>8} | {'Dice':>8} | {'Loss':>8} | {'时间/epoch':>10}")
    sep = "-" * len(header)
    print(header)
    print(sep)

    for i, r in enumerate(ranked, 1):
        print(f"{i:>4} | {r['name']:<14} | {r['params_m']:>5.2f}M | "
              f"{r['best_iou']:>8.4f} | {r['best_dice']:>8.4f} | "
              f"{r['best_loss']:>8.4f} | {r['avg_time']:>7.1f}s")

    print(sep)

    best_acc = ranked[0]
    fastest = min(ranked, key=lambda r: r['avg_time'])
    print(f"\n最佳精度: {best_acc['name']} (IoU={best_acc['best_iou']:.4f})")
    print(f"最快速:   {fastest['name']} ({fastest['avg_time']:.1f}s/epoch)")

    # 精度-速度综合: 取前3名中最快的
    top3 = ranked[:3]
    balanced = min(top3, key=lambda r: r['avg_time'])
    print(f"综合推荐: {balanced['name']} (IoU={balanced['best_iou']:.4f}, "
          f"{balanced['avg_time']:.1f}s/epoch) --- 精度前3中最快")
    print("=" * 90)


def save_csv(results, path, rank):
    if rank != 0:
        return
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    ranked = sorted(results, key=lambda r: r["best_iou"], reverse=True)
    with open(path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=["排名", "模型", "参数量", "IoU", "Dice", "Loss", "时间/epoch"])
        w.writeheader()
        for i, r in enumerate(ranked, 1):
            w.writerow({
                "排名": i, "模型": r["name"], "参数量": f"{r['params_m']}M",
                "IoU": r["best_iou"], "Dice": r["best_dice"],
                "Loss": r["best_loss"], "时间/epoch": f"{r['avg_time']}s",
            })
    print(f"对比结果已保存: {path}")


# ===================== 合成数据 =====================
def generate_synthetic_data(num_samples=100, image_size=(128, 128)):
    os.makedirs("synthetic_data/train/images", exist_ok=True)
    os.makedirs("synthetic_data/train/masks", exist_ok=True)
    os.makedirs("synthetic_data/val/images", exist_ok=True)
    os.makedirs("synthetic_data/val/masks", exist_ok=True)

    rng = np.random.RandomState(42)
    for split in ["train", "val"]:
        n = num_samples if split == "train" else num_samples // 5
        for i in range(n):
            H, W = image_size
            img = rng.randint(0, 256, (H, W, 3), dtype=np.uint8)
            mask = np.zeros((H, W), dtype=np.uint8)
            cx, cy = rng.randint(20, W - 20), rng.randint(20, H - 20)
            r = rng.randint(10, 30)
            cv2.circle(mask, (cx, cy), r, 255, -1)
            if rng.rand() > 0.5:
                x1, y1 = rng.randint(10, W - 10), rng.randint(10, H - 10)
                x2, y2 = x1 + rng.randint(10, 30), y1 + rng.randint(10, 30)
                cv2.rectangle(mask, (x1, y1), (x2, y2), 255, -1)
            cv2.imwrite(f"synthetic_data/{split}/images/{i:04d}.png", img)
            cv2.imwrite(f"synthetic_data/{split}/masks/{i:04d}.png", mask)

    if is_main_process():
        print(f"合成数据已生成: synthetic_data/ (train={num_samples}, val={num_samples//5})")
    return "synthetic_data"


# ===================== 主函数 =====================
def main():
    rank, local_rank, world_size = get_world_info()

    # 初始化分布式
    if world_size > 1:
        dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo")
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if rank == 0:
        print(f"设备: {device} | 总卡数: {world_size}")
        print(f"每卡Batch: {config.batch_size} | 总Batch: {config.batch_size * max(world_size, 1)}")
        print(f"Epochs: {config.epochs} | 图片尺寸: {config.image_size}")
        print(f"参与对比模型: {len(config.model_list)} 个")
        for name, cls_name, backbone in config.model_list:
            print(f"  - {name:<12} class={cls_name:<16} backbone={backbone}")

    # 数据
    if not os.path.exists(config.train_img_dir):
        if rank == 0:
            print("未找到数据, 使用合成数据演示...")
        data_root = generate_synthetic_data()
        config.train_img_dir = f"{data_root}/train/images"
        config.train_mask_dir = f"{data_root}/train/masks"
        config.val_img_dir = f"{data_root}/val/images"
        config.val_mask_dir = f"{data_root}/val/masks"

    train_dataset = SegmentationDataset(config.train_img_dir, config.train_mask_dir, config.image_size)
    val_dataset = SegmentationDataset(config.val_img_dir, config.val_mask_dir, config.image_size)

    if rank == 0:
        print(f"\n训练样本: {len(train_dataset)} | 验证样本: {len(val_dataset)}")

    # DataLoader (数据集不变, 各模型共享)
    train_sampler = DistributedSampler(train_dataset, shuffle=True) if world_size > 1 else None
    val_sampler = DistributedSampler(val_dataset, shuffle=False) if world_size > 1 else None

    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size,
        sampler=train_sampler, shuffle=(train_sampler is None),
        num_workers=config.num_workers, pin_memory=True, drop_last=True,
    )
    val_loader = DataLoader(
        val_dataset, batch_size=config.batch_size,
        sampler=val_sampler, shuffle=False,
        num_workers=config.num_workers, pin_memory=True,
    )

    # 逐个训练模型
    all_results = []
    for display_name, class_name, backbone in config.model_list:
        if rank == 0:
            print(f"\n\n{'#'*70}")
            print(f"# 开始训练: {display_name} ({class_name}, backbone={backbone})")
            print(f"{'#'*70}")

        model_class = getattr(smp, class_name)
        model = model_class(
            encoder_name=backbone,
            encoder_weights="imagenet",
            in_channels=3,
            classes=1,
            activation=None,
        )

        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        if rank == 0:
            print(f"参数量: {total_params/1e6:.2f}M (可训练: {trainable_params/1e6:.2f}M)")

        # DDP 包装
        if world_size > 1:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[local_rank], output_device=local_rank
            )

        trainer = Trainer(model, config, device, local_rank, rank, world_size, display_name)
        trainer.total_params = total_params
        trainer.trainable_params = trainable_params

        best_iou, best_dice, best_loss, avg_time = trainer.fit(train_loader, val_loader, train_sampler)

        all_results.append({
            "name": display_name,
            "class_name": class_name,
            "backbone": backbone,
            "params_m": round(total_params / 1e6, 2),
            "trainable_m": round(trainable_params / 1e6, 2),
            "best_iou": round(best_iou, 4),
            "best_dice": round(best_dice, 4),
            "best_loss": round(best_loss, 4),
            "avg_time": round(avg_time, 1),
        })

        torch.cuda.empty_cache()

    # 排行榜
    print_ranking(all_results, rank)
    save_csv(all_results, os.path.join(config.save_dir, "compare_results.csv"), rank)

    if rank == 0:
        print(f"\n所有模型训练完成! 结果对比: {config.save_dir}/compare_results.csv")


if __name__ == "__main__":
    main()

损失函数与优化器

smp 封装了分割任务专用损失函数,比单纯交叉熵效果更好:

复制代码
# 损失函数:DiceLoss + BCEWithLogitsLoss(单类别分割组合)
loss = smp.losses.DiceLoss(smp.losses.BINARY_MODE)
# 多类别用:loss = smp.losses.DiceLoss(smp.losses.MULTICLASS_MODE)

# 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 评价指标:IoU(交并比,分割核心指标)
metrics = [
    smp.metrics.IoU(threshold=0.5)  # 二分类阈值 0.5
]

六、训练与验证 pipeline

smp 提供了 TrainEpochValidEpoch 封装好的训练循环,代码极简且稳定:

复制代码
from torch.utils.data import DataLoader
from segmentation_models_pytorch.utils.train import TrainEpoch, ValidEpoch

# ===================== 加载数据 =====================
# 替换为你的数据集路径
TRAIN_IMAGE_DIR = "dataset/images/train"
TRAIN_MASK_DIR = "dataset/masks/train"
VAL_IMAGE_DIR = "dataset/images/val"
VAL_MASK_DIR = "dataset/masks/val"

# 构建数据集
train_dataset = SegmentationDataset(TRAIN_IMAGE_DIR, TRAIN_MASK_DIR, train_transform)
val_dataset = SegmentationDataset(VAL_IMAGE_DIR, VAL_MASK_DIR, val_transform)

# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# ===================== 初始化训练器 =====================
train_epoch = TrainEpoch(
    model,
    loss=loss,
    optimizer=optimizer,
    metrics=metrics,
    device=DEVICE,
    verbose=True,
)

valid_epoch = ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
    verbose=True,
)

# ===================== 开始训练 =====================
max_iou = 0  # 保存最优模型
for epoch in range(1, EPOCHS+1):
    print(f"\nEpoch: {epoch}/{EPOCHS}")
    
    # 训练
    train_logs = train_epoch.run(train_loader)
    # 验证
    valid_logs = valid_epoch.run(val_loader)
    
    # 保存最优模型(根据验证集 IoU)
    if max_iou < valid_logs['iou_score']:
        max_iou = valid_logs['iou_score']
        torch.save(model, 'best_segmentation_model.pth')
        print("最优模型已保存!")

训练过程会实时打印:损失值、IoU 分数,直观观察模型收敛情况。

七、模型推理预测

训练完成后,加载最优模型,对单张图像进行分割预测:

复制代码
"""
UNet 批量推理脚本
读取图像目录, 用训练好的 best_model.pth 进行推理, 保存预测掩码

使用方法:
    python pytorch_demo/unet_infer.py
"""

import os
import sys
from glob import glob

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import segmentation_models_pytorch as smp


# ===================== 推理配置 (按需修改) =====================
class InferConfig:
    # 输入图片路径 (单张图片 或 目录)
    input = r"D:\AI+X\AI+XCode\Xx\Code\Unet_0911_Demo\paishe_test"

    # 输出目录 (保存结果)
    output = "inference_results"

    # 模型权重路径 (由 train_compare.py 保存在 checkpoints/<模型名>/ 下)
    checkpoint = r"D:\AI+X\AI+XCode\Xx\dataset\train_sgement_model_pytorch\best_model.pth"

    # 推理尺寸 (W, H)
    image_size = (512, 512)

    # 推理 batch size
    batch_size = 4

    # 二值化阈值
    threshold = 0.7

    # 推理设备 (cuda / cpu)
    device = "cuda"

    # 预测掩码文件名后缀 (单独保存二值掩码, 设为 None 则不保存)
    mask_suffix = "_mask.png"

    # 分割结果画在原图上的叠加图后缀 (设为 None 则不保存)
    overlay_suffix = "_overlay.png"

    # 叠加图参数
    overlay_color = (0, 255, 0)      # BGR: 绿色
    overlay_alpha = 0.4              # 透明度 (0~1)


config = InferConfig()


# ===================== 数据集 =====================
class InferenceDataset(Dataset):
    """推理用数据集: 读取图片, resize, 返回文件名和信息"""
    def __init__(self, img_paths, image_size=(512, 512)):
        self.img_paths = img_paths
        self.image_size = image_size

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        path = self.img_paths[idx]

        image = cv2.imread(path)
        if image is None:
            raise FileNotFoundError(f"无法读取图片: {path}")

        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        original_h, original_w = image.shape[:2]

        # resize 到模型输入尺寸
        image_resized = cv2.resize(image_rgb, self.image_size)

        # 归一化 + 转 tensor
        img_tensor = torch.from_numpy(image_resized.astype(np.float32) / 255.0)
        img_tensor = img_tensor.permute(2, 0, 1).float()

        return img_tensor, path, original_w, original_h


# ===================== 模型加载 =====================
def load_model(checkpoint_path, device):
    """加载 checkpoint 并构建模型(自动识别架构)"""
    ckpt = torch.load(checkpoint_path, map_location="cpu")

    # 从 checkpoint 解析模型配置
    model_name = ckpt.get("model_name", "Unet")
    backbone = ckpt.get("backbone", "resnet34")
    ckpt_config = ckpt.get("config", {})
    in_channels = ckpt_config.get("in_channels", 3)
    classes = ckpt_config.get("classes", 1)

    # 动态构建模型(支持 Unet, UnetPlusPlus, DeepLabV3Plus, FPN, PSPNet, PAN, MAnet ...)
    model_class = getattr(smp, model_name, smp.Unet)
    model = model_class(
        encoder_name=backbone,
        encoder_weights=None,
        in_channels=in_channels,
        classes=classes,
        activation="sigmoid",
    )

    model.load_state_dict(ckpt["model_state_dict"], strict=False)
    model.to(device)
    model.eval()

    print(f"模型加载完成 | arch={model_name} backbone={backbone} "
          f"in_channels={in_channels} classes={classes}")
    print(f"  best IoU={ckpt.get('best_iou', 'N/A')}  best Dice={ckpt.get('best_dice', 'N/A')}")

    return model


# ===================== 推理 =====================
@torch.no_grad()
def infer_batch(model, loader, device, threshold):
    """批量推理, 返回结果列表"""
    model.eval()
    results = []

    for images, paths, orig_ws, orig_hs in tqdm(loader, desc="Infer"):
        images = images.to(device)
        preds = model(images)

        masks = (preds > threshold).float()

        for i in range(len(paths)):
            mask = masks[i].squeeze(0).cpu().numpy()
            mask_uint8 = (mask * 255).astype(np.uint8)

            results.append({
                "mask": mask_uint8,
                "path": paths[i],
                "orig_w": orig_ws[i].item(),
                "orig_h": orig_hs[i].item(),
            })

    return results


# ===================== 保存结果 =====================
def save_results(results, output_dir, config):
    """保存预测掩码和叠加图到输出目录"""
    os.makedirs(output_dir, exist_ok=True)

    for r in results:
        stem = os.path.splitext(os.path.basename(r["path"]))[0]
        mask = r["mask"]
        orig_w, orig_h = r["orig_w"], r["orig_h"]

        # 恢复原图尺寸
        if mask.shape[1] != orig_w or mask.shape[0] != orig_h:
            mask_full = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
        else:
            mask_full = mask

        # 1. 保存二值掩码
        if config.mask_suffix:
            cv2.imwrite(os.path.join(output_dir, stem + config.mask_suffix), mask_full)

        # 2. 在原图上画分割结果
        if config.overlay_suffix:
            img = cv2.imread(r["path"])
            if img is not None:
                overlay = draw_overlay(img, mask_full, config.overlay_color, config.overlay_alpha)
                cv2.imwrite(os.path.join(output_dir, stem + config.overlay_suffix), overlay)

    saved = []
    if config.mask_suffix:
        saved.append("掩码")
    if config.overlay_suffix:
        saved.append("叠加图")
    print(f"共保存 {len(results)} 张{' + '.join(saved)}到: {output_dir}")


def draw_overlay(image, mask, color=(0, 255, 0), alpha=0.4):
    """将分割掩码以半透明颜色叠加到原图上"""
    overlay = image.copy()
    mask_bool = mask > 0

    # 在掩码区域叠加颜色
    for c in range(3):
        overlay[..., c] = np.where(mask_bool, overlay[..., c] * (1 - alpha) + color[c] * alpha, overlay[..., c])

    # 画掩码轮廓 (更清晰地显示边界)
    mask_uint8 = mask.astype(np.uint8)
    contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(overlay, contours, -1, color, 2)

    return overlay.astype(np.uint8)


# ===================== 收集图片 =====================
def collect_images(input_path):
    """收集输入路径下的所有图片文件"""
    if os.path.isfile(input_path):
        return [input_path]

    if os.path.isdir(input_path):
        exts = ["*.png", "*.jpg", "*.jpeg", "*.tif", "*.tiff", "*.bmp"]
        paths = []
        for ext in exts:
            paths.extend(sorted(glob(os.path.join(input_path, ext))))
        return sorted(set(paths))

    raise FileNotFoundError(f"输入路径不存在: {input_path}")


# ===================== 主函数 =====================
def main():
    print(f"输入: {config.input}")
    print(f"输出: {config.output}")
    print(f"Checkpoint: {config.checkpoint}")

    # 1. 设备
    if config.device == "cuda" and not torch.cuda.is_available():
        print("CUDA 不可用, 回退到 CPU")
        config.device = "cpu"
    device = torch.device(config.device)
    print(f"设备: {device}")

    # 2. 收集图片
    img_paths = collect_images(config.input)
    if len(img_paths) == 0:
        print("未找到任何图片!")
        return
    print(f"找到 {len(img_paths)} 张图片")

    # 3. 加载模型
    if not os.path.exists(config.checkpoint):
        print(f"错误: checkpoint 不存在: {config.checkpoint}")
        sys.exit(1)
    model = load_model(config.checkpoint, device)

    # 4. DataLoader
    dataset = InferenceDataset(img_paths, config.image_size)
    loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False, num_workers=0)

    # 5. 推理
    results = infer_batch(model, loader, device, config.threshold)

    # 6. 保存
    save_results(results, config.output, config)

    print(f"推理完成! 输出目录: {config.output}")


if __name__ == "__main__":
    main()

运行后会直接示原图分割预测结果,快速验证模型效果。

八、进阶技巧

  1. 多类别分割

    1. 修改 CLASSES 为类别总数;

    2. 损失函数改为 smp.losses.DiceLoss(smp.losses.MULTICLASS_MODE)

    3. 激活函数改为 activation="softmax"

  2. 骨干网络替换

支持上百种骨干网络,轻量级用 mobilenet_v2efficientnet-b0,高精度用 resnet50resnet101

  1. 训练优化

    1. 加入学习率调度器 torch.optim.lr_scheduler.ReduceLROnPlateau

    2. 增大数据集增强力度;

    3. 微调输入图像分辨率(512×512 精度更高)。

九、总结

segmentation-models-pytorch 是语义分割的效率神器,完美解决了「模型搭建难、训练繁琐」的痛点:

  1. 一行代码定义任意分割模型,无需手动构建网络;

  2. 预训练权重 + 专用损失函数,快速收敛到高精度;

  3. 全流程代码简洁,新手可直接用于比赛、项目、毕业设计

本文的代码可以直接适配医学分割、 遥感 分割、工业检测等几乎所有语义分割任务,替换数据集路径即可快速训练自己的模型。

tushengzhihe/segmentation-models-pytorch: segmentation 训练代码

相关推荐
m0_738120727 小时前
后渗透维权提权基础——CTF模拟红队进行权限维持(一)
服务器·前端·python·安全·web安全·php
qq_283720057 小时前
基于 Transformer,Python 搭建中文文本分类大模型:从零到一实现企业级文本分类
python·分类·transformer
雨声不在7 小时前
python relative_to
python
李景琰7 小时前
Spring AI + Milvus向量数据库:企业级RAG架构实战
人工智能·spring·milvus
aidesignplus7 小时前
扩散模型在自动驾驶路径规划中的技术演进与产业格局
人工智能·机器学习·自动驾驶
ai产品老杨7 小时前
深度解析:基于 Docker 与异构计算的工业级 AI 视频管理平台架构 —— 从 GB28181 接入到全平台源码交付
人工智能·docker·音视频
2301_776045237 小时前
什么叫流动性:数字货币与美股市场解读
人工智能·区块链
AI技术增长7 小时前
Pytorch图像去噪实战(十一):Diffusion扩散模型去噪入门,从噪声预测理解生成式图像恢复
pytorch·深度学习·机器学习·cnn·transformer
玛卡巴卡ldf7 小时前
【Springboot升级AI】(大模型部署)LangChain4j、会话记忆、隔离消失持久化问题、ollama、RAG知识库、Tools工具
java·开发语言·人工智能·spring boot·后端·springboot