VQGAN
paper: Taming Transformers for High-Resolution Image Synthesis
这篇文章是德国海德堡大学IWR团队发表的,被收录于CVPR21。作者提出了VQGAN (Vector Quantised Variational Autoencoder) 模型。我们很自然会想到VQVAE,VAE和GAN都是生成模型,其中VAE要求latent feature是连续的,以扩展模型的泛化性;GAN以生成对抗方式,强化了对样本的辨识能力。再聊聊Vector Quantised, VQVAE作者认为语言、声音等是离散的,而图像可以用语言描述,即离散的对象表示;遂引入了codebook技术实现了latent feature离散化,即用codebook中的code量化了latent feature,最后就可以用离散化的latent feature生成图像。那为什么要整合VQ和GAN呢,VQ模型相比连续latent feature模型具有方差小,易于训练的优势,而GAN模型简单,有效。除此之外,考虑到CNN具有图像的归纳偏好,可以有效处理图像任务;而Transformer具有强大的表达能力;因此,作者使用了Transformer来合成图像。
贡献点
- 作者结合CNN有效的归纳偏好与transformer强大的表达能力
- 实验验证模型的有效性
模型
如图1所示,VQGAN部分包括了一个生成器 与一个判别器 ,其中生成器部分类似VQVAE ,判别器 由一个简单的CNN 组成。 具体流程简述如下,我们送入一张小狗图片,经过CNN编码器得到latent feature,利用codebook对隐特征进行量化,解码还原小狗图。那如何更新codebook呢,我们对生成的图与原图分别送入判别器,利用判别的结果计算损失从而学习code。基于良好的codebook,我们可以将图像用latent feature的索引表示,再对其按行拉伸即可得到序列化的表示。对于序列数据,我们可以用Transformer捕捉序列内部的依赖,进而生成图片。
图1 VQGAN
模型和公式都比较简单,更多的是一些技术的整合,比如Codebook,感知损失,Transformer等。感知损失需要预训练模型,自己实现的可能受到算力,数据限制。但是,感知损失并不是必须的,可以用MSE损失替代。对于文中自适应权重的计算,作者并没有给出解释,我们可以认为从梯度大小上平衡两种损失的影响,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> λ = ▽ G L [ L r e c ] ▽ G L [ L G A N ] + δ ( 1 ) \lambda=\frac{\bigtriangledown_{G_{L}[\mathcal{L}{rec} ]}}{\bigtriangledown{G_{L}[\mathcal{L}{GAN}]+\delta }}\qquad (1) </math>λ=▽GL[LGAN]+δ▽GL[Lrec](1)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q ∗ = a r g m i n E , G , Z m a x D E x ∼ p ( x ) [ L V Q ( E , G , Z ) + λ L G A N ( { E , G , Z } , D ) ] ( 2 ) \mathcal{Q} ^{*}=argmin{E,G,Z}max_{D}E_{x\sim p(x)}[\mathcal{L}{VQ}(E,G,Z)+\lambda\mathcal{L}{GAN}(\{E,G,Z\},D)]\qquad (2) </math>Q∗=argminE,G,ZmaxDEx∼p(x)[LVQ(E,G,Z)+λLGAN({E,G,Z},D)](2)
如公式1,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> ▽ G L [ L r e c ] \bigtriangledown_{G_{L}[\mathcal{L}{rec} ]} </math>▽GL[Lrec]表示重构损失关于解码器最后一层参数的梯度的大小, <math xmlns="http://www.w3.org/1998/Math/MathML"> ▽ G L [ L G A N ] \bigtriangledown{G_{L}[\mathcal{L}_{GAN}]} </math>▽GL[LGAN]表示判别损失关于解码器最后一层参数的梯度的大小。如果重构损失梯度更大,则 <math xmlns="http://www.w3.org/1998/Math/MathML"> λ \lambda </math>λ比较大,实际更新的时候会提高判别损失的梯度值,即统一了两者的梯度大小。下面给出一个小例子。
python
import torch
x= torch.tensor([[0.2, 0.8], [0.1, 0.9], [1.0, 0]], requires_grad= True)
y= torch.tensor([1, 1.0, 0], requires_grad= False)
# loss
mse_loss= lambda label, pred: ((torch.gather(pred, 1, label.view(pred.shape[0], -1).long())- label)** 2).mean()
bce_loss= lambda label, pred: torch.clamp(-1* torch.log(torch.gather(pred, 1, label.view(pred.shape[0], -1).long())), 0, 100).mean()
loss1= mse_loss(y, x)
loss2= bce_loss(y, x)
# grad
g1= torch.autograd.grad(loss1, x, retain_graph= True)[0]
g2= torch.autograd.grad(loss2, x, retain_graph= True)[0]
# norm
lamuda= torch.norm(g1)/ (torch.norm(g2)+ 1e-10)
g2_norm= torch.autograd.grad(loss2* lamuda.detach(), x, retain_graph= True)[0]
print(f'origin g1.norm(): {torch.norm(g1)}\norigin g2.norm(): {torch.norm(g2)}\nnow g2.norm(): {torch.norm(g2_norm)}')
# origin g1.norm(): 0.2854495942592621
# origin g2.norm(): 0.649535596370697
# now g2.norm(): 0.2854495644569397
实验结果
作者利用Transformer合成的如下高分辨图。
图2 合成的雪山
为方便计,我们在MNIST数据集进行实现。通过对抗式训练,VQGAN生成了不错的手写数字,见图3。
图3 VQGAN生成的图
其中,图3第一行是生成的图,第二行是测试集的图。 利用学好的codebook,我们基于Transformer就实现对残缺图进行补全等有趣任务,见图4。
图4 对残缺的手写数字图进行补全
其中,图4第一张是原图,第二张是根据离散Latent feature生成的图,第三张是给定上半部分补全的图,第四张是随机合成的图。
实现
具体代码可到Github见guchengzhong/VQGAN。
参考
[1] VQ-GAN|PyTorch Implementation, Outlier, youtube.