论文解析 | RobustNet / ISW

本文是论文 RobustNet: Improving Domain Generalization in Urban-Scene Segmentation via Instance Selective Whitening (ISW) 的解析。ISW 的贡献是提出了一个能够有选择性的白化协方差矩阵部分区域的损失函数,论文行文脉络十分清晰。

Instance Whitening Loss

首先,论文指出了 deep whitening transformation (DWT) (通过设计损失函数使得各特征图之间组成的协方差矩阵主对角线元素为 1,其它元素为 0)不能同时优化对角线元素和其它元素的缺陷,使用 IN 先将协方差矩阵的对角线元素归一,这样之后只需要优化非对角线元素即可,如下图 (a) 所示。

Instance Whitening 最基本的思路就是将协方差矩阵非对角线元素 -> 0,通过 <math xmlns="http://www.w3.org/1998/Math/MathML"> L I W L_{IW} </math>LIW 损失函数优化即可,如下图 (c) (d) 所示。

源码中每过一次 IW 就做一次 IN 操作,并额外返回经过 IN 处理的特征图

python 复制代码
class InstanceWhitening(nn.Module):

    def __init__(self, dim):
        super(InstanceWhitening, self).__init__()
        self.instance_standardization = nn.InstanceNorm2d(dim, affine=False)

    def forward(self, x):

        x = self.instance_standardization(x)
        w = x

        return x, w

在模型中保留这些特征图,用于计算 loss

python 复制代码
for module in i_block:
    if isinstance(module, InstanceWhitening):
        x, w = module(x)
        w_arr.append(w)
...
data_dict['w_arr'] = w_arr

计算 loss,传入的 mask 是一个 C x C 大小的主对角线元素为 0,其它元素为 1 的矩阵

python 复制代码
w_arr = output_dict['w_arr']
...
wt_loss = torch.FloatTensor([0]).cuda()
for index, f_map in enumerate(w_arr):
    B, C, H, W = f_map.shape
    M_ones = torch.ones(C,C).cuda()
    diag = torch.diag(M_ones)
    diag = torch.diag_embed(diag)
    M_ones = M_ones - diag
    loss = instance_whitening_loss(f_map, None, M_ones, 0, 10000)
    wt_loss = wt_loss + loss
wt_loss = wt_loss / len(w_arr)
total_loss += wt_loss.item()

<math xmlns="http://www.w3.org/1998/Math/MathML"> L I W L_{IW} </math>LIW 源码,协方差矩阵通过矩阵乘积得到,与掩码矩阵逐位相乘,得到需要优化的协方差矩阵,这里其实应该是协方差矩阵的上三角,传入的掩码应该下三角为 0

python 复制代码
def instance_whitening_loss(f_map, eye, mask_matrix, margin, num_remove_cov):
    f_cor, B = get_covariance_matrix(f_map, eye=eye)
    f_cor_masked = f_cor * mask_matrix

    off_diag_sum = torch.sum(torch.abs(f_cor_masked), dim=(1,2), keepdim=True) - margin # B X 1 X 1
    loss = torch.clamp(torch.div(off_diag_sum, num_remove_cov), min=0) # B X 1 X 1
    loss = torch.sum(loss) / B

    return loss


def get_covariance_matrix(f_map, eye=None):
    eps = 1e-5
    B, C, H, W = f_map.shape  # i-th feature size (B X C X H X W)
    HW = H * W
    if eye is None:
        eye = torch.eye(C).cuda()
    f_map = f_map.contiguous().view(B, C, -1)  # B X C X H X W > B X C X (H X W)
    f_cor = torch.bmm(f_map, f_map.transpose(1, 2)).div(HW-1) + (eps * eye)  # B X C X C / HW

    return f_cor, B

Margin-based relaxation of whitening loss

作者认为将协方差矩阵的非对角线元素全部优化为 0,会影响模型的鉴别能力,因此设计了一个 margin 参数,在上面的代码中已有体现:off_diag_sum = torch.sum(torch.abs(f_cor_masked), dim=(1,2), keepdim=True) - margin

Separating Covariance Elements

又到了特征解耦的时候了,本文的出发点是对原始数据引入一个光照变换,比较原图和数据增强后图像对应特征图的协方差矩阵,差异较小的部分认为是 domain-invariant 部分,其它部分为 domain-specific 部分。通过这个操作来得到一个 Selective 的掩码矩阵,只对原始协方差矩阵的这些部分做优化。

相关推荐
Perishell2 小时前
无人机避障——感知篇(Ego_Planner_v2中的滚动窗口实现动态实时感知建图grid_map ROS节点理解与参数调整影响)
计算机视觉·无人机·slam·地图生成·建图感知·双目视觉
kyle~2 小时前
Opencv---深度学习开发
人工智能·深度学习·opencv·计算机视觉·机器人
看到我,请让我去学习7 小时前
OpenCV 图像进阶处理:特征提取与车牌识别深度解析
人工智能·opencv·计算机视觉
音视频牛哥18 小时前
打造实时AI视觉系统:OpenCV结合RTSP|RTMP播放器的工程落地方案
人工智能·opencv·计算机视觉·大牛直播sdk·rtsp播放器·rtmp播放器·android rtmp
云卓SKYDROID20 小时前
无人机环境感知系统运行与技术难点!
人工智能·计算机视觉·目标跟踪·无人机·科普·高科技·云卓科技
金山几座1 天前
OpenCV探索之旅:形态学魔法
opencv·计算机视觉
presenttttt1 天前
用Python和OpenCV从零搭建一个完整的双目视觉系统(六 最终篇)
开发语言·python·opencv·计算机视觉
棱镜研途1 天前
学习笔记丨卷积神经网络(CNN):原理剖析与多领域Github应用
图像处理·笔记·学习·计算机视觉·cnn·卷积神经网络·信号处理
蹦蹦跳跳真可爱5891 天前
Python----OpenCV(几何变换--图像平移、图像旋转、放射变换、图像缩放、透视变换)
开发语言·人工智能·python·opencv·计算机视觉
加油加油的大力1 天前
入门基于深度学习(以yolov8和unet为例)的计算机视觉领域的学习路线
深度学习·yolo·计算机视觉