OCRNet原理与代码解析(ECCV 2020)

paper:Object-Contextual Representations for Semantic Segmentation

official implementation:https://github.com/HRNet/HRNet-Semantic-Segmentation

third-party implementation:https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/ocr_head.py

本文的创新点

本文聚焦于用上下文聚合策略context aggregation strategy来处理语义分割问题。本文的启发来源于一个像素的类别应该是这个像素所属对象的类别。本文提出了一个简单有效的方法,对象-上下文表示object-contextual representations,通过利用对应对象的类别来描述一个像素。首先在ground-truth分割的监督下学习目标区域,然后通过聚合目标区域内像素的表示来计算目标区域的表示,最后计算每个像素与每个目标区域的关系并使用object-contextual representation来增强每个像素的表示,其中object-contextual representation是所有目标区域表示与像素关系的加权聚合。

方法介绍

像素 \(p_{i}\) 的类别 \(l_{i}\) 本质上是 \(p_{i}\) 所在对象的类别。对象-上下文表示包括:(1)将图像 \(I\) 中的所有像素结构化分为 \(K\) 个软对象区域(2)通过聚合第 \(k\) 个对象区域中所有像素的表示用 \(\mathbf{f}_{k}\) 来示每个目标区域(3)基于 \(K\) 个目标区域和所有目标区域的关系,通过聚合 \(K\) 个目标区域的表示来增强每个像素的表示

其中 \(\mathbf{f}{k}\) 是第 \(k\) 个目标区域的表示,\(w{ik}\) 表示第 \(i\) 个像素和第 \(k\) 个目标区域的关系,\(\delta(\cdot)\) 和 \(\rho (\cdot)\) 是变换函数。

Soft object regions

将图像 \(I\) 划分为 \(K\) 个软目标区域 \(\left \{ \mathbf{M}{1},\mathbf{M}{2},...,\mathbf{M}{K} \right \} \),每个目标区域 \(\mathbf{M}{k}\) 对应类别 \(k\),并由一个2D map或一个粗略的segmentation map来表示,其中每个值表示这个位置的像素属于类别 \(k\) 的程度。我们根据骨干网络的中间输出来计算这 \(K\) 个目标区域。在训练过程中,在ground-truth segmentation的监督下用交叉熵损失来学习目标区域。

Object region representations

我们根据所有像素对第 \(k\) 个目标区域的所属程度进行加权聚合,从而得到第 \(k\) 个目标区域的表示

其中 \(\mathbf{x}{i}\) 表示像素 \(p{i}\),\(\widetilde{m}{ki} \) 表示像素 \(p{i}\) 属于第 \(k\) 个目标区域的归一化后的程度。我们使用spatial softmax来归一化每个目标区域 \(\mathbf{M}_{k}\)。

Object contextual representations

我们按下式计算每个像素和每个目标区域的关系

其中 \(\kappa (\mathbf{x},\mathbf{f})=\phi(\mathbf{x})^{\mathsf{T} }\psi (\mathbf{f})\) 是未归一化的关系函数,\(\phi(\cdot)\) 和 \(\psi(\cdot)\) 是两个转换函数具体实现为1x1 conv --->BN --->ReLU。这里是受到了self-attention的启发。

像素 \(p_{i}\) 的object contextual representation \(\mathbf{y}_{i}\) 根据式(3)计算得到。其中 \(\rho(\cdot)\) 和 \(\delta(\cdot)\) 也是由1x1 conv --->BN --->ReLU实现的两个转换函数。这里是受到non-local networks的启发。

Augmented representations

像素 \(p_{i}\) 的最终表示由两部分组成,一是原始表示 \(\mathbf{x}{i}\),二是对象-上下文表示 \(\mathbf{y}{i}\)

其中 \(g(\cdot)\) 是由1x1 conv --->BN --->ReLU实现的转换函数,用于融合原始表示和对象上下文表示。

整个pipeline如下图所示

代码解析

这里以mmsegmentation中的实现为例,介绍一下实现代码。输入shape=(8, 3, 480, 480),backbone采用ResNet-50,配置如下

可以看出和原始的ResNet-50不同的是,只有stage1的stride=2,因此经过backbone的输出shape为[(8,256,120,120),(8,512,60,60),(8,1024,60,60),(8,2048,60,60)]。

接下来head部分是一个级联head的设计,首先是一个FCNHead,然后是一个OCRHead,配置如下。

Head部分的实现如下

python 复制代码
def _decode_head_forward_train(self, inputs: Tensor,
                               data_samples: SampleList) -> dict:
    """Run forward function and calculate loss for decode head in
    training."""
    losses = dict()

    loss_decode = self.decode_head[0].loss(inputs, data_samples,
                                           self.train_cfg)

    losses.update(add_prefix(loss_decode, 'decode_0'))
    # get batch_img_metas
    batch_size = len(data_samples)
    batch_img_metas = []
    for batch_index in range(batch_size):
        metainfo = data_samples[batch_index].metainfo
        batch_img_metas.append(metainfo)

    for i in range(1, self.num_stages):
        # forward test again, maybe unnecessary for most methods.
        if i == 1:
            prev_outputs = self.decode_head[0].forward(inputs)  # ocrnet_r50, (8,2,60,60)
        else:
            prev_outputs = self.decode_head[i - 1].forward(
                inputs, prev_outputs)
        loss_decode = self.decode_head[i].loss(inputs, prev_outputs,
                                               data_samples,
                                               self.train_cfg)
        losses.update(add_prefix(loss_decode, f'decode_{i}'))

    return losses

其中self.decode_head[0] 就是FCNHead,对应图3中的粉色框。从图3和代码实现中可以看出,这里FCNHead既受GT的监督计算损失即line7的loss_decode ,同时又输出Soft Object Regions即line21的prev_outputs作为输入送入OCRHead中。

OCRHead的实现部分如下

python 复制代码
def forward(self, inputs, prev_output):  # [(8,256,120,120),(8,512,60,60),(8,1024,60,60),(8,2048,60,60)], (8,2,60,60)
    """Forward function."""
    x = self._transform_inputs(inputs)  # (8,2048,60,60)
    feats = self.bottleneck(x)  # (8,512,60,60)
    context = self.spatial_gather_module(feats, prev_output)  # 式(4)得到f_{k}, prev_output就是M_{k}, (8,512,2,1)
    object_context = self.object_context_block(feats, context)  # (8,512,60,60)
    output = self.cls_seg(object_context)  # (8,2,60,60)

    return output

其中输入inputs 就是backbone的输出,prev_output 就是FCNHead的输出。首先self._transform_inputs对输入进行转换,配置文件中in_index=3,这里直接就是根据索引取值。

然后self.bottleneck就是一个3x3卷积。

接着self.spatial_gather_module的实现如下,对应的是式(4)。其中probs就是FCNHead得到的object region \(M_{k}\),通过F.softmax来实现spatial softmax,最终得到的就是式(4)的输出 \(\mathbf{f}_{k}\)。

python 复制代码
class SpatialGatherModule(nn.Module):
    """Aggregate the context features according to the initial predicted
    probability distribution.

    Employ the soft-weighted method to aggregate the context.
    """

    def __init__(self, scale):
        super().__init__()
        self.scale = scale

    def forward(self, feats, probs):  # (8,512,60,60),(8,2,60,60)
        """Forward function."""
        batch_size, num_classes, height, width = probs.size()
        channels = feats.size(1)
        probs = probs.view(batch_size, num_classes, -1)  # (8,2,3600)
        feats = feats.view(batch_size, channels, -1)  # (8,512,3600)
        # [batch_size, height*width, num_classes]
        feats = feats.permute(0, 2, 1)  # (8,3600,512)
        # [batch_size, channels, height*width]
        probs = F.softmax(self.scale * probs, dim=2)  # 式(4)中的spatial softmax, (8,2,3600)
        # [batch_size, channels, num_classes]
        ocr_context = torch.matmul(probs, feats)  # (8,2,512)
        ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)  # (8,512,2,1)
        return ocr_context

然后self.object_context_block 是式(5)、式(3)、式(6)的具体实现,forward部分如下,先看第二步对应的就是式(6),其中self.bottleneck就是 \(g(\cdot)\),context就是最终得到的对象-上下文表示 \(\mathbf{y}_{i}\)。

python 复制代码
def forward(self, query_feats, key_feats):
    """Forward function."""
    context = super().forward(query_feats, key_feats)
    output = self.bottleneck(torch.cat([context, query_feats], dim=1))  # 式(6),对应g(.)
    if self.query_downsample is not None:
        output = resize(query_feats)

    return output

然后看第一步,这里直接调用的是SelfAttentionBlock,forward部分如下

python 复制代码
def forward(self, query_feats, key_feats):  # 式(5)中的x_{i}, f_{k}
    # ocrnet_r50_d8
    # (8,512,60,60), (8,512,2,1)
    """Forward function."""
    batch_size = query_feats.size(0)
    query = self.query_project(query_feats)  # \phi, (8,256,60,60)
    if self.query_downsample is not None:
        query = self.query_downsample(query)
    query = query.reshape(*query.shape[:2], -1)  # (8,256,3600)
    query = query.permute(0, 2, 1).contiguous()  # (8,3600,256)

    key = self.key_project(key_feats)  # \psi, (8,256,2,1)
    value = self.value_project(key_feats)  # 式(3)中的\delta, (8,256,2,1)
    if self.key_downsample is not None:
        key = self.key_downsample(key)
        value = self.key_downsample(value)
    key = key.reshape(*key.shape[:2], -1)  # (8,256,2)
    value = value.reshape(*value.shape[:2], -1)  # (8,256,2)
    value = value.permute(0, 2, 1).contiguous()  # (8,2,256)

    sim_map = torch.matmul(query, key)  # (8,3600,2)
    if self.matmul_norm:  # True
        sim_map = (self.channels**-.5) * sim_map  # 256,0.0625
    sim_map = F.softmax(sim_map, dim=-1)  # 用softmax来实现式(5),得到w_{ik}, (8,3600,2)

    context = torch.matmul(sim_map, value)  # (8,3600,256)
    context = context.permute(0, 2, 1).contiguous()  # (8,256,3600)
    context = context.reshape(batch_size, -1, *query_feats.shape[2:])  # (8,256,60,60)
    if self.out_project is not None:
        context = self.out_project(context)  # 式(3)中的\rho, (8,512,60,60)
    return context

其中self.query_project、self.key_project、self.value_project 分别对应 \(\phi(\cdot),\psi(\cdot),\delta(\cdot)\),最后的self.out_project对应 \(\rho(\cdot)\),line24用F.softmax来实现式(5)得到 \(w_{ik}\)。

相关推荐
加百力17 分钟前
AI基建还能投多久?高盛:2-3年不是问题,回报窗口才刚开启
大数据·人工智能
魔力之心18 分钟前
TensorFlow2 study notes[1]
人工智能·python·tensorflow
Swift社区23 分钟前
日志不再孤立!用 Jaeger + TraceId 实现链路级定位
人工智能·chatgpt
AI扶我青云志2 小时前
BPE(Byte Pair Encoding)分词算法
人工智能·自然语言处理
Web3_Daisy2 小时前
想要抢早期筹码?FourMeme专区批量交易教学
大数据·人工智能·区块链·比特币
东风西巷4 小时前
NealFun安卓版:创意无限,娱乐至上
android·人工智能·智能手机·娱乐·软件需求
肥猪猪爸5 小时前
BP神经网络对时序数据进行分类
人工智能·深度学习·神经网络·算法·机器学习·分类·时序数据
Keep learning!5 小时前
深度学习入门代码详细注释-ResNet18分类蚂蚁蜜蜂
人工智能·深度学习·分类
Liudef066 小时前
神经辐射场 (NeRF):重构三维世界的AI新视角
人工智能·重构
音视频牛哥7 小时前
打造实时AI视觉系统:OpenCV结合RTSP|RTMP播放器的工程落地方案
人工智能·opencv·计算机视觉·大牛直播sdk·rtsp播放器·rtmp播放器·android rtmp