OCRNet原理与代码解析(ECCV 2020)

paper:Object-Contextual Representations for Semantic Segmentation

official implementation:https://github.com/HRNet/HRNet-Semantic-Segmentation

third-party implementation:https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/ocr_head.py

本文的创新点

本文聚焦于用上下文聚合策略context aggregation strategy来处理语义分割问题。本文的启发来源于一个像素的类别应该是这个像素所属对象的类别。本文提出了一个简单有效的方法,对象-上下文表示object-contextual representations,通过利用对应对象的类别来描述一个像素。首先在ground-truth分割的监督下学习目标区域,然后通过聚合目标区域内像素的表示来计算目标区域的表示,最后计算每个像素与每个目标区域的关系并使用object-contextual representation来增强每个像素的表示,其中object-contextual representation是所有目标区域表示与像素关系的加权聚合。

方法介绍

像素 \(p_{i}\) 的类别 \(l_{i}\) 本质上是 \(p_{i}\) 所在对象的类别。对象-上下文表示包括:(1)将图像 \(I\) 中的所有像素结构化分为 \(K\) 个软对象区域(2)通过聚合第 \(k\) 个对象区域中所有像素的表示用 \(\mathbf{f}_{k}\) 来示每个目标区域(3)基于 \(K\) 个目标区域和所有目标区域的关系,通过聚合 \(K\) 个目标区域的表示来增强每个像素的表示

其中 \(\mathbf{f}{k}\) 是第 \(k\) 个目标区域的表示,\(w{ik}\) 表示第 \(i\) 个像素和第 \(k\) 个目标区域的关系,\(\delta(\cdot)\) 和 \(\rho (\cdot)\) 是变换函数。

Soft object regions

将图像 \(I\) 划分为 \(K\) 个软目标区域 \(\left \{ \mathbf{M}{1},\mathbf{M}{2},...,\mathbf{M}{K} \right \} \),每个目标区域 \(\mathbf{M}{k}\) 对应类别 \(k\),并由一个2D map或一个粗略的segmentation map来表示,其中每个值表示这个位置的像素属于类别 \(k\) 的程度。我们根据骨干网络的中间输出来计算这 \(K\) 个目标区域。在训练过程中,在ground-truth segmentation的监督下用交叉熵损失来学习目标区域。

Object region representations

我们根据所有像素对第 \(k\) 个目标区域的所属程度进行加权聚合,从而得到第 \(k\) 个目标区域的表示

其中 \(\mathbf{x}{i}\) 表示像素 \(p{i}\),\(\widetilde{m}{ki} \) 表示像素 \(p{i}\) 属于第 \(k\) 个目标区域的归一化后的程度。我们使用spatial softmax来归一化每个目标区域 \(\mathbf{M}_{k}\)。

Object contextual representations

我们按下式计算每个像素和每个目标区域的关系

其中 \(\kappa (\mathbf{x},\mathbf{f})=\phi(\mathbf{x})^{\mathsf{T} }\psi (\mathbf{f})\) 是未归一化的关系函数,\(\phi(\cdot)\) 和 \(\psi(\cdot)\) 是两个转换函数具体实现为1x1 conv --->BN --->ReLU。这里是受到了self-attention的启发。

像素 \(p_{i}\) 的object contextual representation \(\mathbf{y}_{i}\) 根据式(3)计算得到。其中 \(\rho(\cdot)\) 和 \(\delta(\cdot)\) 也是由1x1 conv --->BN --->ReLU实现的两个转换函数。这里是受到non-local networks的启发。

Augmented representations

像素 \(p_{i}\) 的最终表示由两部分组成,一是原始表示 \(\mathbf{x}{i}\),二是对象-上下文表示 \(\mathbf{y}{i}\)

其中 \(g(\cdot)\) 是由1x1 conv --->BN --->ReLU实现的转换函数,用于融合原始表示和对象上下文表示。

整个pipeline如下图所示

代码解析

这里以mmsegmentation中的实现为例,介绍一下实现代码。输入shape=(8, 3, 480, 480),backbone采用ResNet-50,配置如下

可以看出和原始的ResNet-50不同的是,只有stage1的stride=2,因此经过backbone的输出shape为[(8,256,120,120),(8,512,60,60),(8,1024,60,60),(8,2048,60,60)]。

接下来head部分是一个级联head的设计,首先是一个FCNHead,然后是一个OCRHead,配置如下。

Head部分的实现如下

python 复制代码
def _decode_head_forward_train(self, inputs: Tensor,
                               data_samples: SampleList) -> dict:
    """Run forward function and calculate loss for decode head in
    training."""
    losses = dict()

    loss_decode = self.decode_head[0].loss(inputs, data_samples,
                                           self.train_cfg)

    losses.update(add_prefix(loss_decode, 'decode_0'))
    # get batch_img_metas
    batch_size = len(data_samples)
    batch_img_metas = []
    for batch_index in range(batch_size):
        metainfo = data_samples[batch_index].metainfo
        batch_img_metas.append(metainfo)

    for i in range(1, self.num_stages):
        # forward test again, maybe unnecessary for most methods.
        if i == 1:
            prev_outputs = self.decode_head[0].forward(inputs)  # ocrnet_r50, (8,2,60,60)
        else:
            prev_outputs = self.decode_head[i - 1].forward(
                inputs, prev_outputs)
        loss_decode = self.decode_head[i].loss(inputs, prev_outputs,
                                               data_samples,
                                               self.train_cfg)
        losses.update(add_prefix(loss_decode, f'decode_{i}'))

    return losses

其中self.decode_head[0] 就是FCNHead,对应图3中的粉色框。从图3和代码实现中可以看出,这里FCNHead既受GT的监督计算损失即line7的loss_decode ,同时又输出Soft Object Regions即line21的prev_outputs作为输入送入OCRHead中。

OCRHead的实现部分如下

python 复制代码
def forward(self, inputs, prev_output):  # [(8,256,120,120),(8,512,60,60),(8,1024,60,60),(8,2048,60,60)], (8,2,60,60)
    """Forward function."""
    x = self._transform_inputs(inputs)  # (8,2048,60,60)
    feats = self.bottleneck(x)  # (8,512,60,60)
    context = self.spatial_gather_module(feats, prev_output)  # 式(4)得到f_{k}, prev_output就是M_{k}, (8,512,2,1)
    object_context = self.object_context_block(feats, context)  # (8,512,60,60)
    output = self.cls_seg(object_context)  # (8,2,60,60)

    return output

其中输入inputs 就是backbone的输出,prev_output 就是FCNHead的输出。首先self._transform_inputs对输入进行转换,配置文件中in_index=3,这里直接就是根据索引取值。

然后self.bottleneck就是一个3x3卷积。

接着self.spatial_gather_module的实现如下,对应的是式(4)。其中probs就是FCNHead得到的object region \(M_{k}\),通过F.softmax来实现spatial softmax,最终得到的就是式(4)的输出 \(\mathbf{f}_{k}\)。

python 复制代码
class SpatialGatherModule(nn.Module):
    """Aggregate the context features according to the initial predicted
    probability distribution.

    Employ the soft-weighted method to aggregate the context.
    """

    def __init__(self, scale):
        super().__init__()
        self.scale = scale

    def forward(self, feats, probs):  # (8,512,60,60),(8,2,60,60)
        """Forward function."""
        batch_size, num_classes, height, width = probs.size()
        channels = feats.size(1)
        probs = probs.view(batch_size, num_classes, -1)  # (8,2,3600)
        feats = feats.view(batch_size, channels, -1)  # (8,512,3600)
        # [batch_size, height*width, num_classes]
        feats = feats.permute(0, 2, 1)  # (8,3600,512)
        # [batch_size, channels, height*width]
        probs = F.softmax(self.scale * probs, dim=2)  # 式(4)中的spatial softmax, (8,2,3600)
        # [batch_size, channels, num_classes]
        ocr_context = torch.matmul(probs, feats)  # (8,2,512)
        ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)  # (8,512,2,1)
        return ocr_context

然后self.object_context_block 是式(5)、式(3)、式(6)的具体实现,forward部分如下,先看第二步对应的就是式(6),其中self.bottleneck就是 \(g(\cdot)\),context就是最终得到的对象-上下文表示 \(\mathbf{y}_{i}\)。

python 复制代码
def forward(self, query_feats, key_feats):
    """Forward function."""
    context = super().forward(query_feats, key_feats)
    output = self.bottleneck(torch.cat([context, query_feats], dim=1))  # 式(6),对应g(.)
    if self.query_downsample is not None:
        output = resize(query_feats)

    return output

然后看第一步,这里直接调用的是SelfAttentionBlock,forward部分如下

python 复制代码
def forward(self, query_feats, key_feats):  # 式(5)中的x_{i}, f_{k}
    # ocrnet_r50_d8
    # (8,512,60,60), (8,512,2,1)
    """Forward function."""
    batch_size = query_feats.size(0)
    query = self.query_project(query_feats)  # \phi, (8,256,60,60)
    if self.query_downsample is not None:
        query = self.query_downsample(query)
    query = query.reshape(*query.shape[:2], -1)  # (8,256,3600)
    query = query.permute(0, 2, 1).contiguous()  # (8,3600,256)

    key = self.key_project(key_feats)  # \psi, (8,256,2,1)
    value = self.value_project(key_feats)  # 式(3)中的\delta, (8,256,2,1)
    if self.key_downsample is not None:
        key = self.key_downsample(key)
        value = self.key_downsample(value)
    key = key.reshape(*key.shape[:2], -1)  # (8,256,2)
    value = value.reshape(*value.shape[:2], -1)  # (8,256,2)
    value = value.permute(0, 2, 1).contiguous()  # (8,2,256)

    sim_map = torch.matmul(query, key)  # (8,3600,2)
    if self.matmul_norm:  # True
        sim_map = (self.channels**-.5) * sim_map  # 256,0.0625
    sim_map = F.softmax(sim_map, dim=-1)  # 用softmax来实现式(5),得到w_{ik}, (8,3600,2)

    context = torch.matmul(sim_map, value)  # (8,3600,256)
    context = context.permute(0, 2, 1).contiguous()  # (8,256,3600)
    context = context.reshape(batch_size, -1, *query_feats.shape[2:])  # (8,256,60,60)
    if self.out_project is not None:
        context = self.out_project(context)  # 式(3)中的\rho, (8,512,60,60)
    return context

其中self.query_project、self.key_project、self.value_project 分别对应 \(\phi(\cdot),\psi(\cdot),\delta(\cdot)\),最后的self.out_project对应 \(\rho(\cdot)\),line24用F.softmax来实现式(5)得到 \(w_{ik}\)。

相关推荐
多恩Stone1 分钟前
Post-train 入门(1):SFT / DPO / Online RL 概念理解和分类
人工智能·分类·数据挖掘
bin915335 分钟前
解锁Java开发新姿势:飞算JavaAI深度探秘 #飞算JavaAl炫技赛 #Java开发
java·人工智能·python·java开发·飞算javaai·javaai·飞算javaal炫技赛
居然JuRan39 分钟前
LangChain从0到1实战:手把手教你实现RAG
人工智能
摆烂工程师41 分钟前
GPT-5 对应用户可以使用的次数,以及解决 GPT-5 没有推送的问题
人工智能·gpt·程序员
cscshaha1 小时前
《从零构建大语言模型》学习笔记1,环境配置
人工智能·深度学习·语言模型·llm·从零构建大语言模型
双翌视觉2 小时前
机械手的眼睛,视觉系统如何让机器人学会精准抓取
人工智能·机器人·自动化
IvanCodes3 小时前
OpenAI 最新开源模型 gpt-oss (Windows + Ollama/ubuntu)本地部署详细教程
人工智能·语言模型·chatgpt·开源
2301_769006783 小时前
祝贺!1464种期刊被收录,CSCD 核心期刊目录更新!(附下载)
大数据·数据库·人工智能·搜索引擎·期刊
天天代码码天天3 小时前
C# OnnxRuntime Yolov8 纸箱检测
人工智能
猫头虎-人工智能3 小时前
ChatGPT模型选择器详解:全面了解GPT-4o、GPT-4.5、o3等模型的切换与使用策略(2025最新版)
人工智能·chatgpt·开源·aigc·ai编程·ai写作·ai-native