深度学习语义分割完全指南:从原理到实战

本文将系统讲解语义分割的核心概念、经典网络架构、损失函数设计以及完整的PyTorch实战代码,帮助读者从零掌握这一计算机视觉核心技术。


一、什么是语义分割

1.1 定义与任务描述

语义分割(Semantic Segmentation)是计算机视觉中的一项基础任务,其目标是对图像中的每个像素进行分类,将其归属到预定义的语义类别中。

与其他视觉任务的区别:

任务 输出 粒度
图像分类 整张图像的类别标签 图像级
目标检测 物体的边界框+类别 区域级
语义分割 每个像素的类别 像素级
实例分割 每个像素的类别+实例ID 像素级+实例级

1.2 应用场景

语义分割技术已广泛应用于多个领域:

自动驾驶:识别道路、车辆、行人、交通标志等,为决策规划提供环境感知信息。

医学影像:分割器官、病灶区域,辅助医生进行诊断和手术规划。

遥感图像分析:土地利用分类、建筑物提取、变化检测等。

机器人导航:场景理解、可行驶区域识别、障碍物检测。

人像处理:背景替换、虚拟试衣、视频会议背景虚化。


二、语义分割的发展历程

2.1 传统方法时代

在深度学习兴起之前,语义分割主要依赖手工设计的特征:

复制代码
传统方法流程:
图像 → 特征提取(SIFT/HOG) → 超像素分割 → 特征聚合 → 分类器(SVM/RF) → 分割结果

这些方法的局限性在于特征表达能力有限,难以处理复杂场景。

2.2 深度学习革命

2015年,Long等人提出的全卷积网络(FCN) 开创了深度学习语义分割的新时代,此后涌现出众多经典架构:

复制代码
时间线:
2015: FCN (CVPR)
2015: U-Net (MICCAI)
2015: SegNet (TPAMI)
2016: DeepLab v2 (TPAMI)
2017: PSPNet (CVPR)
2018: DeepLab v3+ (ECCV)
2020: SETR (Transformer)
2021: SegFormer
2023+: SAM, Foundation Models

三、核心网络架构详解

3.1 FCN:开山之作

FCN的核心思想是将分类网络的全连接层替换为卷积层,实现任意尺寸输入的端到端分割。

关键创新点

  1. 全卷积化:移除全连接层,保持空间信息

  2. 跳跃连接:融合多尺度特征,恢复空间细节

  3. 转置卷积:上采样恢复原始分辨率

    FCN-8s 结构示意:

    Input(H×W×3)

    Conv1 + Pool → 1/2

    Conv2 + Pool → 1/4

    Conv3 + Pool → 1/8 ──────────┐
    ↓ │
    Conv4 + Pool → 1/16 ─────┐ │
    ↓ │ │
    Conv5 + Pool → 1/32 │ │
    ↓ │ │
    FC6 → FC7 → Score │ │
    ↓ │ │
    Upsample 2× ─────────────┼────┤
    ↓ ↓ │
    + (Fuse) ←── Score_pool4 │
    ↓ │
    Upsample 2× ──────────────────┤
    ↓ ↓
    + (Fuse) ←───────── Score_pool3

    Upsample 8×

    Output(H×W×C)

3.2 U-Net:对称编解码结构

U-Net采用对称的编码器-解码器结构,通过跳跃连接精确恢复空间信息,在医学图像分割中表现卓越。

复制代码
U-Net 结构示意:

编码器(下采样)                    解码器(上采样)
    
Input ─────────────────────────────────────── Skip Connection ─→ Output
  │                                                              ↑
  ↓ Conv×2                                              Conv×2   │
[64] ─────────────────────────────────────────────────────→ [64]
  │                                                              ↑
  ↓ MaxPool                                            UpConv    │
[128] ────────────────────────────────────────────────→ [128]
  │                                                              ↑
  ↓ MaxPool                                            UpConv    │
[256] ────────────────────────────────────────────────→ [256]
  │                                                              ↑
  ↓ MaxPool                                            UpConv    │
[512] ────────────────────────────────────────────────→ [512]
  │                                                              ↑
  ↓ MaxPool                                            UpConv    │
[1024] ──────────────── Bottleneck ────────────────→ [1024]

U-Net的优势

  1. 对称结构便于特征传递
  2. 跳跃连接保留高分辨率细节
  3. 在小样本数据集上也能取得良好效果

3.3 DeepLab系列:空洞卷积与多尺度

DeepLab系列引入了空洞卷积(Atrous/Dilated Convolution),在不增加参数的情况下扩大感受野。

空洞卷积原理

复制代码
标准3×3卷积(rate=1):        空洞卷积(rate=2):
                              
■ ■ ■                         ■ □ ■ □ ■
■ ■ ■                         □ □ □ □ □
■ ■ ■                         ■ □ ■ □ ■
                              □ □ □ □ □
感受野: 3×3                    ■ □ ■ □ ■
                              
                              感受野: 5×5(使用相同参数)

ASPP模块(Atrous Spatial Pyramid Pooling)

复制代码
             ┌─── 1×1 Conv ──────────────┐
             │                           │
Input ───────┼─── 3×3 Conv, rate=6 ──────┼──── Concat ─── 1×1 Conv ─── Output
             │                           │
             ├─── 3×3 Conv, rate=12 ─────┤
             │                           │
             ├─── 3×3 Conv, rate=18 ─────┤
             │                           │
             └─── Global Average Pool ───┘

3.4 PSPNet:金字塔池化

PSPNet提出金字塔池化模块(Pyramid Pooling Module),聚合不同尺度的全局上下文信息。

复制代码
PPM模块:

                ┌── Pool 1×1 → Conv → Upsample ──┐
                │                                │
Feature Map ────┼── Pool 2×2 → Conv → Upsample ──┼── Concat ─→ Output
    (H×W×C)     │                                │
                ├── Pool 3×3 → Conv → Upsample ──┤
                │                                │
                └── Pool 6×6 → Conv → Upsample ──┘

四、损失函数设计

语义分割中常用的损失函数及其适用场景:

4.1 交叉熵损失(Cross Entropy Loss)

最基础的分类损失函数:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossEntropyLoss2d(nn.Module):
    """
    二维交叉熵损失
    适用于类别均衡的数据集
    """
    def __init__(self, weight=None, ignore_index=-100):
        super().__init__()
        self.weight = weight
        self.ignore_index = ignore_index
    
    def forward(self, pred, target):
        """
        Args:
            pred: [B, C, H, W] 预测logits
            target: [B, H, W] 真实标签
        """
        return F.cross_entropy(
            pred, target, 
            weight=self.weight, 
            ignore_index=self.ignore_index
        )

4.2 Dice Loss

针对类别不平衡问题,直接优化Dice系数:

python 复制代码
class DiceLoss(nn.Module):
    """
    Dice损失
    适用于类别不平衡场景,如医学图像分割
    """
    def __init__(self, smooth=1e-5):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        """
        Args:
            pred: [B, C, H, W] 预测logits
            target: [B, H, W] 真实标签
        """
        num_classes = pred.shape[1]
        pred = F.softmax(pred, dim=1)
        
        # 将target转换为one-hot编码
        target_onehot = F.one_hot(target, num_classes)  # [B, H, W, C]
        target_onehot = target_onehot.permute(0, 3, 1, 2).float()  # [B, C, H, W]
        
        # 计算Dice系数
        intersection = (pred * target_onehot).sum(dim=(2, 3))
        union = pred.sum(dim=(2, 3)) + target_onehot.sum(dim=(2, 3))
        
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        
        return 1 - dice.mean()

4.3 Focal Loss

针对难样本挖掘,降低易分类样本的权重:

python 复制代码
class FocalLoss(nn.Module):
    """
    Focal Loss
    适用于正负样本极度不平衡的场景
    """
    def __init__(self, alpha=0.25, gamma=2.0, ignore_index=-100):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index
    
    def forward(self, pred, target):
        """
        FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)
        """
        ce_loss = F.cross_entropy(
            pred, target, 
            reduction='none', 
            ignore_index=self.ignore_index
        )
        
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        return focal_loss.mean()

4.4 组合损失

实际应用中常将多种损失组合使用:

python 复制代码
class CombinedLoss(nn.Module):
    """
    组合损失:CE + Dice
    兼顾像素级精度和区域级一致性
    """
    def __init__(self, ce_weight=0.5, dice_weight=0.5):
        super().__init__()
        self.ce_loss = CrossEntropyLoss2d()
        self.dice_loss = DiceLoss()
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight
    
    def forward(self, pred, target):
        ce = self.ce_loss(pred, target)
        dice = self.dice_loss(pred, target)
        return self.ce_weight * ce + self.dice_weight * dice

五、评价指标

5.1 常用指标定义

python 复制代码
import numpy as np

class SegmentationMetrics:
    """语义分割评价指标计算器"""
    
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.reset()
    
    def reset(self):
        """重置混淆矩阵"""
        self.confusion_matrix = np.zeros((self.num_classes, self.num_classes))
    
    def update(self, pred, target):
        """
        更新混淆矩阵
        Args:
            pred: [H, W] 预测结果
            target: [H, W] 真实标签
        """
        mask = (target >= 0) & (target < self.num_classes)
        label = self.num_classes * target[mask].astype(int) + pred[mask]
        count = np.bincount(label, minlength=self.num_classes**2)
        self.confusion_matrix += count.reshape(self.num_classes, self.num_classes)
    
    def get_pixel_accuracy(self):
        """
        像素准确率 (Pixel Accuracy)
        PA = 正确分类的像素数 / 总像素数
        """
        acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return acc
    
    def get_mean_pixel_accuracy(self):
        """
        平均像素准确率 (Mean Pixel Accuracy)
        MPA = 各类别PA的平均值
        """
        acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        acc = np.nanmean(acc)
        return acc
    
    def get_iou(self):
        """
        交并比 (Intersection over Union)
        IoU = TP / (TP + FP + FN)
        """
        intersection = np.diag(self.confusion_matrix)
        union = (self.confusion_matrix.sum(axis=1) + 
                 self.confusion_matrix.sum(axis=0) - 
                 np.diag(self.confusion_matrix))
        iou = intersection / (union + 1e-10)
        return iou
    
    def get_mean_iou(self):
        """
        平均交并比 (Mean IoU)
        最常用的语义分割评价指标
        """
        iou = self.get_iou()
        return np.nanmean(iou)
    
    def get_frequency_weighted_iou(self):
        """
        频率加权交并比 (Frequency Weighted IoU)
        考虑各类别的像素频率
        """
        freq = self.confusion_matrix.sum(axis=1) / self.confusion_matrix.sum()
        iou = self.get_iou()
        fwiou = (freq[freq > 0] * iou[freq > 0]).sum()
        return fwiou
    
    def get_dice_coefficient(self):
        """
        Dice系数
        Dice = 2*TP / (2*TP + FP + FN) = 2*IoU / (1 + IoU)
        """
        intersection = np.diag(self.confusion_matrix)
        dice = (2 * intersection / 
                (self.confusion_matrix.sum(axis=1) + 
                 self.confusion_matrix.sum(axis=0) + 1e-10))
        return dice

5.2 指标对比

指标 优点 缺点 适用场景
Pixel Accuracy 直观易懂 对类别不平衡敏感 初步评估
Mean IoU 平衡各类别贡献 对小目标敏感 通用评估
FWIoU 考虑类别频率 大类别主导 实际应用
Dice 等价于F1-score 与IoU高度相关 医学影像

六、完整实战代码

6.1 数据集定义

python 复制代码
import os
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

class SegmentationDataset(Dataset):
    """
    通用语义分割数据集
    目录结构:
    root/
        images/
            xxx.jpg
            yyy.jpg
        masks/
            xxx.png  # 单通道,像素值为类别ID
            yyy.png
    """
    
    def __init__(self, root_dir, split='train', transform=None, num_classes=21):
        """
        Args:
            root_dir: 数据集根目录
            split: 'train' 或 'val'
            transform: albumentations变换
            num_classes: 类别数量
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.num_classes = num_classes
        
        # 读取图像列表
        self.images_dir = os.path.join(root_dir, split, 'images')
        self.masks_dir = os.path.join(root_dir, split, 'masks')
        self.images = sorted(os.listdir(self.images_dir))
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # 读取图像
        img_name = self.images[idx]
        img_path = os.path.join(self.images_dir, img_name)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # 读取mask
        mask_name = os.path.splitext(img_name)[0] + '.png'
        mask_path = os.path.join(self.masks_dir, mask_name)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        # 数据增强
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        return image, mask.long()


def get_transforms(split, img_size=512):
    """
    获取数据增强变换
    """
    if split == 'train':
        return A.Compose([
            A.RandomResizedCrop(height=img_size, width=img_size, scale=(0.5, 2.0)),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            A.GaussNoise(var_limit=(10, 50), p=0.3),
            A.GaussianBlur(blur_limit=(3, 7), p=0.3),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(height=img_size, width=img_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

6.2 U-Net模型实现

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(Conv -> BN -> ReLU) × 2"""
    
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if mid_channels is None:
            mid_channels = out_channels
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """下采样:MaxPool + DoubleConv"""
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """上采样:Upsample + Concat + DoubleConv"""
    
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # 处理尺寸不匹配
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNet(nn.Module):
    """
    U-Net语义分割网络
    
    Args:
        n_channels: 输入通道数
        n_classes: 输出类别数
        bilinear: 是否使用双线性插值上采样
        base_channels: 基础通道数
    """
    
    def __init__(self, n_channels=3, n_classes=21, bilinear=True, base_channels=64):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        # 编码器
        self.inc = DoubleConv(n_channels, base_channels)
        self.down1 = Down(base_channels, base_channels * 2)
        self.down2 = Down(base_channels * 2, base_channels * 4)
        self.down3 = Down(base_channels * 4, base_channels * 8)
        
        factor = 2 if bilinear else 1
        self.down4 = Down(base_channels * 8, base_channels * 16 // factor)
        
        # 解码器
        self.up1 = Up(base_channels * 16, base_channels * 8 // factor, bilinear)
        self.up2 = Up(base_channels * 8, base_channels * 4 // factor, bilinear)
        self.up3 = Up(base_channels * 4, base_channels * 2 // factor, bilinear)
        self.up4 = Up(base_channels * 2, base_channels, bilinear)
        
        # 输出层
        self.outc = nn.Conv2d(base_channels, n_classes, kernel_size=1)
    
    def forward(self, x):
        # 编码
        x1 = self.inc(x)      # [B, 64, H, W]
        x2 = self.down1(x1)   # [B, 128, H/2, W/2]
        x3 = self.down2(x2)   # [B, 256, H/4, W/4]
        x4 = self.down3(x3)   # [B, 512, H/8, W/8]
        x5 = self.down4(x4)   # [B, 512, H/16, W/16]
        
        # 解码
        x = self.up1(x5, x4)  # [B, 256, H/8, W/8]
        x = self.up2(x, x3)   # [B, 128, H/4, W/4]
        x = self.up3(x, x2)   # [B, 64, H/2, W/2]
        x = self.up4(x, x1)   # [B, 64, H, W]
        
        logits = self.outc(x) # [B, n_classes, H, W]
        return logits

6.3 训练脚本

python 复制代码
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import logging

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
logger = logging.getLogger(__name__)


class Trainer:
    """语义分割训练器"""
    
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # 初始化模型
        self.model = UNet(
            n_channels=config['n_channels'],
            n_classes=config['n_classes'],
            bilinear=config.get('bilinear', True)
        ).to(self.device)
        
        # 损失函数
        self.criterion = CombinedLoss(ce_weight=0.5, dice_weight=0.5)
        
        # 优化器
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config['lr'],
            weight_decay=config.get('weight_decay', 1e-4)
        )
        
        # 学习率调度器
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=config['epochs'],
            eta_min=config['lr'] * 0.01
        )
        
        # 混合精度训练
        self.scaler = GradScaler()
        
        # 评估指标
        self.metrics = SegmentationMetrics(config['n_classes'])
        
        # 最佳模型
        self.best_miou = 0.0
    
    def train_epoch(self, train_loader):
        """训练一个epoch"""
        self.model.train()
        total_loss = 0.0
        
        pbar = tqdm(train_loader, desc='Training')
        for images, masks in pbar:
            images = images.to(self.device)
            masks = masks.to(self.device)
            
            # 前向传播(混合精度)
            with autocast():
                outputs = self.model(images)
                loss = self.criterion(outputs, masks)
            
            # 反向传播
            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        return total_loss / len(train_loader)
    
    @torch.no_grad()
    def validate(self, val_loader):
        """验证"""
        self.model.eval()
        self.metrics.reset()
        total_loss = 0.0
        
        for images, masks in tqdm(val_loader, desc='Validating'):
            images = images.to(self.device)
            masks = masks.to(self.device)
            
            outputs = self.model(images)
            loss = self.criterion(outputs, masks)
            total_loss += loss.item()
            
            # 计算预测结果
            preds = outputs.argmax(dim=1).cpu().numpy()
            targets = masks.cpu().numpy()
            
            # 更新指标
            for pred, target in zip(preds, targets):
                self.metrics.update(pred, target)
        
        # 计算指标
        val_loss = total_loss / len(val_loader)
        miou = self.metrics.get_mean_iou()
        pixel_acc = self.metrics.get_pixel_accuracy()
        class_iou = self.metrics.get_iou()
        
        return {
            'loss': val_loss,
            'miou': miou,
            'pixel_acc': pixel_acc,
            'class_iou': class_iou
        }
    
    def train(self, train_loader, val_loader):
        """完整训练流程"""
        for epoch in range(self.config['epochs']):
            logger.info(f"\nEpoch {epoch+1}/{self.config['epochs']}")
            logger.info(f"Learning Rate: {self.scheduler.get_last_lr()[0]:.6f}")
            
            # 训练
            train_loss = self.train_epoch(train_loader)
            logger.info(f"Train Loss: {train_loss:.4f}")
            
            # 验证
            val_results = self.validate(val_loader)
            logger.info(f"Val Loss: {val_results['loss']:.4f}")
            logger.info(f"Val mIoU: {val_results['miou']:.4f}")
            logger.info(f"Val Pixel Acc: {val_results['pixel_acc']:.4f}")
            
            # 保存最佳模型
            if val_results['miou'] > self.best_miou:
                self.best_miou = val_results['miou']
                self.save_checkpoint('best_model.pth')
                logger.info(f"New best model saved! mIoU: {self.best_miou:.4f}")
            
            # 更新学习率
            self.scheduler.step()
            
            # 定期保存
            if (epoch + 1) % self.config.get('save_interval', 10) == 0:
                self.save_checkpoint(f'checkpoint_epoch_{epoch+1}.pth')
    
    def save_checkpoint(self, filename):
        """保存检查点"""
        os.makedirs(self.config['save_dir'], exist_ok=True)
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_miou': self.best_miou
        }, os.path.join(self.config['save_dir'], filename))
    
    def load_checkpoint(self, filepath):
        """加载检查点"""
        checkpoint = torch.load(filepath, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_miou = checkpoint['best_miou']


# 主函数
def main():
    # 配置参数
    config = {
        'data_root': './data/voc',
        'n_channels': 3,
        'n_classes': 21,
        'img_size': 512,
        'batch_size': 8,
        'epochs': 100,
        'lr': 1e-3,
        'weight_decay': 1e-4,
        'save_dir': './checkpoints',
        'num_workers': 4
    }
    
    # 创建数据集
    train_dataset = SegmentationDataset(
        config['data_root'],
        split='train',
        transform=get_transforms('train', config['img_size']),
        num_classes=config['n_classes']
    )
    
    val_dataset = SegmentationDataset(
        config['data_root'],
        split='val',
        transform=get_transforms('val', config['img_size']),
        num_classes=config['n_classes']
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True
    )
    
    # 创建训练器并开始训练
    trainer = Trainer(config)
    trainer.train(train_loader, val_loader)


if __name__ == '__main__':
    main()

6.4 推理与可视化

python 复制代码
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt


class Inferencer:
    """语义分割推理器"""
    
    # VOC数据集颜色映射
    VOC_COLORMAP = np.array([
        [0, 0, 0],        # 背景
        [128, 0, 0],      # 飞机
        [0, 128, 0],      # 自行车
        [128, 128, 0],    # 鸟
        [0, 0, 128],      # 船
        [128, 0, 128],    # 瓶子
        [0, 128, 128],    # 公交车
        [128, 128, 128],  # 汽车
        [64, 0, 0],       # 猫
        [192, 0, 0],      # 椅子
        [64, 128, 0],     # 牛
        [192, 128, 0],    # 餐桌
        [64, 0, 128],     # 狗
        [192, 0, 128],    # 马
        [64, 128, 128],   # 摩托车
        [192, 128, 128],  # 人
        [0, 64, 0],       # 盆栽
        [128, 64, 0],     # 羊
        [0, 192, 0],      # 沙发
        [128, 192, 0],    # 火车
        [0, 64, 128],     # 电视
    ], dtype=np.uint8)
    
    def __init__(self, model_path, num_classes=21, img_size=512, device='cuda'):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.img_size = img_size
        self.num_classes = num_classes
        
        # 加载模型
        self.model = UNet(n_channels=3, n_classes=num_classes)
        checkpoint = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(self.device)
        self.model.eval()
        
        # 图像预处理参数
        self.mean = np.array([0.485, 0.456, 0.406])
        self.std = np.array([0.229, 0.224, 0.225])
    
    def preprocess(self, image):
        """图像预处理"""
        # 调整尺寸
        img = cv2.resize(image, (self.img_size, self.img_size))
        img = img.astype(np.float32) / 255.0
        
        # 归一化
        img = (img - self.mean) / self.std
        
        # 转换为张量
        img = torch.from_numpy(img.transpose(2, 0, 1)).float()
        img = img.unsqueeze(0)
        
        return img
    
    @torch.no_grad()
    def predict(self, image):
        """
        单张图像推理
        Args:
            image: BGR图像 (H, W, 3)
        Returns:
            pred: 预测标签 (H, W)
            prob: 类别概率 (H, W, C)
        """
        orig_h, orig_w = image.shape[:2]
        
        # 预处理
        img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        img_tensor = self.preprocess(img).to(self.device)
        
        # 推理
        output = self.model(img_tensor)
        prob = torch.softmax(output, dim=1)
        pred = output.argmax(dim=1)
        
        # 恢复原始尺寸
        pred = pred.squeeze().cpu().numpy()
        pred = cv2.resize(pred.astype(np.uint8), (orig_w, orig_h), 
                         interpolation=cv2.INTER_NEAREST)
        
        prob = prob.squeeze().permute(1, 2, 0).cpu().numpy()
        prob = cv2.resize(prob, (orig_w, orig_h))
        
        return pred, prob
    
    def colorize(self, pred):
        """将预测标签转换为彩色图像"""
        color_mask = self.VOC_COLORMAP[pred]
        return color_mask
    
    def visualize(self, image, pred, alpha=0.5, save_path=None):
        """可视化分割结果"""
        # 转换颜色空间
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # 获取彩色mask
        color_mask = self.colorize(pred)
        
        # 叠加
        overlay = cv2.addWeighted(image_rgb, 1 - alpha, color_mask, alpha, 0)
        
        # 绘制
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        axes[0].imshow(image_rgb)
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        axes[1].imshow(color_mask)
        axes[1].set_title('Segmentation Mask')
        axes[1].axis('off')
        
        axes[2].imshow(overlay)
        axes[2].set_title('Overlay')
        axes[2].axis('off')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            plt.close()
        else:
            plt.show()
        
        return overlay


# 使用示例
def inference_demo():
    # 初始化推理器
    inferencer = Inferencer(
        model_path='./checkpoints/best_model.pth',
        num_classes=21,
        img_size=512
    )
    
    # 读取图像
    image = cv2.imread('test_image.jpg')
    
    # 推理
    pred, prob = inferencer.predict(image)
    
    # 可视化
    inferencer.visualize(image, pred, save_path='result.png')
    
    # 获取各类别置信度
    print(f"预测类别: {np.unique(pred)}")
    print(f"类别分布: {np.bincount(pred.flatten(), minlength=21)}")


if __name__ == '__main__':
    inference_demo()

七、进阶技巧

7.1 数据增强策略

针对语义分割任务的专用增强:

python 复制代码
import albumentations as A

def get_strong_augmentation(img_size=512):
    """强数据增强"""
    return A.Compose([
        # 几何变换
        A.RandomResizedCrop(height=img_size, width=img_size, scale=(0.5, 2.0)),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, p=0.5),
        A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.3),
        A.GridDistortion(p=0.3),
        A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=0.3),
        
        # 颜色变换
        A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
        A.CLAHE(clip_limit=4.0, p=0.3),
        A.RandomGamma(gamma_limit=(80, 120), p=0.3),
        
        # 噪声与模糊
        A.GaussNoise(var_limit=(10, 50), p=0.3),
        A.GaussianBlur(blur_limit=(3, 7), p=0.3),
        A.MotionBlur(blur_limit=7, p=0.2),
        
        # 遮挡
        A.CoarseDropout(max_holes=8, max_height=32, max_width=32, 
                       fill_value=0, mask_fill_value=255, p=0.3),
        
        # 归一化
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

7.2 多尺度测试(TTA)

python 复制代码
def multi_scale_inference(model, image, scales=[0.5, 0.75, 1.0, 1.25, 1.5]):
    """
    多尺度测试时增强
    """
    h, w = image.shape[:2]
    final_prob = np.zeros((h, w, model.n_classes), dtype=np.float32)
    
    for scale in scales:
        # 缩放图像
        new_h, new_w = int(h * scale), int(w * scale)
        scaled_img = cv2.resize(image, (new_w, new_h))
        
        # 推理
        _, prob = model.predict(scaled_img)
        
        # 恢复尺寸并累加
        prob = cv2.resize(prob, (w, h))
        final_prob += prob
        
        # 水平翻转
        flipped_img = cv2.flip(scaled_img, 1)
        _, prob_flip = model.predict(flipped_img)
        prob_flip = cv2.flip(prob_flip, 1)
        prob_flip = cv2.resize(prob_flip, (w, h))
        final_prob += prob_flip
    
    # 平均并取argmax
    final_prob /= (len(scales) * 2)
    pred = np.argmax(final_prob, axis=-1)
    
    return pred

7.3 模型轻量化

python 复制代码
class LightUNet(nn.Module):
    """
    轻量级U-Net
    使用深度可分离卷积减少参数量
    """
    
    def __init__(self, n_channels=3, n_classes=21, base_channels=32):
        super().__init__()
        
        self.inc = self._make_layer(n_channels, base_channels)
        self.down1 = self._down_layer(base_channels, base_channels * 2)
        self.down2 = self._down_layer(base_channels * 2, base_channels * 4)
        self.down3 = self._down_layer(base_channels * 4, base_channels * 8)
        self.down4 = self._down_layer(base_channels * 8, base_channels * 8)
        
        self.up1 = self._up_layer(base_channels * 16, base_channels * 4)
        self.up2 = self._up_layer(base_channels * 8, base_channels * 2)
        self.up3 = self._up_layer(base_channels * 4, base_channels)
        self.up4 = self._up_layer(base_channels * 2, base_channels)
        
        self.outc = nn.Conv2d(base_channels, n_classes, kernel_size=1)
    
    def _depthwise_separable_conv(self, in_ch, out_ch):
        """深度可分离卷积"""
        return nn.Sequential(
            # 深度卷积
            nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1, groups=in_ch, bias=False),
            nn.BatchNorm2d(in_ch),
            nn.ReLU(inplace=True),
            # 逐点卷积
            nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def _make_layer(self, in_ch, out_ch):
        return nn.Sequential(
            self._depthwise_separable_conv(in_ch, out_ch),
            self._depthwise_separable_conv(out_ch, out_ch)
        )
    
    def _down_layer(self, in_ch, out_ch):
        return nn.Sequential(
            nn.MaxPool2d(2),
            self._make_layer(in_ch, out_ch)
        )
    
    def _up_layer(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            self._make_layer(in_ch, out_ch)
        )
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        x = self.up1(torch.cat([x5, x4], dim=1))
        x = self.up2(torch.cat([x, x3], dim=1))
        x = self.up3(torch.cat([x, x2], dim=1))
        x = self.up4(torch.cat([x, x1], dim=1))
        
        return self.outc(x)

八、常见问题与解决方案

8.1 类别不平衡

问题:背景像素远多于前景,模型偏向预测背景。

解决方案

  1. 使用加权交叉熵损失
  2. 采用Dice Loss或Focal Loss
  3. 在线难样本挖掘(OHEM)
  4. 过采样小类别样本

8.2 边界模糊

问题:分割边界不够清晰,存在锯齿。

解决方案

  1. 使用边界感知损失(Boundary Loss)
  2. 添加CRF后处理
  3. 使用更深的解码器
  4. 增加边界增强的数据增强

8.3 小目标丢失

问题:小目标容易被忽略或分割不完整。

解决方案

  1. 使用多尺度特征融合
  2. 提高输入分辨率
  3. 使用注意力机制
  4. 针对小目标的专门增强

九、总结

语义分割作为计算机视觉的核心任务,经历了从传统方法到深度学习的演进。本文系统介绍了FCN、U-Net、DeepLab等经典架构的原理,详细讲解了损失函数设计、评价指标计算,并提供了完整的PyTorch实战代码。

关键要点回顾

  1. 编码器-解码器结构是语义分割的主流范式
  2. 跳跃连接对恢复空间细节至关重要
  3. 空洞卷积可在不丢失分辨率的情况下扩大感受野
  4. 针对类别不平衡,组合损失通常比单一损失效果更好
  5. mIoU是最常用的评价指标

希望这篇文章能帮助你深入理解语义分割技术。如有问题,欢迎在评论区交流讨论!


参考文献

  1. Long J, et al. "Fully Convolutional Networks for Semantic Segmentation." CVPR 2015.
  2. Ronneberger O, et al. "U-Net: Convolutional Networks for Biomedical Image Segmentation." MICCAI 2015.
  3. Chen L C, et al. "DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs." TPAMI 2018.
  4. Zhao H, et al. "Pyramid Scene Parsing Network." CVPR 2017.

作者:Jia

更多技术文章,欢迎关注我的CSDN博客!

相关推荐
毕不了业的硏䆒僧2 小时前
ARM架构的ModuleNotFoundError: No module named ‘thop‘
深度学习·annconda环境
RoboWizard2 小时前
8TB SSD还有掉速问题吗?
人工智能·缓存·智能手机·电脑·金士顿
l14372332672 小时前
电影解说详细教程:从「一条视频」到「持续更新」
人工智能
MUTA️2 小时前
BCEWithLogitsLoss
人工智能
deephub2 小时前
使用 tsfresh 和 AutoML 进行时间序列特征工程
人工智能·python·机器学习·特征工程·时间序列
静听松涛1332 小时前
从模式识别到逻辑推理的认知跨越
人工智能·机器学习
牛客企业服务2 小时前
AI面试选型策略:2026年五大核心维度解析
人工智能
啊阿狸不会拉杆2 小时前
《机器学习》第四章-无监督学习
人工智能·学习·算法·机器学习·计算机视觉
Duang007_2 小时前
【万字学习总结】API设计与接口开发实战指南
开发语言·javascript·人工智能·python·学习