Normalizing flows are generative machine learning models that transform a simple base probability distribution (like a standard Gaussian) into a complex, multi-modal target distribution by passing it through a sequence of invertible and differentiable mathematical functions.
They are highly valued for two exact computations:
- generating novel data by transforming random noise,
- and accurately calculating the explicit likelihood of any given data point.
How They Work
Normalizing flows rely on the statistical change of variables theorem. Because the neural network transformations are fully invertible, you can trace a complex data point (e.g., an image) backwards into a simple latent representation, or forwards to generate new data.
The model tracks how the "volume" of probability space expands or contracts during these transformations using a mathematical concept called the Jacobian determinant.
Mathematically, the relationship between the initial probability density pZ(z)p_Z(z)pZ(z) and the transformed target density pX(x)p_X(x)pX(x) is defined as:
pX(x)=pZ(f−1(x))∣det(∂f−1(x)∂x)∣p_X(x) = p_Z(f^{-1}(x)) \left\vert{} \det \left( \frac{\partial f^{-1}(x)}{\partial x} \right) \right\vert{}pX(x)=pZ(f−1(x)) det(∂x∂f−1(x))
Key Advantages
-
Exact Likelihoods: Unlike Variational Autoencoders (VAEs) or Generative Adversarial Networks (GANs), flows yield exact, tractable probability evaluations. You get a concrete number representing exactly how likely a data point is to occur.
-
No Mode Collapse: Because the model is strictly trained by maximizing the exact data likelihood, it doesn't suffer from the training instabilities or mode collapse often seen in GANs.
-
Invertible Mappings: Every piece of generated data corresponds to a distinct latent point, allowing for precise data editing and interpolation.
我们常常看到两种公式
这两个公式都是对的,它们没有矛盾,只是对函数符号 fff 的定义方向(映射方向)截然相反。
两者的核心差别在于:你把 fff 定义为"从简单变复杂"的变换,还是"从复杂变简单"的变换。
公式一 (从 xxx 映射到 zzz):
px(x)=pz(f(x))⋅∣det∂f(x)∂x∣p_x(x) = p_z(f(x)) \cdot \left\vert{} \det \frac{\partial f(x)}{\partial x} \right\vert{}px(x)=pz(f(x))⋅ det∂x∂f(x)
- 这里的定义:变换函数 fff 的方向是 z=f(x)z = f(x)z=f(x)。
- 物理含义:把复杂的真实数据 xxx 通过网络 fff 映射(编码)到简单的潜在分布 zzz。因为导数是 ∂f(x)∂x\frac{\partial f(x)}{\partial x}∂x∂f(x)(即 ∂z∂x\frac{\partial z}{\partial x}∂x∂z),所以在计算 xxx 的概率时,雅可比行列式直接放在乘号后面。
- 常见出处:在流模型的训练(最大似然估计)阶段,代码通常执行的是这个方向(从数据 xxx 计算似然值),因此很多学术论文(如变分推断、MAF等)喜欢这样写。
公式二 (从 zzz 映射到 xxx):
pX(x)=pZ(f−1(x))∣det(∂f−1(x)∂x)∣p_X(x) = p_Z(f^{-1}(x)) \left\vert{} \det \left( \frac{\partial f^{-1}(x)}{\partial x} \right) \right\vert{}pX(x)=pZ(f−1(x)) det(∂x∂f−1(x))
- 这里的定义:变换函数 fff 的方向是 x=f(z)x = f(z)x=f(z)。
- 物理含义:把简单的噪声 zzz 通过网络 fff 映射(生成)到复杂的真实数据 xxx。此时,逆变换就是 z=f−1(x)z = f^{-1}(x)z=f−1(x)。
- 常见出处:在流模型的生成/采样(Sampling)阶段,我们是从 zzz 生成 xxx,所以很多教科书为了强调"生成器 fff"的概念,会把 fff 的方向定为从 z→xz \to xz→x。根据概率论的变量代换标准公式,写出来就是公式二的样子。
当你阅读不同的 Normalizing Flow 论文时,一定要先看作者在文章开头对 fff 的定义:
- 如果作者说 z=f(x)z = f(x)z=f(x) ⟹ \implies⟹ 用第一个。
- 如果作者说 x=f(z)x = f(z)x=f(z) ⟹ \implies⟹ 用第二个。
它们的数学本质完全相同,因为根据逆矩阵行列式性质:∣det∂f−1(x)∂x∣=∣det∂f(z)∂z∣−1\left\vert{} \det \frac{\partial f^{-1}(x)}{\partial x} \right\vert{} = \left\vert{} \det \frac{\partial f(z)}{\partial z} \right\vert{}^{-1} det∂x∂f−1(x) = det∂z∂f(z) −1。
px(x)=pz(f(x))⋅∣det∂f(x)∂x∣p_x(x) = p_z(f(x)) \cdot \left\vert{} \det \frac{\partial f(x)}{\partial x} \right\vert{}px(x)=pz(f(x))⋅ det∂x∂f(x) 和变量变换公式
从概率论中变量变换公式(Change of Variables Formula) ,以及概率密度如何随坐标变换而变化说起。
1. 直观理解:体积膨胀与密度收缩
根据变量变换公式:
px(x)=pz(f(x))⋅∣det∂f(x)∂x∣ p_x(x) = p_z(f(x)) \cdot \left| \det \frac{\partial f(x)}{\partial x} \right| px(x)=pz(f(x))⋅ det∂x∂f(x)
这里:
- pz(f(x))p_z(f(x))pz(f(x)) 是变换后的点在简单分布下的概率密度。
- ∣det∂f(x)∂x∣\left| \det \frac{\partial f(x)}{\partial x} \right| det∂x∂f(x) 是雅可比行列式,表示变换引起的体积扩张因子。
想象一下水流或者气体:
- 如果我们将空间中的一个区域进行拉伸 (体积变大),那么该区域内的物质密度就会降低。
- 如果我们将空间中的一个区域进行压缩 (体积变小),那么该区域内的物质密度就会升高。
在概率分布中:
- xxx 是数据空间中的点。
- z=f(x)z = f(x)z=f(x) 是通过神经网络变换后的潜在空间中的点(简单潜在分布 zzz)。
- px(x)p_x(x)px(x) 是 xxx 处的概率密度。
- pz(z)p_z(z)pz(z) 是 zzz 处的概率密度。
- 物理含义:把复杂的真实数据 xxx 通过网络 fff 映射(编码)到简单的潜在分布 zzz。因为导数是 ∂f(x)∂x\frac{\partial f(x)}{\partial x}∂x∂f(x)(即 ∂z∂x\frac{\partial z}{\partial x}∂x∂z),所以在计算 xxx 的概率时,雅可比行列式直接放在乘号后面。
当我们从 zzz 变回 xxx(或者反过来)时,空间的"体积"发生了扭曲。为了保持总概率为 1,概率密度必须根据体积的变化进行调整。
2. 数学推导:从积分到密度
假设我们要保证变换前后的概率质量守恒。考虑 xxx 空间中的一个微小区域 dxdxdx,它对应 zzz 空间中的一个微小区域 dzdzdz。
概率守恒意味着:
px(x)dx=pz(z)dz p_x(x) dx = p_z(z) dz px(x)dx=pz(z)dz
我们需要找到 dxdxdx 和 dzdzdz 之间的关系。这由**雅可比矩阵(Jacobian Matrix)**决定。
雅可比矩阵 JfJ_fJf 定义为变换 z=f(x)z = f(x)z=f(x) 的导数矩阵:
Jf=∂z∂x J_f = \frac{\partial z}{\partial x} Jf=∂x∂z
在线性近似下,体积的变化由雅可比矩阵的行列式决定:
dz≈∣det(Jf)∣⋅dx dz \approx |\det(J_f)| \cdot dx dz≈∣det(Jf)∣⋅dx
或者更准确地说,逆变换的雅可比行列式表示从 zzz 到 xxx 的体积变化:
dx≈∣det(∂x∂z)∣⋅dz=∣det(Jf−1)∣⋅dz dx \approx |\det(\frac{\partial x}{\partial z})| \cdot dz = |\det(J_{f^{-1}})| \cdot dz dx≈∣det(∂z∂x)∣⋅dz=∣det(Jf−1)∣⋅dz
根据微积分基本定理,det(∂x∂z)=(det(∂z∂x))−1\det(\frac{\partial x}{\partial z}) = \left( \det(\frac{\partial z}{\partial x}) \right)^{-1}det(∂z∂x)=(det(∂x∂z))−1。
所以:
px(x)=pz(z)⋅∣det∂z∂x∣ p_x(x) = p_z(z) \cdot \left| \det \frac{\partial z}{\partial x} \right| px(x)=pz(z)⋅ det∂x∂z
(注:这里假设 zzz 是先验分布,xxx 是数据分布。通常我们写 px(x)p_x(x)px(x) 时,是用 z=f(x)z=f(x)z=f(x) 代入。)
3. Normalizing Flow 的标准定义:
通常我们定义一个可逆变换 z=f(x)z = f(x)z=f(x),将复杂分布 px(x)p_x(x)px(x) 映射到简单分布 pz(z)p_z(z)pz(z)(如标准高斯分布)。
根据变量变换公式:
px(x)=pz(f(x))⋅∣det∂f(x)∂x∣ p_x(x) = p_z(f(x)) \cdot \left| \det \frac{\partial f(x)}{\partial x} \right| px(x)=pz(f(x))⋅ det∂x∂f(x)
这里:
- pz(f(x))p_z(f(x))pz(f(x)) 是变换后的点在简单分布下的概率密度。
- ∣det∂f(x)∂x∣\left| \det \frac{\partial f(x)}{\partial x} \right| det∂x∂f(x) 是雅可比行列式,表示变换引起的体积扩张因子。
取对数
为了计算稳定性和数值精度,我们通常对等式两边取自然对数:
logpx(x)=log(pz(f(x))⋅∣det∂f(x)∂x∣) \log p_x(x) = \log \left( p_z(f(x)) \cdot \left| \det \frac{\partial f(x)}{\partial x} \right| \right) logpx(x)=log(pz(f(x))⋅ det∂x∂f(x) )
利用对数的性质 log(ab)=loga+logb\log(ab) = \log a + \log blog(ab)=loga+logb,我们得到:
logpx(x)=logpz(f(x))+log∣det∂f(x)∂x∣ \log p_x(x) = \log p_z(f(x)) + \log \left| \det \frac{\partial f(x)}{\partial x} \right| logpx(x)=logpz(f(x))+log det∂x∂f(x)
这就得到了你看到的公式:
logpx(x)=logpz(z)+log∣detJf∣ \log p_x(x) = \log p_z(z) + \log |\det J_f| logpx(x)=logpz(z)+log∣detJf∣
其中:
- logpz(z)\log p_z(z)logpz(z):是先验对数概率 。因为 zzz 是标准高斯分布,这部分很容易计算。
- log∣detJf∣\log |\det J_f|log∣detJf∣:是雅可比行列式的对数 ,被称为复杂性项(Complexity Term)或体积变化项 。它衡量了变换 fff 将 zzz 空间的单位体积映射到 xxx 空间时,体积膨胀或收缩了多少。
4. 为什么需要雅可比行列式?
如果没有雅可比行列式这一项,我们就假设变换是"等体积"的(就像旋转或平移,行列式为 1)。
但是,神经网络通常进行的是非线性的拉伸和压缩:
- 如果模型想把很多数据点挤到一小块区域(高密度),雅可比行列式会很小,log∣detJ∣\log |\det J|log∣detJ∣ 是负数,这会抵消 pz(z)p_z(z)pz(z) 的高概率,使得 logpx(x)\log p_x(x)logpx(x) 不会无限大。
- 如果模型把数据点分散开(低密度),雅可比行列式会很大,log∣detJ∣\log |\det J|log∣detJ∣ 是正数,这会增加 logpx(x)\log p_x(x)logpx(x)。
总结:
log∣detJf∣\log |\det J_f|log∣detJf∣ 项的作用是校正概率密度,以补偿因坐标变换引起的体积变化,从而保证概率质量守恒。这就是为什么对数似然计算必须包含这一项。
对数似然计算 (Log-Likelihood)
在生成模型(包括规范化流 Normalizing Flows)中,极大似然估计(Maximum Likelihood Estimation, MLE) 经常被直接用作训练目标。
具体来说,最大化对数似然(Log-Likelihood),在数值计算上等价于最小化负对数似然(Negative Log-Likelihood, NLL)。所以,负对数似然(NLL)就是我们丢给深度学习优化器的损失函数(Loss Function)。
1. 为什么"最大化似然"能作为优化目标?
- 似然(Likelihood)的含义:输入一个真实的数据点 x(比如一张猫的图片),模型输出的 px(x)p_x(x)px(x) 表示"我的模型能生成这张猫图的概率有多大"。
- 训练的目的:我们希望模型完美拟合真实数据,也就是说,当输入真实数据时,模型输出的概率 px(x)p_x(x)px(x) 越高越好。
- 数学表示:我们的目标是 maxpx(x)\max p_x(x)maxpx(x)。
2. 为什么要加"对数(Log)"?
在数学和代码实现中,我们不直接最大化 px(x)p_x(x)px(x),而是最大化 logpx(x)\log p_x(x)logpx(x)。原因有两个:
- 防止数值下溢:高维数据(如 128 × 128 的图像)的概率值通常极其微小(例如 10⁻⁵⁰),计算机无法精确存储。取对数后,极小的乘积会变成容易计算的加法(例如 log(10−50)=−50\log(10^{-50}) = -50log(10−50)=−50)。
- 简化变换公式:规范化流的公式里有乘号和行列式。取对数后,乘法变加法:
logpx(x)=logpz(f(x))+log∣det∂f(x)∂x∣\log p_x(x) = \log p_z(f(x)) + \log \left\vert{} \det \frac{\partial f(x)}{\partial x} \right\vert{}logpx(x)=logpz(f(x))+log det∂x∂f(x)
这样求导和反向传播就变得非常简单。
3. 对数似然怎么变成代码里的 Loss?
PyTorch 或 TensorFlow 的优化器(如 Adam)在设计时,默认都是通过梯度下降来寻找最小值的(Minimize)。
为了顺应框架,我们将"最大化"目标乘以一个负号,变成"最小化"目标:
Loss=−logpx(x)\text{Loss} = -\log p_x(x)Loss=−logpx(x)
这就是负对数似然损失(Negative Log-Likelihood Loss, NLL Loss)。
- 当 Loss 越来越小(趋近于负无穷),意味着 logpx(x)\log p_x(x)logpx(x) 越来越大,即模型生成真实数据的概率越来越高。
4. 规范化流中的完整 Loss 公式
假设我们使用前文的第一种定义(z = f(x),将数据映射到高斯噪声),那么训练时每个样本的 Loss 表达式 就是:
Loss=−logpz(f(x))⏟第一部分:隐变量的概率−log∣det∂f(x)∂x∣⏟第二部分:雅可比行列式对数\text{Loss} = \underbrace{-\log p_z(f(x))}{\text{第一部分:隐变量的概率}} - \underbrace{\log \left\vert{} \det \frac{\partial f(x)}{\partial x} \right\vert{}}{\text{第二部分:雅可比行列式对数}}Loss=第一部分:隐变量的概率 −logpz(f(x))−第二部分:雅可比行列式对数 log det∂x∂f(x)
- 第一部分:把 x 喂入网络得到 z = f(x)。因为 z 服从标准高斯分布,我们直接用高斯公式算出这个 z 的概率对数。如果 z 离原点太远,这部分 Loss 就会很大,它迫使网络把数据均匀映射在标准高斯分布的核心区域。(先验分布的对数概率(如高斯分布))
- 第二部分:这是流模型特有的空间缩放修正(对数绝对雅可比行列式,通常代码里叫 ladj 或 log_abs_det_jacobian)。它防止网络通过无限放大空间来偷懒作弊。(雅可比行列式的对数,衡量体积变化。)
一旦我们计算出了对数似然(或负对数似然损失),我们就需要知道如何调整参数 θ\thetaθ 来让它变大(或损失变小)。
这时,我们使用反向传播算法 ,其核心就是链式法则(Chain Rule)。
- 链式法则的作用 :计算损失函数 LLL 对每个参数 θi\theta_iθi 的梯度 ∂L∂θi\frac{\partial L}{\partial \theta_i}∂θi∂L。
- 反向传播的过程:从输出层(损失)开始,沿着计算图向后传递误差信号,逐层计算偏导数。
公式关系:
θnew=θold−η⋅∂(−logpθ(x))∂θ \theta_{new} = \theta_{old} - \eta \cdot \frac{\partial (-\log p_\theta(x))}{\partial \theta} θnew=θold−η⋅∂θ∂(−logpθ(x))
这里的 ∂(−logpθ(x))∂θ\frac{\partial (-\log p_\theta(x))}{\partial \theta}∂θ∂(−logpθ(x)) 就是由反向传播(链式法则)计算出来的。
举个具体例子:Planar Flow 中的雅可比行列式
以你之前问的 Planar Flow 为例,看看对数似然和反向传播是如何结合的:
Step 1: 前向传播(计算对数似然)
- 输入 xxx。
- 计算 z=x+utanh(wTx+b)z = x + u \tanh(w^T x + b)z=x+utanh(wTx+b)。
- 计算雅可比行列式:J=1+wTutanh′(wTx+b)J = 1 + w^T u \tanh'(w^T x + b)J=1+wTutanh′(wTx+b)。
- 计算对数似然:
logpx(x)=logN(z;0,I)+log∣J∣ \log p_x(x) = \log \mathcal{N}(z; 0, I) + \log |J| logpx(x)=logN(z;0,I)+log∣J∣
此时,logpx(x)\log p_x(x)logpx(x) 是一个标量值。
Step 2: 反向传播(链式法则求梯度)
我们需要求 ∂(−logpx(x))∂w\frac{\partial (-\log p_x(x))}{\partial w}∂w∂(−logpx(x)), ∂(−logpx(x))∂u\frac{\partial (-\log p_x(x))}{\partial u}∂u∂(−logpx(x)) 等。
使用链式法则:
∂(−logpx(x))∂w=∂(−logpx(x))∂J⋅∂J∂w \frac{\partial (-\log p_x(x))}{\partial w} = \frac{\partial (-\log p_x(x))}{\partial J} \cdot \frac{\partial J}{\partial w} ∂w∂(−logpx(x))=∂J∂(−logpx(x))⋅∂w∂J
- ∂(−logpx(x))∂J\frac{\partial (-\log p_x(x))}{\partial J}∂J∂(−logpx(x)):来自损失函数对雅可比的导数。
- ∂J∂w\frac{\partial J}{\partial w}∂w∂J:来自雅可比公式对 www 的导数。
这个"逐层求导"的过程,就是反向传播(链式法则)。
最经典的 Planar Flow(平面流) 为例。
这是 Danilo Rezende 在 2015 年奠基论文中提出的最简单的一步变换 1。我们来看看它如何一步步把一个 2 维的隐变量 zzz 变换,以及如何计算 Loss。
1. Planar Flow 的数学定义
假设我们的输入数据 xxx 是 2 维的(比如一个二维坐标点)。
Planar Flow 定义的前向变换(从简单到复杂,z→xz \to xz→x)公式如下:
x=f(z)=z+u⋅h(wTz+b)x = f(z) = z + u \cdot h(w^T z + b)x=f(z)=z+u⋅h(wTz+b)
这里面各部分的含义非常直观:
- zzz: 2 维的输入噪声(服从标准正态分布)。
- w,uw, uw,u: 都是 2 维的权重向量(网络要学习的参数)。
- bbb: 一个标量偏置(网络要学习的参数)。
- h(⋅)h(\cdot)h(⋅): 一个非线性激活函数,通常使用 tanh\tanhtanh。
- 物理含义:把空间沿着 www 的方向进行有选择的拉伸或压缩。
这个变换被称为"平面流",因为对于固定的 w 和 b,变换主要沿着向量 w 的方向发生偏移,且在垂直于 w 的方向上保持不变(这就是为什么叫"平面"------它在 w 正交的子空间上是恒等映射)。
2. 核心:雅可比行列式(Jacobian)怎么算?
还记得前文提到的 Loss 公式需要计算雅可比行列式吗?对于 Planar Flow,它的导数矩阵(Jacobian)形式非常特殊,根据矩阵行列式引理(Matrix Determinant Lemma),
根据链式法则,变换的雅可比矩阵 JJJ 为:
J=∂zl+1∂zl=I+u⋅h′(wTzl+b)⋅wT J = \frac{\partial z_{l+1}}{\partial z_l} = I + u \cdot h'(w^T z_l + b) \cdot w^T J=∂zl∂zl+1=I+u⋅h′(wTzl+b)⋅wT
这是一个 单位矩阵 III 加上一个外积(Outer Product,外积(叉乘,Cross Product)) uwTu w^TuwT。这种结构称为秩-1更新(Rank-1 Update)。
雅可比行列式化简公式 (Jacobian Determinant)---内积(点积)
利用 Sherman-Morrison 公式求行列式
在矩阵运算中,两个列向量的内积满足交换律,即:
uTw=wTuu^T w = w^T uuTw=wTu
直接计算 d×dd \times dd×d 矩阵的行列式复杂度是 O(d3)O(d^3)O(d3)。但因为 JJJ 是秩-1更新,我们可以使用 Sherman-Morrison 公式 的推论来高效计算行列式:
det(J)=1+wTu⋅h′(wTzl+b) \det(J) = 1 + w^T u \cdot h'(w^T z_l + b) det(J)=1+wTu⋅h′(wTzl+b)
等价:
∣det∂f(z)∂z∣=∣1+uTψ(z)∣\left\vert{} \det \frac{\partial f(z)}{\partial z} \right\vert{} = \left\vert{} 1 + u^T \psi(z) \right\vert{} det∂z∂f(z) = 1+uTψ(z)
其中 ψ(z)=h′(wTz+b)⋅w\psi(z) = h'(w^T z + b) \cdot wψ(z)=h′(wTz+b)⋅w (即激活函数的导数乘以 www)。
- h′(wTz+b)h'(w^T z + b)h′(wTz+b) 是一个标量(纯数字)
- h'(⋅):是激活函数的导数。如果 h=tanhh = \tanhh=tanh,则 h′(x)=1−tanh2(x)h'(x) = 1 - \tanh^2(x)h′(x)=1−tanh2(x)。因为 tanh\tanhtanh 的导数是 1−tanh21 - \tanh^21−tanh2,所以这个值在代码里极其好算,不需要真正去算复杂的矩阵行列式!
- 此时 uTψ(z)u^T \psi(z)uTψ(z) 演变成一个内积(点积),计算复杂度直接从 O(D³) 降到了 O(D)。
关键约束:
为了确保可逆性,uuu 和 www 不能是任意值。必须满足以下约束:
u=(ϕ(u~)−1)w∥w∥2 u = \frac{(\phi(\tilde{u}) - 1) w}{\|w\|^2} u=∥w∥2(ϕ(u~)−1)w
其中:
- u~\tilde{u}u~ 是一个自由学习的向量。
- ϕ\phiϕ 是激活函数 hhh 的原函数(即 ϕ′(x)=h(x)\phi'(x) = h(x)ϕ′(x)=h(x))。例如,如果 h(x)=tanh(x)h(x) = \tanh(x)h(x)=tanh(x),则 ϕ(x)=log(cosh(x))\phi(x) = \log(\cosh(x))ϕ(x)=log(cosh(x))。
这个约束的物理意义 :它限制了向量 uuu 必须位于 www 的方向上,从而保证变换沿着 www 方向"拉伸"或"压缩",而在正交方向不变。
给定数据 xxx,我们要计算其概率密度 px(x)p_x(x)px(x)。假设先验是标准高斯分布 z∼N(0,I)z \sim \mathcal{N}(0, I)z∼N(0,I)。
logpx(x)=logpz(fθ(x))+log∣det∂fθ(x)∂x∣ \log p_x(x) = \log p_z(f_\theta(x)) + \log \left| \det \frac{\partial f_\theta(x)}{\partial x} \right| logpx(x)=logpz(fθ(x))+log det∂x∂fθ(x)
由于 fθf_\thetafθ 是多层 Planar Flow 的复合,根据链式法则,总雅可比行列式是每一层行列式的乘积,对数后变成求和:
logpx(x)=logpz(zL)+∑l=1Llog(1+wlTul⋅h′(wlTzl−1+bl)) \log p_x(x) = \log p_z(z_L) + \sum_{l=1}^{L} \log \left( 1 + w_l^T u_l \cdot h'(w_l^T z_{l-1} + b_l) \right) logpx(x)=logpz(zL)+l=1∑Llog(1+wlTul⋅h′(wlTzl−1+bl))
其中 zL=fθ(x)z_L = f_\theta(x)zL=fθ(x) 是最终映射到的潜在空间。
3. PyTorch 代码实现与 Loss 计算
在实际训练时,我们手头有真实数据 xxx,为了算 Loss,我们需要计算 logpx(x)\log p_x(x)logpx(x)。
这里有一个微妙的细节:Planar Flow 的反向函数 f−1(x)f^{-1}(x)f−1(x) 没有解析解(很难从 xxx 逆推回 zzz)。所以我们在论文中,通常用它来做变分推断(Variational Inference):即从已知噪声 zzz 出发,推演到复杂分布,并在 zzz 的空间里集成计算。
为了让你看懂最核心的 "概率对数(Log-Probability)+ 雅可比对数(Log-Det)= 最终 Loss" 的逻辑,请看下面这段精简的 PyTorch 代码:
py
import torchimport torch.nn as nnimport numpy as np
class PlanarFlow(nn.Module):
def __init__(self, dim=2):
super().__init__()
# 初始化模型参数 w, u, b
self.w = nn.Parameter(torch.randn(dim, 1))
self.u = nn.Parameter(torch.randn(dim, 1))
self.b = nn.Parameter(torch.randn(1))
def forward(self, z):
# h 是 tanh 激活函数
# z 的形状为 [batch_size, 2]
# 1. 计算前向变换: x = z + u * tanh(w^T * z + b)
lin = torch.mm(z, self.w) + self.b # [batch_size, 1]
x = z + self.u.t() * torch.tanh(lin) # [batch_size, 2]
# 2. 计算雅可比行列式的对数 (Log-Absolute-Determinant-Jacobian, LADJ)
# tanh 的导数是 1 - tanh^2
h_prime = 1 - torch.tanh(lin) ** 2 # [batch_size, 1]
psi = h_prime * self.w.t() # [batch_size, 2]
# det = |1 + u^T * psi|
det = 1 + torch.mm(psi, self.u) # [batch_size, 1]
# 为了防止训练不稳定,通常需要对参数进行约束确保 det > 0,这里简化取绝对值
log_abs_det = torch.log(torch.abs(det) + 1e-6) # [batch_size, 1]
return x, log_abs_det
# --- 模拟一次训练和 Loss 计算 ---
# 1. 实例化模型flow = PlanarFlow(dim=2)
# 2. 从简单的基础分布(标准高斯分布)中采样噪声 zbatch_size = 64z = torch.randn(batch_size, 2)
# 3. 计算基础分布的对数概率 log p_z(z)# 二维标准高斯公式的对数形式:-0.5 * (2 * log(2*pi) + x_1^2 + x_2^2)log_p_z = -0.5 * (2 * np.log(2 * np.pi) + torch.sum(z ** 2, dim=1, keepdim=True))
# 4. 通过 Planar Flow 变换空间x, log_abs_det = flow(z)
# 5. 根据变量代换公式,计算变换后的复杂目标分布的对数似然 log p_x(x)# 注意:因为是从 z -> x 的方向,公式为:log p_x(x) = log p_z(z) - log_abs_detlog_p_x = log_p_z - log_abs_det
# 6. 我们的目标是最大化对数似然,所以 Loss 是负对数似然(Negative Log-Likelihood)loss = -torch.mean(log_p_x)
print(f"当前批次的 Loss 值为: {loss.item():.4f}")
4. 直观总结这个 Loss 在干什么?
看一下代码中的第 5、6 步公式:Loss = - (log_p_z - log_abs_det),也就是:
Loss=−logpz(z)+log∣det∂f(z)∂z∣\text{Loss} = -\log p_z(z) + \log \left\vert{} \det \frac{\partial f(z)}{\partial z} \right\vert{}Loss=−logpz(z)+log det∂z∂f(z)
- 前项 −logpz(z)-\log p_z(z)−logpz(z):希望生成的 zzz 尽量靠近标准高斯的核心(惩罚飞得太远的点)。
- 后项 log∣detJ∣\log \vert{} \det J \vert{}log∣detJ∣:这是空间膨胀惩罚项。如果网络走捷径,通过无限放大或无限压缩空间来让概率看起来很高,这一项就会变得极大,从而拉高 Loss。它强迫网络做"平滑、保持能量守恒"的空间拉伸。
通过反向传播更新 w,u,bw, u, bw,u,b,原本一团圆形的标准高斯噪声 zzz,就会被逐渐拉扯、弯曲成你指定的任何复杂形状(比如双峰分布、月牙形分布等)。
1. 理论效果:它证明了什么?
- 可微生成的可行性:它证明了可以通过一系列简单、可逆的变换,将标准高斯分布映射到任意复杂的数据分布。
- 精确似然的计算 :它确立了 Normalizing Flow 的核心优势------可以精确计算数据生成的概率密度(不像 GAN 只能打分,也不像 VAE 只有下界)。
- 单步采样能力:一旦训练完成,从噪声生成样本只需一次前向传播(Single-pass),速度极快。
2. 实际效果(在图像/高维数据上):较差
如果你拿 Planar Flow 去生成 MNIST 或 CIFAR-10 图像,效果通常不如以下模型:
- GANs(如 DCGAN, StyleGAN):视觉上更清晰,模式崩溃问题少。
- DDPM/Diffusion Models:目前的主流,生成质量极高。
- Better NFs(如 Glow, RealNVP):在似然值(NLL)上远超 Planar Flow。
为什么效果不好?核心原因是"表达能力受限":
A. 秩-1 更新限制 (Rank-1 Constraint)
Planar Flow 的每一层变换公式是:
znew=zold+uh(wTz+b) z_{new} = z_{old} + u h(w^T z + b) znew=zold+uh(wTz+b)
这本质上是一个秩-1 更新。
- 这意味着每一层只能在一个特定的方向(由向量 www 定义)上拉伸或压缩数据。
- 在垂直于 www 的方向上,数据完全不变。
- 后果 :要捕捉高维数据(如图像有数万像素)中复杂的、多维度的相关性,Planar Flow 需要堆叠非常非常多层(可能几百甚至上千层)。
B. 训练困难 (Vanishing Gradients)
由于需要堆叠很多层,深层网络会导致:
- 梯度消失/爆炸:反向传播时梯度可能变得极小或极大。
- 优化困难:损失函数地形非常崎岖,很难收敛到全局最优。
C. 难以捕捉高维流形
Planar Flow 擅长处理低维、结构简单的数据。对于高维、复杂流形(如真实照片),它很难通过简单的"平面"变换来拟合。
3. 实际效果(在低维/简单数据上):中等偏上
如果数据维度较低(如 2D、10D),或者分布相对简单,Planar Flow 可以工作得不错:
- 2D 可视化:在教科书中,Planar Flow 经常用于演示如何将高斯球变成双峰分布或环形分布。
- 基准测试:在低维基准数据集上,它能比 VAE 提供更好的似然估计。
Planar Flow 是 Normalizing Flow 的"Hello World"。它在理论上完美,但在实践中因表达能力太弱而难以应用于高维复杂任务。它是理解现代生成模型的重要基石,但本身已不再是主流的技术选择。