堆叠沙漏网络(stacked hourglass network)学习

定义

Stacked Hourglass Networks是2016年密歇根大学提出的经典网络架构。是曾经最具代表性的姿态识别SOTA之一。

hourglass network

hourglass network 本身其实可以理解成是一个encoder-decoder的结构,encoder最大程度的提取图像在每一个scale的特征以及空间信息(spatial information),decoder则是将网络在不同分辨率下提取的特征进行综合,最后得到一个与输入图像大小一致的heatmap。值得注意的是,很多情况网络会以一个或者多个全连接成作为最后的输出层,但是hourglass 网络使用了一个1x1的卷积层来代替了全连接层,这样做的目的是为了让网络可以接受不同维度的输入。

stacked hourglass network

Stacked Hourglass Module的架构很通俗易懂。首先它通过卷积层下采样,再通过最近邻的方法上采样。Upsampling的过程中,再将直接下采样过程中对应大小的特征图直接加到上采样的特征图上。通过最后一个上采样层,从而恢复回原来图像大小。最后,再加上一些1*1的卷积层,输出最终的heatmap。如果需要堆叠许多这样的沙漏模块的话,也可以等到堆叠完总网络,再最后再加上这些小卷积。

十分类似UNet

小点

stacked hourglass network 使用的是 immediate supervision,因为它是由很多个hourglass network组合而成,在每一个hourglass network输出之后就计算损失,而不是到跑完所有的hourglass networks都跑完之后在计算损失。

伪代码

bash 复制代码
#一个Stacked Hourglass模块,里面把普通的卷积层用残差卷积模块来替代了,u型架构则采用了递归实现
#(此代码只负责讲清楚思路,因此只包含了最基本的网络,而具体工程实现比这个要来的复杂,会加入更多卷积以及归一化层)
class Hourglass(nn.Module):
    def __init__(self, n, channel, increase=0):
        super(Hourglass, self).__init__()
        #先池化降分辨率
        self.pool = Pool(2, 2)
        #用residual_module来Downsample
        self.conv = Residual(channel, channel)
        self.n = n
        # Recursive hourglass: 使用递归的方式完成u型结构,每次向内递归一层,直到n=1不能再套u型层,此时就通过卷积层。
        if self.n > 1:
            self.smaller_hourglass = Hourglass(n-1, channel)
        else:
            self.smaller_hourglass = Residual(channel, channel)
        #最近邻上采样,恢复分辨率,对应于self.pool的池化层
        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):
        #先池化降低分辨率,完成下采样
        pool = self.pool(x)
        #用residual_module来Downsample
        conv  = self.conv(pool)
        #hourglass,不过参数n=n-1
        low = self.smaller_hourglass(conv)
        #最近邻上采样
        up  = self.up(low)
        #将原图像和操作后图像相加,完成u型结构
        return x + up

参考文章

堆叠沙漏网络(stacked hourglass network)
Stacked Hourglass Networks精讲(含代码分析+Colab)- 姿态估计论文精读------文章不错,里面有Hourglass的贡献及可以改进的地方

相关推荐
ambition202427 分钟前
从“分组游戏”到数学结构:等价关系、等价类、商集与划分完全指南
人工智能·游戏
黎阳之光13 分钟前
AI数智筑防线 绿色科技启新篇——黎阳之光硬核技术赋能生态安全双升级
大数据·人工智能·算法·安全·数字孪生
高德开放平台13 分钟前
高德开放平台已全面接入“鹰眼守护”预警系统,两轮车版率先适配小牛电动
人工智能
twc82916 分钟前
Query 改写 大模型测试的数据倍增器
开发语言·人工智能·python·rag·大模型测试
程序员小郭8329 分钟前
Spring AI 06 提示词(Prompt)全场景实战:从基础到高级模板用法
人工智能·spring·prompt
zhangfeng113335 分钟前
`transformers` 的 `per_device_train_batch_size` 不支持小于 1 的浮点数值,llamafactory 支持
人工智能·算法·batch
zl_vslam35 分钟前
SLAM中的非线性优-3D图优化之绝对位姿SE3约束四元数形式(十九)
人工智能·算法·计算机视觉·3d
Predestination王瀞潞35 分钟前
1.3.1 AI->Tesseract OCR Engine标准(HP、Google):Tesseract OCR Engine
人工智能·ocr
Fleshy数模42 分钟前
基于PyTorch的食品图像分类:数据增强与调优实战
人工智能·pytorch·分类
岁岁种桃花儿44 分钟前
AI超级智能开发系列从入门到上天第十篇:SpringAI+云知识库服务
linux·运维·数据库·人工智能·oracle·llm