关于VQ-GAN的记录

VQGAN

paper: Taming Transformers for High-Resolution Image Synthesis

这篇文章是德国海德堡大学IWR团队发表的,被收录于CVPR21。作者提出了VQGAN (Vector Quantised Variational Autoencoder) 模型。我们很自然会想到VQVAE,VAE和GAN都是生成模型,其中VAE要求latent feature是连续的,以扩展模型的泛化性;GAN以生成对抗方式,强化了对样本的辨识能力。再聊聊Vector Quantised, VQVAE作者认为语言、声音等是离散的,而图像可以用语言描述,即离散的对象表示;遂引入了codebook技术实现了latent feature离散化,即用codebook中的code量化了latent feature,最后就可以用离散化的latent feature生成图像。那为什么要整合VQ和GAN呢,VQ模型相比连续latent feature模型具有方差小,易于训练的优势,而GAN模型简单,有效。除此之外,考虑到CNN具有图像的归纳偏好,可以有效处理图像任务;而Transformer具有强大的表达能力;因此,作者使用了Transformer来合成图像。

贡献点

  • 作者结合CNN有效的归纳偏好与transformer强大的表达能力
  • 实验验证模型的有效性

模型

如图1所示,VQGAN部分包括了一个生成器一个判别器 ,其中生成器部分类似VQVAE ,判别器 由一个简单的CNN 组成。 具体流程简述如下,我们送入一张小狗图片,经过CNN编码器得到latent feature,利用codebook对隐特征进行量化,解码还原小狗图。那如何更新codebook呢,我们对生成的图与原图分别送入判别器,利用判别的结果计算损失从而学习code。基于良好的codebook,我们可以将图像用latent feature的索引表示,再对其按行拉伸即可得到序列化的表示。对于序列数据,我们可以用Transformer捕捉序列内部的依赖,进而生成图片。

图1 VQGAN

模型和公式都比较简单,更多的是一些技术的整合,比如Codebook,感知损失,Transformer等。感知损失需要预训练模型,自己实现的可能受到算力,数据限制。但是,感知损失并不是必须的,可以用MSE损失替代。对于文中自适应权重的计算,作者并没有给出解释,我们可以认为从梯度大小上平衡两种损失的影响,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> λ = ▽ G L [ L r e c ] ▽ G L [ L G A N ] + δ ( 1 ) \lambda=\frac{\bigtriangledown_{G_{L}[\mathcal{L}{rec} ]}}{\bigtriangledown{G_{L}[\mathcal{L}{GAN}]+\delta }}\qquad (1) </math>λ=▽GL[LGAN]+δ▽GL[Lrec](1)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q ∗ = a r g m i n E , G , Z m a x D E x ∼ p ( x ) [ L V Q ( E , G , Z ) + λ L G A N ( { E , G , Z } , D ) ] ( 2 ) \mathcal{Q} ^{*}=argmin
{E,G,Z}max_{D}E_{x\sim p(x)}[\mathcal{L}{VQ}(E,G,Z)+\lambda\mathcal{L}{GAN}(\{E,G,Z\},D)]\qquad (2) </math>Q∗=argminE,G,ZmaxDEx∼p(x)[LVQ(E,G,Z)+λLGAN({E,G,Z},D)](2)

如公式1,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> ▽ G L [ L r e c ] \bigtriangledown_{G_{L}[\mathcal{L}{rec} ]} </math>▽GL[Lrec]表示重构损失关于解码器最后一层参数的梯度的大小, <math xmlns="http://www.w3.org/1998/Math/MathML"> ▽ G L [ L G A N ] \bigtriangledown{G_{L}[\mathcal{L}_{GAN}]} </math>▽GL[LGAN]表示判别损失关于解码器最后一层参数的梯度的大小。如果重构损失梯度更大,则 <math xmlns="http://www.w3.org/1998/Math/MathML"> λ \lambda </math>λ比较大,实际更新的时候会提高判别损失的梯度值,即统一了两者的梯度大小。下面给出一个小例子。

python 复制代码
import torch
x= torch.tensor([[0.2, 0.8], [0.1, 0.9], [1.0, 0]], requires_grad= True)
y= torch.tensor([1, 1.0, 0], requires_grad= False)
# loss
mse_loss= lambda label, pred: ((torch.gather(pred, 1, label.view(pred.shape[0], -1).long())- label)** 2).mean()
bce_loss= lambda label, pred: torch.clamp(-1* torch.log(torch.gather(pred, 1, label.view(pred.shape[0], -1).long())), 0, 100).mean()
loss1= mse_loss(y, x)
loss2= bce_loss(y, x)
# grad
g1= torch.autograd.grad(loss1, x, retain_graph= True)[0]
g2= torch.autograd.grad(loss2, x, retain_graph= True)[0]
# norm
lamuda= torch.norm(g1)/ (torch.norm(g2)+ 1e-10)
g2_norm= torch.autograd.grad(loss2* lamuda.detach(), x, retain_graph= True)[0]
print(f'origin g1.norm(): {torch.norm(g1)}\norigin g2.norm(): {torch.norm(g2)}\nnow g2.norm(): {torch.norm(g2_norm)}')
# origin g1.norm(): 0.2854495942592621
# origin g2.norm(): 0.649535596370697
# now g2.norm(): 0.2854495644569397

实验结果

作者利用Transformer合成的如下高分辨图。

图2 合成的雪山

为方便计,我们在MNIST数据集进行实现。通过对抗式训练,VQGAN生成了不错的手写数字,见图3。

图3 VQGAN生成的图

其中,图3第一行是生成的图,第二行是测试集的图。 利用学好的codebook,我们基于Transformer就实现对残缺图进行补全等有趣任务,见图4。

图4 对残缺的手写数字图进行补全

其中,图4第一张是原图,第二张是根据离散Latent feature生成的图,第三张是给定上半部分补全的图,第四张是随机合成的图。

实现

具体代码可到Github见guchengzhong/VQGAN。

参考

1\] VQ-GAN\|PyTorch Implementation, Outlier, youtube.

相关推荐
人工智能小豪4 小时前
2025年大模型平台落地实践研究报告|附75页PDF文件下载
大数据·人工智能·transformer·anythingllm·ollama·大模型应用
芯盾时代4 小时前
AI在网络安全领域的应用现状和实践
人工智能·安全·web安全·网络安全
黑鹿0224 小时前
机器学习基础(三) 逻辑回归
人工智能·机器学习·逻辑回归
电鱼智能的电小鱼5 小时前
虚拟现实教育终端技术方案——基于EFISH-SCB-RK3588的全场景国产化替代
linux·网络·人工智能·分类·数据挖掘·vr
天天代码码天天5 小时前
C# Onnx 动漫人物头部检测
人工智能·深度学习·神经网络·opencv·目标检测·机器学习·计算机视觉
Joseit6 小时前
从零打造AI面试系统全栈开发
人工智能·面试·职场和发展
小猪猪_16 小时前
多视角学习、多任务学习,迁移学习
人工智能·迁移学习
飞哥数智坊6 小时前
AI编程实战:Cursor 1.0 上手实测,刀更锋利马更快
人工智能·cursor
vlln6 小时前
【论文解读】ReAct:从思考脱离行动, 到行动反馈思考
人工智能·深度学习·机器学习
qq_430908576 小时前
华为ICT和AI智能应用
人工智能·华为