Adversarial Learning forSemi-Supervised Semantic Segmentation

首先来了解一下对抗学习:

**对抗样本:**将真实的样本添加扰动而合成的新样本,是由深度神经网络的输入的数据和人工精心设计好的噪声合成得到的,但它不会被人类视觉系统识别错误。然而在对抗数据面前,深度神经网络却是脆弱的,可以轻易迷惑深度神经网络。

**对抗训练:**想要在模型训练中提升模型的对抗防御能力,识别一些对抗样本

Abstract

本文提出了一种基于对抗网络的半监督语义分割方法,设计了一个全卷积判别器来判断预测值和GT,而且可以通过将对抗损失和交叉熵损失 相结合来提高准确率,判别器分析无标签图像的初步预测,识别出其中比较确定或较为可靠的部分,并将这些区域作为监督信号,进一步指导模型进行学习。现有的一些方法可能采用弱标签 (如图像级标签或者不完全标注的信息)来进行训练,而这篇方法则不同,它使用无标签图像,通过判别器识别可信区域来帮助模型训练,进一步提升了无标签数据的使用效率。

Introduction

在语义分割方面,有很多方法和数据集已经被提出,但因为物体/场景外观变化、遮挡和缺乏上下文理解,这项任务仍然具有挑战性,基于CNN的全卷积网络在语义分割上表现了良好的效果,许多方法的提出也是基于FCN的。

不同于 图像分类和目标检测,语义分割需要逐像素的标注,所以成本比较高,因此提出了半监督或弱监督来降低成本。

在本论文中,将语义分割网络视为GAN框架中的生成器 ,并提出一个FCN的判别器 。不同于传统的GAN生成器(用于生成图片),语义分割网络的"输出"是输入图像的每个像素的类别概率图 (即概率地图 ),而不是生成图像。因此,语义分割网络的任务是输出每个像素对应的类别概率 ,表示该像素属于某个语义类别的可能性。通过对抗训练的方式,FCN的判别器强制要求分割网络的输出尽可能接近真实的标签地图。

在本文中结合两个半监督损失,首先,利用判别器生成的置信度给出可信区域,可以作为有效的训练信号,通过信任图作为监督信号来**引导交叉熵损失(交叉熵损失一般用于监督学习),**信任图可以作为掩码,使得模型只在可信区域内进行训练。在未标记数据上使用对抗损失,可以指导生成器生成逼近GT的mask。

基于CNN的最新方法

  • 近年来,卷积神经网络(CNN)的进步使得语义分割方法得到了显著提升。例如,一些经典的分类网络(如AlexNet、VGG、ResNet等)可以通过改造为全卷积网络(FCN)来执行语义分割任务。这种改造过程的核心思想是,将分类网络中的全连接层替换为卷积层,使得网络能够输出与输入图像尺寸相同的像素级预测。
  • 这种方法虽然有效,但需要大量像素级标注的真实数据,这在实际应用中是非常困难和昂贵的。

弱监督方法的不同实现

  • 多实例学习(Multiple Instance Learning, MIL):在[36]和[35]的研究中,使用了多实例学习的方法,利用图像级标签生成潜在的分割标签图。MIL是一种弱监督学习方法,它可以利用图像级标签来推断图像中不同区域(或"实例")的类别,从而生成伪标签进行训练。

  • 图像级标签惩罚:Papandreou等人[33]通过图像级标签来惩罚网络对不存在的类别进行预测。这种方法的核心思想是,如果图像级标签没有标明某个类别,网络就不应该在预测中出现该类别,从而强制模型学习到更加精准的分割边界。

  • 对象定位优化 :Qi等人[37]通过对象定位来精细化分割结果,即通过识别物体的大致位置来优化网络的分割精度,从而减少对精确像素级标注的需求。

  • 分类网络作为特征提取器:Hong等人[15]使用有标签的图像训练一个分类网络作为特征提取器,用来进行解卷积操作(即上采样)。这样,分割网络能够利用分类网络提取到的特征进行更精确的分割。

还有利用边界框、点标注、Web视频数据进行训练的。半监督学习通过结合全标注数据和弱标注数据,提升了模型在语义分割任务中的表现。具体来说,半监督学习方法不仅利用图像级标签等弱监督信号来训练网络,还结合了少量的完全标注数据(如像素级标签),从而在保证性能的同时,降低了标注成本。

Algorithm Overview

该模块包含两个模块:分割网络和判别器网络

分割网络用于生成类别概率分布图,可以是任何语义分割的框架,比如FCN、Deeplab等,给定一个输入图像H*W*3,输出一个H*W*C的类别概率分布图。

判别器网络: 是一个基于FCN的网络,用于评估分割网络的输出(类别概率图)与真实标签(ground truth label maps)之间的差异。

输入:是类别概率分图(来自分割网络或者真实标签)

输出:输出是 H × W × 1 的空间概率图,表示每个像素是否来自真实标签图p=1,否则来自分割网络的输出p=0.

训练过程: 对于有标注的数据,分割网络 在训练时会受到两个损失函数(交叉熵损失和对抗损失)的监督;对于无标签数据,我们采用半监督训练的方式,首先,Unlabeled image会进入到分割网络生成一个类别概图,然后将这个类别概率图输入进入判别器网络中,生成一个置信度图,用置信度高的区域指导训练,相当于一个伪标签的作用。判别器训练判别器网络仅使用带标签的数据进行训练,它的任务是区分分割网络输出的类别概率图和真实标签之间的差异。

Semi-Supervised Training with Adversarial Network

Network Architecture

**Segmentation network.**DeepLab-v2框架with ResNet-101model pre-trained on the ImageNet dataset and MSCOCO。去掉最后一个分类层,并将最后两个卷积的stride从2改为1,从而使输出特征映射的分辨率有效地为输入图像大小的1/8。为了扩大感受野,我们在conv4和conv5层分别应用扩展卷积,stride分别为2和4。在最后一层使用了ASPP方法,最后,应用一个上采样层和softmax输出来匹配输入图像的大小。

**Discriminator network.**它由5个卷积层组成,其中4×4内核和{64,128,256,512,1}通道,步幅为2。 每个卷积层后面都有一个Leaky-ReLU[30]参数化为0.2,最后一层除外。为了将模型转换为全卷积网络,在最后一层添加上采样层,以将输出重新缩放为输入映射的大小。没有使用任何批处理归一化层,因为它只有在批处理大小足够大时才表现良好。

Loss Function

Discriminator network. 区别真实标签和分割网络的预测标签

第一项是针对来自分割网络的标签yn=0,我们希望它越趋近于0越好,因为想要判别器更能区分生成的标签和真实标签

第二项是针对真实标签yn=1,我们希望它越趋近于1越好,因为这是表示判别器认为该输入来自真实标签。

Yn真实的标注数据,经过 one-hot 编码处理后得到的概率图,因此判别器可以很容易地区分标签是来自真实标签,还是来自分割网络生成的标签,比如:

Yn​=[0,1,0,0,0]

S(Xn​)=[0.1,0.7,0.1,0.05,0.05]

解决办法:

  • 全卷积结构 :在判别器中采用 全卷积网络 (Fully Convolutional Network),这样判别器的输入不仅仅是一个全局的标签(像是一个标量值),而是包含了空间信息的 置信度图 。这种设计使得判别器要根据 空间局部信息 来判断每个像素点的真实性,而不仅仅依赖于标签的全局结构。
  • 扩散方案(Scale scheme) :为了增加判别器的难度,论文还尝试了一种扩散方案,将真实标签的 one-hot 编码进行 轻微扩散,让标签在不同的类别通道之间稍微分布,避免判别器依赖于 one-hot 编码的明确结构

Segmentation network.

过最小化多任务损失函数来训练分割网络:

**交叉熵损失:**使分割网络的预测趋近于真实值

**对抗损失:**使判别器判断不出来是预测值,最大化判别器的输出,欺骗判别器

Training with unlabeled data.

不使用交叉熵损失,因为没有GT,但是还是使用对抗损失,因为对抗损失仅依赖于判别器。

Self-taught Learning: 使用训练好的判别器来处理分割网络的预测结果,得到一个置信度图 D(S(Xn​)),表示判别器对于每个像素预测是否可信。通过设置一个阈值 Tsemi​,将置信度图二值化,得到可信区域(即置信度高于阈值的区域)。

半监督损失:

在实验中:阈值 Tsemi通常被设置在 0.1 到 0.3 之间、

Experiment

如果分割网络不试图欺骗鉴别器,由鉴别器生成的置信度图将是无意义的,提供较弱的监督信号。

相关推荐
就爱学编程5 分钟前
重生之我在异世界学编程之C语言:选择结构与循环结构篇
c语言·数据结构·算法
一只大侠11 分钟前
输入一串字符,以“?”结束。统计其中字母个数,数字个数,其它符号个数。:JAVA
java·开发语言·算法
资讯分享周22 分钟前
思特奇亮相2024数字科技生态大会,以“智”谋新共赢AI新时代
人工智能·科技
winstongit23 分钟前
捷联惯导原理和算法预备知识
算法·机器人
HuggingAI27 分钟前
Stable Diffusion Controlnet常用控制类型解析与实战课程 2
人工智能·ai·stable diffusion·ai绘画
一尘之中1 小时前
基于Transformer的编码器-解码器图像描述模型在AMD GPU上的应用
人工智能·深度学习·transformer
£suPerpanda1 小时前
P3916 图的遍历(Tarjan缩点和反向建边)
数据结构·c++·算法·深度优先·图论
IT古董1 小时前
【机器学习】机器学习的基本分类-监督学习-决策树-C4.5 算法
人工智能·学习·算法·决策树·机器学习·分类
IT古董1 小时前
【机器学习】机器学习的基本分类-监督学习-决策树-CART(Classification and Regression Tree)
学习·决策树·机器学习·分类
电子工程师UP学堂1 小时前
电子应用设计方案-37:智能鼠标系统方案设计
人工智能·单片机·嵌入式硬件·计算机外设