
实例分割:Mask R-CNN(分割分支、ROI Align)
一、实例分割概述
1.1 语义分割 vs 实例分割
python
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import warnings
warnings.filterwarnings('ignore')
print("=" * 60)
print("实例分割:区分不同物体实例")
print("=" * 60)
# 语义分割 vs 实例分割对比
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# 原始图像
ax1 = axes[0]
ax1.imshow(np.random.rand(100, 100, 3))
ax1.set_title('原始图像', fontsize=10)
ax1.axis('off')
# 语义分割
ax2 = axes[1]
ax2.imshow(np.random.rand(100, 100, 3))
# 语义分割掩码(同类用同色)
mask_semantic = np.zeros((100, 100))
mask_semantic[20:50, 20:50] = 1
mask_semantic[60:80, 60:80] = 1
ax2.imshow(mask_semantic, cmap='viridis', alpha=0.5)
ax2.set_title('语义分割\n(所有人物同一类)', fontsize=10)
ax2.axis('off')
# 实例分割
ax3 = axes[2]
ax3.imshow(np.random.rand(100, 100, 3))
# 实例分割掩码(不同实例不同色)
mask1 = np.zeros((100, 100))
mask1[20:50, 20:50] = 0.5
mask2 = np.zeros((100, 100))
mask2[60:80, 60:80] = 0.7
ax3.imshow(mask1, cmap='Reds', alpha=0.5)
ax3.imshow(mask2, cmap='Blues', alpha=0.5)
ax3.set_title('实例分割\n(区分不同人物)', fontsize=10)
ax3.axis('off')
plt.suptitle('语义分割 vs 实例分割', fontsize=14)
plt.tight_layout()
plt.show()
print("\n💡 实例分割定义:")
print(" 不仅要知道每个像素的类别,还要区分同一类别的不同实例")
print("\n📊 任务对比:")
print(" - 目标检测: 边界框")
print(" - 语义分割: 像素级分类(不区分实例)")
print(" - 实例分割: 像素级分类 + 实例区分")
二、Mask R-CNN架构
2.1 整体架构
python
def mask_rcnn_architecture():
"""Mask R-CNN整体架构"""
print("\n" + "=" * 60)
print("Mask R-CNN:在Faster R-CNN上加分割分支")
print("=" * 60)
fig, ax = plt.subplots(figsize=(14, 10))
ax.axis('off')
# 输入
input_box = plt.Rectangle((0.35, 0.85), 0.3, 0.06,
facecolor='lightgray', ec='black')
ax.add_patch(input_box)
ax.text(0.5, 0.88, '输入图像', ha='center', va='center', fontsize=9)
# 骨干网络
backbone = plt.Rectangle((0.35, 0.73), 0.3, 0.08,
facecolor='lightblue', ec='black')
ax.add_patch(backbone)
ax.text(0.5, 0.77, '骨干网络\n(ResNet+FPN)', ha='center', va='center', fontsize=7)
ax.annotate('', xy=(0.5, 0.85), xytext=(0.5, 0.81),
arrowprops=dict(arrowstyle='->', lw=1))
# 特征图
feature = plt.Rectangle((0.35, 0.62), 0.3, 0.07,
facecolor='lightyellow', ec='black')
ax.add_patch(feature)
ax.text(0.5, 0.655, '特征图', ha='center', va='center', fontsize=8)
ax.annotate('', xy=(0.5, 0.73), xytext=(0.5, 0.69),
arrowprops=dict(arrowstyle='->', lw=1))
# RPN
rpn = plt.Rectangle((0.35, 0.52), 0.3, 0.06,
facecolor='lightgreen', ec='black')
ax.add_patch(rpn)
ax.text(0.5, 0.55, 'RPN\n(区域提议)', ha='center', va='center', fontsize=7)
ax.annotate('', xy=(0.5, 0.62), xytext=(0.5, 0.58),
arrowprops=dict(arrowstyle='->', lw=1))
# 候选区域
proposals = plt.Rectangle((0.35, 0.43), 0.3, 0.06,
facecolor='lightcoral', ec='black')
ax.add_patch(proposals)
ax.text(0.5, 0.46, '候选区域\n(RoIs)', ha='center', va='center', fontsize=7)
ax.annotate('', xy=(0.5, 0.52), xytext=(0.5, 0.49),
arrowprops=dict(arrowstyle='->', lw=1))
# ROI Align
roi_align = plt.Rectangle((0.35, 0.33), 0.3, 0.06,
facecolor='lightpink', ec='black')
ax.add_patch(roi_align)
ax.text(0.5, 0.36, 'ROI Align', ha='center', va='center', fontsize=8)
ax.annotate('', xy=(0.5, 0.43), xytext=(0.5, 0.39),
arrowprops=dict(arrowstyle='->', lw=1))
# 三个分支
branches = [
(0.2, 0.2, '分类分支', 'lightblue'),
(0.5, 0.2, '回归分支', 'lightgreen'),
(0.8, 0.2, '分割分支', 'lightcoral'),
]
for x, y, label, color in branches:
box = plt.Rectangle((x-0.08, y-0.03), 0.16, 0.06,
facecolor=color, ec='black')
ax.add_patch(box)
ax.text(x, y, label, ha='center', va='center', fontsize=7)
ax.annotate('', xy=(x, y+0.03), xytext=(0.5, 0.33),
arrowprops=dict(arrowstyle='->', lw=1))
# 输出
outputs = [
(0.2, 0.08, '类别', 'lightblue'),
(0.5, 0.08, '边界框', 'lightgreen'),
(0.8, 0.08, '掩码', 'lightcoral'),
]
for x, y, label, color in outputs:
box = plt.Rectangle((x-0.08, y-0.03), 0.16, 0.06,
facecolor=color, ec='black')
ax.add_patch(box)
ax.text(x, y, label, ha='center', va='center', fontsize=7)
ax.annotate('', xy=(x, y+0.03), xytext=(x, y+0.06),
arrowprops=dict(arrowstyle='->', lw=1))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_title('Mask R-CNN架构', fontsize=14)
plt.tight_layout()
plt.show()
print("\n📊 Mask R-CNN三个输出:")
print(" 1. 分类分支: 预测物体类别")
print(" 2. 回归分支: 精修边界框")
print(" 3. 分割分支: 预测实例掩码(新增)")
mask_rcnn_architecture()
三、ROI Align
3.1 ROI Pooling vs ROI Align
python
def roi_align_vs_pooling():
"""ROI Align vs ROI Pooling对比"""
print("\n" + "=" * 60)
print("ROI Align:解决量化误差")
print("=" * 60)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# ROI Pooling(有量化误差)
ax1 = axes[0]
ax1.set_title('ROI Pooling\n(有量化误差)', fontsize=10)
# 绘制特征图网格
for i in range(8):
for j in range(8):
rect = Rectangle((i*0.1, j*0.1), 0.09, 0.09,
facecolor='lightgray', ec='gray', alpha=0.5)
ax1.add_patch(rect)
# 绘制ROI区域(偏移)
roi_rect = Rectangle((0.28, 0.28), 0.44, 0.44,
linewidth=2, edgecolor='red', facecolor='none')
ax1.add_patch(roi_rect)
# 量化后的ROI区域
quantized_rect = Rectangle((0.3, 0.3), 0.4, 0.4,
linewidth=2, edgecolor='blue', facecolor='none', linestyle='--')
ax1.add_patch(quantized_rect)
ax1.text(0.5, 0.15, '量化导致位置偏移', ha='center', fontsize=8, color='red')
ax1.set_xlim(0, 0.8)
ax1.set_ylim(0, 0.8)
ax1.set_aspect('equal')
# ROI Align(无量化误差)
ax2 = axes[1]
ax2.set_title('ROI Align\n(双线性插值)', fontsize=10)
for i in range(8):
for j in range(8):
rect = Rectangle((i*0.1, j*0.1), 0.09, 0.09,
facecolor='lightgray', ec='gray', alpha=0.5)
ax2.add_patch(rect)
# ROI区域
roi_rect2 = Rectangle((0.28, 0.28), 0.44, 0.44,
linewidth=2, edgecolor='red', facecolor='none')
ax2.add_patch(roi_rect2)
# 采样点(双线性插值)
for i in range(4):
for j in range(4):
x = 0.28 + (j + 0.5) * 0.11
y = 0.28 + (i + 0.5) * 0.11
circle = Circle((x, y), 0.01, color='green', alpha=0.8)
ax2.add_patch(circle)
ax2.text(0.5, 0.15, '连续坐标 + 双线性插值', ha='center', fontsize=8, color='green')
ax2.set_xlim(0, 0.8)
ax2.set_ylim(0, 0.8)
ax2.set_aspect('equal')
plt.suptitle('ROI Align vs ROI Pooling', fontsize=12)
plt.tight_layout()
plt.show()
print("\n📊 ROI Pooling问题:")
print(" 1. 量化: 坐标取整,丢失精度")
print(" 2. 位置偏移: 影响分割精度")
print("\n✅ ROI Align改进:")
print(" 1. 保留浮点坐标")
print(" 2. 双线性插值采样")
print(" 3. 无量化误差")
roi_align_vs_pooling()
3.2 双线性插值
python
def bilinear_interpolation():
"""双线性插值原理"""
print("\n" + "=" * 60)
print("双线性插值")
print("=" * 60)
fig, ax = plt.subplots(figsize=(10, 8))
# 四个已知点
points = [(0.2, 0.2, 10), (0.6, 0.2, 20), (0.2, 0.6, 30), (0.6, 0.6, 40)]
for x, y, val in points:
ax.plot(x, y, 'bo', markersize=10)
ax.text(x+0.02, y+0.02, f'v={val}', fontsize=9)
# 待插值点
px, py = 0.4, 0.4
ax.plot(px, py, 'ro', markersize=12)
ax.text(px+0.02, py+0.02, '插值点', fontsize=9, color='red')
# 绘制网格
ax.axvline(x=0.2, color='gray', linestyle='--', alpha=0.5)
ax.axvline(x=0.6, color='gray', linestyle='--', alpha=0.5)
ax.axhline(y=0.2, color='gray', linestyle='--', alpha=0.5)
ax.axhline(y=0.6, color='gray', linestyle='--', alpha=0.5)
# 公式
formula = """
双线性插值公式:
v = (1 - dx)(1 - dy)v₁ + dx(1 - dy)v₂
+ (1 - dx)dy v₃ + dx dy v₄
其中:
dx = (x - x₁)/(x₂ - x₁)
dy = (y - y₁)/(y₂ - y₁)
"""
ax.text(0.7, 0.7, formula, transform=ax.transAxes, fontsize=9,
verticalalignment='top', fontfamily='monospace')
ax.set_xlim(0.1, 0.7)
ax.set_ylim(0.1, 0.7)
ax.set_aspect('equal')
ax.set_title('双线性插值原理', fontsize=12)
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.tight_layout()
plt.show()
print("\n💡 双线性插值步骤:")
print(" 1. 在x方向线性插值两次")
print(" 2. 在y方向线性插值一次")
print(" 3. 得到平滑的采样值")
bilinear_interpolation()
四、分割分支
4.1 Mask分支结构
python
def mask_branch():
"""分割分支详解"""
print("\n" + "=" * 60)
print("Mask分支:小全卷积网络")
print("=" * 60)
fig, ax = plt.subplots(figsize=(12, 8))
ax.axis('off')
# 输入(ROI Align输出)
input_box = plt.Rectangle((0.35, 0.8), 0.3, 0.06,
facecolor='lightpink', ec='black')
ax.add_patch(input_box)
ax.text(0.5, 0.83, 'ROI Align输出\n(14×14×256)', ha='center', va='center', fontsize=7)
# 卷积层
conv_layers = [
(0.35, 0.68, 'Conv 3×3, 256', 'lightblue'),
(0.35, 0.56, 'Conv 3×3, 256', 'lightblue'),
(0.35, 0.44, 'Conv 3×3, 256', 'lightblue'),
(0.35, 0.32, 'Conv 3×3, 256', 'lightblue'),
]
for x, y, label, color in conv_layers:
box = plt.Rectangle((x, y), 0.3, 0.06,
facecolor=color, ec='black')
ax.add_patch(box)
ax.text(x+0.15, y+0.03, label, ha='center', va='center', fontsize=7)
ax.annotate('', xy=(x+0.15, y+0.06), xytext=(x+0.15, y+0.03),
arrowprops=dict(arrowstyle='->', lw=1))
# 反卷积(上采样)
deconv = plt.Rectangle((0.35, 0.2), 0.3, 0.06,
facecolor='lightgreen', ec='black')
ax.add_patch(deconv)
ax.text(0.5, 0.23, '反卷积 2×2\n(上采样到28×28)', ha='center', va='center', fontsize=7)
ax.annotate('', xy=(0.5, 0.26), xytext=(0.5, 0.32),
arrowprops=dict(arrowstyle='->', lw=1))
# 输出
output_box = plt.Rectangle((0.35, 0.08), 0.3, 0.06,
facecolor='lightcoral', ec='black')
ax.add_patch(output_box)
ax.text(0.5, 0.11, 'Mask输出\n(28×28×K)', ha='center', va='center', fontsize=7)
ax.annotate('', xy=(0.5, 0.14), xytext=(0.5, 0.2),
arrowprops=dict(arrowstyle='->', lw=1))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_title('Mask分支结构(小FCN)', fontsize=12)
plt.tight_layout()
plt.show()
print("\n📊 Mask分支特点:")
print(" - 输入: 14×14×256 (ROI Align输出)")
print(" - 输出: 28×28×K (K为类别数)")
print(" - 结构: 4个卷积层 + 1个反卷积")
print(" - 每个类别独立预测掩码")
mask_branch()
4.2 掩码预测
python
def mask_prediction():
"""掩码预测详解"""
print("\n" + "=" * 60)
print("掩码预测")
print("=" * 60)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# ROI区域
ax1 = axes[0]
ax1.imshow(np.random.rand(100, 100, 3))
rect = Rectangle((30, 30), 50, 50, linewidth=2, edgecolor='red', facecolor='none')
ax1.add_patch(rect)
ax1.set_title('1. ROI区域')
ax1.axis('off')
# ROI Align后的特征
ax2 = axes[1]
roi_features = np.random.rand(14, 14)
ax2.imshow(roi_features, cmap='viridis')
ax2.set_title('2. ROI Align\n(14×14特征)')
ax2.axis('off')
# 预测的掩码
ax3 = axes[2]
mask = np.random.rand(28, 28)
mask[10:18, 10:18] = 1
ax3.imshow(mask, cmap='gray')
ax3.set_title('3. 预测掩码\n(28×28)')
ax3.axis('off')
plt.suptitle('掩码预测流程', fontsize=12)
plt.tight_layout()
plt.show()
print("\n💡 掩码预测细节:")
print(" - 每个ROI独立预测掩码")
print(" - 使用Sigmoid激活(二分类)")
print(" - 损失函数: 平均二值交叉熵")
print(" - 推理时: 掩码上采样到ROI大小")
mask_prediction()
五、损失函数
5.1 多任务损失
python
def mask_rcnn_loss():
"""Mask R-CNN损失函数"""
print("\n" + "=" * 60)
print("Mask R-CNN损失函数")
print("=" * 60)
fig, ax = plt.subplots(figsize=(12, 6))
ax.axis('off')
# 损失组成
loss_text = """
📐 Mask R-CNN总损失:
L = L_cls + L_box + L_mask
其中:
1. 分类损失 L_cls:
- 多类别交叉熵
- 判断ROI中物体的类别
2. 边界框损失 L_box:
- Smooth L1损失
- 精修边界框位置
3. 掩码损失 L_mask:
- 平均二值交叉熵
- 只对正样本计算
- 每个类别独立预测
"""
ax.text(0.05, 0.95, loss_text, transform=ax.transAxes, fontsize=10,
verticalalignment='top', fontfamily='monospace')
plt.tight_layout()
plt.show()
print("\n📊 损失函数特点:")
print(" - 多任务联合训练")
print(" - 掩码损失与类别解耦")
print(" - 每个类别独立预测掩码")
mask_rcnn_loss()
六、代码实现示例
6.1 ROI Align实现
python
def roi_align_code():
"""ROI Align代码实现"""
print("\n" + "=" * 60)
print("ROI Align代码")
print("=" * 60)
code = """
import torch
import torch.nn as nn
import torch.nn.functional as F
class ROIAlign(nn.Module):
def __init__(self, output_size, spatial_scale=1.0):
super(ROIAlign, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
def forward(self, features, rois):
# features: [N, C, H, W]
# rois: [num_rois, 5] (batch_idx, x1, y1, x2, y2)
batch_size = features.shape[0]
num_rois = rois.shape[0]
channels = features.shape[1]
# 缩放坐标
rois = rois * self.spatial_scale
# 计算每个ROI的特征
roi_features = []
for i in range(num_rois):
batch_idx = int(rois[i, 0])
x1, y1, x2, y2 = rois[i, 1:5]
# 采样点坐标
step_x = (x2 - x1) / self.output_size
step_y = (y2 - y1) / self.output_size
feat = []
for iy in range(self.output_size):
for ix in range(self.output_size):
# 采样点中心
px = x1 + (ix + 0.5) * step_x
py = y1 + (iy + 0.5) * step_y
# 双线性插值
value = self._bilinear_interpolate(
features[batch_idx], px, py
)
feat.append(value)
feat = torch.stack(feat).view(self.output_size, self.output_size, -1)
roi_features.append(feat)
return torch.stack(roi_features).permute(0, 3, 1, 2)
def _bilinear_interpolate(self, feature, x, y):
# 双线性插值实现
h, w = feature.shape[-2:]
# 四个角点
x0, x1 = int(x), min(int(x) + 1, w - 1)
y0, y1 = int(y), min(int(y) + 1, h - 1)
# 插值权重
wa = (x1 - x) * (y1 - y)
wb = (x - x0) * (y1 - y)
wc = (x1 - x) * (y - y0)
wd = (x - x0) * (y - y0)
# 插值计算
value = (wa * feature[..., y0, x0] +
wb * feature[..., y0, x1] +
wc * feature[..., y1, x0] +
wd * feature[..., y1, x1])
return value
# 使用示例
roi_align = ROIAlign(output_size=14, spatial_scale=1/16)
features = torch.randn(1, 256, 32, 32)
rois = torch.tensor([[0, 10, 10, 50, 50]]) # batch_idx, x1, y1, x2, y2
output = roi_align(features, rois)
print(f"ROI Align输出形状: {output.shape}")
"""
print(code)
roi_align_code()
6.2 Mask分支实现
python
def mask_head_code():
"""Mask分支代码实现"""
print("\n" + "=" * 60)
print("Mask分支代码")
print("=" * 60)
code = """
class MaskHead(nn.Module):
def __init__(self, in_channels, num_classes):
super(MaskHead, self).__init__()
# 卷积层
self.conv_layers = nn.Sequential(
nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
# 反卷积(上采样)
self.deconv = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2)
# 输出层
self.mask = nn.Conv2d(256, num_classes, kernel_size=1)
# 激活函数
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# x: [num_rois, 256, 14, 14]
x = self.conv_layers(x) # [num_rois, 256, 14, 14]
x = self.deconv(x) # [num_rois, 256, 28, 28]
x = self.mask(x) # [num_rois, num_classes, 28, 28]
x = self.sigmoid(x) # 二分类概率
return x
# 使用示例
mask_head = MaskHead(in_channels=256, num_classes=80)
roi_features = torch.randn(10, 256, 14, 14)
masks = mask_head(roi_features)
print(f"Mask输出形状: {masks.shape}")
"""
print(code)
mask_head_code()
6.3 完整Mask R-CNN简化版
python
def mask_rcnn_simple():
"""简化版Mask R-CNN"""
print("\n" + "=" * 60)
print("简化版Mask R-CNN")
print("=" * 60)
code = """
class SimpleMaskRCNN(nn.Module):
def __init__(self, num_classes=80):
super(SimpleMaskRCNN, self).__init__()
# 骨干网络
self.backbone = ResNet50FPN()
# RPN
self.rpn = RPN(in_channels=256)
# ROI Align
self.roi_align = ROIAlign(output_size=14, spatial_scale=1/16)
# 分类回归头
self.cls_head = nn.Linear(256 * 14 * 14, num_classes + 1)
self.bbox_head = nn.Linear(256 * 14 * 14, 4)
# Mask头
self.mask_head = MaskHead(in_channels=256, num_classes=num_classes)
def forward(self, x):
# 1. 特征提取
features = self.backbone(x)
# 2. RPN生成候选框
proposals = self.rpn(features)
# 3. ROI Align
roi_features = self.roi_align(features, proposals)
# 4. 分类和回归
roi_flat = roi_features.view(roi_features.size(0), -1)
cls_scores = self.cls_head(roi_flat)
bbox_deltas = self.bbox_head(roi_flat)
# 5. 掩码预测
masks = self.mask_head(roi_features)
return cls_scores, bbox_deltas, masks
# 推理示例
model = SimpleMaskRCNN(num_classes=80)
model.eval()
x = torch.randn(1, 3, 800, 800)
with torch.no_grad():
cls_scores, bbox_deltas, masks = model(x)
print(f"分类分数: {cls_scores.shape}")
print(f"边界框偏移: {bbox_deltas.shape}")
print(f"掩码: {masks.shape}")
"""
print(code)
mask_rcnn_simple()
七、总结
| 组件 | 作用 | 创新点 |
|---|---|---|
| ROI Align | 特征提取 | 双线性插值,无量化 |
| Mask分支 | 掩码预测 | 小FCN,与类别解耦 |
| 多任务损失 | 联合训练 | 分类+回归+分割 |
Mask R-CNN核心要点:
- 在Faster R-CNN基础上增加分割分支
- ROI Align是分割精度的关键
- 掩码与类别解耦(每个类别独立预测)
- 实例分割的经典方法
与其他方法对比:
| 方法 | 检测 | 分割 | 速度 |
|---|---|---|---|
| Faster R-CNN | ✅ | ❌ | 快 |
| Mask R-CNN | ✅ | ✅ | 中 |
| YOLACT | ✅ | ✅ | 快 |