引言:当精确计算变得不可能
想象一下,你试图绘制一张整个城市的地图,但只能通过询问路人来获取信息。每个路人都只知道城市的一小部分,而且他们的描述有时会相互矛盾。你需要整合所有这些零散、不完整的信息,画出一张尽可能准确的地图。
这就是贝叶斯推断面临的核心挑战:我们有一些观测数据(路人描述),有一些未知的隐藏变量(城市真实布局),需要通过贝叶斯定理计算后验分布。但当模型复杂时,这个计算往往像解一个千层谜题------理论上可行,实际上不可能。
变分推断(Variational Inference)就是解决这个问题的优雅近似方法。它不追求精确解,而是寻找一个"足够好"的近似,用可处理的简单分布来逼近复杂的真实后验。
一、问题:为什么精确贝叶斯推断如此困难?
1.1 贝叶斯推断的困境
在贝叶斯框架中,我们通常关心后验分布:
p(z∣x)=p(x∣z)p(z)p(x) p(z|x) = \frac{p(x|z)p(z)}{p(x)} p(z∣x)=p(x)p(x∣z)p(z)
其中:
- zzz 是隐变量(我们想推断的)
- xxx 是观测数据(我们已知的)
- p(x∣z)p(x|z)p(x∣z) 是似然函数
- p(z)p(z)p(z) 是先验分布
- p(x)p(x)p(x) 是证据(边缘似然)
难点在于分母 p(x)p(x)p(x):
p(x)=∫p(x∣z)p(z)dz p(x) = \int p(x|z)p(z) dz p(x)=∫p(x∣z)p(z)dz
对于复杂模型,这个积分往往:
- 维度灾难 :zzz 通常是高维的,积分在指数级多的区域上进行
- 解析不可解:没有闭合形式的解
- 计算昂贵:即使用蒙特卡洛方法,收敛速度也很慢
1.2 MCMC的局限性
传统的马尔可夫链蒙特卡洛(MCMC)方法通过采样来近似后验:
- 优点:最终能收敛到精确后验
- 缺点:需要大量采样,收敛慢,难以判断何时停止
变分推断提供了一种完全不同的思路:与其通过采样慢慢逼近,不如直接寻找一个简单分布来近似。
二、核心思想:用简单逼近复杂
2.1 基本直觉
变分推断的核心思想可以用一个类比理解:
与其精确计算一个人的完整基因序列(几乎不可能),不如找到与他最相似的已知基因模板(容易得多)。
数学上,我们:
- 选择一个简单的分布族 Q={q(z;λ)}Q = \{q(z;\lambda)\}Q={q(z;λ)},其中 λ\lambdaλ 是参数
- 在这个族中寻找与真实后验 p(z∣x)p(z|x)p(z∣x) 最"接近"的分布 q∗(z)q^*(z)q∗(z)
- 用 q∗(z)q^*(z)q∗(z) 作为后验的近似
2.2 如何衡量"接近"?
我们使用KL散度 (Kullback-Leibler divergence)来衡量两个分布的差异:
DKL(q(z)∥p(z∣x))=Eq(z)[logq(z)p(z∣x)] D_{KL}(q(z) \| p(z|x)) = \mathbb{E}_{q(z)}\left[\log\frac{q(z)}{p(z|x)}\right] DKL(q(z)∥p(z∣x))=Eq(z)[logp(z∣x)q(z)]
KL散度越小,两个分布越相似。但这里有个问题:计算 DKLD_{KL}DKL 需要知道 p(z∣x)p(z|x)p(z∣x),而这正是我们不知道的!
变分推断的巧妙之处在于重新表述问题。
三、数学推导:从KL散度到ELBO
3.1 证据下界(ELBO)的诞生
让我们从对数边缘似然开始:
logp(x)=log∫p(x,z)dz \log p(x) = \log \int p(x,z) dz logp(x)=log∫p(x,z)dz
引入近似分布 q(z)q(z)q(z):
logp(x)=log∫q(z)p(x,z)q(z)dz \log p(x) = \log \int q(z) \frac{p(x,z)}{q(z)} dz logp(x)=log∫q(z)q(z)p(x,z)dz
由Jensen不等式(因为log是凹函数):
logp(x)≥∫q(z)logp(x,z)q(z)dz \log p(x) \geq \int q(z) \log \frac{p(x,z)}{q(z)} dz logp(x)≥∫q(z)logq(z)p(x,z)dz
我们得到了证据下界 (Evidence Lower BOund,ELBO):
L(q)=Eq(z)[logp(x,z)]−Eq(z)[logq(z)] \mathcal{L}(q) = \mathbb{E}{q(z)}[\log p(x,z)] - \mathbb{E}{q(z)}[\log q(z)] L(q)=Eq(z)[logp(x,z)]−Eq(z)[logq(z)]
3.2 KL散度与ELBO的关系
有趣的是,对数边缘似然可以分解为:
logp(x)=L(q)+DKL(q(z)∥p(z∣x)) \log p(x) = \mathcal{L}(q) + D_{KL}(q(z) \| p(z|x)) logp(x)=L(q)+DKL(q(z)∥p(z∣x))
因为:
logp(x)=Eq(z)[logp(x)]=Eq(z)[logp(x,z)p(z∣x)]=Eq(z)[logp(x,z)q(z)]+Eq(z)[logq(z)p(z∣x)]=L(q)+DKL(q(z)∥p(z∣x)) \begin{aligned} \log p(x) &= \mathbb{E}{q(z)}[\log p(x)] \\ &= \mathbb{E}{q(z)}\left[\log\frac{p(x,z)}{p(z|x)}\right] \\ &= \mathbb{E}{q(z)}\left[\log\frac{p(x,z)}{q(z)}\right] + \mathbb{E}{q(z)}\left[\log\frac{q(z)}{p(z|x)}\right] \\ &= \mathcal{L}(q) + D_{KL}(q(z) \| p(z|x)) \end{aligned} logp(x)=Eq(z)[logp(x)]=Eq(z)[logp(z∣x)p(x,z)]=Eq(z)[logq(z)p(x,z)]+Eq(z)[logp(z∣x)q(z)]=L(q)+DKL(q(z)∥p(z∣x))
由于 DKL≥0D_{KL} \geq 0DKL≥0,我们有 L(q)≤logp(x)\mathcal{L}(q) \leq \log p(x)L(q)≤logp(x),所以 L(q)\mathcal{L}(q)L(q) 确实是下界。
3.3 变分推断的优化问题
关键洞察来了:因为 logp(x)\log p(x)logp(x) 是常数(对于给定的数据和模型),最大化ELBO等价于最小化KL散度!
于是,我们不需要直接处理难以计算的 p(z∣x)p(z|x)p(z∣x),而是将问题转化为:
q∗=argmaxq∈QL(q) q^* = \arg\max_{q \in Q} \mathcal{L}(q) q∗=argq∈QmaxL(q)
其中 QQQ 是我们选择的简单分布族。
四、平均场变分推断
4.1 平均场假设
最常用的分布族是平均场 (Mean-Field)族,它假设隐变量之间相互独立:
q(z)=∏j=1mqj(zj) q(z) = \prod_{j=1}^m q_j(z_j) q(z)=j=1∏mqj(zj)
每个 zjz_jzj 可以是单个变量或一组变量。这种假设大大简化了问题,但也引入了近似误差(因为真实后验中的变量通常不是独立的)。
4.2 坐标上升法
在平均场假设下,我们可以使用坐标上升法 来优化ELBO。固定其他 qi≠jq_{i\neq j}qi=j,优化 qjq_jqj:
qj∗(zj)∝exp(E−j[logp(x,z)]) q_j^*(z_j) \propto \exp\left(\mathbb{E}_{-j}[\log p(x,z)]\right) qj∗(zj)∝exp(E−j[logp(x,z)])
其中 E−j\mathbb{E}_{-j}E−j 表示对除 zjz_jzj 外所有变量的期望。
这个公式有优美的直觉:每个因子 qjq_jqj 的更新只依赖于其他因子的期望。
4.3 算法流程
平均场变分推断的算法:
- 初始化 :为所有 qj(zj)q_j(z_j)qj(zj) 选择初始分布
- 迭代直到收敛 :
- 对 j=1j = 1j=1 到 mmm:
- 更新 qj(zj)∝exp(E−j[logp(x,z)])q_j(z_j) \propto \exp\left(\mathbb{E}_{-j}[\log p(x,z)]\right)qj(zj)∝exp(E−j[logp(x,z)])
- 对 j=1j = 1j=1 到 mmm:
- 返回 :近似后验 q(z)=∏jqj(zj)q(z) = \prod_j q_j(z_j)q(z)=∏jqj(zj)
五、实例:贝叶斯混合高斯模型
让我们通过一个具体例子理解变分推断。
5.1 模型设定
考虑一个贝叶斯混合高斯模型:
- 观测数据:x={x1,...,xN}x = \{x_1, ..., x_N\}x={x1,...,xN},xi∈RDx_i \in \mathbb{R}^Dxi∈RD
- 隐变量:
- 聚类分配:z={z1,...,zN}z = \{z_1, ..., z_N\}z={z1,...,zN},ziz_izi 是one-hot向量
- 混合权重:π∼Dirichlet(α)\pi \sim \text{Dirichlet}(\alpha)π∼Dirichlet(α)
- 聚类中心:μk∼N(0,σ2I)\mu_k \sim \mathcal{N}(0, \sigma^2 I)μk∼N(0,σ2I)
- 精度矩阵:Λk∼Wishart(W0,ν0)\Lambda_k \sim \text{Wishart}(W_0, \nu_0)Λk∼Wishart(W0,ν0)
联合分布:
p(x,z,π,μ,Λ)=p(π)∏k=1Kp(μk)p(Λk)∏i=1Np(zi∣π)p(xi∣zi,μ,Λ) p(x,z,\pi,\mu,\Lambda) = p(\pi)\prod_{k=1}^K p(\mu_k)p(\Lambda_k)\prod_{i=1}^N p(z_i|\pi)p(x_i|z_i,\mu,\Lambda) p(x,z,π,μ,Λ)=p(π)k=1∏Kp(μk)p(Λk)i=1∏Np(zi∣π)p(xi∣zi,μ,Λ)
5.2 平均场变分分布
我们假设变分分布可以因子化:
q(z,π,μ,Λ)=q(π)q(μ,Λ)q(z) q(z,\pi,\mu,\Lambda) = q(\pi)q(\mu,\Lambda)q(z) q(z,π,μ,Λ)=q(π)q(μ,Λ)q(z)
更具体地:
- q(π)=Dirichlet(γ)q(\pi) = \text{Dirichlet}(\gamma)q(π)=Dirichlet(γ)
- q(μ,Λ)=∏kN(μk∣mk,(βkΛk)−1)Wishart(Λk∣Wk,νk)q(\mu,\Lambda) = \prod_k \mathcal{N}(\mu_k|m_k, (\beta_k\Lambda_k)^{-1})\text{Wishart}(\Lambda_k|W_k,\nu_k)q(μ,Λ)=∏kN(μk∣mk,(βkΛk)−1)Wishart(Λk∣Wk,νk)
- q(z)=∏iCategorical(ϕi)q(z) = \prod_i \text{Categorical}(\phi_i)q(z)=∏iCategorical(ϕi)
5.3 更新公式
通过推导,我们得到更新公式:
对于 q(z)q(z)q(z):
ϕik∝exp(E[logπk]+12E[log∣Λk∣]−D2log(2π)−12E[(xi−μk)⊤Λk(xi−μk)]) \phi_{ik} \propto \exp\left(\mathbb{E}[\log \pi_k] + \frac{1}{2}\mathbb{E}[\log|\Lambda_k|] - \frac{D}{2}\log(2\pi) - \frac{1}{2}\mathbb{E}[(x_i-\mu_k)^\top\Lambda_k(x_i-\mu_k)]\right) ϕik∝exp(E[logπk]+21E[log∣Λk∣]−2Dlog(2π)−21E[(xi−μk)⊤Λk(xi−μk)])
对于 q(π)q(\pi)q(π):
γk=αk+∑i=1Nϕik \gamma_k = \alpha_k + \sum_{i=1}^N \phi_{ik} γk=αk+i=1∑Nϕik
对于 q(μk,Λk)q(\mu_k,\Lambda_k)q(μk,Λk):
βk=β0+Nkmk=1βk(β0m0+Nkxˉk)Wk−1=W0−1+NkSk+β0Nkβ0+Nk(xˉk−m0)(xˉk−m0)⊤νk=ν0+Nk \begin{aligned} \beta_k &= \beta_0 + N_k \\ m_k &= \frac{1}{\beta_k}(\beta_0 m_0 + N_k \bar{x}_k) \\ W_k^{-1} &= W_0^{-1} + N_k S_k + \frac{\beta_0 N_k}{\beta_0 + N_k}(\bar{x}_k - m_0)(\bar{x}_k - m_0)^\top \\ \nu_k &= \nu_0 + N_k \end{aligned} βkmkWk−1νk=β0+Nk=βk1(β0m0+Nkxˉk)=W0−1+NkSk+β0+Nkβ0Nk(xˉk−m0)(xˉk−m0)⊤=ν0+Nk
其中 Nk=∑iϕikN_k = \sum_i \phi_{ik}Nk=∑iϕik,xˉk=1Nk∑iϕikxi\bar{x}k = \frac{1}{N_k}\sum_i \phi{ik}x_ixˉk=Nk1∑iϕikxi,Sk=1Nk∑iϕik(xi−xˉk)(xi−xˉk)⊤S_k = \frac{1}{N_k}\sum_i \phi_{ik}(x_i-\bar{x}_k)(x_i-\bar{x}_k)^\topSk=Nk1∑iϕik(xi−xˉk)(xi−xˉk)⊤。
5.4 直观解释
-
q(z)q(z)q(z) 更新 :每个数据点属于聚类 kkk 的概率 ϕik\phi_{ik}ϕik 取决于:
- 聚类权重 πk\pi_kπk 的期望
- 聚类的精度(逆协方差)Λk\Lambda_kΛk 的期望
- 数据点与聚类中心的距离
-
q(π)q(\pi)q(π) 更新 :聚类权重 γk\gamma_kγk 是先验 αk\alpha_kαk 加上分配给该聚类的"有效"数据点数
-
q(μk,Λk)q(\mu_k,\Lambda_k)q(μk,Λk) 更新:每个聚类的参数更新为数据加权平均的形式
这个过程交替进行,直到收敛。
六、变分推断 vs MCMC
6.1 比较
| 特性 | 变分推断 | MCMC |
|---|---|---|
| 哲学 | 优化(寻找最佳近似) | 采样(模拟后验) |
| 目标 | 最小化 DKL(q∣p)D_{KL}(q|p)DKL(q∣p) | 从 $p(z |
| 收敛速度 | 通常更快(迭代优化) | 通常较慢(需要burn-in) |
| 收敛判断 | ELBO收敛 | 链的统计量稳定 |
| 近似误差 | 有偏(取决于 QQQ) | 无偏(渐进精确) |
| 可扩展性 | 好(随机优化) | 差(串行依赖) |
| 提供 | 解析近似分布 | 经验分布(样本) |
6.2 何时使用变分推断?
- 数据量大:需要快速推断
- 需要多次推断:如在线学习、超参数调优
- 需要解析形式:后续分析需要分布形式
- 计算资源有限:无法承受MCMC的长链
6.3 何时使用MCMC?
- 小到中等数据:可以承受采样开销
- 需要精确推断:近似误差不可接受
- 模型复杂:难以选择好的变分族
- 需要后验样本:如计算分位数
七、现代扩展
7.1 随机变分推断(SVI)
对于大规模数据,传统的变分推断仍然需要全数据集计算梯度。SVI使用随机梯度上升:
- 从数据集中采样一个小批量
- 计算该批量的梯度估计
- 更新变分参数
这允许处理海量数据。
7.2 标准化流
平均场假设太强,可能丢失变量间的相关性。标准化流(Normalizing Flows)通过一系列可逆变换构建复杂分布:
zK=fK∘⋯∘f1(z0) z_K = f_K \circ \cdots \circ f_1(z_0) zK=fK∘⋯∘f1(z0)
其中 z0∼q0z_0 \sim q_0z0∼q0 是简单分布(如高斯),通过变换得到复杂分布 zKz_KzK。
7.3 变分自编码器(VAE)
VAE将变分推断与神经网络结合:
- 编码器 :学习变分分布 qϕ(z∣x)q_\phi(z|x)qϕ(z∣x) 的参数
- 解码器 :定义生成分布 pθ(x∣z)p_\theta(x|z)pθ(x∣z)
- 目标:最大化ELBO
VAE展示了如何用变分推断训练深度生成模型。
八、实践建议
8.1 选择变分族
- 从简单开始:先尝试平均场,如果效果不好再考虑更复杂的族
- 考虑共轭性:如果模型是指数族,选择共轭的变分分布可以简化计算
- 使用摊销推断:对于类似结构的数据(如图像),使用神经网络学习变分参数
8.2 监控收敛
- 跟踪ELBO:应该单调增加(除了SVI的随机波动)
- 检查参数变化:当参数变化很小时可以停止
- 多次初始化:避免局部最优
8.3 评估质量
- 预测性能:在测试集上评估
- 后验预测检查:从变分后验采样,生成数据,与真实数据比较
- 与MCMC比较:如果可能,用MCMC作为基准
结语:近似之美
变分推断代表了一种深刻的哲学转变:从追求精确解转向寻求实用近似。在现实世界的复杂问题中,完美往往是好的敌人。通过接受近似,我们获得了可扩展性、速度和实用性。
就像绘制城市地图:我们可能永远无法知道每一条小巷的精确位置,但一张"足够好"的地图已经能让我们高效导航。变分推断就是为我们绘制这种"足够好"的概率地图的工具。
它教会我们,在面对不确定性时,有时大胆近似比谨慎精确更有价值。在这个数据爆炸的时代,这种思想比以往任何时候都更加重要。
"所有模型都是错的,但有些是有用的。" ------ George Box
变分推断正是这一思想的完美体现:我们明知道近似分布 q(z)q(z)q(z) 不是真正的后验 p(z∣x)p(z|x)p(z∣x),但只要它能帮助我们做出更好的决策、生成更逼真的数据、发现更深层的模式,它就是有价值的。
在机器学习的工具箱中,变分推断不是最精确的工具,但它常常是最实用的------在精确与可行之间,它选择了明智的平衡。