背景与动机
传统 VAE 使用连续的潜在空间,但自然语言、音频帧、图像 patch 等数据在语义上更接近离散结构。VQ-VAE(Vector Quantized VAE)由 van den Oord 等人于 2017 年提出,其核心思想是将连续的编码器输出"对齐"到一个有限的离散码本(Codebook)上,从而学习离散的潜在表示,同时避免了 VAE 中后验坍塌(posterior collapse)的问题。
模型架构
VQ-VAE 由三个关键模块构成:编码器(Encoder) 、向量量化层(VQ Layer) 和 解码器(Decoder) ,另外维护一个可学习的码本 (\mathcal{E} = {e_k}_{k=1}^{K}),其中每个码向量 (e_k \in \mathbb{R}^D),(K) 为码本大小,(D) 为嵌入维度。
向量量化过程
编码器将输入 (x) 映射为连续向量 (z_e(x) \in \mathbb{R}^D),随后通过最近邻查找,将其替换为码本中距离最近的码向量:
zq(x)=ek,k=argminj∥ze(x)−ej∥2 z_q(x) = e_k, \quad k = \arg\min_{j} \| z_e(x) - e_j \|_2 zq(x)=ek,k=argjmin∥ze(x)−ej∥2
量化后的 (z_q(x)) 被送入解码器进行重建。这一"硬分配"操作是不可微的,因此需要特殊的梯度估计方法。
训练目标与损失函数
VQ-VAE 的总损失由三项组成:
L=∥x−x^∥22⏟重建损失+∥sg[ze(x)]−e∥22⏟码本损失+β∥ze(x)−sg[e]∥22⏟承诺损失 \mathcal{L} = \underbrace{\| x - \hat{x} \|2^2}{\text{重建损失}} + \underbrace{\| \text{sg}[z_e(x)] - e \|2^2}{\text{码本损失}} + \underbrace{\beta \| z_e(x) - \text{sg}[e] \|2^2}{\text{承诺损失}} L=重建损失 ∥x−x^∥22+码本损失 ∥sg[ze(x)]−e∥22+承诺损失 β∥ze(x)−sg[e]∥22
其中 (\text{sg}[\cdot]) 表示 stop-gradient(停止梯度)操作,在前向传播中是恒等变换,反向传播时梯度为零。
重建损失(Reconstruction Loss) 衡量解码器输出 (\hat{x}) 与原始输入 (x) 之间的差距。对于连续数据通常用均方误差,对于离散数据(如像素值)则使用交叉熵。
码本损失(Codebook Loss / Dictionary Loss) 的目的是将码本中的码向量 (e) 拉向编码器输出 (z_e(x))。这里 (\text{sg}[z_e(x)]) 表示编码器输出被视为固定目标,只更新码向量。
承诺损失(Commitment Loss) 方向相反:固定码向量,驱使编码器输出 (z_e(x)) 向已选中的码向量靠拢。超参数 (\beta)(论文中通常取 0.25)控制承诺损失的权重,防止编码器输出游荡不定(即让编码器对选择的码向量"承诺")。
直通估计器(Straight-Through Estimator)
由于量化操作 (\arg\min) 不可微,梯度无法从解码器传回编码器。VQ-VAE 采用直通估计器近似处理:在反向传播时,将 (z_q(x)) 的梯度直接复制到 (z_e(x)$) 上,即:
∂L∂ze≈∂L∂zq \frac{\partial \mathcal{L}}{\partial z_e} \approx \frac{\partial \mathcal{L}}{\partial z_q} ∂ze∂L≈∂zq∂L
这是一种有偏但实践中有效的近似方法。
EMA 更新码本(替代方案)
码本损失也可以用指数移动平均(EMA) 来替代梯度下降更新码向量,这往往更稳定:
Nk←γNk+(1−γ)nk,mk←γmk+(1−γ)∑ze∈Bkze N_k \leftarrow \gamma N_k + (1-\gamma) n_k, \quad m_k \leftarrow \gamma m_k + (1-\gamma) \sum_{z_e \in \mathcal{B}_k} z_e Nk←γNk+(1−γ)nk,mk←γmk+(1−γ)ze∈Bk∑ze
ek←mkNk e_k \leftarrow \frac{m_k}{N_k} ek←Nkmk
其中 (\gamma) 为衰减系数(如 0.99),(n_k) 为当前 batch 中被分配到第 (k) 个码的样本数,(\mathcal{B}_k) 为对应的编码器输出集合。使用 EMA 时,训练损失中的码本损失项可以去掉。
评估指标
困惑度(Codebook Perplexity)
困惑度是衡量码本利用率最核心的指标,反映有效使用了多少个码向量。其定义基于码向量的使用频率分布:
pk=count(k)∑jcount(j),Perplexity=exp (−∑k=1Kpklogpk) p_k = \frac{\text{count}(k)}{\sum_j \text{count}(j)}, \quad \text{Perplexity} = \exp\!\left(-\sum_{k=1}^{K} p_k \log p_k\right) pk=∑jcount(j)count(k),Perplexity=exp(−k=1∑Kpklogpk)
困惑度的值域为 ([1, K])。当所有码向量被均匀使用时,困惑度达到最大值 (K);当只有少数几个码被激活时,困惑度趋近于 1。困惑度越高,码本利用越充分,表明模型学到了更丰富多样的离散表示。
码本利用率不足(codebook collapse)是训练 VQ-VAE 的常见问题,即大量码向量从未被使用,本质上浪费了码本容量。常见的缓解策略包括 EMA 更新、码向量随机重置(将"死亡"码向量重置为近期出现的编码器输出)以及码本扩大后聚类初始化等。
重建质量指标
重建质量通常用以下指标衡量:MSE/PSNR (像素级误差)、SSIM (结构相似性)、以及感知指标 LPIPS(学习感知图像块相似度),后者与人类感知相关性更强,在图像生成任务中更常用。
下游任务指标
VQ-VAE 的真正价值在于其离散编码序列可以被语言模型(如 PixelCNN、Transformer)直接建模。因此常用负对数似然(NLL) 或 bits-per-dim(BPD) 评估先验模型的建模能力,用 FID(Fréchet Inception Distance) 衡量生成样本的整体质量。
VQ-VAE-2 的改进
VQ-VAE-2(Razavi et al., 2019)引入了层级化的离散表示,使用粗粒度(global/top)和细粒度(local/bottom)两个 VQ 层分别捕获全局结构与局部细节,配合强大的 Transformer 先验,在 ImageNet 上实现了接近 GAN 的生成质量,同时保持了更好的多样性。
小结
VQ-VAE 的精妙之处在于用一个简单的最近邻量化操作,将连续表示学习与离散结构建模优雅地结合起来。码本困惑度是训练过程中最重要的监控指标,它直接反映模型是否在充分利用离散空间的表达能力。在实践中,VQ-VAE 已成为图像生成(DALL-E、VQ-GAN)、音频合成(SoundStream、EnCodec)和视频压缩等领域的基础架构之一。