关于VQ-GAN的记录

VQGAN

paper: Taming Transformers for High-Resolution Image Synthesis

这篇文章是德国海德堡大学IWR团队发表的,被收录于CVPR21。作者提出了VQGAN (Vector Quantised Variational Autoencoder) 模型。我们很自然会想到VQVAE,VAE和GAN都是生成模型,其中VAE要求latent feature是连续的,以扩展模型的泛化性;GAN以生成对抗方式,强化了对样本的辨识能力。再聊聊Vector Quantised, VQVAE作者认为语言、声音等是离散的,而图像可以用语言描述,即离散的对象表示;遂引入了codebook技术实现了latent feature离散化,即用codebook中的code量化了latent feature,最后就可以用离散化的latent feature生成图像。那为什么要整合VQ和GAN呢,VQ模型相比连续latent feature模型具有方差小,易于训练的优势,而GAN模型简单,有效。除此之外,考虑到CNN具有图像的归纳偏好,可以有效处理图像任务;而Transformer具有强大的表达能力;因此,作者使用了Transformer来合成图像。

贡献点

  • 作者结合CNN有效的归纳偏好与transformer强大的表达能力
  • 实验验证模型的有效性

模型

如图1所示,VQGAN部分包括了一个生成器一个判别器 ,其中生成器部分类似VQVAE ,判别器 由一个简单的CNN 组成。 具体流程简述如下,我们送入一张小狗图片,经过CNN编码器得到latent feature,利用codebook对隐特征进行量化,解码还原小狗图。那如何更新codebook呢,我们对生成的图与原图分别送入判别器,利用判别的结果计算损失从而学习code。基于良好的codebook,我们可以将图像用latent feature的索引表示,再对其按行拉伸即可得到序列化的表示。对于序列数据,我们可以用Transformer捕捉序列内部的依赖,进而生成图片。

图1 VQGAN

模型和公式都比较简单,更多的是一些技术的整合,比如Codebook,感知损失,Transformer等。感知损失需要预训练模型,自己实现的可能受到算力,数据限制。但是,感知损失并不是必须的,可以用MSE损失替代。对于文中自适应权重的计算,作者并没有给出解释,我们可以认为从梯度大小上平衡两种损失的影响,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> λ = ▽ G L [ L r e c ] ▽ G L [ L G A N ] + δ ( 1 ) \lambda=\frac{\bigtriangledown_{G_{L}[\mathcal{L}{rec} ]}}{\bigtriangledown{G_{L}[\mathcal{L}{GAN}]+\delta }}\qquad (1) </math>λ=▽GL[LGAN]+δ▽GL[Lrec](1)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q ∗ = a r g m i n E , G , Z m a x D E x ∼ p ( x ) [ L V Q ( E , G , Z ) + λ L G A N ( { E , G , Z } , D ) ] ( 2 ) \mathcal{Q} ^{*}=argmin
{E,G,Z}max_{D}E_{x\sim p(x)}[\mathcal{L}{VQ}(E,G,Z)+\lambda\mathcal{L}{GAN}(\{E,G,Z\},D)]\qquad (2) </math>Q∗=argminE,G,ZmaxDEx∼p(x)[LVQ(E,G,Z)+λLGAN({E,G,Z},D)](2)

如公式1,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> ▽ G L [ L r e c ] \bigtriangledown_{G_{L}[\mathcal{L}{rec} ]} </math>▽GL[Lrec]表示重构损失关于解码器最后一层参数的梯度的大小, <math xmlns="http://www.w3.org/1998/Math/MathML"> ▽ G L [ L G A N ] \bigtriangledown{G_{L}[\mathcal{L}_{GAN}]} </math>▽GL[LGAN]表示判别损失关于解码器最后一层参数的梯度的大小。如果重构损失梯度更大,则 <math xmlns="http://www.w3.org/1998/Math/MathML"> λ \lambda </math>λ比较大,实际更新的时候会提高判别损失的梯度值,即统一了两者的梯度大小。下面给出一个小例子。

python 复制代码
import torch
x= torch.tensor([[0.2, 0.8], [0.1, 0.9], [1.0, 0]], requires_grad= True)
y= torch.tensor([1, 1.0, 0], requires_grad= False)
# loss
mse_loss= lambda label, pred: ((torch.gather(pred, 1, label.view(pred.shape[0], -1).long())- label)** 2).mean()
bce_loss= lambda label, pred: torch.clamp(-1* torch.log(torch.gather(pred, 1, label.view(pred.shape[0], -1).long())), 0, 100).mean()
loss1= mse_loss(y, x)
loss2= bce_loss(y, x)
# grad
g1= torch.autograd.grad(loss1, x, retain_graph= True)[0]
g2= torch.autograd.grad(loss2, x, retain_graph= True)[0]
# norm
lamuda= torch.norm(g1)/ (torch.norm(g2)+ 1e-10)
g2_norm= torch.autograd.grad(loss2* lamuda.detach(), x, retain_graph= True)[0]
print(f'origin g1.norm(): {torch.norm(g1)}\norigin g2.norm(): {torch.norm(g2)}\nnow g2.norm(): {torch.norm(g2_norm)}')
# origin g1.norm(): 0.2854495942592621
# origin g2.norm(): 0.649535596370697
# now g2.norm(): 0.2854495644569397

实验结果

作者利用Transformer合成的如下高分辨图。

图2 合成的雪山

为方便计,我们在MNIST数据集进行实现。通过对抗式训练,VQGAN生成了不错的手写数字,见图3。

图3 VQGAN生成的图

其中,图3第一行是生成的图,第二行是测试集的图。 利用学好的codebook,我们基于Transformer就实现对残缺图进行补全等有趣任务,见图4。

图4 对残缺的手写数字图进行补全

其中,图4第一张是原图,第二张是根据离散Latent feature生成的图,第三张是给定上半部分补全的图,第四张是随机合成的图。

实现

具体代码可到Github见guchengzhong/VQGAN。

参考

[1] VQ-GAN|PyTorch Implementation, Outlier, youtube.

相关推荐
深圳市青牛科技实业有限公司12 分钟前
【青牛科技】应用方案|D2587A高压大电流DC-DC
人工智能·科技·单片机·嵌入式硬件·机器人·安防监控
水豚AI课代表32 分钟前
分析报告、调研报告、工作方案等的提示词
大数据·人工智能·学习·chatgpt·aigc
几两春秋梦_33 分钟前
符号回归概念
人工智能·数据挖掘·回归
用户691581141651 小时前
Ascend Extension for PyTorch的源码解析
人工智能
用户691581141652 小时前
Ascend C的编程模型
人工智能
成富2 小时前
文本转SQL(Text-to-SQL),场景介绍与 Spring AI 实现
数据库·人工智能·sql·spring·oracle
CSDN云计算3 小时前
如何以开源加速AI企业落地,红帽带来新解法
人工智能·开源·openshift·红帽·instructlab
艾派森3 小时前
大数据分析案例-基于随机森林算法的智能手机价格预测模型
人工智能·python·随机森林·机器学习·数据挖掘
hairenjing11233 小时前
在 Android 手机上从SD 卡恢复数据的 6 个有效应用程序
android·人工智能·windows·macos·智能手机
小蜗子3 小时前
Multi‐modal knowledge graph inference via media convergenceand logic rule
人工智能·知识图谱