变分推断(Variational Inference)

本篇博客的推到内容全部来自 变分推断(Variational Inference)初探,原本博客内容写的非常清楚,请感兴趣的读者移步原博客,这里只当作笔记整理,方便以后查阅。

变分推断(Variational Inference)

前言

在贝叶斯体系中,推断(inference)指的是利用已知变量推测未知变量的分布,即我们在已经输入变量 x x x 后,如何获得未知变量 y y y 的分布 p ( y ∣ x ) p(y|x) p(y∣x)。

精确推断方法准确地计算 p ( y ∣ x ) p(y|x) p(y∣x),该过程往往需要很大的计算开销,现实应用中近似推断更为常用。

近似推断的方法往往分为两大类,第一类是采样,常见的是MCMC方法,第二类是使用另一个分布近似 p ( y ∣ x ) p(y|x) p(y∣x),典型代表就是变分推断

变分推断(Variational Inference,下文简称VI)是一大类通过简单分布近似复杂分布、求解推断(inference)问题的方法的统称 ,具体包括平均场变分推断等算法。下面来看如何得到变分推断优化问题的具体形式

变分推断

我们假设 x x x是观测变量(或者叫证据变量、输入变量), z z z是隐变量(或者说是我们希望推断的label,在监督学习中通常用 y y y表示,但在贝叶斯中,一般会用 z z z表示隐变量)。

例如在线性回归问题中, x x x是线性回归模型的输入, z z z是线性回归模型的预测值;在图像分类问题中, x x x是图像的像素矩阵, z z z是图像的类别,即label。

贝叶斯模型中,我们的目的是得到后验分布 p ( z ∣ x , ϕ ) p(z|x,\phi) p(z∣x,ϕ),即我们观测到输入为 x x x时,输出变量 z z z的概率分布,其中 ϕ ϕ ϕ为模型参数

精确推断的方法,一般使用贝叶斯公式 p ( z ∣ x ) = p ( x ∣ z ) p ( z ) p ( x ) = p ( x ∣ z ) p ( z ) ∫ z p ( x , z ) d z p(z|x)=\frac{p(x|z)p(z)}{p(x)}=\frac{p(x|z)p(z)}{\int_zp(x,z)dz} p(z∣x)=p(x)p(x∣z)p(z)=∫zp(x,z)dzp(x∣z)p(z),然后精确计算每一项的值,得到后验分布,但 p ( x ) p(x) p(x)项涉及到积分的计算,很多时候是很难求解的,所以有了近似推断的方法,更加高效地求解该问题。

VI通过一个简单的分布 q ( z ∣ x , θ ) q(z|x,\theta) q(z∣x,θ)近似复杂的分布 p ( z ∣ x , ϕ ) p(z|x,\phi) p(z∣x,ϕ),其中 θ θ θ是 q q q分布的参数,我们希望 q ( z ∣ x , θ ) q(z|x,\theta) q(z∣x,θ)和 p ( z ∣ x , ϕ ) p(z|x,\phi) p(z∣x,ϕ)的差异越小越好。一般通过反向KL散度来度量这种差异性(什么是反向KL散度,为什么不用一般的KL散度,两者有什么差别等问题在文章最后会解释,这里就先接受这个想法就好)。

所以寻找一个与后验分布接近的简单分布的问题就变成了最小化反向KL散度的问题,即:
min ⁡ θ K L ( q ( z ∣ x , θ ) ∣ ∣ p ( z ∣ x , ϕ ) ) = ∫ z q ( z ∣ x , θ ) log ⁡ q ( z ∣ x , θ ) p ( z ∣ x , ϕ ) d z = E z ∼ q ( z ∣ x , θ ) [ log ⁡ q ( z ∣ x , θ ) p ( z ∣ x , ϕ ) ] \min_\theta KL(q(z|x,\theta)||p(z|x,\phi))=\int_zq(z|x,\theta)\log\frac{q(z|x,\theta)}{p(z|x,\phi)}dz\\=E_{z\sim q(z|x,\theta)}[\log\frac{q(z|x,\theta)}{p(z|x,\phi)}] θminKL(q(z∣x,θ)∣∣p(z∣x,ϕ))=∫zq(z∣x,θ)logp(z∣x,ϕ)q(z∣x,θ)dz=Ez∼q(z∣x,θ)[logp(z∣x,ϕ)q(z∣x,θ)]

但因为后验分布 p ( z ∣ x , ϕ ) p(z|x,\phi) p(z∣x,ϕ)未知,这个式子是没有办法直接求解的,变分推断通过一系列的变换,然后进行优化。

下面我们直接把积分项 ∫ z q ( z ∣ x , θ ) f ( z ) d z \int_zq(z|x,\theta)f(z)dz ∫zq(z∣x,θ)f(z)dz写成等价的期望形式 E z ∼ q ( z ∣ x , θ ) [ f ( z ) ] E_{z\sim q(z|x,\theta)}[f(z)] Ez∼q(z∣x,θ)[f(z)],网上的很多推导中,是写成积分或求和形式的,推导过程是完全相同的,但积分和求和形式的推导只针对连续或离散变量中的一种,这里选择用期望的形式进行推导,保证推导过程对于连续和离散变量都是成立的。
K L ( q ( z ∣ x , θ ) ∣ ∣ p ( z ∣ x , ϕ ) ) = E z ∼ q ( z ∣ x , θ ) [ log ⁡ q ( z ∣ x , θ ) p ( z ∣ x , ϕ ) ] = E z ∼ q ( z ∣ x , θ ) [ log ⁡ q ( z ∣ x , θ ) p ( x ∣ ϕ ) p ( z , x ∣ ϕ ) ] , 根据 p ( z ∣ x , ϕ ) = p ( z , x ∣ ϕ ) p ( x ∣ ϕ ) = E z ∼ q ( z ∣ x , θ ) [ log ⁡ q ( z ∣ x , θ ) p ( z , x ∣ ϕ ) ] + E z ∼ q ( z ∣ x , θ ) [ log ⁡ p ( x ∣ ϕ ) ] = − L + E z ∼ q ( z ∣ x , θ ) [ log ⁡ p ( x ∣ ϕ ) ] = − L + log ⁡ p ( x ∣ ϕ ) \begin{gathered} KL(q(z|x,\theta)||p(z|x,\phi)) \\ =E_{z\sim q(z|x,\theta)}[\log\frac{q(z|x,\theta)}{p(z|x,\phi)}] \\ =E_{z\sim q(z|x,\theta)}[\log\frac{q(z|x,\theta)p(x|\phi)}{p(z,x|\phi)}],\text{根据}p(z|x,\phi)=\frac{p(z,x|\phi)}{p(x|\phi)} \\ =E_{z\sim q(z|x,\theta)}[\log\frac{q(z|x,\theta)}{p(z,x|\phi)}]+E_{z\sim q(z|x,\theta)}[\log p(x|\phi)] \\ =-\mathcal{L}+E_{z\sim q(z|x,\theta)}[\log p(x|\phi)] \\ =-\mathcal{L}+\log p(x|\phi) \end{gathered} KL(q(z∣x,θ)∣∣p(z∣x,ϕ))=Ez∼q(z∣x,θ)[logp(z∣x,ϕ)q(z∣x,θ)]=Ez∼q(z∣x,θ)[logp(z,x∣ϕ)q(z∣x,θ)p(x∣ϕ)],根据p(z∣x,ϕ)=p(x∣ϕ)p(z,x∣ϕ)=Ez∼q(z∣x,θ)[logp(z,x∣ϕ)q(z∣x,θ)]+Ez∼q(z∣x,θ)[logp(x∣ϕ)]=−L+Ez∼q(z∣x,θ)[logp(x∣ϕ)]=−L+logp(x∣ϕ)

这里我们定义 L = − E z ∼ q ( z ∣ x , θ ) [ − log ⁡ q ( z ∣ x , θ ) p ( z , x ∣ ϕ ) ] \mathcal{L}=-E_{z\sim q(z|x,\theta)}[-\log\frac{q(z|x,\theta)}{p(z,x|\phi)}] L=−Ez∼q(z∣x,θ)[−logp(z,x∣ϕ)q(z∣x,θ)]。

注意到第二项 E z ∼ q ( z ∣ x , θ ) [ log ⁡ p ( x ∣ ϕ ) ] E_{z\sim q(z|x,\theta)}[\log p(x|\phi)] Ez∼q(z∣x,θ)[logp(x∣ϕ)]与 z z z无关,所以求期望的结果为 log ⁡ p ( x ∣ ϕ ) \log p(x|\phi) logp(x∣ϕ),对于优化变量 θ θ θ是一个常数,不需要优化,之后只考虑第一项的最小化问题,即 m a x L max\mathcal{L} maxL,这里的 L \mathcal{L} L被叫做证据下界(Evidence Lower BOund, 即ELBO),至于为什么叫ELBO会在文章后面解释。

因为联合分布 p ( z , x ∣ ϕ ) {p}(z,x|\phi) p(z,x∣ϕ)也是很难获得的,所以我们还需要进行进一步的转化,才能求解该问题。

L = E z ∼ q ( z ∣ x , θ ) [ − log ⁡ q ( z ∣ x , θ ) p ( z , x ∣ ϕ ) ] = E z ∼ q ( z ∣ x , θ ) [ − log ⁡ q ( z ∣ x , θ ) p ( x ∣ z , ϕ ) p ( z ∣ ϕ ) ] ,根据 p ( z , x ∣ ϕ ) = p ( x ∣ z , ϕ ) p ( z ∣ ϕ ) \mathcal{L}=E_{z\sim q(z|x,\theta)}[-\log\frac{q(z|x,\theta)}{p(z,x|\phi)}]\\=E_{z\sim q(z|x,\theta)}[-\log\frac{q(z|x,\theta)}{p(x|z,\phi)p(z|\phi)}]\text{,根据}p(z,x|\phi)=p(x|z,\phi)p(z|\phi) L=Ez∼q(z∣x,θ)[−logp(z,x∣ϕ)q(z∣x,θ)]=Ez∼q(z∣x,θ)[−logp(x∣z,ϕ)p(z∣ϕ)q(z∣x,θ)],根据p(z,x∣ϕ)=p(x∣z,ϕ)p(z∣ϕ)

转化到这里其实已经可以求解了,式子里的 q ( z ∣ x , θ ) q(z|x,\theta) q(z∣x,θ)是我们引入的简单的分布,是已知的。
p ( x ∣ z , ϕ ) p(x|z,\phi) p(x∣z,ϕ)是似然函数,也是已知的。
p ( z ∣ ϕ ) p(z|\phi) p(z∣ϕ)是对于 z z z的先验,与 ϕ ϕ ϕ是无关的,后面直接写成 p ( z ) p(z) p(z),贝叶斯模型中会假设先验为特定的形式,所以也是已知的。

到这里就已经转化为了我们可以计算的形式,推导就已经结束了。但一般会对这个结果进行一个简单的转化,变为直观上更容易理解的形式。
L = E z ∼ q ( z ∣ x , θ ) [ − log ⁡ q ( z ∣ x , θ ) p ( x ∣ z , ϕ ) p ( z ) ] = E z ∼ q ( z ∣ x , θ ) [ − log ⁡ q ( z ∣ x , θ ) p ( z ) + log ⁡ p ( x ∣ z , ϕ ) ] = E z ∼ q ( z ∣ x , θ ) [ log ⁡ p ( x ∣ z , ϕ ) ] − E z ∼ q ( z ∣ x , θ ) [ log ⁡ q ( z ∣ x , θ ) p ( z ) ] = E z ∼ q ( z ∣ x , θ ) [ log ⁡ p ( x ∣ z , ϕ ) ] − K L ( q ( z ∣ x , θ ) ∣ ∣ p ( z ) ) \begin{gathered} \mathcal{L}=E_{z\sim q(z|x,\theta)}[-\log\frac{q(z|x,\theta)}{p(x|z,\phi)p(z)}] \\ =E_{z\sim q(z|x,\theta)}[-\log\frac{q(z|x,\theta)}{p(z)}+\log p(x|z,\phi)] \\ =E_{z\sim q(z|x,\theta)}[\log p(x|z,\phi)]-E_{z\sim q(z|x,\theta)}[\log\frac{q(z|x,\theta)}{p(z)}] \\ =E_{z\sim q(z|x,\theta)}[\log p(x|z,\phi)]-KL(q(z|x,\theta)||p(z)) \end{gathered} L=Ez∼q(z∣x,θ)[−logp(x∣z,ϕ)p(z)q(z∣x,θ)]=Ez∼q(z∣x,θ)[−logp(z)q(z∣x,θ)+logp(x∣z,ϕ)]=Ez∼q(z∣x,θ)[logp(x∣z,ϕ)]−Ez∼q(z∣x,θ)[logp(z)q(z∣x,θ)]=Ez∼q(z∣x,θ)[logp(x∣z,ϕ)]−KL(q(z∣x,θ)∣∣p(z))

最后一步是根据KL散度的定义直接转化的,推导到这里就结束了。

回忆一下整体的流程:VI中使用简单的分布 q ( z ∣ x , θ ) q(z|x,\theta) q(z∣x,θ)近似复杂分布 p ( z ∣ x , ϕ ) p(z|x,\phi) p(z∣x,ϕ),所以最小化二者的KL散度,但无法直接求解,所以通过一系列的变换,转化为最大化ELBO的形式,进行求解。所以VI问题就是最大化证据下界,即:
max ⁡ θ L = max ⁡ θ { E z ∼ q ( z ∣ x , θ ) [ log ⁡ p ( x ∣ z , ϕ ) ] − K L ( q ( z , θ ) ∣ ∣ p ( z ) ) } \max_\theta\mathcal{L}=\max_\theta \{E_{z\sim q(z|x,\theta)}[\log p(x|z,\phi)]-KL(q(z,\theta)||p(z))\} θmaxL=θmax{Ez∼q(z∣x,θ)[logp(x∣z,ϕ)]−KL(q(z,θ)∣∣p(z))}

直观上理解一下最后的结果:

  1. 第一项中,简单的分布 q ( z ∣ x , θ ) q(z|x,\theta) q(z∣x,θ)是在已知 x x x的情况下,使用近似分布获得 z z z的过程,可以看做是 x x x编码到 z z z的过程;似然函数 p ( x ∣ z , ϕ ) p(x|z,\phi) p(x∣z,ϕ)是在已知 z z z后,获得 x x x的过程,可以看做是 z z z编码到 x x x的过程。第一项直观上衡量了从简单分布 q ( z ∣ x , θ ) q(z|x,\theta) q(z∣x,θ)中获得一个编码后的结果,多大程度上能够得到编码前的数据 p ( x ∣ z , ϕ ) p(x|z,\phi) p(x∣z,ϕ)。
  2. 第二项是希望我们的简单分布和真实的 z z z的先验分布尽量接近。

接下来我们看一下前面遗留的两个小问题,即为什么使用反向KL散度和为什么 L \mathcal{L} L被称为证据下界。

为什么使用反向KL散度

关于正向KL散度和反向KL散度的解释在博客《概率论相关知识随记》已经阐述清楚。

变分推断为什么使用反向KL?(这里是猜的,我也不太清楚)我感觉就是要在多峰时,尽量逼近其中一个峰,而不是尝试逼近所有峰,导致每个位置的近似效果都不好。

为什么 L \mathcal{L} L被叫做ELBO

第一个长公式的结果为:
K L ( q ( z ∣ x , θ ) ∣ ∣ p ( z ∣ x , ϕ ) ) = − L + log ⁡ p ( x ∣ ϕ ) KL(q(z|x,\theta)||p(z|x,\phi))=-\mathcal{L}+\log p(x|\phi) KL(q(z∣x,θ)∣∣p(z∣x,ϕ))=−L+logp(x∣ϕ)

变换形式后:
log ⁡ p ( x ∣ ϕ ) = K L ( q ( z ∣ x , θ ) ∣ ∣ p ( z ∣ x , ϕ ) ) + L \log p(x|\phi)=KL(q(z|x,\theta)||p(z|x,\phi))+\mathcal{L} logp(x∣ϕ)=KL(q(z∣x,θ)∣∣p(z∣x,ϕ))+L

公式左边的是关于 x x x的函数,右边是 L \mathcal{L} L与KL散度的和,KL散度结果一定大于等于0,所以一定有:
log ⁡ p ( x ∣ ϕ ) ≥ L \log p(x|\phi)\geq\mathcal{L} logp(x∣ϕ)≥L

在文章开头我们说在贝叶斯模型中,我们称 x x x为证据变量,右边可以看做是证据变量的下界,所以叫做证据下界(ELBO)。

变分推断的应用

变分推断在多种概率模型中有广泛的应用,尤其在贝叶斯深度学习生成模型中。以下是一些常见的应用场景:

  • 变分自编码器(VAE):VAE使用变分推断来近似后验分布,在生成模型中具有重要应用。VAE的目标是通过最大化变分下界(ELBO)来近似后验分布。

  • 高斯混合模型(GMM):变分推断可用于高斯混合模型的参数估计,尤其在数据集非常大的情况下,变分推断提供了一种高效的近似方法。

  • 贝叶斯神经网络:在贝叶斯神经网络中,通常需要对网络的权重分布进行推断。通过变分推断,可以得到这些权重的近似后验分布。

  • Topic Modeling(主题模型):如LDA(Latent Dirichlet Allocation)模型,变分推断被广泛用于估计主题分布和文档分布的后验。

变分推断与MCMC的区别

变分推断和马尔可夫链蒙特卡罗(MCMC)方法都是用于贝叶斯推断的技术。它们的主要区别在于:

  • 计算效率:MCMC通过构建马尔可夫链来进行采样,计算量较大且可能需要较长的时间才能收敛。变分推断则通过优化一个下界来进行近似推断,通常计算速度较快。

  • 精度:MCMC方法通过随机采样生成后验分布的样本,因此在理论上它可以提供精确的后验分布。变分推断则通过优化近似分布,可能无法得到精确的后验分布,但在大规模数据和复杂模型中通常能提供足够好的近似。

变分推断的优缺点

优点

  • 计算效率高:相比MCMC方法,变分推断计算速度较快,可以在较短的时间内得到近似结果。
  • 可扩展性:变分推断特别适合大规模数据集,可以处理许多需要贝叶斯推断的复杂模型。
  • 易于实现:通过优化下界的形式,变分推断可以通过标准的优化方法(如梯度下降)来进行实现。

缺点

  • 精度较低:变分推断的目标是逼近真实的后验分布,而不是精确采样,因此得到的近似结果可能不如MCMC方法准确。
  • 选择变分分布的困难:选择合适的变分分布形式对于近似的精度和计算效率有很大影响,有时选择不当可能导致不理想的结果。
相关推荐
风象南12 小时前
Claude Code这个隐藏技能,让我告别PPT焦虑
人工智能·后端
Mintopia13 小时前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮14 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬14 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia14 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区15 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两17 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪18 小时前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain
strayCat2325518 小时前
Clawdbot 源码解读 7: 扩展机制
人工智能·开源