第五章:计算机视觉(Computer Vision)- 项目实战之目标检测实战
第二部分:图像分割实战:人像抠图
第五节:模型训练与测试
在人像抠图任务中,训练与测试是从 模型设计到实际落地 的关键阶段。本节将介绍 数据准备、训练流程、优化策略与测试方法,并结合 PyTorch 代码给出实战示例。
1. 数据准备
训练人像抠图模型需要高质量的 输入图像 (RGB) 与 对应的 Alpha Matte (标签)。常见数据格式包括:
-
输入图像:JPEG/PNG 格式的人像图片。
-
Alpha Matte:灰度图,取值范围 [0,1],0 表示背景,1 表示前景,中间值表示半透明区域(如头发)。
数据加载方式通常采用 torchvision.datasets
或 自定义Dataset
,并进行以下预处理:
-
Resize/CenterCrop:统一图像大小。
-
Normalization :归一化到 [0,1] 或标准化为
mean,std
。 -
数据增强:如随机裁剪、水平翻转、颜色抖动,以提升模型鲁棒性。
示例数据集类:
python
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as T
class HumanMattingDataset(Dataset):
def __init__(self, img_paths, alpha_paths, transform=None):
self.img_paths = img_paths
self.alpha_paths = alpha_paths
self.transform = transform
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
img = Image.open(self.img_paths[idx]).convert("RGB")
alpha = Image.open(self.alpha_paths[idx]).convert("L")
if self.transform:
img = self.transform(img)
alpha = self.transform(alpha)
return img, alpha
2. 训练流程
训练目标是最小化损失函数,使预测的 Alpha Matte 与真实标签尽可能接近。流程如下:
-
模型初始化(如 Semantic Human Matting 架构)。
-
定义损失函数:组合 L1 Loss、BCE Loss、Gradient Loss、Composition Loss。
-
优化器设置:Adam/AdamW 通常比 SGD 收敛更快,学习率 1e-4 是常见起点。
-
训练循环:
-
前向传播 → 得到预测 Alpha。
-
计算损失 → 反向传播。
-
参数更新 → 迭代优化。
-
训练代码示例:
python
import torch
import torch.nn as nn
import torch.optim as optim
# 定义损失函数 (示例:L1 + BCE)
l1_loss = nn.L1Loss()
bce_loss = nn.BCEWithLogitsLoss()
def matting_loss(pred_alpha, gt_alpha, semantic_pred=None):
loss_alpha = l1_loss(pred_alpha, gt_alpha)
if semantic_pred is not None:
loss_semantic = bce_loss(semantic_pred, (gt_alpha > 0.5).float())
return loss_alpha + 0.5 * loss_semantic
return loss_alpha
# 优化器
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# 训练循环
for epoch in range(10):
for imgs, alphas in dataloader:
imgs, alphas = imgs.cuda(), alphas.cuda()
semantic_out, refine_out, alpha_pred = model(imgs)
loss = matting_loss(alpha_pred, alphas, semantic_out)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")
3. 模型验证与测试
在测试阶段,我们需要 评估模型抠图质量,主要包括:
-
MAE (Mean Absolute Error):预测与真实 Alpha 的像素差异。
-
SAD (Sum of Absolute Differences):整体误差衡量。
-
MSE (Mean Squared Error):适合平滑区域评估。
-
Gradient Loss:在边缘细节上效果评估。
-
Composition Loss:基于前景合成后的感知误差。
测试代码示例:
python
model.eval()
with torch.no_grad():
for imgs, alphas in test_dataloader:
imgs, alphas = imgs.cuda(), alphas.cuda()
_, _, alpha_pred = model(imgs)
mae = torch.mean(torch.abs(alpha_pred - alphas)).item()
print(f"MAE: {mae:.4f}")
4. 结果可视化
可视化是评估模型性能的重要手段,可以直观比较输入、GT Alpha 和预测结果。
python
import matplotlib.pyplot as plt
def visualize_result(img, alpha_gt, alpha_pred):
plt.subplot(1, 3, 1)
plt.imshow(img.permute(1,2,0).cpu())
plt.title("Input Image")
plt.subplot(1, 3, 2)
plt.imshow(alpha_gt.squeeze().cpu(), cmap="gray")
plt.title("Ground Truth Alpha")
plt.subplot(1, 3, 3)
plt.imshow(alpha_pred.squeeze().cpu(), cmap="gray")
plt.title("Predicted Alpha")
plt.show()
5. 总结
-
训练阶段:通过 L1、BCE、Gradient、Composition Loss 联合优化,确保全局和边缘细节都准确。
-
测试阶段:采用 MAE、SAD、MSE、Gradient Loss 等指标进行全面评估。
-
可视化:直观展示模型在抠图任务上的表现,尤其是头发丝、衣物边缘等细节区域。
在实际应用中,模型还可以通过 混合精度训练 (AMP)、学习率调度、数据增强 来进一步提升性能。