
语义分割:FCN与U-Net(全卷积网络、跳跃连接、医学图像应用)
一、什么是语义分割?
1.1 图像理解的不同层次
python
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import warnings
warnings.filterwarnings('ignore')
print("=" * 60)
print("语义分割:像素级的分类")
print("=" * 60)
# 图像理解层次对比
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
# 创建一个示例图像
np.random.seed(42)
img = np.random.rand(100, 100, 3)
# 1. 图像分类
ax1 = axes[0]
ax1.imshow(img)
ax1.set_title('图像分类\n这是什么物体?', fontsize=10)
ax1.axis('off')
ax1.text(50, 50, '→ 猫', ha='center', va='center', fontsize=12, color='white')
# 2. 目标检测
ax2 = axes[1]
ax2.imshow(img)
rect = Rectangle((30, 30), 40, 40, linewidth=2, edgecolor='r', facecolor='none')
ax2.add_patch(rect)
ax2.set_title('目标检测\n物体在哪里?', fontsize=10)
ax2.axis('off')
# 3. 实例分割
ax3 = axes[2]
ax3.imshow(img)
# 模拟实例分割掩码
mask1 = np.zeros((100, 100))
mask1[30:70, 30:70] = 0.5
mask2 = np.zeros((100, 100))
mask2[50:80, 60:90] = 0.5
ax3.imshow(mask1, alpha=0.3, cmap='Reds')
ax3.imshow(mask2, alpha=0.3, cmap='Blues')
ax3.set_title('实例分割\n区分不同物体', fontsize=10)
ax3.axis('off')
# 4. 语义分割
ax4 = axes[3]
ax4.imshow(img)
# 模拟语义分割掩码
semantic_mask = np.zeros((100, 100))
semantic_mask[30:70, 30:70] = 1
semantic_mask[50:80, 60:90] = 1
ax4.imshow(semantic_mask, alpha=0.3, cmap='viridis')
ax4.set_title('语义分割\n像素级分类', fontsize=10)
ax4.axis('off')
plt.suptitle('计算机视觉任务层次', fontsize=14)
plt.tight_layout()
plt.show()
print("\n💡 语义分割定义:")
print(" 对图像中的每个像素进行分类,赋予语义标签")
print(" 输出: 与输入同尺寸的分割图(每个像素一个类别)")
print("\n应用场景:")
print(" - 自动驾驶: 道路、车辆、行人分割")
print(" - 医学影像: 器官、肿瘤分割")
print(" - 遥感图像: 土地覆盖分类")
print(" - 人像分割: 背景虚化")
二、FCN:全卷积网络
2.1 全连接层 vs 全卷积层
python
def fcn_principle():
"""FCN原理讲解"""
print("\n" + "=" * 60)
print("FCN:全卷积网络")
print("=" * 60)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# 传统CNN(有全连接层)
ax1 = axes[0]
ax1.axis('off')
ax1.set_title('传统CNN\n(有全连接层)', fontsize=10)
# 卷积部分
conv_rect = plt.Rectangle((0.1, 0.6), 0.3, 0.2,
facecolor='lightblue', ec='black')
ax1.add_patch(conv_rect)
ax1.text(0.25, 0.7, '卷积层', ha='center', va='center', fontsize=8)
# 全连接部分
fc_rect = plt.Rectangle((0.5, 0.6), 0.2, 0.2,
facecolor='lightcoral', ec='black')
ax1.add_patch(fc_rect)
ax1.text(0.6, 0.7, '全连接层', ha='center', va='center', fontsize=8)
# 输出
out_rect = plt.Rectangle((0.8, 0.6), 0.1, 0.2,
facecolor='lightgray', ec='black')
ax1.add_patch(out_rect)
ax1.text(0.85, 0.7, '类别\n向量', ha='center', va='center', fontsize=6)
ax1.annotate('', xy=(0.5, 0.7), xytext=(0.4, 0.7), arrowprops=dict(arrowstyle='->', lw=1))
ax1.annotate('', xy=(0.8, 0.7), xytext=(0.7, 0.7), arrowprops=dict(arrowstyle='->', lw=1))
# FCN(全卷积)
ax2 = axes[1]
ax2.axis('off')
ax2.set_title('FCN\n(全卷积网络)', fontsize=10)
# 卷积部分
conv_rect2 = plt.Rectangle((0.1, 0.6), 0.3, 0.2,
facecolor='lightblue', ec='black')
ax2.add_patch(conv_rect2)
ax2.text(0.25, 0.7, '卷积层', ha='center', va='center', fontsize=8)
# 卷积替代全连接
fc_conv = plt.Rectangle((0.5, 0.6), 0.2, 0.2,
facecolor='lightgreen', ec='black')
ax2.add_patch(fc_conv)
ax2.text(0.6, 0.7, '1×1\n卷积', ha='center', va='center', fontsize=8)
# 上采样
up_rect = plt.Rectangle((0.8, 0.6), 0.1, 0.2,
facecolor='lightyellow', ec='black')
ax2.add_patch(up_rect)
ax2.text(0.85, 0.7, '上采样', ha='center', va='center', fontsize=6)
ax2.annotate('', xy=(0.5, 0.7), xytext=(0.4, 0.7), arrowprops=dict(arrowstyle='->', lw=1))
ax2.annotate('', xy=(0.8, 0.7), xytext=(0.7, 0.7), arrowprops=dict(arrowstyle='->', lw=1))
plt.suptitle('FCN:用卷积层替代全连接层', fontsize=12)
plt.tight_layout()
plt.show()
print("\n📐 FCN核心思想:")
print(" 1. 全卷积: 将全连接层替换为1×1卷积")
print(" 2. 端到端: 输入任意尺寸图像,输出同尺寸分割图")
print(" 3. 上采样: 通过转置卷积恢复分辨率")
fcn_principle()
2.2 FCN架构
python
def fcn_architecture():
"""FCN架构详解"""
print("\n" + "=" * 60)
print("FCN架构")
print("=" * 60)
fig, ax = plt.subplots(figsize=(14, 6))
ax.axis('off')
# 输入
input_box = plt.Rectangle((0.02, 0.4), 0.06, 0.2,
facecolor='lightgray', ec='black')
ax.add_patch(input_box)
ax.text(0.05, 0.5, '输入\n图像', ha='center', va='center', fontsize=7)
# 卷积和池化
conv_layers = [
(0.12, 0.4, '卷积+池化\n(下采样)', 'lightblue'),
(0.25, 0.4, '卷积+池化\n(下采样)', 'lightblue'),
(0.38, 0.4, '卷积+池化\n(下采样)', 'lightblue'),
(0.51, 0.4, '1×1卷积\n(预测)', 'lightgreen'),
]
for x, y, label, color in conv_layers:
box = plt.Rectangle((x, y), 0.1, 0.2,
facecolor=color, ec='black')
ax.add_patch(box)
ax.text(x+0.05, y+0.1, label, ha='center', va='center', fontsize=6)
ax.annotate('', xy=(x+0.1, 0.5), xytext=(x+0.07, 0.5),
arrowprops=dict(arrowstyle='->', lw=1))
# FCN-32s
up32 = plt.Rectangle((0.65, 0.4), 0.08, 0.2,
facecolor='lightyellow', ec='black')
ax.add_patch(up32)
ax.text(0.69, 0.5, '32x\n上采样', ha='center', va='center', fontsize=6)
ax.annotate('', xy=(0.65, 0.5), xytext=(0.61, 0.5), arrowprops=dict(arrowstyle='->', lw=1))
# 输出
out_box = plt.Rectangle((0.77, 0.4), 0.06, 0.2,
facecolor='lightpink', ec='black')
ax.add_patch(out_box)
ax.text(0.8, 0.5, '分割\n图', ha='center', va='center', fontsize=7)
ax.annotate('', xy=(0.77, 0.5), xytext=(0.73, 0.5), arrowprops=dict(arrowstyle='->', lw=1))
# 跳跃连接
# FCN-16s
ax.annotate('', xy=(0.38, 0.3), xytext=(0.38, 0.4),
arrowprops=dict(arrowstyle='->', lw=1, color='red'))
ax.text(0.42, 0.33, '跳跃连接\n(16s)', fontsize=7, color='red')
# FCN-8s
ax.annotate('', xy=(0.25, 0.25), xytext=(0.25, 0.4),
arrowprops=dict(arrowstyle='->', lw=1, color='blue'))
ax.text(0.29, 0.27, '跳跃连接\n(8s)', fontsize=7, color='blue')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_title('FCN架构(32s, 16s, 8s)', fontsize=12)
plt.tight_layout()
plt.show()
print("\n📊 FCN三种变体:")
print(" - FCN-32s: 直接32倍上采样(粗糙)")
print(" - FCN-16s: 融合pool4特征,16倍上采样")
print(" - FCN-8s: 融合pool3+pool4特征,8倍上采样(最优)")
print("\n💡 跳跃连接的作用:")
print(" 1. 融合浅层细节信息(边缘、纹理)")
print(" 2. 融合深层语义信息(类别)")
print(" 3. 恢复空间分辨率")
fcn_architecture()
三、U-Net
3.1 U-Net架构
python
def unet_architecture():
"""U-Net架构详解"""
print("\n" + "=" * 60)
print("U-Net:对称的编码器-解码器")
print("=" * 60)
fig, ax = plt.subplots(figsize=(12, 10))
ax.axis('off')
# U形结构
# 编码器(下采样)
encoder = [
(0.2, 0.85, '输入\n(572×572)', 'lightgray'),
(0.2, 0.75, 'Conv 3×3×64\n(×2)', 'lightblue'),
(0.2, 0.65, 'MaxPool + Conv\n(128)', 'lightblue'),
(0.2, 0.55, 'MaxPool + Conv\n(256)', 'lightblue'),
(0.2, 0.45, 'MaxPool + Conv\n(512)', 'lightblue'),
(0.2, 0.35, 'MaxPool + Conv\n(1024)', 'lightblue'),
]
# 解码器(上采样)
decoder = [
(0.7, 0.35, 'UpConv + Conv\n(512)', 'lightgreen'),
(0.7, 0.45, 'UpConv + Conv\n(256)', 'lightgreen'),
(0.7, 0.55, 'UpConv + Conv\n(128)', 'lightgreen'),
(0.7, 0.65, 'UpConv + Conv\n(64)', 'lightgreen'),
(0.7, 0.75, 'Conv 1×1\n(2)', 'lightgreen'),
(0.7, 0.85, '输出\n(388×388)', 'lightpink'),
]
# 绘制编码器
for i, (x, y, label, color) in enumerate(encoder):
box = plt.Rectangle((x, y), 0.15, 0.08,
facecolor=color, ec='black')
ax.add_patch(box)
ax.text(x+0.075, y+0.04, label, ha='center', va='center', fontsize=7)
if i < len(encoder)-1:
ax.annotate('', xy=(x+0.075, y), xytext=(x+0.075, y+0.08),
arrowprops=dict(arrowstyle='->', lw=1))
# 绘制解码器
for i, (x, y, label, color) in enumerate(decoder):
box = plt.Rectangle((x, y), 0.15, 0.08,
facecolor=color, ec='black')
ax.add_patch(box)
ax.text(x+0.075, y+0.04, label, ha='center', va='center', fontsize=7)
if i < len(decoder)-1:
ax.annotate('', xy=(x+0.075, y+0.08), xytext=(x+0.075, y),
arrowprops=dict(arrowstyle='->', lw=1))
# 跳跃连接
for i in range(5):
x1 = 0.35
x2 = 0.7
y = 0.75 - i * 0.1
ax.annotate('', xy=(x2, y+0.04), xytext=(x1, y+0.04),
arrowprops=dict(arrowstyle='->', lw=1, color='red'))
ax.text(0.52, y+0.05, f'复制\n裁剪', ha='center', va='center', fontsize=6, color='red')
# 底部连接
ax.annotate('', xy=(0.7, 0.39), xytext=(0.35, 0.39),
arrowprops=dict(arrowstyle='->', lw=2))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_title('U-Net架构(对称U形结构)', fontsize=14)
plt.tight_layout()
plt.show()
print("\n📊 U-Net特点:")
print(" 1. U形对称结构: 编码器-解码器")
print(" 2. 跳跃连接: 连接对应层,保留细节")
print(" 3. 裁剪拼接: 解决边界信息丢失")
print(" 4. 适用于小数据集(医学图像)")
unet_architecture()
3.2 跳跃连接详解
python
def skip_connections():
"""跳跃连接详解"""
print("\n" + "=" * 60)
print("跳跃连接(Skip Connection)")
print("=" * 60)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# 跳跃连接原理
ax1 = axes[0]
ax1.axis('off')
ax1.set_title('跳跃连接原理', fontsize=11)
# 浅层特征
shallow = plt.Rectangle((0.1, 0.6), 0.2, 0.15,
facecolor='lightblue', ec='black')
ax1.add_patch(shallow)
ax1.text(0.2, 0.675, '浅层特征\n(高分辨率)', ha='center', va='center', fontsize=8)
# 深层特征
deep = plt.Rectangle((0.5, 0.6), 0.2, 0.15,
facecolor='lightcoral', ec='black')
ax1.add_patch(deep)
ax1.text(0.6, 0.675, '深层特征\n(语义信息)', ha='center', va='center', fontsize=8)
# 融合
fused = plt.Rectangle((0.3, 0.3), 0.4, 0.15,
facecolor='lightgreen', ec='black')
ax1.add_patch(fused)
ax1.text(0.5, 0.375, '融合特征\n(细节+语义)', ha='center', va='center', fontsize=8)
ax1.annotate('', xy=(0.3, 0.45), xytext=(0.2, 0.6),
arrowprops=dict(arrowstyle='->', lw=1))
ax1.annotate('', xy=(0.5, 0.45), xytext=(0.6, 0.6),
arrowprops=dict(arrowstyle='->', lw=1))
# 跳跃连接效果对比
ax2 = axes[1]
ax2.axis('off')
ax2.set_title('有无跳跃连接对比', fontsize=11)
# 无跳跃连接
ax2.text(0.25, 0.8, '无跳跃连接', ha='center', fontsize=9, fontweight='bold')
img_no_skip = np.random.rand(10, 10)
img_no_skip[2:8, 2:8] = 0.8
ax2.imshow(img_no_skip, cmap='gray', alpha=0.7, extent=[0.1, 0.4, 0.3, 0.7])
ax2.text(0.25, 0.25, '边缘模糊\n细节丢失', ha='center', fontsize=8, color='red')
# 有跳跃连接
ax2.text(0.75, 0.8, '有跳跃连接', ha='center', fontsize=9, fontweight='bold')
img_skip = np.random.rand(10, 10)
img_skip[2:8, 2:8] = 0.8
img_skip[4:6, 4:6] = 0.2
ax2.imshow(img_skip, cmap='gray', alpha=0.7, extent=[0.6, 0.9, 0.3, 0.7])
ax2.text(0.75, 0.25, '边缘清晰\n细节保留', ha='center', fontsize=8, color='green')
plt.suptitle('跳跃连接:融合浅层细节与深层语义', fontsize=12)
plt.tight_layout()
plt.show()
print("\n💡 跳跃连接的作用:")
print(" 1. 保留空间细节(边缘、纹理)")
print(" 2. 梯度流动更顺畅")
print(" 3. 缓解梯度消失")
print(" 4. 提高分割精度")
skip_connections()
四、医学图像应用
4.1 U-Net在医学图像中的优势
python
def medical_imaging():
"""U-Net在医学图像中的应用"""
print("\n" + "=" * 60)
print("U-Net在医学图像中的应用")
print("=" * 60)
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# 应用场景
applications = [
(0, 0, '细胞分割', '识别细胞边界\n计数细胞数量'),
(0, 1, '器官分割', '肝脏、心脏\n肾脏分割'),
(0, 2, '肿瘤检测', '脑瘤、肺结节\n乳腺癌检测'),
(1, 0, '血管分割', '视网膜血管\n冠状动脉'),
(1, 1, '骨骼分割', 'X光图像\n骨骼提取'),
(1, 2, '病变区域', '肺炎区域\n病灶定位'),
]
for row, col, title, desc in applications:
ax = axes[row, col]
# 模拟医学图像
img = np.random.rand(50, 50)
img[20:35, 20:35] = 0.8
# 分割掩码
mask = np.zeros((50, 50))
mask[22:33, 22:33] = 1
ax.imshow(img, cmap='gray')
ax.imshow(mask, cmap='Reds', alpha=0.5)
ax.set_title(f'{title}', fontsize=10)
ax.text(25, -5, desc, ha='center', fontsize=7)
ax.axis('off')
plt.suptitle('U-Net在医学图像分割中的应用', fontsize=14)
plt.tight_layout()
plt.show()
print("\n📊 U-Net适合医学图像的原因:")
print(" 1. 数据量小: 医学图像标注成本高")
print(" 2. 需要细节: 肿瘤边界、器官轮廓")
print(" 3. 多模态: CT、MRI、X光等")
print(" 4. 可解释性: 可视化分割结果")
medical_imaging()
五、代码实现示例
5.1 U-Net实现
python
def unet_code():
"""U-Net代码实现"""
print("\n" + "=" * 60)
print("U-Net PyTorch实现")
print("=" * 60)
code = """
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""双卷积块"""
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""下采样块"""
def __init__(self, in_channels, out_channels):
super(Down, self).__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""上采样块(带跳跃连接)"""
def __init__(self, in_channels, out_channels):
super(Up, self).__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2,
kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# 处理尺寸不一致
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class UNet(nn.Module):
"""U-Net完整模型"""
def __init__(self, n_channels=1, n_classes=2, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
# 编码器
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024)
# 解码器
self.up1 = Up(1024, 512)
self.up2 = Up(512, 256)
self.up3 = Up(256, 128)
self.up4 = Up(128, 64)
# 输出层
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
def forward(self, x):
# 编码路径
x1 = self.inc(x) # 64
x2 = self.down1(x1) # 128
x3 = self.down2(x2) # 256
x4 = self.down3(x3) # 512
x5 = self.down4(x4) # 1024
# 解码路径(跳跃连接)
x = self.up1(x5, x4) # 512
x = self.up2(x, x3) # 256
x = self.up3(x, x2) # 128
x = self.up4(x, x1) # 64
# 输出
logits = self.outc(x)
return logits
# 使用示例
model = UNet(n_channels=1, n_classes=2)
print(f"参数量: {sum(p.numel() for p in model.parameters()):,}")
# 测试前向传播
x = torch.randn(1, 1, 572, 572)
output = model(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
"""
print(code)
unet_code()
5.2 训练代码
python
def unet_training():
"""U-Net训练代码"""
print("\n" + "=" * 60)
print("U-Net训练代码")
print("=" * 60)
code = """
# U-Net训练示例(医学图像分割)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
# 数据集
class MedicalDataset(Dataset):
def __init__(self, image_paths, mask_paths, transform=None):
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = plt.imread(self.image_paths[idx])
mask = plt.imread(self.mask_paths[idx])
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
return image, mask
# 训练配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_channels=1, n_classes=2).to(device)
criterion = nn.BCEWithLogitsLoss() # 二分类
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# 训练循环
def train_epoch(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0
for images, masks in dataloader:
images = images.to(device)
masks = masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# IoU评估指标
def iou_score(pred, target, smooth=1e-6):
pred = (pred > 0.5).float()
intersection = (pred & target).sum()
union = (pred | target).sum()
return (intersection + smooth) / (union + smooth)
# Dice系数(医学图像常用)
def dice_coefficient(pred, target, smooth=1e-6):
pred = (pred > 0.5).float()
intersection = (pred * target).sum()
return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
"""
print(code)
unet_training()
六、总结
| 模型 | 特点 | 优点 | 缺点 |
|---|---|---|---|
| FCN | 全卷积+上采样 | 端到端、任意尺寸输入 | 细节丢失 |
| U-Net | U形结构+跳跃连接 | 细节保留、适合医学图像 | 计算量大 |
FCN vs U-Net对比:
| 特性 | FCN | U-Net |
|---|---|---|
| 跳跃连接 | 简单融合 | 密集连接 |
| 上采样方式 | 转置卷积 | 转置卷积 |
| 适用场景 | 通用分割 | 医学图像 |
| 参数量 | 较少 | 较多 |
核心要点:
- 全卷积网络是语义分割的基础
- 跳跃连接是U-Net的关键创新
- U-Net在医学图像领域应用广泛
- Dice系数是医学图像分割常用指标