关于VQ-GAN的记录

VQGAN

paper: Taming Transformers for High-Resolution Image Synthesis

这篇文章是德国海德堡大学IWR团队发表的,被收录于CVPR21。作者提出了VQGAN (Vector Quantised Variational Autoencoder) 模型。我们很自然会想到VQVAE,VAE和GAN都是生成模型,其中VAE要求latent feature是连续的,以扩展模型的泛化性;GAN以生成对抗方式,强化了对样本的辨识能力。再聊聊Vector Quantised, VQVAE作者认为语言、声音等是离散的,而图像可以用语言描述,即离散的对象表示;遂引入了codebook技术实现了latent feature离散化,即用codebook中的code量化了latent feature,最后就可以用离散化的latent feature生成图像。那为什么要整合VQ和GAN呢,VQ模型相比连续latent feature模型具有方差小,易于训练的优势,而GAN模型简单,有效。除此之外,考虑到CNN具有图像的归纳偏好,可以有效处理图像任务;而Transformer具有强大的表达能力;因此,作者使用了Transformer来合成图像。

贡献点

  • 作者结合CNN有效的归纳偏好与transformer强大的表达能力
  • 实验验证模型的有效性

模型

如图1所示,VQGAN部分包括了一个生成器一个判别器 ,其中生成器部分类似VQVAE ,判别器 由一个简单的CNN 组成。 具体流程简述如下,我们送入一张小狗图片,经过CNN编码器得到latent feature,利用codebook对隐特征进行量化,解码还原小狗图。那如何更新codebook呢,我们对生成的图与原图分别送入判别器,利用判别的结果计算损失从而学习code。基于良好的codebook,我们可以将图像用latent feature的索引表示,再对其按行拉伸即可得到序列化的表示。对于序列数据,我们可以用Transformer捕捉序列内部的依赖,进而生成图片。

图1 VQGAN

模型和公式都比较简单,更多的是一些技术的整合,比如Codebook,感知损失,Transformer等。感知损失需要预训练模型,自己实现的可能受到算力,数据限制。但是,感知损失并不是必须的,可以用MSE损失替代。对于文中自适应权重的计算,作者并没有给出解释,我们可以认为从梯度大小上平衡两种损失的影响,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> λ = ▽ G L [ L r e c ] ▽ G L [ L G A N ] + δ ( 1 ) \lambda=\frac{\bigtriangledown_{G_{L}[\mathcal{L}{rec} ]}}{\bigtriangledown{G_{L}[\mathcal{L}{GAN}]+\delta }}\qquad (1) </math>λ=▽GL[LGAN]+δ▽GL[Lrec](1)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q ∗ = a r g m i n E , G , Z m a x D E x ∼ p ( x ) [ L V Q ( E , G , Z ) + λ L G A N ( { E , G , Z } , D ) ] ( 2 ) \mathcal{Q} ^{*}=argmin
{E,G,Z}max_{D}E_{x\sim p(x)}[\mathcal{L}{VQ}(E,G,Z)+\lambda\mathcal{L}{GAN}(\{E,G,Z\},D)]\qquad (2) </math>Q∗=argminE,G,ZmaxDEx∼p(x)[LVQ(E,G,Z)+λLGAN({E,G,Z},D)](2)

如公式1,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> ▽ G L [ L r e c ] \bigtriangledown_{G_{L}[\mathcal{L}{rec} ]} </math>▽GL[Lrec]表示重构损失关于解码器最后一层参数的梯度的大小, <math xmlns="http://www.w3.org/1998/Math/MathML"> ▽ G L [ L G A N ] \bigtriangledown{G_{L}[\mathcal{L}_{GAN}]} </math>▽GL[LGAN]表示判别损失关于解码器最后一层参数的梯度的大小。如果重构损失梯度更大,则 <math xmlns="http://www.w3.org/1998/Math/MathML"> λ \lambda </math>λ比较大,实际更新的时候会提高判别损失的梯度值,即统一了两者的梯度大小。下面给出一个小例子。

python 复制代码
import torch
x= torch.tensor([[0.2, 0.8], [0.1, 0.9], [1.0, 0]], requires_grad= True)
y= torch.tensor([1, 1.0, 0], requires_grad= False)
# loss
mse_loss= lambda label, pred: ((torch.gather(pred, 1, label.view(pred.shape[0], -1).long())- label)** 2).mean()
bce_loss= lambda label, pred: torch.clamp(-1* torch.log(torch.gather(pred, 1, label.view(pred.shape[0], -1).long())), 0, 100).mean()
loss1= mse_loss(y, x)
loss2= bce_loss(y, x)
# grad
g1= torch.autograd.grad(loss1, x, retain_graph= True)[0]
g2= torch.autograd.grad(loss2, x, retain_graph= True)[0]
# norm
lamuda= torch.norm(g1)/ (torch.norm(g2)+ 1e-10)
g2_norm= torch.autograd.grad(loss2* lamuda.detach(), x, retain_graph= True)[0]
print(f'origin g1.norm(): {torch.norm(g1)}\norigin g2.norm(): {torch.norm(g2)}\nnow g2.norm(): {torch.norm(g2_norm)}')
# origin g1.norm(): 0.2854495942592621
# origin g2.norm(): 0.649535596370697
# now g2.norm(): 0.2854495644569397

实验结果

作者利用Transformer合成的如下高分辨图。

图2 合成的雪山

为方便计,我们在MNIST数据集进行实现。通过对抗式训练,VQGAN生成了不错的手写数字,见图3。

图3 VQGAN生成的图

其中,图3第一行是生成的图,第二行是测试集的图。 利用学好的codebook,我们基于Transformer就实现对残缺图进行补全等有趣任务,见图4。

图4 对残缺的手写数字图进行补全

其中,图4第一张是原图,第二张是根据离散Latent feature生成的图,第三张是给定上半部分补全的图,第四张是随机合成的图。

实现

具体代码可到Github见guchengzhong/VQGAN。

参考

[1] VQ-GAN|PyTorch Implementation, Outlier, youtube.

相关推荐
珠海新立电子科技有限公司1 小时前
FPC柔性线路板与智能生活的融合
人工智能·生活·制造
IT古董1 小时前
【机器学习】机器学习中用到的高等数学知识-8. 图论 (Graph Theory)
人工智能·机器学习·图论
曼城周杰伦2 小时前
自然语言处理:第六十三章 阿里Qwen2 & 2.5系列
人工智能·阿里云·语言模型·自然语言处理·chatgpt·nlp·gpt-3
余炜yw2 小时前
【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感
人工智能·rnn·深度学习
莫叫石榴姐3 小时前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
如若1233 小时前
利用 `OpenCV` 和 `Matplotlib` 库进行图像读取、颜色空间转换、掩膜创建、颜色替换
人工智能·opencv·matplotlib
YRr YRr3 小时前
深度学习:神经网络中的损失函数的使用
人工智能·深度学习·神经网络
ChaseDreamRunner3 小时前
迁移学习理论与应用
人工智能·机器学习·迁移学习
Guofu_Liao3 小时前
大语言模型---梯度的简单介绍;梯度的定义;梯度计算的方法
人工智能·语言模型·矩阵·llama
我爱学Python!3 小时前
大语言模型与图结构的融合: 推荐系统中的新兴范式
人工智能·语言模型·自然语言处理·langchain·llm·大语言模型·推荐系统