VQ-GAN是一种自动编码器,它学习将数据压缩为离散的潜在表示 ,由编码器 E E E、解码器 G G G、码本 C C C和判别器 D D D组成。过程如下:
给定一个图像 x ∈ R H × W × 3 x \in \mathbb{R}^{H×W×3} x∈RH×W×3,编码器 E E E将 x x x映射到其潜在表示 h ∈ R H ′ × W ′ × D h \in \mathbb{R}^{H'×W'×D} h∈RH′×W′×D,
通过在由嵌入 C = { e i } i = 1 K C = \{e_{i}\}_{i = 1}^{K} C={ei}i=1K组成的码本中 进行最近邻查找 对其进行量化,生成 z ∈ R H ′ × W ′ × D z \in \mathbb{R}^{H'×W'×D} z∈RH′×W′×D。
z z z通过解码器 G G G进行重构得到 x ^ \hat{x} x^。
在这个过程中,直通估计器(Bengio,2013)用于在量化步骤中保持梯度流动。码本优化以下损失:
L V Q = ∥ s g ( h ) − e ∥ 2 2 + β ∥ h − s g ( e ) ∥ 2 2 ( 1 ) \mathcal{L}_{VQ}=\| sg(h)-e\| _{2}^{2}+\beta\| h-sg(e)\| _{2}^{2} (1) LVQ=∥sg(h)−e∥22+β∥h−sg(e)∥22(1)
其中:
L V Q \mathcal{L}_{VQ} LVQ:VQ - GAN中码本优化的矢量量化损失,用于衡量量化过程误差。
h h h:编码器输出的潜在表示。
e e e :码本 c c c中与 h h h最接近的嵌入向量。
s g ( ⋅ ) sg(\cdot) sg(⋅):停止梯度操作符,保证量化过程中梯度正确流动。
β \beta β:超参数,常取0.25,控制两部分损失的相对权重。
公式由两部分组成, ∥ s g ( h ) − e ∥ 2 2 \| sg(h)-e\| _{2}^{2} ∥sg(h)−e∥22关注编码误差, β ∥ h − s g ( e ) ∥ 2 2 \beta\| h - sg(e)\| _{2}^{2} β∥h−sg(e)∥22关注解码误差。
其中 β = 0.25 \beta = 0.25 β=0.25是一个超参数, e e e是从码本 C C C中得到的最近邻嵌入。为了进行重构,VQ-GAN用感知损失(Zhang等人,2012) L L P I P S \mathcal{L}{LPIPS} LLPIPS代替了原来的 ℓ 2 \ell{2} ℓ2损失。最后,为了鼓励生成更高保真度的样本,训练补丁级判别器 D D D 对真实图像和重构图像进行分类,损失为:
L G A N = l o g D ( x ) + l o g ( 1 − D ( x ^ ) ) ( 2 ) \mathcal{L}_{GAN}=log D(x)+log (1-D(\hat{x})) (2) LGAN=logD(x)+log(1−D(x^))(2)
其中:
L G A N \mathcal{L}_{GAN} LGAN:生成对抗网络(GAN)的损失函数,用于训练判别器以区分真实图像和生成图像。
D D D :判别器,是一个神经网络,用于判断输入图像是真实图像的概率,输出值范围在 [ 0 , 1 ] [0, 1] [0,1]之间。
x x x:真实图像,来自原始的训练数据集。
x ^ \hat{x} x^:生成的(重构的)图像,由VQ - GAN的解码器生成。
该公式通过使 D ( x ) D(x) D(x)趋近于1(判别真实图像), D ( x ^ ) D(\hat{x}) D(x^)趋近于0(判别生成图像)来优化判别器。
总体而言,VQ-GAN优化以下损失:
min E , G , C max D L L P I P S + L V Q + λ L G A N ( 3 ) \min_{E, G, C} \max_{D} \mathcal{L}{LPIPS}+\mathcal{L}{VQ}+\lambda \mathcal{L}_{GAN} (3) E,G,CminDmaxLLPIPS+LVQ+λLGAN(3)
L L P I P S \mathcal{L}_{LPIPS} LLPIPS:基于学习的感知图像块相似性损失(Learned Perceptual Image Patch Similarity)。它是一种感知损失,用于衡量生成图像与真实图像在感知上的差异,更符合人类对图像相似性的主观判断。
L V Q \mathcal{L}_{VQ} LVQ :矢量量化损失,用于优化码本。它包含两部分,主要衡量编码器输出的潜在表示与码本中最近邻嵌入向量之间的编码和解码误差,公式为 L V Q = ∥ s g ( h ) − e ∥ 2 2 + β ∥ h − s g ( e ) ∥ 2 2 \mathcal{L}_{VQ}=\| sg(h)-e\| _{2}^{2}+\beta\| h - sg(e)\| _{2}^{2} LVQ=∥sg(h)−e∥22+β∥h−sg(e)∥22。
L G A N \mathcal{L}_{GAN} LGAN :生成对抗网络的损失,用于训练判别器区分真实图像和生成图像,公式为 L G A N = log D ( x ) + log ( 1 − D ( x ^ ) ) \mathcal{L}_{GAN}=\log D(x)+\log (1 - D(\hat{x})) LGAN=logD(x)+log(1−D(x^))。
λ \lambda λ :自适应权重,用于平衡 L G A N \mathcal{L}{GAN} LGAN与其他损失项的相对重要性,其计算公式为 λ = ∥ ∇ G L L L P I P S ∥ 2 ∥ ∇ G L L G A N ∥ 2 + δ \lambda=\frac{\left\|\nabla{G_{L}} L_{LPIPS}\right\|{2}}{\left\|\nabla{G_{L}} L_{GAN}\right\|{2}+\delta} λ=∥∇GLLGAN∥2+δ∥∇GLLLPIPS∥2,其中 G L G{L} GL是解码器的最后一层, δ \delta δ是一个小的常数(如 δ = 1 0 − 6 \delta = 10^{-6} δ=10−6)。
其中 λ = ∥ ∇ G L L L P I P S ∥ 2 ∥ ∇ G L L G A N ∥ 2 + δ \lambda=\frac{\left\|\nabla_{G_{L}} \mathcal{L}{LPIPS}\right\|{2}}{\left\|\nabla_{G_{L}} \mathcal{L}{GAN}\right\|{2}+\delta} λ=∥∇GLLGAN∥2+δ∥∇GLLLPIPS∥2是一个自适应权重, G L G_{L} GL是解码器的最后一层, δ = 1 0 − 6 \delta = 10^{-6} δ=10−6, L L P I P S \mathcal{L}_{LPIPS} LLPIPS是Zhang等人(2012)中描述的相同距离度量。
形式上,将 z ∈ Z H × W z \in \mathbb{Z}^{H×W} z∈ZH×W表示为代表图像的离散潜在标记。对于每个训练步骤,均匀采样 t ∈ [ 0 , 1 ) t \in [0, 1) t∈[0,1),并随机生成一个掩码 m ∈ { 0 , 1 } H × W m \in \{0, 1\}^{H×W} m∈{0,1}H×W,其中有 N = ⌈ γ H W ⌉ N=\lceil\gamma H W\rceil N=⌈γHW⌉个被掩码的值,这里 γ = cos ( π 2 t ) \gamma=\cos (\frac{\pi}{2} t) γ=cos(2πt)。然后,MaskGit通过以下目标学习预测被掩码的标记:
L m a s k = − E z ∈ D [ log p ( z ∣ z ⊙ m ) ] \mathcal{L}{mask }=-\mathbb{E}{z \in \mathcal{D}}[\log p(z | z \odot m)] Lmask=−Ez∈D[logp(z∣z⊙m)]
编码器 : z t = E ( x t , x t − 1 ) z_{t}=E(x_{t}, x_{t - 1}) zt=E(xt,xt−1)
时间变换器 : h t = H ( z ≤ t ) h_{t}=H(z_{\leq t}) ht=H(z≤t)
空间MaskGit : p ( z t ∣ h t − 1 ) p(z_{t} | h_{t - 1}) p(zt∣ht−1)
解码器 : p ( x t ∣ z t , h t − 1 ) p(x_{t} | z_{t}, h_{t - 1}) p(xt∣zt,ht−1)
编码器
利用视频数据中的时空冗余 来实现压缩 表示。为此,作者提出学习一个CNN编码器 z t = E ( x t , x t − 1 ) z_{t}=E(x_{t}, x_{t - 1}) zt=E(xt,xt−1) ,它通过在通道维度 上连接前一帧 x t − 1 x_{t - 1} xt−1对当前帧 x t x_{t} xt进行编码 ,然后使用码本 c c c对输出进行量化以生成 z t z_{t} zt。作者还做了如下的优化:
与连续潜在表示相比,压缩的离散潜在表示损失更大,并且往往需要更高的空间分辨率 。因此,在对时间信息进行建模之前 ,先应用一个跨步卷积对每个离散潜在 z t z_{t} zt进行下采样,在视觉上更简单的数据集可以进行更多的下采样,而视觉上复杂的数据集则需要较少的下采样。之后,再学习一个大型变换器对时间依赖关系进行建模,然后应用转置卷积将表示上采样回 z t z_{t} zt的原始分辨率 。总之,使用以下架构:
h t = H ( z < t ) = ConvT ( Transformer ( Conv ( z < t ) ) ) h_{t}=H\left(z_{<t}\right)=\text{ConvT}\left(\text{Transformer}\left(\text{Conv}\left(z_{<t}\right)\right)\right) ht=H(z<t)=ConvT(Transformer(Conv(z<t)))
解码器
解码器是一个上采样CNN ,用于重建 x ^ t = D ( z t , h t ) \hat{x}{t}=D(z{t}, h_{t}) x^t=D(zt,ht),其中 z t z_{t} zt可以解释为时间步 t t t的后验, h t h_{t} ht是时间变换器的输出,它汇总了先前时间步的信息。z t z_{t} zt和 h t h_{t} ht在通道维度上连接后输入到解码器 中。解码器与编码器一起优化以下交叉熵重建损失:
L r e c o n = − 1 T ∑ t = 1 T log p ( x t ∣ z t , h t ) \mathcal{L}{recon }=-\frac{1}{T} \sum{t = 1}^{T} \log p\left(x_{t} | z_{t}, h_{t}\right) Lrecon=−T1t=1∑Tlogp(xt∣zt,ht)
这鼓励 z t z_{t} zt特征编码帧之间的相对信息 ,因为时间变换器输出 h t h_{t} ht随时间聚合信息,从而学习更压缩的代码,以便在更长的序列上进行高效建模。
空间MaskGit
最后,使用MaskGit对先验 p ( z t ∣ h t ) p(z_{t} | h_{t}) p(zt∣ht)进行建模 。作者表明,与自回归先验相比,使用MaskGit先验不仅可以实现更快的采样,还能提高采样质量 。在每次训练迭代中,我们按照先前的工作对随机掩码 m t m_{t} mt进行采样,并优化 :
L p r i o r = − 1 T ∑ t = 1 T log p ( z t ∣ z t ⊙ m t ) \mathcal{L}{prior }=-\frac{1}{T} \sum{t = 1}^{T} \log p\left(z_{t} | z_{t} \odot m_{t}\right) Lprior=−T1t=1∑Tlogp(zt∣zt⊙mt)
其中 h t h_{t} ht与被掩码的 z t z_{t} zt在通道维度上连接,以预测被掩码的标记。在生成过程中,作者遵循Lee等人(2022)的方法,即最初每次以 8 个为一组生成每一帧,然后经过两轮修正,每次重新生成一半的标记。
训练目标
最终目标如下:
L T E C O = L V Q + L r e c o n + L p r i o r \mathcal{L}{TECO }=\mathcal{L}{VQ}+\mathcal{L}{recon }+\mathcal{L}{prior } LTECO=LVQ+Lrecon+Lprior
可以通过随机丢弃不进行解码的时间步来提高训练效率 ,这些时间步从重建损失中省略。例如,给定一个有 T T T帧的视频,我们计算所有 t ∈ { 1 , ... , T } t \in \{1, \ldots, T\} t∈{1,...,T}的 h t h_{t} ht,然后仅对10%的索引计算损失 L p r i o r L_{prior} Lprior和 L r e c o n L_{recon} Lrecon。