关于VQ-GAN的记录

VQGAN

paper: Taming Transformers for High-Resolution Image Synthesis

这篇文章是德国海德堡大学IWR团队发表的,被收录于CVPR21。作者提出了VQGAN (Vector Quantised Variational Autoencoder) 模型。我们很自然会想到VQVAE,VAE和GAN都是生成模型,其中VAE要求latent feature是连续的,以扩展模型的泛化性;GAN以生成对抗方式,强化了对样本的辨识能力。再聊聊Vector Quantised, VQVAE作者认为语言、声音等是离散的,而图像可以用语言描述,即离散的对象表示;遂引入了codebook技术实现了latent feature离散化,即用codebook中的code量化了latent feature,最后就可以用离散化的latent feature生成图像。那为什么要整合VQ和GAN呢,VQ模型相比连续latent feature模型具有方差小,易于训练的优势,而GAN模型简单,有效。除此之外,考虑到CNN具有图像的归纳偏好,可以有效处理图像任务;而Transformer具有强大的表达能力;因此,作者使用了Transformer来合成图像。

贡献点

  • 作者结合CNN有效的归纳偏好与transformer强大的表达能力
  • 实验验证模型的有效性

模型

如图1所示,VQGAN部分包括了一个生成器一个判别器 ,其中生成器部分类似VQVAE ,判别器 由一个简单的CNN 组成。 具体流程简述如下,我们送入一张小狗图片,经过CNN编码器得到latent feature,利用codebook对隐特征进行量化,解码还原小狗图。那如何更新codebook呢,我们对生成的图与原图分别送入判别器,利用判别的结果计算损失从而学习code。基于良好的codebook,我们可以将图像用latent feature的索引表示,再对其按行拉伸即可得到序列化的表示。对于序列数据,我们可以用Transformer捕捉序列内部的依赖,进而生成图片。

图1 VQGAN

模型和公式都比较简单,更多的是一些技术的整合,比如Codebook,感知损失,Transformer等。感知损失需要预训练模型,自己实现的可能受到算力,数据限制。但是,感知损失并不是必须的,可以用MSE损失替代。对于文中自适应权重的计算,作者并没有给出解释,我们可以认为从梯度大小上平衡两种损失的影响,
λ = ▽ G L L r e c ▽ G L L G A N + δ ( 1 ) \lambda=\frac{\bigtriangledown_{G_{L}\\mathcal{L}_{rec} }}{\bigtriangledown_{G_{L}\\mathcal{L}_{GAN}+\delta }}\qquad (1) λ=▽GLLGAN+δ▽GLLrec(1)
Q ∗ = a r g m i n E , G , Z m a x D E x ∼ p ( x ) L V Q ( E , G , Z ) + λ L G A N ( { E , G , Z } , D ) ( 2 ) \mathcal{Q} ^{*}=argmin_{E,G,Z}max_{D}E_{x\sim p(x)}\\mathcal{L}_{VQ}(E,G,Z)+\\lambda\\mathcal{L}_{GAN}(\\{E,G,Z\\},D)\qquad (2) Q∗=argminE,G,ZmaxDEx∼p(x)LVQ(E,G,Z)+λLGAN({E,G,Z},D)(2)

如公式1,其中 ▽ G L L r e c \bigtriangledown_{G_{L}\\mathcal{L}_{rec} } ▽GLLrec表示重构损失关于解码器最后一层参数的梯度的大小, ▽ G L L G A N \bigtriangledown_{G_{L}\\mathcal{L}_{GAN}} ▽GLLGAN表示判别损失关于解码器最后一层参数的梯度的大小。如果重构损失梯度更大,则 λ \lambda λ比较大,实际更新的时候会提高判别损失的梯度值,即统一了两者的梯度大小。下面给出一个小例子。

python 复制代码
import torch
x= torch.tensor([[0.2, 0.8], [0.1, 0.9], [1.0, 0]], requires_grad= True)
y= torch.tensor([1, 1.0, 0], requires_grad= False)
# loss
mse_loss= lambda label, pred: ((torch.gather(pred, 1, label.view(pred.shape[0], -1).long())- label)** 2).mean()
bce_loss= lambda label, pred: torch.clamp(-1* torch.log(torch.gather(pred, 1, label.view(pred.shape[0], -1).long())), 0, 100).mean()
loss1= mse_loss(y, x)
loss2= bce_loss(y, x)
# grad
g1= torch.autograd.grad(loss1, x, retain_graph= True)[0]
g2= torch.autograd.grad(loss2, x, retain_graph= True)[0]
# norm
lamuda= torch.norm(g1)/ (torch.norm(g2)+ 1e-10)
g2_norm= torch.autograd.grad(loss2* lamuda.detach(), x, retain_graph= True)[0]
print(f'origin g1.norm(): {torch.norm(g1)}\norigin g2.norm(): {torch.norm(g2)}\nnow g2.norm(): {torch.norm(g2_norm)}')
# origin g1.norm(): 0.2854495942592621
# origin g2.norm(): 0.649535596370697
# now g2.norm(): 0.2854495644569397

实验结果

作者利用Transformer合成的如下高分辨图。

图2 合成的雪山

为方便计,我们在MNIST数据集进行实现。通过对抗式训练,VQGAN生成了不错的手写数字,见图3。

图3 VQGAN生成的图

其中,图3第一行是生成的图,第二行是测试集的图。 利用学好的codebook,我们基于Transformer就实现对残缺图进行补全等有趣任务,见图4。

图4 对残缺的手写数字图进行补全

其中,图4第一张是原图,第二张是根据离散Latent feature生成的图,第三张是给定上半部分补全的图,第四张是随机合成的图。

实现

具体代码可到Github见guchengzhong/VQGAN。

参考

1 VQ-GAN|PyTorch Implementation, Outlier, youtube.

相关推荐
前沿科技说i2 小时前
2026年AI大模型API中转站:主流服务商性能与成本
人工智能
黄啊码4 小时前
【黄啊码】程序员真正该担心的,不是 AI 会写代码
人工智能
weixin_468466855 小时前
Ava 2.0 智能应用场景落地指南
人工智能·自然语言处理·大模型·智能交互·ava
John_ToDebug5 小时前
MCP 深度解析:大模型的“万能插头”
人工智能·经验分享·ai
浦信仿真大讲堂5 小时前
CST 仿真软件与 AI 融合的工程应用实战
人工智能·仿真软件·达索仿真·达索软件
mit6.8245 小时前
A Software Engineer‘s Apology | CODA
人工智能
段一凡-华北理工大学5 小时前
2026 高炉炼铁智能化技术全景与演进路径~系列文章11:演进路径与行业未来
大数据·网络·人工智能·算法·工业智能体·高炉炼铁智能化
小脑斧1235 小时前
AI技能化落地:从对话式大模型到可生产、可复用的AI工程体系
人工智能·skills·openclaw·hermes·marvis
西陵5 小时前
Agent 为什么会陷入 Doom Loop?OpenClaw 的破解之道
前端·人工智能·ai编程
飞哥数智坊5 小时前
动动嘴皮子就把事干了,Mic Air + TRAE SOLO 让我越来越懒
人工智能