彻底理解 Variational Autoencoder

为什么要写这篇文章

我2020年的时候曾用 Variational Autoencoder (VAE) 做过一个研究项目。 当时觉得其中的理论太过复杂,于是就把 VAE 当作黑盒使用。 此后虽然也看到过一些关于 VAE 的文章和视频,但是并没有再做 VAE 相关的研究项目了,所以就没去深入研究 VAE。

去年底开始做 Deepfake Speech Detection,遇到了一个叫做 YourTTS 的模型,其中有使用 Variational Interence 来进行模型优化。Variational Inference 就是 VAE 中最关键的理论。 再加上最近看了 Stanford 大学 CS330 的免费课程,其中有一节课是讲解 Variational Inference 的,于是决心把 VAE 彻底弄懂,就有了这一篇文章。 根据费曼学习法,将学到的知识讲解出来,可以帮助自己巩固这些知识。 如果同时能帮助到需要学习 VAE 的朋友,那就更好了。

在Arxiv上,有一篇叫做 Tutorial on Variational Autoencoders 的文章。 我起初想通过这篇文章来弄懂VAE,结果发现自己阅读这篇文章时,很容易迷失(可能是自己的水平不够吧)。 相比来说,Prof. Chelsea Finn 对 Variational Inference 的讲解,深入浅出,配合上 Tutorial on Variational Autoencoders,让我终于理解了 VAE 的精髓。

接下来的内容大部分是基于 Prof. Chelsea Finn 的讲解,其在 YouTube 上可以直接观看。 部分补充内容来自于 Tutorial on Variational Autoencoders。 文中的部分术语,我直接用其英文原文表示,以避免误解。 如果文中有任何错误,欢迎大家指教。

对读者的假设

读者应掌握基础的概率论和统计学,比如 Law of total probability,Kullback--Leibler divergence。 这篇文章也假设读者已经了解了 Autoencoder 的基础知识,比如 Encoder 和 Decoder 分别是干什么的。

Latent Variable Model

VAE 可以被用于生成图片,声音等。 为了简化讨论,我们仅讨论图片的生成。 VAE 是一个 latent variable model,意思是我们假设有一些不可见的 latent variables 在控制着 VAE 生成图片。 用生成手写数字来做一个例子,这些 latent variables 可能控制着数字的倾斜度,笔画的宽度,笔迹潦草度,等等。 然后 VAE 根据这些 latent variables 的指示,来生成对应的数字。 不过在深度学习中,我们并不指望每个 latent variable 都有可解释性,我们最关心的还是最后生成出来的图片是否看起来和训练数据来是自于同一个 distribution。

拿全国高考成绩作为另一个简单的例子,如果我们假设所有的高考分数遵循 Normal Distribution ,那么我们就可以定义一个均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ \mu </math>μ,一个方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ 2 \sigma^2 </math>σ2 来表示这个distribution。 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ \mu </math>μ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ 2 \sigma^2 </math>σ2 就是两个可解释的 latent variables。 我们可以用所有的高考成绩来求得 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ \mu </math>μ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ 2 \sigma^2 </math>σ2,然后我们就可以生成无数个新的高考成绩了(当然得注意数值的最大最小值),而且生成的高考成绩和原始的高考成绩都来自于同一个 Normal Distribution。

VAE 的核心思路

上个章节里举的高考分数例子过于简单。 对于图片来说,其 distribution 要远远复杂于 Normal Distribution。 我们用 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ∈ Z z \in \cal{Z} </math>z∈Z 来表示生成图片的 latent variable,在深度学习中, <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z通常为一个高维度的向量。 我们的目标是找到一个可以生成图片的函数: <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ : Z → X f_\theta:\cal{Z} \rightarrow \cal{X} </math>fθ:Z→X,其可以将 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z转换为图片( <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 表示函数的参数)。 由于训练数据代表了我们的目标 distribution,我们需要 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ ( Z ) f_\theta(\cal{Z}) </math>fθ(Z) 和整个训练数据类似。 换言之,随机采样一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z,我们需要 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ ( z ) f_\theta(z) </math>fθ(z) 看起来大概率来自于训练数据的 distribution,这样我们就可以通过随机采样 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z来生成图片了。

但是,我们无法用数学公式定义训练数据的 distribution(试想一下如何用数学公式定义人脸图片的distribution),这意味着给定 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ ( z ) f_\theta(z) </math>fθ(z),我们并不能计算出其有多大概率是来自于训练数据。 有人可能会说,那我们用一个深度神经网络(DNN)来模拟训练数据的 distribution,然后我们就可以用这个 DNN 来计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ ( z ) f_\theta(z) </math>fθ(z) 来自于训练数据的概率了。 这个正是Generative Adversarial Network(GAN)的思路,事实证明,GAN 的思路是行得通的。

VAE 采用了另一条思路来解决这个问题:由于训练数据是我们的目标 distribution,那么 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ ( Z ) f_\theta(\cal{Z}) </math>fθ(Z) 至少得把全部训练数据给还原出来: <math xmlns="http://www.w3.org/1998/Math/MathML"> X ⊆ f θ ( Z ) \cal{X} \subseteq f_\theta(\cal{Z}) </math>X⊆fθ(Z)。如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ ( Z ) f_\theta(\cal{Z}) </math>fθ(Z) 还能生成除训练数据外的图片,那么我们"希望"这些图片也能和训练数据类似。

Variational Inference

理解了 VAE 的核心思路后,下面我们来看看 VAE 是怎么解决这个问题的。 给定一张训练图片 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∈ X x \in \cal{X} </math>x∈X,为了能用 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ ( z ) f_\theta(z) </math>fθ(z) 还原出 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x, <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 需要对应着一个或多个 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z。 但是我们并不知道有多少个 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 以及这些 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 的值,那么我们只能从整体 <math xmlns="http://www.w3.org/1998/Math/MathML"> Z \cal{Z} </math>Z 空间的角度思考。 我们用 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x ) p_{\theta}(x) </math>pθ(x) 表示在整体 <math xmlns="http://www.w3.org/1998/Math/MathML"> Z \cal{Z} </math>Z 空间中, <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 可以被生成的概率,我们用 Law of total probability来展开 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x ) p_{\theta}(x) </math>pθ(x):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> p θ ( x ) = ∫ p θ ( x ∣ z ) p ( z ) d z p_{\theta}(x)=\int p_{\theta}(x|z)p(z)dz </math>pθ(x)=∫pθ(x∣z)p(z)dz

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ( z ) p(z) </math>p(z) 表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 的 distribution, <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x ∣ z ) p_{\theta}(x|z) </math>pθ(x∣z) 表示用一个特定的 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 生成 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 的概率。 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x ∣ z ) p_{\theta}(x|z) </math>pθ(x∣z) 隐性的包含了 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ ( z ) f_\theta(z) </math>fθ(z) 函数: <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x ∣ z ) = g ( x , f θ ( z ) ) p_{\theta}(x|z) = g(x,f_\theta(z)) </math>pθ(x∣z)=g(x,fθ(z))。 <math xmlns="http://www.w3.org/1998/Math/MathML"> g g </math>g 函数可以理解为计算两张图片其实是同一张图片的概率。

我们可以假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ( z ) p(z) </math>p(z) 就是一个简单的 Normal Distribution,这是因为 Normal Distribution 可以通过函数转换为任意复杂的 distribution。 从理论上来说我们可以这样理解,在计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x ) p_{\theta}(x) </math>pθ(x) 时,部分 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 的作用是将 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ( z ) p(z) </math>p(z) 转换为一个特定的 distribution, 然后剩下的 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 部分才是实质的被用于计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 可以用来生成 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 的概率。

让 <math xmlns="http://www.w3.org/1998/Math/MathML"> D = { x 1 , x 2 , ... , x N } \cal{D}=\{x_1,x_2,\dots,x_N\} </math>D={x1,x2,...,xN} 表示所有的 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 个训练数据。 我们的目标是让 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 可以最大化每个训练数据被生成的概率(这也叫做 Maximum likelihood fit 策略):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ ← arg max ⁡ θ 1 N ∑ i N log ⁡ p θ ( x i ) \theta \leftarrow \argmax_\theta \frac{1}{N} \sum_{i}^{N}\log p_{\theta}(x_i) </math>θ←θargmaxN1i∑Nlogpθ(xi)

这里需要注意的是,在概率前面加上 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ \log </math>log 函数是一个常见的做法,可以简化数学计算。 从优化的角度说, <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ \log </math>log 函数是一个单调递增函数,所以加上 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ \log </math>log 函数并不影响优化结果。

但是, <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x ) p_\theta(x) </math>pθ(x) 的计算包含了一个 intractable 的积分计算。 这意味着直接优化其值是不可能的,这正是我们需要 variational inference 的地方。 Variational inference 的策略是,得到一个可计算的 lower bound, 然后我们最大化这个 lower bound,这样就间接性的最大化了原函数值。

让我们开始计算 lower bound:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> l o g p θ ( x i ) = log ⁡ ∫ p θ ( x i ∣ z ) p ( z ) d z = log ⁡ ∫ p θ ( x i ∣ z ) p ( z ) q i ( z ) q i ( z ) d z = log ⁡ E z ∼ q i ( z ) [ p θ ( x i ∣ z ) p ( z ) q i ( z ) ] ≥ E z ∼ q i ( z ) [ log ⁡ p θ ( x i ∣ z ) p ( z ) q i ( z ) ] = E z ∼ q i ( z ) [ log ⁡ p θ ( x i ∣ z ) + log ⁡ p ( z ) ] − E z ∼ q i ( z ) [ log ⁡ q i ( z ) ] = E z ∼ q i ( z ) [ log ⁡ p θ ( x i ∣ z ) + log ⁡ p ( z ) ] + H ( q i ) = L i \begin{aligned} log p_{\theta}(x_i) &= \log \int p_{\theta}(x_i|z)p(z)dz \\ &= \log \int p_{\theta}(x_i|z)p(z) \frac{q_i(z)}{q_i(z)} dz \\ &= \log \mathbb{E}{z\sim q_i(z)}[\frac{p{\theta}(x_i|z)p(z)}{q_i(z)}] \\ &\geq \mathbb{E}{z\sim q_i(z)}[\log\frac{p{\theta}(x_i|z)p(z)}{q_i(z)}] \\ &= \mathbb{E}{z\sim q_i(z)}[\log p{\theta}(x_i|z) + \log p(z)] - \mathbb{E}{z\sim q_i(z)}[\log q_i(z)] \\ &= \mathbb{E}{z\sim q_i(z)}[\log p_{\theta}(x_i|z) + \log p(z)] + \cal{H}(q_i) \\ &= \cal{L}_i \end{aligned} </math>logpθ(xi)=log∫pθ(xi∣z)p(z)dz=log∫pθ(xi∣z)p(z)qi(z)qi(z)dz=logEz∼qi(z)[qi(z)pθ(xi∣z)p(z)]≥Ez∼qi(z)[logqi(z)pθ(xi∣z)p(z)]=Ez∼qi(z)[logpθ(xi∣z)+logp(z)]−Ez∼qi(z)[logqi(z)]=Ez∼qi(z)[logpθ(xi∣z)+logp(z)]+H(qi)=Li

推导的第一步是加入一个辅助的概率密度函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i ( z ) q_i(z) </math>qi(z)。 由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i ( z ) q i ( z ) = 1 \frac{q_i(z)}{q_i(z)}=1 </math>qi(z)qi(z)=1,我们可以乘以 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i ( z ) q i ( z ) = 1 \frac{q_i(z)}{q_i(z)}=1 </math>qi(z)qi(z)=1 而不改变函数值。 推导过程中的不等式来自于 Jesen's Inquality。 因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ \log </math>log 函数是一个凹函数,在把 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ \log </math>log 从期待函数里面拿出来时,要加上 <math xmlns="http://www.w3.org/1998/Math/MathML"> ≥ \geq </math>≥。 最后, <math xmlns="http://www.w3.org/1998/Math/MathML"> H ( q i ) \cal{H}(q_i) </math>H(qi) 表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i q_i </math>qi 的 Entrophy,其值的范围是 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0 , 1 ] [0, 1] </math>[0,1],值越大,表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i q_i </math>qi 越难预测。

现在我们得到了 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g p θ ( x i ) log p_{\theta}(x_i) </math>logpθ(xi) 的 evidence lower bound(ELBO)了,让我们仔细看看最大化这个 lower bound 意味着什么。 当我们最大化第一项 <math xmlns="http://www.w3.org/1998/Math/MathML"> E z ∼ q i ( z ) [ log ⁡ p θ ( x i ∣ z ) + log ⁡ p ( z ) ] \mathbb{E}{z\sim q_i(z)}[\log p{\theta}(x_i|z) + \log p(z)] </math>Ez∼qi(z)[logpθ(xi∣z)+logp(z)] 时,由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ p ( z ) \log p(z) </math>logp(z) 是固定的,我们仅仅在最大化 <math xmlns="http://www.w3.org/1998/Math/MathML"> E z ∼ q i ( z ) [ log ⁡ p θ ( x i ∣ z ) ] \mathbb{E}{z\sim q_i(z)}[\log p{\theta}(x_i|z)] </math>Ez∼qi(z)[logpθ(xi∣z)]。 这意味着从 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i q_i </math>qi 中采样 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 后,我们最大化 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 可以用来生成训练数据的平均概率。 当我们最大化第二项 <math xmlns="http://www.w3.org/1998/Math/MathML"> H ( q i ) \cal{H}(q_i) </math>H(qi) 时,我们增加 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i q_i </math>qi 的随机性,因此让采样出来的 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 可以覆盖更广的范围。

Lower Bound 是否 Tight

我们已经计算出了 lower bound。 接下来我们需要验证,这个 lower bound 是否足够 tight。 如果 lower bound 并不 tight 的话,最大化这个 lower bound 并不一定能最大化我们的目标函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g p θ ( x i ) log p_{\theta}(x_i) </math>logpθ(xi)。 我们用 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g p θ ( x i ) log p_{\theta}(x_i) </math>logpθ(xi) 减去这个 lower bound 看看能得到什么:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> l o g p θ ( x i ) − L i = l o g p θ ( x i ) − E z ∼ q i ( z ) [ log ⁡ p θ ( x i ∣ z ) + log ⁡ p ( z ) ] − H ( q i ) = l o g p θ ( x i ) − H ( q i ) − E z ∼ q i ( z ) [ log ⁡ p θ ( x i ∣ z ) + log ⁡ p ( z ) ] = E z ∼ q i ( z ) [ l o g p θ ( x i ) ] + E z ∼ q i ( z ) [ q i ( z ) ] − E z ∼ q i ( z ) [ log ⁡ p θ ( x i ∣ z ) + log ⁡ p ( z ) ] = E z ∼ q i ( z ) [ log ⁡ p θ ( x i ) + q i ( z ) − log ⁡ p θ ( x i ∣ z ) − log ⁡ p ( z ) ] = E z ∼ q i ( z ) [ log ⁡ p θ ( x i ) q i ( z ) p θ ( x i , z ) ] = E z ∼ q i ( z ) [ log ⁡ q i ( z ) p θ ( z ∣ x i ) ] = D K L ( q i ( z ) ∣ ∣ p θ ( z ∣ x i ) ) \begin{aligned} &log p_{\theta}(x_i) - \cal{L}i \\ &=log p{\theta}(x_i) - \mathbb{E}{z\sim q_i(z)}[\log p{\theta}(x_i|z) + \log p(z)] - \cal{H}(q_i) \\ &= log p_{\theta}(x_i) - \cal{H}(q_i) - \mathbb{E}{z\sim q_i(z)}[\log p{\theta}(x_i|z) + \log p(z)] \\ &= \mathbb{E}{z\sim q_i(z)}[log p{\theta}(x_i)] + \mathbb{E}{z\sim q_i(z)}[q_i(z)] - \mathbb{E}{z\sim q_i(z)}[\log p_{\theta}(x_i|z) + \log p(z)] \\ &= \mathbb{E}{z\sim q_i(z)}[\log p{\theta}(x_i) + q_i(z) - \log p_{\theta}(x_i|z) - \log p(z)] \\ &= \mathbb{E}{z\sim q_i(z)}[\log \frac{p{\theta}(x_i) q_i(z)}{p_{\theta}(x_i,z)}] \\ &= \mathbb{E}{z\sim q_i(z)}[\log \frac{ q_i(z)}{p{\theta}(z|x_i)}] \\ &= D_{KL}(q_i(z) || p_{\theta}(z|x_i)) \end{aligned} </math>logpθ(xi)−Li=logpθ(xi)−Ez∼qi(z)[logpθ(xi∣z)+logp(z)]−H(qi)=logpθ(xi)−H(qi)−Ez∼qi(z)[logpθ(xi∣z)+logp(z)]=Ez∼qi(z)[logpθ(xi)]+Ez∼qi(z)[qi(z)]−Ez∼qi(z)[logpθ(xi∣z)+logp(z)]=Ez∼qi(z)[logpθ(xi)+qi(z)−logpθ(xi∣z)−logp(z)]=Ez∼qi(z)[logpθ(xi,z)pθ(xi)qi(z)]=Ez∼qi(z)[logpθ(z∣xi)qi(z)]=DKL(qi(z)∣∣pθ(z∣xi))

<math xmlns="http://www.w3.org/1998/Math/MathML"> D K L D_{KL} </math>DKL 表示 Kullback--Leibler Divergence。 因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> D K L ( q i ( z ) ∣ ∣ p θ ( z ∣ x i ) ) ≥ 0 D_{KL}(q_i(z) || p_{\theta}(z|x_i)) \geq 0 </math>DKL(qi(z)∣∣pθ(z∣xi))≥0, 由上面的推导可知,当且仅当 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i ( z ) = p θ ( z ∣ x i ) q_i(z) = p_{\theta}(z|x_i) </math>qi(z)=pθ(z∣xi) 时,lower bound 是 tight 的。

根据上面的推导,我们还可以得出:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L i = l o g p θ ( x i ) − D K L ( q i ( z ) ∣ ∣ p θ ( z ∣ x i ) ) \cal{L}i = log p{\theta}(x_i) - D_{KL}(q_i(z) || p_{\theta}(z|x_i)) </math>Li=logpθ(xi)−DKL(qi(z)∣∣pθ(z∣xi))

这意味着,当我们通过 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i q_i </math>qi 来最大化 <math xmlns="http://www.w3.org/1998/Math/MathML"> L i \cal{L}i </math>Li 时, <math xmlns="http://www.w3.org/1998/Math/MathML"> D K L ( q i ( z ) ∣ ∣ p θ ( z ∣ x i ) ) D{KL}(q_i(z) || p_{\theta}(z|x_i)) </math>DKL(qi(z)∣∣pθ(z∣xi)) 会同时最小化,使得我们的 lower bound 变得 tighter。 因此,在优化过程中,我们不仅要最大化 lower bound, 我们同时还要使得 lower bound 变得 tighter:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> max ⁡ θ , q i 1 N ∑ i = 1 N L i \max_{\theta, q_i} \frac{1}{N} \sum_{i=1}^{N} \cal{L}_i </math>θ,qimaxN1i=1∑NLi

合并 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i ( z ) q_i(z) </math>qi(z)

对于每一个数据,我们独立定义了一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i q_i </math>qi。 这样做有个缺陷:在深度学习的任务中, <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 一般为一个比较大的值,比如说 ImageNet1000 就包含一百多万张训练图片,那么数量过多的 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i q_i </math>qi 可能会使得整个模型过于庞大。 因此,我们用一个 DNN 来模拟所有的 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i q_i </math>qi:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q i ( z ) = q ϕ ( z ∣ x i ) = N ( μ ϕ ( x i ) , σ ϕ ( x i ) ) q_i(z) = q_\phi(z|x_i) = \cal{N}(\mu_{\phi}(x_i), \sigma_{\phi}(x_i)) </math>qi(z)=qϕ(z∣xi)=N(μϕ(xi),σϕ(xi))

上面这个公式表示,我们把 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i ( z ) q_i(z) </math>qi(z) 定义为 Normal Distribution, 其 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ \mu </math>μ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ \sigma </math>σ 会根据每个 <math xmlns="http://www.w3.org/1998/Math/MathML"> x i x_i </math>xi 计算出来。 我们为什么想要用 <math xmlns="http://www.w3.org/1998/Math/MathML"> x i x_i </math>xi 来计算出 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i ( z ) q_i(z) </math>qi(z)? 一个 intuition 是我们最终目标是想让 <math xmlns="http://www.w3.org/1998/Math/MathML"> D K L ( q i ( z ) ∣ ∣ p θ ( z ∣ x i ) ) = 0 D_{KL}(q_i(z) || p_{\theta}(z|x_i))=0 </math>DKL(qi(z)∣∣pθ(z∣xi))=0,那么将 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i ( z ) q_i(z) </math>qi(z) 定义为类似于 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( z ∣ x i ) p_{\theta}(z|x_i) </math>pθ(z∣xi) 的形式,更有利于取得这个目标。 至于为什么要用 Normal Distribution,这是因为用它可以得到 closed-form 公式。

优化以及 The Reparameterization Trick

在 <math xmlns="http://www.w3.org/1998/Math/MathML"> L i \cal{L}i </math>Li 中把 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i ( z ) q_i(z) </math>qi(z) 替换成 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ϕ ( z ∣ x i ) q\phi(z|x_i) </math>qϕ(z∣xi) 后,我们的优化目标变成了:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L i = E z ∼ q ϕ ( z ∣ x i ) [ log ⁡ p θ ( x i ∣ z ) + log ⁡ p ( z ) ] + H ( q ϕ ( z ∣ x i ) ) = E z ∼ q ϕ ( z ∣ x i ) [ log ⁡ p θ ( x i ∣ z ) ] + log ⁡ p ( z ) − E z ∼ q ϕ ( z ∣ x i ) [ log ⁡ q ϕ ( z ∣ x i ) ] = E z ∼ q ϕ ( z ∣ x i ) [ log ⁡ p θ ( x i ∣ z ) ] − E z ∼ q ϕ ( z ∣ x i ) [ log ⁡ q ϕ ( z ∣ x i ) − log ⁡ p ( z ) ] = E z ∼ q ϕ ( z ∣ x i ) [ log ⁡ p θ ( x i ∣ z ) ] − D K L [ q ϕ ( z ∣ x i ) ∣ ∣ p ( z ) ] \begin{aligned} \cal{L}i &= \mathbb{E}{z\sim q_\phi(z|x_i)}[\log p_{\theta}(x_i|z) + \log p(z)] + \cal{H}(q_\phi(z|x_i)) \\ &= \mathbb{E}{z\sim q\phi(z|x_i)}[\log p_{\theta}(x_i|z)] + \log p(z) - \mathbb{E}{z\sim q\phi(z|x_i)}[\log q_\phi(z|x_i)] \\ &= \mathbb{E}{z\sim q\phi(z|x_i)}[\log p_{\theta}(x_i|z)] - \mathbb{E}{z\sim q\phi(z|x_i)}[\log q_\phi(z|x_i) - \log p(z)] \\ &= \mathbb{E}{z\sim q\phi(z|x_i)}[\log p_{\theta}(x_i|z)] - D_{KL}[q_\phi(z|x_i)||p(z)] \\ \end{aligned} </math>Li=Ez∼qϕ(z∣xi)[logpθ(xi∣z)+logp(z)]+H(qϕ(z∣xi))=Ez∼qϕ(z∣xi)[logpθ(xi∣z)]+logp(z)−Ez∼qϕ(z∣xi)[logqϕ(z∣xi)]=Ez∼qϕ(z∣xi)[logpθ(xi∣z)]−Ez∼qϕ(z∣xi)[logqϕ(z∣xi)−logp(z)]=Ez∼qϕ(z∣xi)[logpθ(xi∣z)]−DKL[qϕ(z∣xi)∣∣p(z)]

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ϕ ( z ∣ x i ) q_\phi(z|x_i) </math>qϕ(z∣xi) 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ( z ) = N ( 0 , 1 ) p(z)= \cal{N}(0,1) </math>p(z)=N(0,1) 都是 Normal Distribution,因此 <math xmlns="http://www.w3.org/1998/Math/MathML"> D K L [ q ϕ ( z ∣ x i ) ∣ ∣ p ( z ) ] D_{KL}[q_\phi(z|x_i)||p(z)] </math>DKL[qϕ(z∣xi)∣∣p(z)] 有 closed-form 公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> D K L [ q ϕ ( z ∣ x i ) ∣ ∣ p ( z ) ] = 1 2 ( σ ϕ ( x i ) 2 + μ ϕ ( x i ) 2 − 1 − log ⁡ σ ϕ ( x i ) 2 ) D_{KL}[q_\phi(z|x_i)||p(z)] = \frac{1}{2}(\sigma_{\phi}(x_i)^2+ \mu_{\phi}(x_i)^2 - 1 -\log \sigma_{\phi}(x_i)^2) </math>DKL[qϕ(z∣xi)∣∣p(z)]=21(σϕ(xi)2+μϕ(xi)2−1−logσϕ(xi)2)

我们尝试使用 Gradient Ascent 算法进行优化 <math xmlns="http://www.w3.org/1998/Math/MathML"> L i \cal{L}_i </math>Li ,具体如下:

  1. 取一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> x i x_i </math>xi (或者 mini-batch 个 <math xmlns="http://www.w3.org/1998/Math/MathML"> x i x_i </math>xi)
  2. 从 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ϕ ( z ∣ x i ) = N ( μ ϕ ( x i ) , σ ϕ ( x i ) ) q_\phi(z|x_i)=\cal{N}(\mu_{\phi}(x_i), \sigma_{\phi}(x_i)) </math>qϕ(z∣xi)=N(μϕ(xi),σϕ(xi)) 中采样出一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z
  3. <math xmlns="http://www.w3.org/1998/Math/MathML"> θ ← θ + α ∇ θ L i \theta \leftarrow \theta + \alpha\nabla_\theta\cal{L}_i </math>θ←θ+α∇θLi
  4. <math xmlns="http://www.w3.org/1998/Math/MathML"> ϕ ← ϕ + α ∇ ϕ L i \phi \leftarrow \phi + \alpha\nabla_\phi\cal{L}_i </math>ϕ←ϕ+α∇ϕLi

仔细看这个算法,我们会发现一个问题,这个算法会让 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ϕ ( z ∣ x i ) q_\phi(z|x_i) </math>qϕ(z∣xi) 简单的越来越趋近于 <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , 1 ) \cal{N}(0,1) </math>N(0,1)。 这是因为在第四步中, <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ ϕ D K L [ q ϕ ( z ∣ x i ) ∣ ∣ p ( z ) ] ≠ 0 \nabla_\phi D_{KL}[q_\phi(z|x_i)||p(z)] \neq 0 </math>∇ϕDKL[qϕ(z∣xi)∣∣p(z)]=0,而 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ ϕ E z ∼ q ϕ ( z ∣ x i ) [ log ⁡ p θ ( x i ∣ z ) ] = 0 \nabla_\phi \mathbb{E}{z\sim q\phi(z|x_i)}[\log p_{\theta}(x_i|z)]=0 </math>∇ϕEz∼qϕ(z∣xi)[logpθ(xi∣z)]=0。 这是不对的,我们也需要从 <math xmlns="http://www.w3.org/1998/Math/MathML"> E z ∼ q ϕ ( z ∣ x i ) [ log ⁡ p θ ( x i ∣ z ) ] \mathbb{E}{z\sim q\phi(z|x_i)}[\log p_{\theta}(x_i|z)] </math>Ez∼qϕ(z∣xi)[logpθ(xi∣z)] 中得到的信息来优化 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϕ \phi </math>ϕ。

以上错误是由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 是直接从 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ϕ ( z ∣ x i ) q_\phi(z|x_i) </math>qϕ(z∣xi) 中采样出来的,而采样的过程是 non-differentiable 的。 这里我们就需要 The Reparemeterization Trick 来正确计算出 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ ϕ E z ∼ q ϕ ( z ∣ x i ) [ log ⁡ p θ ( x i ∣ z ) ] \nabla_\phi \mathbb{E}{z\sim q\phi(z|x_i)}[\log p_{\theta}(x_i|z)] </math>∇ϕEz∼qϕ(z∣xi)[logpθ(xi∣z)] 了。 其策略是在第二步中用以下公式采样 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> z = μ ϕ ( x i ) + ϵ σ ϕ ( x i ) z = \mu_{\phi}(x_i) + \epsilon \sigma_{\phi}(x_i) </math>z=μϕ(xi)+ϵσϕ(xi)

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ ∼ N ( 0 , 1 ) \epsilon \sim \cal{N}(0, 1) </math>ϵ∼N(0,1)。 用这个公式采样 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z,我们就可以得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ ϕ E z ∼ q ϕ ( z ∣ x i ) [ log ⁡ p θ ( x i ∣ z ) ] ≠ 0 \nabla_\phi \mathbb{E}{z\sim q\phi(z|x_i)}[\log p_{\theta}(x_i|z)]\neq0 </math>∇ϕEz∼qϕ(z∣xi)[logpθ(xi∣z)]=0 了。

计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ p θ ( x i ∣ z ) \log p_\theta(x_i|z) </math>logpθ(xi∣z)

离代码实现前,我们还有最后一个问题没解决,那就是怎么计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ p θ ( x i ∣ z ) \log p_\theta(x_i|z) </math>logpθ(xi∣z)。 如前文所述, <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x i ∣ z ) p_\theta(x_i|z) </math>pθ(xi∣z) 表示用一个特定的 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 生成 <math xmlns="http://www.w3.org/1998/Math/MathML"> x i x_i </math>xi 的概率。 我们可以用 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ ( z ) f_{\theta}(z) </math>fθ(z) 把 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 转换为图片,但该怎么计算两张图片是否是同一张的概率呢? 实际情况是,目前我们还没有一个数学公式可以计算出这个结果。 在具体实现中,一个简单的办法是用 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ ( z ) f_{\theta}(z) </math>fθ(z) 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> x i x_i </math>xi 之间负的 Mean Squared Errors 替代 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ p θ ( x i ∣ z ) \log p_\theta(x_i|z) </math>logpθ(xi∣z):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> log ⁡ p θ ( x i ∣ z ) = − M S E ( f θ ( z ) , x i ) \log p_\theta(x_i|z) = -MSE(f_{\theta}(z), x_i) </math>logpθ(xi∣z)=−MSE(fθ(z),xi)

在 MSE 前面加上负号,那么其范围变成了 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( − ∞ , 0 ] (-\infty, 0] </math>(−∞,0]。 如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> − M S E -MSE </math>−MSE 越小,表示两个图片越不像。 如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> − M S E -MSE </math>−MSE 等于 0, 那么表示两张图片一模一样。 这样当我们最大化 <math xmlns="http://www.w3.org/1998/Math/MathML"> − M S E -MSE </math>−MSE 时,就使得生成的图片越来越像 <math xmlns="http://www.w3.org/1998/Math/MathML"> x i x_i </math>xi。

到此,我们的理论准备就完成了,接下来看看怎么用代码实现 VAE 的优化。

关键代码

我们参考一下 Vanilla VAE 的开源实现,其用的工具是 pytorch。 我们看看代码是怎么计算 VAE 的 loss 的。

python 复制代码
recons_loss = F.mse_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
loss = recons_loss + kld_weight * kld_loss

首先得说明的是,pytorch 的优化器,都是用来做最小化的,而我们需要最大化 <math xmlns="http://www.w3.org/1998/Math/MathML"> L i \cal{L}i </math>Li。 所以在 pytorch 里求 loss 的时候,我们需要求 <math xmlns="http://www.w3.org/1998/Math/MathML"> − L i -\cal{L}i </math>−Li,然后对其最小化:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> − L i = M S E ( f θ ( z ) , x i ) + D K L [ q ϕ ( z ∣ x i ) ∣ ∣ p ( z ) ] -\cal{L}i = MSE(f{\theta}(z), x_i) + D
{KL}[q
\phi(z|x_i)||p(z)] </math>−Li=MSE(fθ(z),xi)+DKL[qϕ(z∣xi)∣∣p(z)]

第一行的代码用了 pytorch 自带的函数求出 <math xmlns="http://www.w3.org/1998/Math/MathML"> M S E MSE </math>MSE。 第二行的 kld_loss 求出了 <math xmlns="http://www.w3.org/1998/Math/MathML"> D K L [ q ϕ ( z ∣ x i ) ∣ ∣ p ( z ) ] D_{KL}[q_\phi(z|x_i)||p(z)] </math>DKL[qϕ(z∣xi)∣∣p(z)]。 第三行将 recons_loss 和 kld_loss 进行加权求和(在我们的公式中 <math xmlns="http://www.w3.org/1998/Math/MathML"> k l d _ w e i g h t = 1 kld\_weight = 1 </math>kld_weight=1),其结果就是最终的 loss。

更多的思考以及后面的文章

基础版本的 VAE 虽然可以用于生成图片,但其效果并不是非常优秀。 主要原因有两个:

  1. 我们为了得到 closed-form 公式,假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ϕ ( z ∣ x i ) q_\phi(z|x_i) </math>qϕ(z∣xi) 为一个 Normal Distribution,但这可能并不符合实际情况。
  2. 我们用 <math xmlns="http://www.w3.org/1998/Math/MathML"> − M S E ( f θ ( z ) , x i ) -MSE(f_{\theta}(z), x_i) </math>−MSE(fθ(z),xi) 来替代 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ p θ ( x i ∣ z ) \log p_\theta(x_i|z) </math>logpθ(xi∣z) 是非常不符合视觉机制的。比如我们把图片平移几个像素, <math xmlns="http://www.w3.org/1998/Math/MathML"> − M S E ( f θ ( z ) , x i ) -MSE(f_{\theta}(z), x_i) </math>−MSE(fθ(z),xi) 可能会变得非常小,但是理论上来说,轻微平移后的图片和原图片应大概率是同一张图片。

目前 Generative Models 已经霸占了 AI 领域的半壁江山。 我计划写一个系列的文章来梳理 Generative Models 中比较难懂的理论知识。 这样不仅仅可以极大的帮助我自己的研究工作,对于想研究 Generative Models 的朋友,希望也会有所帮助。

相关推荐
这个男人是小帅30 分钟前
【GAT】 代码详解 (1) 运行方法【pytorch】可运行版本
人工智能·pytorch·python·深度学习·分类
__基本操作__32 分钟前
边缘提取函数 [OPENCV--2]
人工智能·opencv·计算机视觉
Doctor老王37 分钟前
TR3:Pytorch复现Transformer
人工智能·pytorch·transformer
热爱生活的五柒37 分钟前
pytorch中数据和模型都要部署在cuda上面
人工智能·pytorch·深度学习
HyperAI超神经3 小时前
【TVM 教程】使用 Tensorize 来利用硬件内联函数
人工智能·深度学习·自然语言处理·tvm·计算机技术·编程开发·编译框架
扫地的小何尚4 小时前
NVIDIA RTX 系统上使用 llama.cpp 加速 LLM
人工智能·aigc·llama·gpu·nvidia·cuda·英伟达
埃菲尔铁塔_CV算法7 小时前
深度学习神经网络创新点方向
人工智能·深度学习·神经网络
艾思科蓝-何老师【H8053】7 小时前
【ACM出版】第四届信号处理与通信技术国际学术会议(SPCT 2024)
人工智能·信号处理·论文发表·香港中文大学
weixin_452600697 小时前
《青牛科技 GC6125:驱动芯片中的璀璨之星,点亮 IPcamera 和云台控制(替代 BU24025/ROHM)》
人工智能·科技·单片机·嵌入式硬件·新能源充电桩·智能充电枪
学术搬运工7 小时前
【珠海科技学院主办,暨南大学协办 | IEEE出版 | EI检索稳定 】2024年健康大数据与智能医疗国际会议(ICHIH 2024)
大数据·图像处理·人工智能·科技·机器学习·自然语言处理