一、背景
传统的 CNN 目标检测任务在算法本质上根本不需要它。在传统的 CNN 目标检测中,我们只和 VJP(Vector-Jacobian Product,向量-雅可比积) 打交道,也就是我们挂在嘴边的反向传播(Backpropagation / loss.backward())。
1. 目标检测的输出最终是"标量损失(Scalar Loss)"
- 目标检测的流程:
输入图片 → CNN 骨干网络 → 检测头 → 输出成千上万个 Bounding Box 和类别概率 → 计算 Loss(一个单维度的标量,如 Loss = 2.3)。 - 数学本质:
因为最终的 Loss 是一个标量(Scalar),我们求的是 Loss 对网络参数 W 的梯度(∂Loss∂W\frac{\partial Loss}{\partial W}∂W∂Loss)。 - 为什么不需要 JVP:
当输出是标量时,雅可比矩阵其实只有 1 行(变成了一个行向量)。我们要计算的是向量与矩阵的乘积,PyTorch 的 backward() 底层自动用的就是 VJP。VJP 完美契合"单输出、多输入"的反向传播,因此完全不需要 JVP。
2. 什么时候才会轮到 JVP 登场?
JVP(前向模式微分)擅长处理的是"输入维度小,输出维度大",或者"我们需要知道输入沿着某个特定特定方向变化时,整个复杂输出系统的即时反应"。
在深度学习中,JVP 主要活跃在以下非传统目标检测的领域:
- 生成模型(Flow Matching / 扩散模型 / CNF):如我们最开始讨论的,这类模型需要计算漂移场(Velocity Field)的散度(Divergence)。散度需要对网络输出的每一个维度求偏导并求和,此时利用 JVP 配合随机采样(Hutchinson 估计),可以在不显式计算大矩阵的情况下算出散度。
- 内隐层模型(Implicit Models / DEQ):在神经常微分方程(Neural ODEs)或寻根网络中,前向传播本身就是一个迭代优化过程,需要用 JVP 来加速前向的敏感度分析。
- 物理驱动深度学习(PINN):当用神经网络去解偏微分方程(PDE)时,损失函数直接包含了解的导数项,需要频繁用到 JVP。
3. 目标检测里真的完全没有雅可比(Jacobian)吗?
其实是有的,但它隐式地藏在坐标回归的 Loss 以及 无锚框(Anchor-free)/ 变换网络(Transformer) 的几何变换中:
- 比如在一些旋转目标检测(Oriented Object Detection)或仿射变换(Affine Transformation)中,为了将特征图从一个坐标系映射到另一个坐标系,数学上会引入雅可比矩阵来计算面积或梯度的缩放比例。
- 但由于这些变换的维度非常小(通常只是 (x,y,w,h,θ)(x, y, w, h, \theta)(x,y,w,h,θ) 这 5 个维度的变换),框架直接硬算或者写死了公式,不需要调用显式的 JVP 工具。
二、JVP的方法
在深度学习和数值计算中,"不用完整计算雅可比矩阵,而是采用 JVP (Jacobian-Vector Product,雅可比-向量积)" 是降低显存消耗和计算复杂度的核心技巧。简单来说:我们不需要知道完整的矩阵长什么样,我们只需要知道这个矩阵乘以一个特定向量后的"结果"是什么。
1. 为什么完整计算雅可比矩阵不可行?
显存开销(空间复杂度):高维灾难
假设你的输入 xxx 是一个大小为 nnn 的向量,输出 f(x)f(x)f(x) 是一个大小为 mmm 的向量。
- 雅可比矩阵 JJJ 的大小为 m×nm \times nm×n。
- 实际例子:在图像生成(如 Flow Matching)中,假设你处理的是一张大小为 256×256256 \times 256256×256 的单通道图片,n=65,536n = 65,536n=65,536。如果输出维度相同(m=nm = nm=n),那么雅可比矩阵包含的元素数量为:
65,536×65,536≈43 亿个元素65,536 \times 65,536 \approx 43 \text{ 亿个元素}65,536×65,536≈43 亿个元素 - 光是存储这一个矩阵(使用 FP32 单精度,每个元素 4 字节),就需要约 16 GB 的显存。这还仅仅是一张图片,如果加上 Batch Size,显存会直接溢出(OOM)。
计算量(时间复杂度):需要多次"反向传播"
现代深度学习框架(如 PyTorch、TensorFlow)擅长计算的是标量对向量的导数。当输出是 mmm 维向量时,框架在底层无法一次性吐出整个矩阵。
- 底层原理:框架为了构建完整的 m×nm \times nm×n 矩阵,通常需要进行 mmm 次独立的反向传播(Backpropagation),或者 nnn 次独立的前向传播。
- 每一次反向传播,只能填满雅可比矩阵的其中一行(即某一个输出分量 fif_ifi 对所有输入 xxx 的梯度)。
- 代价:如果你的输出维度 m=1000m = 1000m=1000,模型就必须把整个网络的前向和反向计算重复跑 1000 次。这使得训练或推理速度直接变慢成千上万倍。
2. 什么是 JVP(雅可比-向量积)?
JVP 计算的是 J⋅vJ \cdot vJ⋅v,其中 vvv 是一个与输入 xxx 维度相同的已知向量(大小为 n×1n \times 1n×1)。
- 矩阵 JJJ 是 m×nm \times nm×n,向量 vvv 是 n×1n \times 1n×1。
- 相乘后的结果 J⋅vJ \cdot vJ⋅v 是一个大小仅为 m×1m \times 1m×1 的向量。
- 核心优势:JVP 允许我们绕过显式构建 m×nm \times nm×n 巨型矩阵的过程,直接暴露出最终的 m×1m \times 1m×1 向量结果。
直观理解 JVP 的物理意义,在流匹配(Flow Matching)或微积分中:
- 向量 vvv 可以理解为输入空间中的一个变化方向(或速度)。
- J⋅vJ \cdot vJ⋅v(即 JVP)则代表:当输入 xxx 沿着 vvv 方向发生微小改变时,输出 yyy 随之产生的变化方向(或速度)。
- 它本质上就是前向模式自动微分(Forward-mode Automatic Differentiation)。
3. 怎么做到"不完整计算"却能得到结果?
PyTorch (通过 torch.func.jvp) 或 JAX 等框架在底层实现 JVP 时,利用了链式法则的伴随传播。
它不会先算 JJJ 再乘 vvv,而是在计算图前向传播的同时,把 vvv 顺着计算图一步步传下去:
- 假设操作是 y=f(g(x))y = f(g(x))y=f(g(x))。
- 框架会先计算 g(x)g(x)g(x) 的 JVP 得到 v′v'v′。
- 然后立刻将 v′v'v′ 传给 fff,计算 fff 的 JVP。
- 整个过程中,内存里始终只有和原始数据一样大的向量,而没有任何中间的大矩阵。
在流匹配(Flow Matching)、CNF(连续常微分方程归一化流)等领域,我们虽然理论上遇到了 Jacobian,但实际代码中都会通过数学技巧绕过去:
- 如果你需要计算 JJJ 乘以一个向量(如 J⋅v\boldsymbol{J} \cdot \boldsymbol{v}J⋅v):
- 采用 JVP (Jacobian-Vector Product)。计算量等同于运行一次普通的前向传播,空间复杂度瞬间从 O(m×n)\mathcal{O}(m \times n)O(m×n) 降到 O(m+n)\mathcal{O}(m + n)O(m+n)。
- 如果你需要计算一个向量乘以 JJJ(如 uT⋅J\boldsymbol{u}^T \cdot \boldsymbol{J}uT⋅J):
- 采用 VJP (Vector-Jacobian Product)。这本质上就是 PyTorch 最拿手的 backward()。计算量等同于运行一次普通的反向传播。
- 如果你需要计算 Jacobian 的迹(Trace,即散度 ∇⋅f(x)\nabla \cdot f(x)∇⋅f(x)):
- 比如在计算似然(Likelihood)时,需要算 Tr(J)\text{Tr}(J)Tr(J)。
- 我们不会硬算,而是借用 Hutchinson 随机估计器:
Tr(J)=Eϵ∼N(0,I)ϵTJϵ\text{Tr}(J) = \mathbb{E}_{\epsilon \sim \mathcal{N}(0, I)} \\epsilon\^T J \\epsilonTr(J)=Eϵ∼N(0,I)ϵTJϵ - 这里 JϵJ\epsilonJϵ 就是一个标准的 JVP(或 ϵTJ\epsilon^T JϵTJ 是 VJP),只需要一次自动微分就能估算出来。
- 我们不会硬算,而是借用 Hutchinson 随机估计器:
4. 物理意义:我们关心的不是"所有的导数",而是"某个方向的变化"
这是数值计算和深度学习中极其精妙的核心思想。之所以"只需要结果,不需要矩阵本身",是因为我们最终的物理目的、优化目标以及计算机的底层实现,都只需要那个"相乘后的向量"。
雅可比矩阵 JJJ 存储了所有输入分量对所有输出分量的导数。这就好比一本厚厚的地形字典,记录了地图上每一个点在所有方向上的坡度。
但在实际应用中,我们往往不需要知道整本地形字典,我们只关心"如果我往特定方向迈出一步(向量 vvv),我的高度会发生什么变化(结果 J⋅vJ \cdot vJ⋅v)"。
- JVP (J⋅vJ \cdot vJ⋅v) 的直观意义:输入沿着 vvv 方向移动时,输出移动的方向和速度。
- VJP (uT⋅Ju^T \cdot JuT⋅J) 的直观意义:在输出的各个分量中,如果我给它们分配不同的权重/扰动(向量 uuu),这会如何反向作用于输入的每一个分量。
举个具体例子(自动驾驶/目标检测):
假设输入 xxx 是汽车的控制参数(油门、方向盘),输出 yyy 是汽车在 3D 空间中的位置和姿态。
- 完整的雅可比矩阵 JJJ 会告诉你:每一个控制微调对每一个位置坐标的所有可能影响(信息量巨大)。
- 但现在你只想知道:"如果我猛打方向盘(给一个特定的方向向量 vvv),汽车的姿态会怎么变?"
- 你只需要计算 J⋅vJ \cdot vJ⋅v,它直接给你一个反映汽车新姿态的变化向量。在这个过程中,你根本不需要知道如果踩油门会发生什么,因此完整计算 JJJ 是极大的浪费。
5. 数学目的:矩阵只是中间桥梁,最终目标是"降维"
在深度学习的算法设计中,雅可比矩阵几乎永远只是一个中间过渡,它最终都会被压缩成一个标量或者低维向量。
- 在损失函数优化中:我们最终的目的是为了求 Loss(标量)。
在反向传播时,我们求的是 ∂Loss∂x\frac{\partial Loss}{\partial x}∂x∂Loss。根据链式法则,它等于:
∂Loss∂x=∂Loss∂y⋅∂y∂x=uT⋅J\frac{\partial Loss}{\partial x} = \frac{\partial Loss}{\partial y} \cdot \frac{\partial y}{\partial x} = u^T \cdot J∂x∂Loss=∂y∂Loss⋅∂x∂y=uT⋅J
看到了吗?我们数学上真正的目标是求出最左边的梯度向量,雅可比矩阵 JJJ 只是夹在中间的工具。我们直接用 VJP 算出 uT⋅Ju^T \cdot JuT⋅J 的最终向量即可。 - 在流匹配(Flow Matching)计算散度时:我们的目的是为了求 Jacobian 的迹(Trace,对角线之和)。
通过 Hutchinson 随机估计,我们将求迹问题转化为了求 ϵT(Jϵ)\epsilon^T (J \epsilon)ϵT(Jϵ)。
我们先用 JVP 算出向量 v′=Jϵv' = J \epsilonv′=Jϵ,再让 ϵT\epsilon^TϵT 与之做内积。自始至终,我们的数学目标只是为了得到一个单一的标量数值。为了得到一个标量,去算一个几万亿元素的矩阵,显然是不合理的。
6. 计算效率:动态图"边走边算",绕过显式矩阵
从计算机科学的角度来看,不计算完整矩阵不仅是"不想算",更是因为框架在底层实现 JVP/VJP 时,根本不需要构造这个矩阵。
现代自动微分(如 PyTorch, JAX)利用了计算图的线性相继性:
假设你的网络是由很多层堆叠而成的:x→h1→h2→yx \to h_1 \to h_2 \to yx→h1→h2→y。
那么总的雅可比矩阵是各层雅可比矩阵的乘积:J=J3⋅J2⋅J1J = J_3 \cdot J_2 \cdot J_1J=J3⋅J2⋅J1。
如果要完整计算 JJJ:
- 计算机必须把庞大的 J1J_1J1、J2J_2J2、J3J_3J3 全部显式计算并存下来。
- 然后把这些巨型矩阵做矩阵乘法(极其消耗显存和算力)。
如果要计算 JVP (J⋅vJ \cdot vJ⋅v):
- 计算机从输入端出发,拿着向量 vvv。
- 经过第一层时,直接计算 v1=J1⋅vv_1 = J_1 \cdot vv1=J1⋅v(矩阵与向量相乘,结果还是个小向量)。
- 经过第二层时,直接计算 v2=J2⋅v1v_2 = J_2 \cdot v_1v2=J2⋅v1。
- 经过第三层时,直接计算 v3=J3⋅v2v_3 = J_3 \cdot v_2v3=J3⋅v2。
核心秘密就在这里:在整个前向传播的过程中,计算机从来没有在内存里拼凑过任何一个大矩阵,它只是在一层层地传递和更新一个大小和特征图一模一样的小向量。这种"边走边算"的机制,用极低的代价换取了我们想要的最终结果。
5.总结对照
"知其果,不必知其因。" 在高维世界中,把所有变量之间的勾连关系(完整的 Jacobian)巨细靡遗地打印出来是计算机的灾难。抓住我们关心的特定方向(向量),让它顺着网络流淌,直接拿到高维空间坍缩后的结果,才是深度学习能跑在普通 GPU 上的关键。
这个"向量代替矩阵"的逻辑,在不同的数学场景下有不同的妙用。
- 你想了解如何用这个原理,在 Flow Matching 中不费吹灰之力就估算出图像生成的似然(Likelihood/散度)吗?
- 或者说,你在代码中遇到了某些具体的矩阵乘法导致了显存溢出(OOM),想看看怎么用 JVP/VJP 去改造它?
| 特性 | 完整雅可比矩阵 (JJJ) | 雅可比-向量积 (JVPJVPJVP) |
|---|---|---|
| 空间复杂度 | O(m×n)\mathcal{O}(m \times n)O(m×n) (灾难级) | O(m+n)\mathcal{O}(m + n)O(m+n) (极小) |
| 计算逻辑 | 算出所有输入对所有输出的偏导 | 只算输出在特定输入方向上的投影 |
| 对应工具 | torch.autograd.functional.jacobian | torch.func.jvp (前向微分) |
三、一般代码示例
在 PyTorch 中,计算完整的雅可比矩阵、JVP(雅可比-向量积)以及 VJP(向量-雅可比积)有不同的高效实现方式。
下面为你提供一套完整的 Python 代码示例,对比完整计算(低效)与 JVP/VJP(高效)的写法。
我们需要先安装或确保 PyTorch 版本较新(推荐 PyTorch 2.0+),因为它内置了强大的 torch.func 模块(原名 functorch),专门用于处理这类高级自动微分。
import torch
from torch.func import jacrev, jacfwd, jvp, vjp
# 1. 定义一个多维输入、多维输出的复杂函数 f(x)
# 假设输入 x 是 3 维,输出 y 是 2 维: y = f(x)
def f(x):
y1 = x[0]**2 + x[1] * x[2]
y2 = torch.sin(x[0]) + x[2]**3
return torch.stack([y1, y2])
# 初始化输入数据 x (3维)
x = torch.tensor([2.0, 3.0, 4.0], requires_grad=True)
print(f"输入 x: {x}\n")
示例 1:完整计算 Jacobian 矩阵(高维时不推荐)
PyTorch 提供了 jacrev(基于反向传播,适合 m<nm < nm<n)和 jacfwd(基于前向传播,适合 m>nm > nm>n)。
# 完整计算 Jacobian (大小为 2 x 3)
jacobian_matrix = jacrev(f)(x)
print("--- 1. 完整的 Jacobian 矩阵 ---")
print(jacobian_matrix)
# 输出结构为:
# [[dy1/dx0, dy1/dx1, dy1/dx2],
# [dy2/dx0, dy2/dx1, dy2/dx2]]
示例 2:使用 JVP(雅可比-向量积)------ 高效绕过大矩阵
当你只需要计算 J⋅vJ \cdot vJ⋅v,且 vvv 的维度与输入 xxx 相同时使用。整个过程不构建上面的 2×32 \times 32×3 矩阵。
# 定义一个与输入 x 维度相同的切向量 v
v = torch.tensor([1.0, 0.5, -1.0])
# 计算 JVP
# jvp 接收的参数依次是:函数、(输入元组,)、(向量元组,)
# 它会同时返回:f(x) 的输出结果,以及 J*v 的结果
f_x_output, jvp_result = jvp(f, (x,), (v,))
print("\n--- 2. JVP (Jacobian-Vector Product) ---")
print(f"f(x) 的输出: {f_x_output}")
print(f"J * v 的结果: {jvp_result}") # 这是一个 2 维向量
示例 3:使用 VJP(向量-雅可比积)------ 深度学习反向传播的核心
当你只需要计算 uT⋅Ju^T \cdot JuT⋅J,且 uuu 的维度与输出 f(x)f(x)f(x) 相同时使用。
# 定义一个与输出 f(x) 维度相同的余切向量 u
u = torch.tensor([2.0, 1.0])
# 计算 VJP
# vjp 返回两样东西:f(x) 的输出结果,以及一个"能计算 VJP 的函数 (vjp_fn)"
f_x_output, vjp_fn = vjp(f, x)
# 将向量 u 传给 vjp_fn,得到 u^T * J
vjp_result = vjp_fn(u)[0]
print("\n--- 3. VJP (Vector-Jacobian Product) ---")
print(f"f(x) 的输出: {f_x_output}")
print(f"u^T * J 的结果: {vjp_result}") # 这是一个 3 维向量
示例 4:实战应用 ------ 利用 JVP 快速估计 Jacobian 的迹(Trace / 散度)
在 Flow Matching 或 CNF 中,我们需要计算 Tr(J)=∑∂fi∂xi\text{Tr}(J) = \sum \frac{\partial f_i}{\partial x_i}Tr(J)=∑∂xi∂fi。如果维度有几百万,算完整 Jacobian 再求对角线和会直接爆显存。
我们用 Hutchinson 估计器,只需运行一次 JVP 即可无偏估计出 Trace:``
# 假设高维场景下,随机采样一个与 x 维度相同的噪声向量 ϵ ~ N(0, I)
epsilon = torch.randn_like(x)
# 只需要算一次 JVP,得到 J * ϵ
_, j_epsilon = jvp(f, (x,), (epsilon,))
# 计算 ϵ^T * J * ϵ
trace_estimate = torch.dot(epsilon, j_epsilon)
print("\n--- 4. 实战:Jacobian 迹的随机估计 (Hutchinson) ---")
print(f"Jacobian 迹的估计值: {trace_estimate.item()}")
print(f"对比真实迹 (对角线之和): {jacobian_matrix[0,0] + jacobian_matrix[1,1]}")
四、在 Flow Matching应用
在 Flow Matching(以及连续归一化流 CNF)中,要计算一幅图像生成的精确似然(Likelihood),数学上要求解连续变量的变化率,这就必须计算速度场(Velocity Field,即你的神经网络 f(x))的散度(Divergence)。
散度的定义是雅可比矩阵的迹(Trace),即对角线元素之和:
Div(f(x))=Tr(J)=∑i=1n∂fi(x)∂xi\text{Div}(f(x)) = \text{Tr}(J) = \sum_{i=1}^{n} \frac{\partial f_i(x)}{\partial x_i}Div(f(x))=Tr(J)=i=1∑n∂xi∂fi(x)
如果图像维度 n 是几十万,算完整的 J 再取对角线会直接导致显存崩溃。而利用 JVP/VJP + Hutchinson 随机估计器,我们可以"不费吹灰之力",只需一次额外的前向或反向传播就能精准估算出它。
以下是具体的演进逻辑和 PyTorch 实战代码:
1. 核心数学转换(Hutchinson 技巧)
Hutchinson 估计器提出了一个精妙的数学恒等式:对于任意方阵 J,如果我们引入一个随机噪声向量 ε,只要这个噪声满足均值为 0、协方差矩阵为单位阵 I(例如标准的正态分布或 Rademacher 分布),那么就有:
Tr(J)=EϵϵTJϵ\text{Tr}(J) = \mathbb{E}_{\epsilon} \left \\epsilon\^T J \\epsilon \\rightTr(J)=EϵϵTJϵ
我们把这个公式拆开看,就会发现它完美的变成了 JVP(或 VJP) 的形状:
ϵTJϵ=ϵT⋅(J⋅ϵ)=ϵT⋅JVP(f,x,ϵ)\epsilon^T J \epsilon = \epsilon^T \cdot (J \cdot \epsilon) = \epsilon^T \cdot \text{JVP}(f, x, \epsilon)ϵTJϵ=ϵT⋅(J⋅ϵ)=ϵT⋅JVP(f,x,ϵ)
这就是不费吹灰之力的秘密:
- 随机抽一个和图像尺寸完全一样的噪声 ε。
- 不用算 J,直接用 jvp 工具算出向量 Jε(它的大小和图像一模一样)。
- 让 ε 和 Jε 做点积(Dot Product),直接得到一个标量数值!
2. PyTorch 实战代码示例
在 Flow Matching 的实际训练或评估中,输入数据通常包含 Batch 维度。为了高效处理 Batch 数据,我们需要结合 PyTorch 的 torch.func.vmap(自动向量化映射)和 jvp。
以下是标准的工业级实现:
import torch
from torch.func import jvp, vmap
# 1. 模拟一个 Flow Matching 的速度场网络 (Velocity Field / Drift)
# 假设输入图像特征图大小为 [Channels, H, W],例如 3 * 32 * 32 = 3072 维
class SimpleVelocityField(torch.nn.Module):
def __init__(self):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(3072, 4096),
torch.nn.GELU(),
torch.nn.Linear(4096, 3072)
)
def forward(self, x):
# x 形状为 [3072]
return self.net(x)
model = SimpleVelocityField().eval()
# 2. 准备 Batch 数据
batch_size = 16
img_dim = 3072
# 模拟网络输入:[Batch, Dim]
x_batch = torch.randn(batch_size, img_dim)
# 3. 核心:定义单个样本的 JVP 散度计算函数
def compute_single_div(x, eps):
# jvp 接收:函数、(输入元组,)、(特定方向向量元组,)
# 返回:(网络输出, J * eps)
_, j_eps = jvp(model, (x,), (eps,))
# eps^T * J * eps
# 因为是单个样本(一维向量),直接做内积(点积)
return torch.dot(eps, j_eps)
# 4. 关键加速:使用 vmap 批量并行处理整个 Batch
# 采样与 x_batch 形状完全一致的随机噪声 ϵ ~ N(0, I)
epsilon_batch = torch.randn_like(x_batch)
# vmap 会自动并行处理 Batch 维(第 0 维)
batch_div_fn = vmap(compute_single_div, in_dims=(0, 0))
# 只需要一步,算出整个 Batch 每一个样本的速度场散度!
divergence = batch_div_fn(x_batch, epsilon_batch)
print(f"输入 Batch 形状: {x_batch.shape}")
print(f"估算出的散度 (每个样本一个标量): \n{divergence}")
print(f"散度输出形状: {divergence.shape}") # 形状为 [16],没有产生任何大矩阵!
3. 为什么说它"不费吹灰之力"?
我们可以从两个硬指标来看看这个方案有多优雅:
- 显存几乎归零:如果图像尺寸是 3072,完整雅可比矩阵需要存储 3072 × 3072 ≈ 940 万个浮点数。而上面这段代码在运行过程中,显存里最大的变量也就是和输入一样大的 epsilon_batch(16 × 3072),显存开销降低了成千上万倍。
- 计算速度极快:普通的 jvp 底层是通过前向自动微分(Forward-mode AD)实现的。计算一次 jvp 的耗时大约仅仅相当于网络做一次普通前向传播(Forward Pass)的 2 到 3 倍。
4. 了解这个对算似然(Likelihood)有什么用?
在 Flow Matching 中,根据连续连续变量公式(Instantaneous Change of Variables 公式),一幅图像 x₀ 变换到噪声 x₁ 的对数似然(Log-Likelihood)变化,可以通过沿着 ODE 轨迹对散度进行积分得到:
logp1(x1)−logp0(x0)=−∫01Div(f(xt))dt\log p_1(x_1) - \log p_0(x_0) = -\int_{0}^{1} \text{Div}(f(x_t)) dtlogp1(x1)−logp0(x0)=−∫01Div(f(xt))dt
有了上面这个高效的散度估计器,你就可以在使用 ODE 求解器(如 torchdiffeq)对图像进行生成或重建的同时,顺手把散度丢给求解器一起做积分。整个生成结束时,图像的精确似然(比如用于衡量生成质量的 BPD 位每像素指标)就被同步算出来了,完全不会拖慢你的训练和推理速度。
这个 Hutchinson 估计器不仅可以使用正态分布(Gaussian)噪声,在生图领域中大家往往更喜欢使用 Rademacher 噪声(即只有 +1 和 -1 的随机噪声),后者的方差通常更小。