Mask-Aware Transformer 大空洞修复。
1、图像修复 Introduction
定义
图像修复(Image inpainting、Image completion、image hole-filling)指的是合成图像中缺失区域的过程,可以帮助恢复被遮挡或降质的部分。
在下图中,左图是原图,左图蓝色区域是mask区域(原图的mask区域是不传给模型的),右图是模型输出图。
[图片]
一般输入的是Image图(扣掉后)+Mask图(单通道),如下图。
[图片]
用途
-
移除物体 remove objects
举例,如上图。
这个任务常用的数据集是Places,包含约400个场景的图像。
-
生成物体 generate novel objects
举例如下图,生成人脸面部的眼镜、鼻子。
这个任务常用的数据集是CelebA-HQ人脸数据集。
[图片]
[图片]
Tips
GAN模型对于mask区域是进行"移除物体"还是"生成物体",如果不外加干预(比如一些模型会加入人为互动绘制sketch),那么GAN模型的效果是取决于模型的。下图中,从Places数据集中训练出来的模型和从CelebA-HQ数据集中训练出来的模型对黑色mask区域的填充效果可见,效果取决于模型对数据的学习。
[图片]
- 图像超分 image super resolution
试想将原图每个像素之间都插入一个空像素,所有空像素构成mask区域,然后让模型对其进行图像补全。
举例如下图:
[图片] - 其他。
可以扩展用于图像压缩、隐私保护、照片修复、图像编辑、旧照片修复等场景。
难点
图像修复任务的难点大抵有如下:
- 语义和结构一致性。修复的图像区域应与周围环境保持一致的语义和结构,这需要算法能够理解图像内容并生成合理的修复结果。比如陆地上不能出现鱼。
- 细节保留和重建。修复区域可能包含复杂的纹理和细节,如面部特征、纹理等。算法需要能够精确地恢复这些细节,以生成逼真的修复图像。
- 多样性和创造性。不同的修复情况需要生成多样性的结果,以应对各种不同的损坏情况和修复需求。
- 遮挡和变形处理。修复区域可能被遮挡、变形或者包含不规则的形状,这需要算法能够适应不同的情况来填补缺失部分。下图是一个mask的发展过程,最初研究图中央区域规则的方形mask,后来mask逐渐往着任意形状、遮盖区域比例越来越大发展。
[图片] - 光照和阴影一致性。修复结果需要与周围环境的光照和阴影保持一致,以确保生成的图像看起来自然。
- 生成稳定性和模式崩溃:某些情况下,生成的修复结果可能不稳定,导致图像质量下降或出现异常。此外,一些算法可能会受到"模式崩溃"问题影响,即生成过于重复的图像内容。
方法
- 早期的方法:纹理合成、Patch之类,用图像其他区域填充的办法
- GAN
- diffusion
- GAN结合傅里叶变换、小波变换之类的特征提取
关键点
要想把图像修复做好,需要着重关注两点:
- 远距离上下文的图像内容推理。
- 大缺失区域或者任意形状区域的细粒度纹理合成。
基本所有论文都是围绕这两点做工作。
2、Related Work
Globally and Locally Consistent Image Completion
在关心全局语义的情况下,也注重局部细节。全局判别器网络将整个图像作为输入,而局部判别器网络仅将完成区域周围的小区域作为输入。训练两个判别器网络来确定图像是真实的还是由补全网络完成的,而训练补全网络来欺骗两个判别器网络。
缺点:色彩缺失,需要额外的后处理(快速行军和泊松图像混合)。
[图片]
[图片]
Generative Image Inpainting with Contextual Attention
使用上下文注意力层关注遥远空间位置的特征块。
双阶段图像生成,在Coarse Result之后再次精修。
全局判别器+局部判别器。
[图片]
Image Inpainting for Irregular Holes Using Partial Convolutions
Partial Convolutional Layer,包括一个masked和re-normalized的卷积操作,然后是一个mask-update step。
第一个证明在不规则形状的孔上训练图像绘制模型的有效性的人。
Free-Form Image Inpainting with Gated Convolution
引入门控卷积,为所有层中每个空间位置的每个通道学习动态特征选择机制,显著提高了自由形式掩模和输入的颜色一致性和修复质量。
提出了一种更实用的基于补丁的GAN鉴别器SN-PatchGAN,用于自由形式的图像修复。它简单、快速,并产生高质量的修复结果。
[图片]
Aggregated Contextual Transformations for High-Resolution Image Inpainting
建议学习高分辨率图像在绘画中的aggregated contextual transformations,这允许捕获信息 informative distant contexts 和rich patterns of interest for context reasoning进行上下文推理。
设计了一种新的掩模预测任务来训练适合图像绘制的discriminator。这样的设计迫使discriminator区分真实斑块和合成斑块的详细外观,这反过来又有利于生成器合成细粒度纹理。
[图片]
LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions
使用具有图像宽的感受野的快速傅里叶卷积fast Fourier convolutions (FFCs);
高感受野感知损失 (a high receptive field perceptual loss);
大面积mask训练。
[图片]
Co-ModGAN:Large Scale Image Completion via Co-Modulated Generative Adversarial Networks
提出Co-ModGAN,弥合了图像条件生成结构和最近的无条件调制生成结构之间的差距;
提出了新的P-IDS/U-IDS,用于对GAN的感知保真度进行稳健评估;
3、MAT:Mask-Aware Transformer
CVPR 2022 Best Paper Finalist, Oral
创新点
- 开发了一种新颖的修复框架 MAT,是第一个能够直接处理高分辨率图像的基于 transformer 的修复系统。
- 提出了一种新的多头自注意力 (MSA) 变体,称为多头上下文注意力 (MCA),只使用有效的token来计算注意力。
- 设计了一个风格操作模块,使模型能够通过调节卷积的权重来提供不同的预测结果。
模型结构
网络分为粗修复与细修复两个阶段。粗修复主要由一个卷积头,五个transformer模块和一个卷积尾构成;细修复采用一个 Conv-U-Net 来细化高频细节。
下图是粗修复的网络:
[图片]
细修复的网络是U-Net构型的,论文没有绘制此图。下图是整体的Generator网络。
[图片]
下图是整体的Discriminator网络,是VGG19网络构型。
[图片]
Convolutional Head
卷积头主要由四个卷积层构成,将3512 512的图像转换成1806464的特征图,用来提取token。
Transformer Body
本文对transformer模块进行了改进,一是删除了层归一化,二是采用融合学习(使用特征拼接)代替残差学习。
删除层归一化的原因:在大面积区域缺失的情况下,大部分的token是无效的,而层归一化会放大这些无效的token,从而导致训练不稳定;
替换残差连接的原因:残差连接鼓励模型学习高频内容,然而在刚开始大多数的token是无效的,在训练过程中没有适当的低频基础,很难直接学习高频细节,如果使用残差连接就会使优化变得困难。采用融合学习(使用特征拼接)代替残差学习,如下面的T图。
[图片]
[图片]
[图片]
Multi-Head Contextual Attention
为了处理大量的标记(对于512×512的图像,最多有4096个标记)和给定标记的低保真度(最多90%的标记是无用的),我们的注意力模块采用了位移窗口[36]和动态遮罩,能够利用少量可行的标记进行非局部交互。
注意力模块利用移位窗口和动态掩码,只使用有效的token进行加权求和。MCA输出是有效标记的加权和,如下图:
[图片]
Mask Updating Strategy
更新规则:只要当前窗口有一个token是有效的,经过注意力后,该窗口中的所有token都会更新为有效的。如果一个窗口中的所有token都是无效的,经过注意力后,它们仍然无效。
[图片]
Style Manipulation Module
设计了一个风格操作模块,使MAT具有多元化的生成。它通过在重建生成过程中使用额外的噪声输入改变卷积层的权值归一化来操纵输出。为了增强噪声输入的表示能力,我们强制图像条件样式sc从图像特征X和噪声无条件样式su中学习。
B是随机给的mask,由su和sc得到风格表达的s。
s将会改变权重W,从而让模型可以使用随机噪声作为输入,让模型可以有多元化生成。
[图片]
成绩
[图片]
4、图像修复 损失函数
重建损失 Reconstruction Loss
GAN(生成对抗网络)中的重建损失通常用于度量生成器生成的图像与真实图像之间的差异,帮助生成器学习生成更逼真的图像。在 GAN 中,生成器试图生成与真实图像相似的样本,而判别器则评估生成器生成的样本是否足够逼真。重建损失通常使用生成器生成的图像与对应的真实图像之间的差异来衡量。在实际应用中,可以根据任务和需求选择适当的损失函数,如 L1 损失、结构相似性损失(SSIM)等,下面用均方误差(MSE)作为重建损失举例。此外还有一些难以解释理解的重建损失:https://www.zhihu.com/question/521284760/answer/2384076383
import torch
import torch.nn as nn
生成器生成的图像
generated_image = torch.rand((16, 3, 64, 64)) # 16张3通道64x64的随机生成图像
真实图像
real_image = torch.rand((16, 3, 64, 64)) # 16张3通道64x64的随机真实图像
计算重建损失
reconstruction_loss = nn.MSELoss() # 使用均方误差损失
loss_value = reconstruction_loss(generated_image, real_image)
print("重建损失值:", loss_value.item())
对抗性损失 Adversarial Loss
在生成对抗网络(GAN)中,对抗性损失是用来训练判别器(Discriminator)和生成器(Generator)之间竞争的损失函数。它鼓励生成器生成逼真的样本,同时使判别器能够区分生成的样本和真实样本。对抗性损失通常是使用交叉熵损失函数来衡量生成样本被正确分类为真实样本的程度。
import torch
import torch.nn as nn
判别器的预测
discriminator_predictions = torch.rand((16, 1)) # 判别器对16个样本的预测结果
生成样本的标签(0表示生成样本)
generated_labels = torch.zeros((16, 1))
计算对抗性损失
adversarial_loss = nn.BCEWithLogitsLoss() # 使用二进制交叉熵损失
loss_value = adversarial_loss(discriminator_predictions, generated_labels)
print("对抗性损失值:", loss_value.item())
在上述示例中,我们使用了带有 logits 的二进制交叉熵损失(BCEWithLogitsLoss),将判别器的预测与生成样本的标签进行比较。对抗性损失的目标是使判别器能够正确区分生成样本和真实样本,同时促使生成器生成逼真的样本,从而使两者之间形成平衡竞争关系。
感知损失 Perceived Loss
在图像修复中,感知损失是一种用于训练生成对抗网络(GAN)的损失函数,它帮助网络学习更好地合成逼真的修复图像。感知损失通过比较生成图像和真实图像之间的特征表示来量化生成图像的质量。
下面是一个用PyTorch演示感知损失在图像修复中的示例代码片段:
import torch
import torch.nn as nn
import torchvision.models as models
class PerceptualLoss(nn.Module):
def init (self):
super(PerceptualLoss, self).init ()
self.vgg = models.vgg19(pretrained=True).features
self.layers = {
'3': 'relu1_2', # Conv3_2 -> ReLU1_2
'8': 'relu2_2', # Conv8_2 -> ReLU2_2
'17': 'relu3_3', # Conv17_3 -> ReLU3_3
'26': 'relu4_3' # Conv26_3 -> ReLU4_3
}
for param in self.vgg.parameters():
param.requires_grad = False
def forward(self, x, y):
x_features = self.get_features(x)
y_features = self.get_features(y)
loss = 0
for layer_name in self.layers:
loss += nn.functional.mse_loss(x_features[layer_name], y_features[layer_name])
return loss
def get_features(self, x):
features = {}
prev_x = x
for name, layer in self.vgg._modules.items():
x = layer(x)
if name in self.layers:
features[self.layers[name]] = x
if name == '26':
break
return features
使用示例
criterion = PerceptualLoss()
fake_image = torch.randn(1, 3, 256, 256) # 生成的修复图像
real_image = torch.randn(1, 3, 256, 256) # 真实图像
loss = criterion(fake_image, real_image)
print("Perceptual Loss:", loss.item())
在这个示例中,PerceptualLoss 类从预训练的VGG19模型中提取了不同层的特征,并计算修复图像和真实图像之间的感知损失。这有助于生成对抗网络学习将合成图像的特征与真实图像的特征匹配,从而提高修复图像的质量。
风格损失 Style Loss
风格损失是一种用于训练生成对抗网络(GAN)的损失函数,它有助于确保修复图像在视觉上与原始图像在风格上保持一致。风格损失通过比较生成图像与原始图像之间的特定风格特征,如纹理、颜色和形状等,来量化生成图像的风格相似性。
以下是一个使用PyTorch编写的示例程序,演示如何计算图像修复中的风格损失:
import torch
import torch.nn as nn
import torchvision.models as models
class StyleLoss(nn.Module):
def init (self):
super(StyleLoss, self).init ()
self.vgg = models.vgg19(pretrained=True).features
self.layers = {
'3': 'relu1_2', # Conv3_2 -> ReLU1_2
'8': 'relu2_2', # Conv8_2 -> ReLU2_2
'17': 'relu3_3', # Conv17_3 -> ReLU3_3
'26': 'relu4_3' # Conv26_3 -> ReLU4_3
}
for param in self.vgg.parameters():
param.requires_grad = False
def forward(self, x, y):
x_features = self.get_features(x)
y_features = self.get_features(y)
loss = 0
for layer_name in self.layers:
loss += nn.functional.mse_loss(self.gram_matrix(x_features[layer_name]),
self.gram_matrix(y_features[layer_name]))
return loss
def gram_matrix(self, input):
b, c, h, w = input.size()
features = input.view(b, c, h * w)
gram = torch.bmm(features, features.transpose(1, 2))
gram = gram / (c * h * w)
return gram
def get_features(self, x):
features = {}
prev_x = x
for name, layer in self.vgg._modules.items():
x = layer(x)
if name in self.layers:
features[self.layers[name]] = x
if name == '26':
break
return features
使用示例
criterion = StyleLoss()
fake_image = torch.randn(1, 3, 256, 256) # 生成的修复图像
original_image = torch.randn(1, 3, 256, 256) # 原始图像
loss = criterion(fake_image, original_image)
print("Style Loss:", loss.item())
在这个示例中,StyleLoss 类从预训练的VGG19模型中提取了不同层的特征,并计算修复图像与原始图像之间的风格损失。风格损失有助于生成对抗网络学习将修复图像的风格与原始图像的风格保持一致,从而提高修复图像的视觉品质。
感知损失和风格损失都是用于训练生成模型的损失函数,但它们分别强调了内容和风格两个不同的方面。
5、评价指标 Evaluation Metrics
- L1↓
- L2↓
- PSNR↑
[图片] - SSIM↑
[图片] - FID ↓
Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., Hochreiter, S.: Gans trained by a two time-scale update rule converge to a local nash equilibrium. In: Advances in Neural Information Processing Systems. pp. 6626{6637 (2017) 5, 9, 12
[图片] - LPIPS↓
Zhang, R., Isola, P., Efros, A.A., Shechtman, E., Wang, O.: The unreasonable effectiveness of deep features as a perceptual metric. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. pp. 586{595 (2018) 9, 12
[图片] - U-IDS↑ 和 P-IDS↑
Zhao, S., Cui, J., Sheng, Y., Dong, Y., Liang, X., Chang, E.I., Xu, Y.: Large scale image completion via co-modulated generative adversarial networks. arXiv preprint arXiv:2103.10428 (2021) 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 26, 27
[图片]