关键点检测——Deeppose源码解析篇

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!
🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊专栏推荐:深度学习网络原理与实战

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

关键点检测------Deeppose源码解析篇

写在前面

Hello,大家好,我是小苏👦👦👦

在上一小结中,我为大家介绍了基于回归式的关键点检测方法Deeppose的原理,还不清楚的可以点击下面链接了解详情:

那么这节,我将来带大家梳理梳理Deeppose的代码,在具体介绍之前,我还是要提一句,看代码时,一定不要只停留在听别人说或看别人写的阶段,一定要多动手调试调试,看看各种参数是如何变换的,看看具体流程是怎么样的,也可以学学别人写的比较简练的一些代码。

好了,话不多说,让我们一起来学学Deeppose的源码叭~~~🥗🥗🥗

源码地址:Deeppose

Deeppose源码

首先,我们来看看有哪些参数,如下图所示:

接着,我们看看我们的网络结构,如下:

python 复制代码
model = create_deep_pose_model(num_keypoints)

def create_deep_pose_model(num_keypoints: int) -> nn.Module:
    res50 = resnet50(ResNet50_Weights.IMAGENET1K_V2)
    in_features = res50.fc.in_features
    res50.fc = nn.Linear(in_features=in_features, out_features=num_keypoints * 2)

    return res50

这里先是通过resnet50(ResNet50_Weights.IMAGENET1K_V2)下载了一个resnet50的预训练权重,如果你是第一次运行这段代码时,PyTorch 会从远程服务器下载预训练的权重文件到本地计算机的缓存目录,我的Windows下载位置是这里(你们自己的类似):

然后通过res50.fc.in_features获取最后一层全连接层(fc层)的输入通道数in_features,最后通过res50.fc = nn.Linear(in_features=in_features, out_features=num_keypoints * 2)将resnet50最后的全连接层的输出通道数换成num_keypoints * 2(这个在原理篇中已经解释),对于本项目,num_keypoints = 98,这样,我们就构建好了基于resnet50的Deeppose网络结构啦。其和标准resnet50的差距只在最后的全连接层输出的通道数量上,如下图所示:

接着我们会定义一些数据增强手段,如下:

python 复制代码
data_transform = {
        "train": transforms.Compose([
            transforms.AffineTransform(scale_factor=(0.65, 1.35), rotate=45, shift_factor=0.15, fixed_size=img_hw),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]),
        "val": transforms.Compose([
            transforms.AffineTransform(scale_prob=0., rotate_prob=0., shift_prob=0., fixed_size=img_hw),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    }

关于这里使用的放射变化和水平翻转数据增强,我在HRNet的源码解析篇已经非常详细的解释过,需要特别注意的是做这些数据增强时,不仅要变图片,还要变换标签!!!【在我的一段实习过程中,一位同事就因为水平翻转的标签问题踩过坑,所以这里大家要格外注意一下】而ToTensor和Normalize就是非常普遍的一种数据变换方式了,不清楚的可以看看我的这篇博客:

接下来是训练集和验证集WFLW数据集的构建(后文以训练集为例),如下:

python 复制代码
train_dataset = WFLWDataset(root=dataset_dir,
                                train=True,
                                transforms=data_transform["train"])
    val_dataset = WFLWDataset(root=dataset_dir,
                              train=False,
                              transforms=data_transform["val"])

进入WFLWDataset内部,首先我们会获取WFLW的标注信息路径和图片路径,并初始化一些存储关键点信息、位置信息和路径信息的列表,如下:

接下来就是从标注信息中循环读取每一行的信息,然后分别将98个关键点的x和y坐标、脸部矩形框位置坐标和图片路径分别存储在keypointsface_rectsimg_path三个列表中:

python 复制代码
with open(self.anno_path, "rt") as f:
for line in f.readlines():
    if not line.strip():
        continue

    split_list = line.strip().split(" ")
    keypoint_ = self.get_98_points(split_list)
    keypoint = np.array(keypoint_, dtype=np.float32).reshape((-1, 2))
    face_rect = list(map(int, split_list[196: 196 + 4]))  # xmin, ymin, xmax, ymax
    img_name = split_list[-1]

    self.keypoints.append(keypoint)
    self.face_rects.append(face_rect)
    self.img_paths.append(os.path.join(self.img_root, img_name))

def get_98_points(keypoints: List[str]) -> List[float]:
	return list(map(float, keypoints[:196]))

最终,可以来看看这三个列表:

可以看到,它们都有7500个值,因为训练集中有7500条数据。face_rects存储的每个矩形框数据有4个值,分别表示矩形框左上和右下的横纵坐标;keypoints存储的每个数据是98*2的numpy数组,表示98个关键点的横纵坐标。这里展示的是训练集数据的构建,验证集也是类似的。

dataset构建好之后,就是构建DataLoader、定义优化器和学习率调度器等内容,这部分比较常见,不作为本节的重点,所以这里也不过多介绍,在设置学习率过程中,这里用到了Warm up,感兴趣的可以看看我的这篇博客,有比较详细的讲解:


接下来就是训练过程了,训练过程还是分为几个经典步骤:

  1. 定义损失函数

    python 复制代码
    loss_func = L2Loss()  # 定义L2损失函数
    
    class L2Loss(nn.Module):
        def __init__(self) -> None:
            super().__init__()
    
        def forward(self, pred: torch.Tensor, label: torch.Tensor, mask: torch = None) -> torch.Tensor:
            losses = F.mse_loss(pred, label, reduction="none")
            if mask is not None:
                # filter invalid keypoints(e.g. out of range)
                losses = losses * mask.unsqueeze(2)
    
            return torch.mean(torch.sum(losses, dim=(1, 2)), dim=0)
  2. 前向传播、获取损失

    python 复制代码
    with torch.autocast(device_type=device.type):
        pred: torch.Tensor = model(imgs)
        loss: torch.Tensor = loss_func(pred.reshape((-1, num_keypoints, 2)), labels)
  3. 反向传播、参数更新与学习率更新

    python 复制代码
    loss.backward()
    optimizer.step()
    lr_scheduler.step()

训练过程解释完了,接下来就是在验证集上做评估了,我们先来看看我们使用的评估指标NME:

python 复制代码
class NMEMetric:
    def __init__(self, device: torch.device) -> None:
        # 两眼外角点对应keypoint索引
        self.keypoint_idxs = [60, 72]
        self.nme_accumulator: float = 0.
        self.counter: float = 0.
        self.device = device

    def update(self, pred: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor = None):

        ion = torch.linalg.norm(gt[:, self.keypoint_idxs[0]] - gt[:, self.keypoint_idxs[1]], dim=1)

        valid_ion_mask = ion > 0
        if mask is None:
            mask = valid_ion_mask
        else:
            mask = torch.logical_and(mask, valid_ion_mask.unsqueeze_(dim=1)).sum(dim=1) > 0
        num_valid = mask.sum().item()

        # equal: (pred - gt).pow(2).sum(dim=2).pow(0.5).mean(dim=1)
        l2_dis = torch.linalg.norm(pred - gt, dim=2)[mask].mean(dim=1)  # [N]

        # avoid divide by zero
        ion = ion[mask]  # [N]

        self.nme_accumulator += l2_dis.div(ion).sum().item()
        self.counter += num_valid

首先,我们定义了keypoint_idxs = [60, 72],这就是我们原理篇说的d :两眼外眼角间距的索引。并且定义了nme_accumulator ,其表示累积了所有图像的 NME 误差, counter 记录有效图像的数量。在update方法中,我们首先通过计算两眼外角关键点的欧氏距离 ion 来获取基准距离,这个距离会用于归一化误差。接着,代码检查哪些图像的基准距离大于零,确保这些图像的关键点距离有效(防止无效数据影响评估)。如果 maskNone,直接使用有效的 ion,否则结合 mask 进一步筛选有效的图像。随后,l2_dis 计算了预测关键点和真实关键点的 L2 距离,即两者的欧氏距离,然后通过 mask 筛选出有效的样本。为了防止 ion为0,即会产生除零操作,因此确保 ion 中只有有效的基准距离,接着将 L2 距离除以基准距离 ion,得出每张图像的归一化误差,并将其累加到 nme_accumulator 中。🍗🍗🍗

有了NME评估指标后,就可以实现验证集的评估了,如下:

python 复制代码
def evaluate(model: torch.nn.Module,
             epoch: int,
             val_loader: DataLoader,
             device: torch.device,
             tb_writer: SummaryWriter,
             affine_points_torch_func: Callable,
             num_keypoints: int,
             img_hw: List[int]) -> None:
    model.eval()
    metric = NMEMetric(device=device)
    wh_tensor = torch.as_tensor(img_hw[::-1], dtype=torch.float32, device=device).reshape([1, 1, 2])
    eval_bar = val_loader
    if is_main_process():
        eval_bar = tqdm(val_loader, file=sys.stdout, desc="evaluation")

    for step, (imgs, targets) in enumerate(eval_bar):
        imgs = imgs.to(device)
        m_invs = targets["m_invs"].to(device)
        labels = targets["ori_keypoints"].to(device)

        pred = model(imgs)
        pred = pred.reshape((-1, num_keypoints, 2))  # [N, K, 2]
        pred = pred * wh_tensor  # rel coord to abs coord
        pred = affine_points_torch_func(pred, m_invs)

        metric.update(pred, labels)

    metric.synchronize_results()

首先,模型被设置为评估模式 model.eval(),这会关闭一些训练期间专用的功能(如 dropout),这些都是一些基础的操作啦。接着,定义了一个 NMEMetric 来跟踪每个验证步骤中的误差,并将输入图像的尺寸转换成张量 wh_tensor 以便后续使用。随后,代码将批次图像和目标标签加载到设备上,并从模型中获取预测的关键点坐标。预测的坐标是相对坐标,因此需要乘以 wh_tensor 来转换为绝对坐标。接着,通过 affine_points_torch_func 函数对预测结果进行仿射变换,关于仿射变换不了解的去阅读一下我之前HRNet的原理详解篇。之后在每个批次评估完成后,预测的关键点和真实标签会被传递给 metric.update() 来累计评估结果。🥗🥗🥗


我们利用训练好的网络,就可以来实现一下人脸关键点检测任务啦,这里来展示下检测的效果,如下图所示:

这里还有一点需要大家注意一下,Deeppose只支持单人检测,如果想实现多人检测,则还需要配合一个目标检测器进行使用,这也是我们在前几节中所说的自顶向下的关键点检测方法。

那么这里留下一个疑问考大家一下,目前介绍的HRNet、Openpose和Deeppose都有什么区别呢?

小结

呼呼呼~~~终于写完啦,如果有什么疑问欢迎评论区交流讨论喔,我们下期间。🌱🌱🌱

参考链接

如若文章对你有所帮助,那就🛴🛴🛴

相关推荐
old_power15 分钟前
【PCL】Segmentation 模块—— 基于图割算法的点云分割(Min-Cut Based Segmentation)
c++·算法·计算机视觉·3d
PaLu-LI2 小时前
ORB-SLAM2源码学习:Initializer.cc⑧: Initializer::CheckRT检验三角化结果
c++·人工智能·opencv·学习·ubuntu·计算机视觉
清图4 小时前
Python 预训练:打通视觉与大语言模型应用壁垒——Python预训练视觉和大语言模型
人工智能·python·深度学习·机器学习·计算机视觉·自然语言处理·ai作画
pchmi7 小时前
C# OpenCV机器视觉:红外体温检测
人工智能·数码相机·opencv·计算机视觉·c#·机器视觉·opencvsharp
好评笔记8 小时前
AIGC视频扩散模型新星:Video 版本的SD模型
论文阅读·深度学习·机器学习·计算机视觉·面试·aigc·transformer
Fxrain8 小时前
[Computer Vision]实验三:图像拼接
人工智能·计算机视觉
AI视觉网奇11 小时前
python 统计相同像素值个数
python·opencv·计算机视觉
好评笔记21 小时前
AIGC视频生成模型:Stability AI的SVD(Stable Video Diffusion)模型
论文阅读·人工智能·深度学习·机器学习·计算机视觉·面试·aigc
冰万森1 天前
【图像处理】——掩码
python·opencv·计算机视觉
Antonio9151 天前
【opencv】第10章 角点检测
人工智能·opencv·计算机视觉