VAE原理及代码实现

VAE

文章目录

前置知识

后验概率

(1)已知结果,根据结果估计原因的概率分布。

公式表示为,
P ( θ ∣ x ) P(\theta|x) P(θ∣x)
θ \theta θ表示事情发生的原因, x x x表示事情发生的结果。该式表示 x x x发生后 θ \theta θ的概率。

(2)已知原因,根据原因估计结果的概率分布。

公式表示为,
P ( x ∣ θ ) P(x|\theta) P(x∣θ)
θ \theta θ表示事情发生的原因, x x x表示事情发生的结果。该式表示 θ \theta θ发生后 x x x的概率。

先验概率:未知结果(在结果发生之前),根据历史规律确定原因的概率分布。

公式表示为,
P ( θ ) P(\theta) P(θ)
θ \theta θ表示事情发生的原因。

贝叶斯公式
P ( θ ∣ x ) = P ( x ∣ θ ) ∗ p ( θ ) p ( x ) P(\theta|x)=\frac{P(x|\theta)*p(\theta)}{p(x)} P(θ∣x)=p(x)P(x∣θ)∗p(θ)
p ( x ) p(x) p(x)表示不考虑原因,只看结果的概率分布。

似然概率:似然函数是一种关于统计模型参数的函数。

已知输出 x x x,关于模型参数变量 θ \theta θ的似然函数为 L ( θ ∣ x ) L(\theta|x) L(θ∣x)

等价于 已知模型参数 θ \theta θ,变量 x x x的概率为 P ( x ∣ θ ) P(x|\theta) P(x∣θ)

ML/MAP

给定一些数据样本 x x x ,假定我们知道样本是从某一种分布中随机取出的,但我们不知道这个分布具体的参数 θ \theta θ.

++最大似然估计(ML)++ 可以估计模型的参数,公式表示如下,
L ( θ ∣ x ) = a r g m a x θ P ( x ∣ θ ) L(\theta|x)=\underset{\theta}{argmax}P(x|\theta) L(θ∣x)=θargmaxP(x∣θ)
++最大后验估计(MAP)++ ,可以直接由贝叶斯公式给出,
a r g m a x θ P ( θ ∣ x ) = a r g m a x θ P ( x ∣ θ ) ∗ p ( θ ) p ( x ) \underset{\theta}{argmax}P(\theta|x)=\underset{\theta}{argmax}\frac{P(x|\theta)*p(\theta)}{p(x)} θargmaxP(θ∣x)=θargmaxp(x)P(x∣θ)∗p(θ)

因为给定样本 x x x后,$p(x) $会在 θ \theta θ空间上为一个定值,和 θ \theta θ的大小没有关系,所以可以省略分母 $p(x) $。

高斯混合模型:高斯混合模型可以看作是由 K 个单高斯模型组合而成的模型,这 K 个子模型是混合模型的隐变量(Hidden variable)。一般来说,一个混合模型可以使用任何概率分布,这里使用高斯混合模型是因为高斯分布具备很好的数学性质以及良好的计算性能。

假设 x j x_j xj表示第 j j j个观测数据, j = 1 , 2 , 3... , N j=1,2,3...,N j=1,2,3...,N,

K K K是混合模型中子高斯模型的数量, k = 1 , 2 , . . . , K k=1,2,...,K k=1,2,...,K

α k \alpha_k αk是观测数据属于第 k k k个子模型的概率, α k ≥ 0 \alpha_k \geq 0 αk≥0, ∑ k = 1 K α k = 1 \sum_{k=1}^{K}\alpha_k=1 ∑k=1Kαk=1

ψ ( x ∣ θ k ) \psi(x|\theta_k) ψ(x∣θk)是第 k k k个子模型的高斯分布密度函数, θ k = ( μ k , σ k 2 ) \theta_k=(\mu_k,\sigma^2_k) θk=(μk,σk2)

故高斯混合模型的概率分布为,
P ( x ∣ θ ) = ∑ k = 1 K α k ψ ( x ∣ θ k ) P(x|\theta)=\sum_{k=1}^{K}\alpha_k\psi(x|\theta_k) P(x∣θ)=k=1∑Kαkψ(x∣θk)

基本介绍

原理

VAE中文为变分自编码器。主要由两部分组成,编码器网络(推断网络)和解码器网络(生成网络)。

基本思路是:把一堆真实样本通过编码器网络变换成一个理想的数据分布,然后这个数据分布再传递给一个解码器网络,得到一堆生成样本,生成样本与真实样本足够接近的话,就训练出了一个自编码器模型。

我们使用神经网络来替换编码网络和解码网络(AE但还不是VAE),替换的明显好处是,引入了神经网络强大的拟合能力,使得编码的维度能够比原始图像的维度低非常多。

采用神经网络之后,对于一个生成模型而言,还需要达到两个标准,

  1. 编码器和解码器部分应该是单独能够提取出来的(独立拆分的)
  2. 对于在规定维度下任意采样的一个编码,都应该能通过解码器产生一张清晰且真实的图片。
AE的局限(参考网络)

如下图所示,我们用一张全月图和一张半月图去训练一个AE,经过训练,模型能够很好地还原出这两张图片。接下来,我们在latent code上中间一点,即两张图片编码点中间处任取一点,将这点交给解码器进行解码,直觉上我们会得到一张介于全月图和半月图之间的图片(比如阴影面积覆盖3/4的样子)。然而,实际当你那这个点去decode的时候你会发现AE还原出来的图片不仅模糊而且还是乱码的。

为什么会出现这种现象?一个直观上的解释是AE的Encoder和Decoder都使用了DNN,DNN是一个非线性的变换过程,因此在latent space上点与点之间transform往往没有规律可循。

如何解决这个问题呢?一个思想就是引入噪声,扩大图片的编码区域,从而能够覆盖到失真的空白编码区。其实说白了就是通过增加输入的多样性从而增强输出的鲁棒性。当我们给输入图片进行编码之前引入一点噪声,使得每张图片的编码点出现在绿色箭头范围内,这样一来所得到的latent space就能覆盖到更多的编码点。此时我们再从中间点抽取去还原便可以得到一个我们比较希望得到的输出,如下所示:

虽然我们为输入图片增添了一些噪声使得latent space能够覆盖到比较多的区域,但是还是有不少地方没有被覆盖到,比如上图右边黄色的部分因为离得比较远所以就没编码到。因此,我们是不是可以尝试利用更多的噪音,使得对于每一个输入样本,它的编码都能够覆盖到整个编码空间?只不过这里我们需要保证的是,对于源编码附近的编码我们应该给定一个高的概率值,而对于距离原编码点距离较远的,我们应该给定一个低的概率值。没错,总体来说,我们就是要将原先一个单点拉伸到整个编码空间,即将离散的编码点引申为一条连续的接近正态分布的编码曲线,如下所示:

上述的这种将图像编码由离散变为连续的方法,就是变分自编码的核心思想。

数学推导

VAE的模型架构如下图所示,

在AE中,编码器是直接产生一个编码的,但是在VAE中,为了给编码添加合适的噪音,编码器会输出两个编码,一个是原有编码 μ \mu μ,另外一个是控制噪音干扰程度的编码 σ \sigma σ,第二个编码其实很好理解,就是为随机噪音码 ϵ \epsilon ϵ分配权重,然后加上 e x p ( σ i ) exp(σ_i) exp(σi)的目的是为了保证这个分配的权重是个正值,最后将原编码与噪音编码相加,就得到了VAE在code层的输出结果 z z z。

损失函数方面,除了必要的重构损失外,VAE还增添了一个损失函数(见上图Minimize2内容),这同样是必要的部分,因为如果不加的话,整个模型就会出现问题:为了保证生成图片的质量越高,编码器肯定希望噪音对自身生成图片的干扰越小,于是分配给噪音的权重越小,这样只需要将 σ σ σ赋为接近负无穷大的值就好了。所以,第二个损失函数就有限制编码器走这样极端路径的作用,这也从图像上就能看出来,exp(σi)-(1+σi)在σi=0处取得最小值,于是 σ σ σ就会避免被赋值为负无穷大。

我们假设经过整个VAE网络之后的数据 x x x服从分布为 P ( x ) P(x) P(x), M M M为 P ( x ) P(x) P(x)分布可以分解为的子高斯函数的个数, M M M服从一个概率分布 P ( m ) P(m) P(m), P ( x ∣ m ) P(x|m) P(x∣m)表示对于某个采样m对应的高斯分布概率密度函数,对应高斯混合模型有,
P ( x ) = ∑ m P ( m ) P ( x ∣ m ) P(x)=\sum_{m}P(m)P(x|m) P(x)=m∑P(m)P(x∣m)

进一步,我们将离散的变量 m m m换成一个连续的变量 z z z, z z z服从正太分布 N ( 0 , 1 ) N(0,1) N(0,1), x ∣ z x|z x∣z服从分布高斯分布 N ( μ ( z ) ∣ σ ( z ) ) N(\mu(z)|\sigma(z)) N(μ(z)∣σ(z)),对应高斯混合模型,
P ( x ) = ∫ z P ( z ) P ( x ∣ z ) d z P(x)=\int_{z}P(z)P(x|z)dz P(x)=∫zP(z)P(x∣z)dz

原则上,我们希望 P ( x ) P(x) P(x)这个分布越大越好,使得能够覆盖到更多的区域,根据极大似然估计法,
L ( θ ∣ x ) = P ( x ∣ θ ) = ∏ i = 1 n P ( x i ∣ θ ) M a x i m u m L ( θ ∣ x ) = a r g m a x θ ∏ i = 1 n P ( x i ∣ θ ) 转化成 l o g 函数的目的是 = a r g m a x θ l o g ∏ i = 1 n P ( x i ∣ θ ) 避免下溢出和加法方便运算 = a r g m a x θ ∑ i = 1 n l o g P ( x i ∣ θ ) 这一步运用了期望和贝叶斯公式 = a r g m a x θ ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) P ( z ∣ x ) ) d z 分子分母同时乘一个 q ( z ∣ x ) = a r g m a x θ ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) P ( z ∣ x ) q ( z ∣ x ) ) d z 右边式子同 K L 散度 = a r g m a x θ ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z + ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( z ∣ x ) ) d z = a r g m a x θ ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z + K L ( q ( z ∣ x ) ∣ ∣ P ( z ∣ x ) ) \begin{aligned} L(\theta|x)=P(x|\theta)&=\prod_{i=1}^{n}P(x_i|\theta)\\ Maximum L(\theta|x)&=\underset{\theta}{argmax}\prod_{i=1}^{n}P(x_i|\theta)\\ 转化成log函数的目的是&=\underset{\theta}{argmax}~log\prod_{i=1}^{n}P(x_i|\theta)\\ 避免下溢出和加法方便运算&=\underset{\theta}{argmax} \sum_{i=1}^nlogP(x_i|\theta)\\ 这一步运用了期望和贝叶斯公式&=\underset{\theta}{argmax}\int_{z}q(z|x)log(\frac{P(x|z)P(z)}{P(z|x)})dz\\ 分子分母同时乘一个q(z|x)&=\underset{\theta}{argmax}\int_{z}q(z|x)log(\frac{P(x|z)P(z)~q(z|x)}{P(z|x)~q(z|x)})dz\\ 右边式子同KL散度&=\underset{\theta}{argmax}\int_{z}q(z|x)log(\frac{P(x|z)P(z)}{q(z|x)})dz+\int_{z}q(z|x)log(\frac{q(z|x)}{P(z|x)})dz\\ &=\underset{\theta}{argmax}\int_{z}q(z|x)log(\frac{P(x|z)P(z)}{q(z|x)})dz+KL(q(z|x)||P(z|x)) \end{aligned} L(θ∣x)=P(x∣θ)MaximumL(θ∣x)转化成log函数的目的是避免下溢出和加法方便运算这一步运用了期望和贝叶斯公式分子分母同时乘一个q(z∣x)右边式子同KL散度=i=1∏nP(xi∣θ)=θargmaxi=1∏nP(xi∣θ)=θargmax logi=1∏nP(xi∣θ)=θargmaxi=1∑nlogP(xi∣θ)=θargmax∫zq(z∣x)log(P(z∣x)P(x∣z)P(z))dz=θargmax∫zq(z∣x)log(P(z∣x) q(z∣x)P(x∣z)P(z) q(z∣x))dz=θargmax∫zq(z∣x)log(q(z∣x)P(x∣z)P(z))dz+∫zq(z∣x)log(P(z∣x)q(z∣x))dz=θargmax∫zq(z∣x)log(q(z∣x)P(x∣z)P(z))dz+KL(q(z∣x)∣∣P(z∣x))

上面式子右边那一项为这两个分布的KL散度距离,根据KL散度公式的性质,我们可以知道右边项恒大于等于0,于是我们找到了 的一个下确界:
K L ( q ( z ∣ x ) ∣ ∣ P ( z ∣ x ) ) ≥ 0 L ( θ ∣ x ) ≥ ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z KL(q(z|x)||P(z|x))\geq0\\ L(\theta|x)\geq\int_{z}q(z|x)log(\frac{P(x|z)P(z)}{q(z|x)})dz KL(q(z∣x)∣∣P(z∣x))≥0L(θ∣x)≥∫zq(z∣x)log(q(z∣x)P(x∣z)P(z))dz

我们将这个下界记为,
L b = ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z L_b=\int_{z}q(z|x)log(\frac{P(x|z)P(z)}{q(z|x)})dz Lb=∫zq(z∣x)log(q(z∣x)P(x∣z)P(z))dz

代入原式为,
L ( θ ∣ x ) = L b + K L ( q ( z ∣ x ) ∣ ∣ P ( z ∣ x ) ) L(\theta|x)=L_b+KL(q(z|x)||P(z|x)) L(θ∣x)=Lb+KL(q(z∣x)∣∣P(z∣x))

由上式可知,当我们固定住 P ( x ∣ z ) P(x|z) P(x∣z)时,因为 l o g P ( x ) logP(x) logP(x)只与 P ( x ∣ z ) 有关 P(x|z)有关 P(x∣z)有关(根据高斯混合模型),所以 l o g P ( x ) logP(x) logP(x)的值是会不变的,此时我们去调节 q ( z ∣ x ) q(z|x) q(z∣x),使得 L b L_b Lb越来越高,同时 K L KL KL散度越来越小,当我们调节到​ q ( z ∣ x ) q(z|x) q(z∣x)与 P ( z ∣ x ) P(z|x) P(z∣x)完全一致时, K L KL KL散度就消失为 0 0 0, L b L_b Lb与 l o g P ( x ) logP(x) logP(x)完全一致。由此可以得出,不论 l o g P ( x ) logP(x) logP(x)的值如何,我们总能够通过调节使得 L b L_b Lb等于$ logP(x) ,又因为 ,又因为 ,又因为 L b L_b Lb 是 是 是logP(x) 的下界,所以求解 的下界,所以求解 的下界,所以求解 M a x i m u m l o g P ( x ) Maximum logP(x) MaximumlogP(x) 等价为求解 等价为求解 等价为求解 M a x i m u m L b Maximum Lb MaximumLb 。对 。 对 。对L_b$进行化解,
L b = ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( P ( z ) q ( z ∣ x ) ) d z + ∫ z q ( z ∣ x ) l o g P ( x ∣ z ) d z = − K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) + ∫ z q ( z ∣ x ) l o g P ( x ∣ z ) d z \begin{aligned} L_b&=\int_{z}q(z|x)log(\frac{P(x|z)P(z)}{q(z|x)})dz\\ &=\int_{z}q(z|x)log(\frac{P(z)}{q(z|x)})dz+\int_{z}q(z|x)logP(x|z)dz\\ &=-KL(q(z|x)||P(z))+\int_{z}q(z|x)logP(x|z)dz \end{aligned} Lb=∫zq(z∣x)log(q(z∣x)P(x∣z)P(z))dz=∫zq(z∣x)log(q(z∣x)P(z))dz+∫zq(z∣x)logP(x∣z)dz=−KL(q(z∣x)∣∣P(z))+∫zq(z∣x)logP(x∣z)dz

已知, P ( z ) P(z) P(z)是服从标准正太分布的, q ( z ∣ x ) q(z|x) q(z∣x)为任意某种分布,
P ( z ) = N ( 0 , 1 ) q ( z ∣ x ) = N ( z , μ , σ 2 ) P(z)=N(0,1)\\ q(z|x)=N(z,\mu,\sigma^2) P(z)=N(0,1)q(z∣x)=N(z,μ,σ2)

然后我们将 L b L_b Lb分解为,
L b = L 1 + L 2 L_b=L_1+L_2 Lb=L1+L2

其中, L 1 = − K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) L_1=-KL(q(z|x)||P(z)) L1=−KL(q(z∣x)∣∣P(z)), L 2 = ∫ z q ( z ∣ x ) l o g P ( x ∣ z ) d z L_2=\int_{z}q(z|x)logP(x|z)dz L2=∫zq(z∣x)logP(x∣z)dz

对于 L 1 L_1 L1有,
L 1 = ∫ z q ( z ∣ x ) l o g ( P ( z ) ) d z − ∫ z q ( z ∣ x ) l o g q ( z ∣ x ) d z \begin{aligned} L_1&=\int_{z}q(z|x)log(P(z))dz-\int_{z}q(z|x)log~q(z|x)dz\\ \end{aligned} L1=∫zq(z∣x)log(P(z))dz−∫zq(z∣x)log q(z∣x)dz

对于 L 1 L_1 L1的第一项有,
∫ z q ( z ∣ x ) l o g ( P ( z ) ) d z = ∫ z N ( z , μ , σ 2 ) l o g N ( 0 , 1 ) d z = E z ∼ N ( μ , σ 2 ) [ l o g N ( 0 , 1 ) ] = E z ∼ N ( μ , σ 2 ) [ l o g ( 1 2 π e ( z 2 2 ) ) ] = − 1 2 l o g 2 π − 1 / 2 E z ∼ N ( μ , σ 2 ) [ z 2 ] = − 1 2 l o g 2 π − 1 2 ( μ 2 + σ 2 ) \begin{aligned} \int_{z}q(z|x)log(P(z))dz&=\int_{z}N(z,\mu,\sigma^2)logN(0,1)dz\\ &=E_{z\sim{N(\mu,\sigma^2)}}[logN(0,1)]\\ &=E_{z\sim{N(\mu,\sigma^2)}}[log(\frac{1}{\sqrt{2\pi}}e^{(\frac{z^2}{2})})]\\ &=-\frac{1}{2}log2\pi-1/2E_{z\sim{N(\mu,\sigma^2)}}[z^2]\\ &=-\frac{1}{2}log2\pi-\frac{1}{2}(\mu^2+\sigma^2) \end{aligned} ∫zq(z∣x)log(P(z))dz=∫zN(z,μ,σ2)logN(0,1)dz=Ez∼N(μ,σ2)[logN(0,1)]=Ez∼N(μ,σ2)[log(2π 1e(2z2))]=−21log2π−1/2Ez∼N(μ,σ2)[z2]=−21log2π−21(μ2+σ2)

对于 L 1 L_1 L1的第二项有,
∫ z q ( z ∣ x ) l o g q ( z ∣ x ) d z = ∫ z N ( z , μ , σ 2 ) l o g N ( z , μ , σ 2 ) d z = E z ∼ N ( μ , σ 2 ) [ l o g N ( z , μ , σ 2 ) ] = E z ∼ N ( μ , σ 2 ) [ l o g 1 2 π σ 2 e ( z − μ ) 2 2 σ 2 ] = − 1 2 l o g 2 π − 1 2 l o g σ 2 − 1 2 σ 2 E z ∼ N ( μ , σ 2 ) [ ( z − μ ) 2 ] = − 1 2 l o g 2 π − 1 2 ( l o g σ 2 + 1 ) \begin{aligned} \int_{z}q(z|x)log~q(z|x)dz&=\int_{z}N(z,\mu,\sigma^2)logN(z,\mu,\sigma^2)dz\\ &=E_{z\sim{N(\mu,\sigma^2)}}[logN(z,\mu,\sigma^2)]\\ &=E_{z\sim{N(\mu,\sigma^2)}}[log\frac{1}{\sqrt{2\pi\sigma^2}}e^{\frac{(z-\mu)^2}{2\sigma^2}}]\\ &=-\frac{1}{2}log2\pi-\frac{1}{2}log\sigma^2-\frac{1}{2\sigma^2}E_{z\sim N(\mu,\sigma^2)}[(z-\mu)^2]\\ &=-\frac{1}{2}log2\pi-\frac{1}{2}(log\sigma^2+1) \end{aligned} ∫zq(z∣x)log q(z∣x)dz=∫zN(z,μ,σ2)logN(z,μ,σ2)dz=Ez∼N(μ,σ2)[logN(z,μ,σ2)]=Ez∼N(μ,σ2)[log2πσ2 1e2σ2(z−μ)2]=−21log2π−21logσ2−2σ21Ez∼N(μ,σ2)[(z−μ)2]=−21log2π−21(logσ2+1)

我们用第一项减去第二项有,
L 1 = 1 2 ∑ j − 1 J [ 1 + l o g ( σ j 2 ) − μ j 2 − σ j 2 ] L_1=\frac{1}{2}\sum_{j-1}^{J}[1+log(\sigma_j^2)-\mu_j^2-\sigma_j^2] L1=21j−1∑J[1+log(σj2)−μj2−σj2]

对于 L 2 L_2 L2,由蒙特卡洛方法计算得到,
L 2 = ∫ z q ( z ∣ x ) l o g P ( x ∣ z ) d z = E q ( z ∣ x ) l o g P ( x ∣ z ) ≈ 1 L ∑ l = 1 L l o g ( P ( x ( i ) ∣ z ( i , l ) ) ) L_2=\int_{z}q(z|x)logP(x|z)dz=E_{q(z|x)}logP(x|z) \approx\frac{1}{L}\sum_{l=1}^{L}log(P(x^{(i)}|z^{(i,l)})) L2=∫zq(z∣x)logP(x∣z)dz=Eq(z∣x)logP(x∣z)≈L1l=1∑Llog(P(x(i)∣z(i,l)))

其中 x ( i ) x^{(i)} x(i)是从真实数据中采样得到的第 i i i个数据;以 x ( i ) x^{(i)} x(i)作为Encoder的输入,随后从编码器 q ( z ∣ x ( i ) ) q(z|x^{(i)}) q(z∣x(i))中抽取 L L L个数据 z ( i , l ) z^{(i,l)} z(i,l)。实际上就是调整Decoder( 即P(x\|z) ),使得以 ),使得以 ),使得以x\^{(i)} 作为 E n c o d e r 的输入,编码采样得到多个 作为Encoder的输入,编码采样得到多个 作为Encoder的输入,编码采样得到多个 z\^{(i,l)} ,最后能用Decoder最大概率地从 z ( i , l ) z^{(i,l)} z(i,l)中恢复出 x ( i ) x^{(i)} x(i)

故,
L b = L 1 + L 2 = 1 2 ∑ j − 1 J [ 1 + l o g ( σ j 2 ) − μ j 2 − σ j 2 ] + 1 L ∑ l = 1 L l o g ( P ( x ( i ) ∣ z ( i , l ) ) ) \begin{aligned} L_b&=L_1+L_2\\ &=\frac{1}{2}\sum_{j-1}^{J}[1+log(\sigma_j^2)-\mu_j^2-\sigma_j^2]+\frac{1}{L}\sum_{l=1}^{L}log(P(x^{(i)}|z^{(i,l)})) \end{aligned} Lb=L1+L2=21j−1∑J[1+log(σj2)−μj2−σj2]+L1l=1∑Llog(P(x(i)∣z(i,l)))

M a x i m u m L b = M a x i m u m { 1 2 ∑ j − 1 J [ 1 + l o g ( σ j 2 ) − μ j 2 − σ j 2 ] + 1 L ∑ l = 1 L l o g ( P ( x ( i ) ∣ z ( i , l ) ) ) } Maximum~ Lb=Maximum~\{\frac{1}{2}\sum_{j-1}^{J}[1+log(\sigma_j^2)-\mu_j^2-\sigma_j^2]+\frac{1}{L}\sum_{l=1}^{L}log(P(x^{(i)}|z^{(i,l)}))\} Maximum Lb=Maximum {21j−1∑J[1+log(σj2)−μj2−σj2]+L1l=1∑Llog(P(x(i)∣z(i,l)))}

综上,我们再来回顾一个变量的定义, x x x表示输入的真实数据, P ( x ) P(x) P(x)表示整个VAE系统产生数据 x x x的概率为 P ( x ) P(x) P(x),我们在求取 P ( x ) P(x) P(x)的最大值过程中,加入了一个分布 q ( z ∣ x ) q(z|x) q(z∣x),这个 q ( z ∣ x ) q(z|x) q(z∣x)我们就可以看做是,当真实数据 x x x通过编码器网络输出编码 z z z的概率, P ( z ) P(z) P(z)表示子高斯模型(标准正太分布)随机采样得到编码 z z z的概率, P ( x ∣ z ) P(x|z) P(x∣z)就表示我们随机采样的编码 z z z通过解码器网络输出数据 x x x的概率。

对于整个过程,我个人的理解是,在编码阶段,由于 P ( x ∣ z ) P(x|z) P(x∣z)只与解码器有关,因此 P ( x ∣ z ) P(x|z) P(x∣z)在这个阶段可以看做是固定的,因此 l o g P ( x ) logP(x) logP(x)也就是固定的,这时候我们主要调整编码网络 q ( z ∣ x ) q(z|x) q(z∣x)使得 K L KL KL为0,让求取 l o g P ( x ) logP(x) logP(x)的最大值转化为求取 L b L_b Lb的最大值、

在解码阶段,我们可以调整的是 P ( x ∣ z ) P(x|z) P(x∣z),并且 L b = L 1 + L 2 L_b=L_1+L_2 Lb=L1+L2,其中第一项 L 1 L_1 L1的 q ( z ∣ x ) q(z|x) q(z∣x)已经在编码阶段确定了, P ( z ) P(z) P(z)也是已知的。只有 L 2 L_2 L2与 P ( x ∣ z ) P(x|z) P(x∣z)有关,因此我们可以调整 P ( x ∣ z ) P(x|z) P(x∣z)使得第二项 L 2 L_2 L2增大,因此来增大 L b L_b Lb。

代码实现

我们采用手写数字MNIST数据集来验证,将原始维度为[batch,1,28,28],将其展平为[batch,784]。

编码网络
python 复制代码
# 定义编码器(Encoder)类
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)  # 全连接层,将输入特征映射到隐藏层
        self.fc2_mu = nn.Linear(hidden_dim, latent_dim)  # 全连接层,将隐藏层特征映射到潜在空间均值
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)  # 全连接层,将隐藏层特征映射到潜在空间的log方差
        self.relu = nn.ReLU()  # ReLU激活函数

    def forward(self, x):
        h = self.relu(self.fc1(x))  # 隐藏层特征
        mu = self.fc2_mu(h)  # 潜在空间均值
        logvar = self.fc2_logvar(h)  # 潜在空间的log方差
        return mu, logvar
解码网络
python 复制代码
# 定义解码器(Decoder)类
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)  # 全连接层,将潜在空间特征映射到隐藏层
        self.fc2 = nn.Linear(hidden_dim, output_dim)  # 全连接层,将隐藏层特征映射到输出空间
        self.relu = nn.ReLU()  # ReLU激活函数
        self.sigmoid = nn.Sigmoid()  # Sigmoid激活函数

    def forward(self, z):
        h = self.relu(self.fc1(z))  # 隐藏层特征
        x_recon = self.sigmoid(self.fc2(h))  # 重建输入
        return x_recon
综合代码
python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 定义编码器(Encoder)类
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)  # 全连接层,将输入特征映射到隐藏层
        self.fc2_mu = nn.Linear(hidden_dim, latent_dim)  # 全连接层,将隐藏层特征映射到潜在空间均值
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)  # 全连接层,将隐藏层特征映射到潜在空间的log方差
        self.relu = nn.ReLU()  # ReLU激活函数

    def forward(self, x):
        h = self.relu(self.fc1(x))  # 隐藏层特征
        mu = self.fc2_mu(h)  # 潜在空间均值
        logvar = self.fc2_logvar(h)  # 潜在空间的log方差
        return mu, logvar

# 定义解码器(Decoder)类
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)  # 全连接层,将潜在空间特征映射到隐藏层
        self.fc2 = nn.Linear(hidden_dim, output_dim)  # 全连接层,将隐藏层特征映射到输出空间
        self.relu = nn.ReLU()  # ReLU激活函数
        self.sigmoid = nn.Sigmoid()  # Sigmoid激活函数

    def forward(self, z):
        h = self.relu(self.fc1(z))  # 隐藏层特征
        x_recon = self.sigmoid(self.fc2(h))  # 重建输入
        return x_recon

# 定义VAE(Variational Autoencoder)类
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)  # 实例化编码器
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)  # 实例化解码器

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # 计算标准差
        eps = torch.randn_like(std)  # 从标准正态分布采样
        return mu + eps * std  # 重参数化技巧

    def forward(self, x):
        mu, logvar = self.encoder(x)  # 编码器输出均值和log方差
        z = self.reparameterize(mu, logvar)  # 重参数化
        x_recon = self.decoder(z)  # 解码器重建输入
        return x_recon, mu, logvar

# 设置参数
input_dim = 784  # 输入维度(28x28图像展开为784维)
hidden_dim = 400  # 隐藏层维度
latent_dim = 20  # 潜在空间维度
batch_size = 128  # 批量大小
learning_rate = 0.001  # 学习率
num_epochs = 50  # 训练周期

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))# 归一化到 [-1, 1]
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 实例化VAE模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(input_dim, hidden_dim, latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # Adam优化器

# 损失函数(主要公式)
def loss_function(x_recon, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(x_recon, torch.sigmoid(x), reduction='sum')  # 重建损失
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())  # KL散度损失
    return BCE + KLD

# 训练VAE模型
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, input_dim).to(device)  # 将数据展平并移动到设备上
        optimizer.zero_grad()
        x_recon, mu, logvar = model(data)  # 前向传播
        loss = loss_function(x_recon, data, mu, logvar)  # 计算损失
        loss.backward()  # 反向传播
        train_loss += loss.item()
        optimizer.step()  # 更新参数

    print(f'Epoch {epoch+1}, Loss: {train_loss/len(train_loader.dataset):.4f}')

# 保存编码器和解码器模型
torch.save(model.encoder.state_dict(), 'vae_encoder.pth')
torch.save(model.decoder.state_dict(), 'vae_decoder.pth')

print("训练和保存完成")

由于训练时间较久,我们训练50轮看看效果即可,

接着我们用自己的图片来测试一下生成的模型的效果,代码如下,

python 复制代码
import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


# 定义编码器(Encoder)类
from torch import nn


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)  # 全连接层,将输入特征映射到隐藏层
        self.fc2_mu = nn.Linear(hidden_dim, latent_dim)  # 全连接层,将隐藏层特征映射到潜在空间均值
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)  # 全连接层,将隐藏层特征映射到潜在空间的log方差
        self.relu = nn.ReLU()  # ReLU激活函数

    def forward(self, x):
        h = self.relu(self.fc1(x))  # 隐藏层特征
        mu = self.fc2_mu(h)  # 潜在空间均值
        logvar = self.fc2_logvar(h)  # 潜在空间的log方差
        return mu, logvar


# 定义解码器(Decoder)类
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)  # 全连接层,将潜在空间特征映射到隐藏层
        self.fc2 = nn.Linear(hidden_dim, output_dim)  # 全连接层,将隐藏层特征映射到输出空间
        self.relu = nn.ReLU()  # ReLU激活函数
        self.sigmoid = nn.Sigmoid()  # Sigmoid激活函数

    def forward(self, z):
        h = self.relu(self.fc1(z))  # 隐藏层特征
        x_recon = self.sigmoid(self.fc2(h))  # 重建输入
        return x_recon


# 实例化VAE模型并加载保存的模型参数
input_dim = 784  # 输入维度(28x28图像展开为784维)
hidden_dim = 400  # 隐藏层维度
latent_dim = 20  # 潜在空间维度
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(input_dim, hidden_dim, latent_dim).to(device)
decoder = Decoder(latent_dim, hidden_dim, input_dim).to(device)

encoder.load_state_dict(torch.load('vae_encoder.pth'))
decoder.load_state_dict(torch.load('vae_decoder.pth'))

encoder.eval()
decoder.eval()
print("模型已加载")


# 预处理函数,将图片转换为模型可处理的张量
def preprocess_image(image_path):
    image = Image.open(image_path).convert('L')  # 以灰度模式加载图片
    transform = transforms.Compose([
        transforms.Resize((28, 28)),  # 调整图片大小为28x28
        transforms.ToTensor(),  # 转换为张量
        transforms.Normalize((0.5,), (0.5,))  # 归一化到[-1, 1]
    ])
    image = transform(image).view(-1, input_dim).to(device)  # 展平并移动到设备
    return image


# 假设你有一张自己的图片,路径为'path_to_your_image.png'
image_path = 'img_5.png'
image = preprocess_image(image_path)

# 使用编码器生成潜在变量
with torch.no_grad():
    mu, logvar = encoder(image)
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + eps * std  # 重参数化

# 使用解码器生成重建图片
with torch.no_grad():
    reconstructed_image = decoder(z)


# 显示原始图片和重建图片
def show_images(original, reconstructed):
    original = original.view(28, 28).cpu().numpy()
    reconstructed = reconstructed.view(28, 28).cpu().numpy()

    fig, axes = plt.subplots(1, 2)
    axes[0].imshow(original, cmap='gray')
    axes[0].set_title("Original Image")
    axes[1].imshow(reconstructed, cmap='gray')
    axes[1].set_title("Reconstructed Image")
    plt.show()

show_images(image, reconstructed_image)

可以看到效果还可以,我多次尝试了几次,效果有差别,可能是因为训练轮数较少的缘故。

相关推荐
杨浦老苏2 分钟前
开源PDF翻译工具PDFMathTranslate
人工智能·docker·ai·pdf·群晖·翻译
地中海~11 分钟前
DENIAL-OF-SERVICE POISONING ATTACKS ON LARGE LANGUAGE MODELS
人工智能·语言模型·自然语言处理
HSunR24 分钟前
概率论 期末 笔记
笔记·概率论
边缘计算社区1 小时前
首个!艾灵参编的工业边缘计算国家标准正式发布
大数据·人工智能·边缘计算
游客5201 小时前
opencv中的各种滤波器简介
图像处理·人工智能·python·opencv·计算机视觉
一位小说男主1 小时前
编码器与解码器:从‘乱码’到‘通话’
人工智能·深度学习
深圳南柯电子2 小时前
深圳南柯电子|电子设备EMC测试整改:常见问题与解决方案
人工智能
Kai HVZ2 小时前
《OpenCV计算机视觉》--介绍及基础操作
人工智能·opencv·计算机视觉
biter00882 小时前
opencv(15) OpenCV背景减除器(Background Subtractors)学习
人工智能·opencv·学习