计算雅可比矩阵时,不用完整计算,采用JVP?如何理解?传统的 CNN 目标检测任务在算法本质上根本不需要JVP

一、背景

传统的 CNN 目标检测任务在算法本质上根本不需要它。在传统的 CNN 目标检测中,我们只和 VJP(Vector-Jacobian Product,向量-雅可比积) 打交道,也就是我们挂在嘴边的反向传播(Backpropagation / loss.backward())。

1. 目标检测的输出最终是"标量损失(Scalar Loss)"

  • 目标检测的流程:
    输入图片 → CNN 骨干网络 → 检测头 → 输出成千上万个 Bounding Box 和类别概率 → 计算 Loss(一个单维度的标量,如 Loss = 2.3)。
  • 数学本质:
    因为最终的 Loss 是一个标量(Scalar),我们求的是 Loss 对网络参数 W 的梯度(∂Loss∂W\frac{\partial Loss}{\partial W}∂W∂Loss)。
  • 为什么不需要 JVP:
    当输出是标量时,雅可比矩阵其实只有 1 行(变成了一个行向量)。我们要计算的是向量与矩阵的乘积,PyTorch 的 backward() 底层自动用的就是 VJP。VJP 完美契合"单输出、多输入"的反向传播,因此完全不需要 JVP。

2. 什么时候才会轮到 JVP 登场?

JVP(前向模式微分)擅长处理的是"输入维度小,输出维度大",或者"我们需要知道输入沿着某个特定特定方向变化时,整个复杂输出系统的即时反应"。

在深度学习中,JVP 主要活跃在以下非传统目标检测的领域:

  • 生成模型(Flow Matching / 扩散模型 / CNF):如我们最开始讨论的,这类模型需要计算漂移场(Velocity Field)的散度(Divergence)。散度需要对网络输出的每一个维度求偏导并求和,此时利用 JVP 配合随机采样(Hutchinson 估计),可以在不显式计算大矩阵的情况下算出散度。
  • 内隐层模型(Implicit Models / DEQ):在神经常微分方程(Neural ODEs)或寻根网络中,前向传播本身就是一个迭代优化过程,需要用 JVP 来加速前向的敏感度分析。
  • 物理驱动深度学习(PINN):当用神经网络去解偏微分方程(PDE)时,损失函数直接包含了解的导数项,需要频繁用到 JVP。

3. 目标检测里真的完全没有雅可比(Jacobian)吗?

其实是有的,但它隐式地藏在坐标回归的 Loss 以及 无锚框(Anchor-free)/ 变换网络(Transformer) 的几何变换中:

  • 比如在一些旋转目标检测(Oriented Object Detection)或仿射变换(Affine Transformation)中,为了将特征图从一个坐标系映射到另一个坐标系,数学上会引入雅可比矩阵来计算面积或梯度的缩放比例。
  • 但由于这些变换的维度非常小(通常只是 (x,y,w,h,θ)(x, y, w, h, \theta)(x,y,w,h,θ) 这 5 个维度的变换),框架直接硬算或者写死了公式,不需要调用显式的 JVP 工具。

二、JVP的方法

在深度学习和数值计算中,"不用完整计算雅可比矩阵,而是采用 JVP (Jacobian-Vector Product,雅可比-向量积)" 是降低显存消耗和计算复杂度的核心技巧。简单来说:我们不需要知道完整的矩阵长什么样,我们只需要知道这个矩阵乘以一个特定向量后的"结果"是什么。

1. 为什么完整计算雅可比矩阵不可行?

显存开销(空间复杂度):高维灾难

假设你的输入 xxx 是一个大小为 nnn 的向量,输出 f(x)f(x)f(x) 是一个大小为 mmm 的向量。

  • 雅可比矩阵 JJJ 的大小为 m×nm \times nm×n。
  • 实际例子:在图像生成(如 Flow Matching)中,假设你处理的是一张大小为 256×256256 \times 256256×256 的单通道图片,n=65,536n = 65,536n=65,536。如果输出维度相同(m=nm = nm=n),那么雅可比矩阵包含的元素数量为:
    65,536×65,536≈43 亿个元素65,536 \times 65,536 \approx 43 \text{ 亿个元素}65,536×65,536≈43 亿个元素
  • 光是存储这一个矩阵(使用 FP32 单精度,每个元素 4 字节),就需要约 16 GB 的显存。这还仅仅是一张图片,如果加上 Batch Size,显存会直接溢出(OOM)。

计算量(时间复杂度):需要多次"反向传播"

现代深度学习框架(如 PyTorch、TensorFlow)擅长计算的是标量对向量的导数。当输出是 mmm 维向量时,框架在底层无法一次性吐出整个矩阵。

  • 底层原理:框架为了构建完整的 m×nm \times nm×n 矩阵,通常需要进行 mmm 次独立的反向传播(Backpropagation),或者 nnn 次独立的前向传播。
  • 每一次反向传播,只能填满雅可比矩阵的其中一行(即某一个输出分量 fif_ifi 对所有输入 xxx 的梯度)。
  • 代价:如果你的输出维度 m=1000m = 1000m=1000,模型就必须把整个网络的前向和反向计算重复跑 1000 次。这使得训练或推理速度直接变慢成千上万倍。

2. 什么是 JVP(雅可比-向量积)?

JVP 计算的是 J⋅vJ \cdot vJ⋅v,其中 vvv 是一个与输入 xxx 维度相同的已知向量(大小为 n×1n \times 1n×1)。

  • 矩阵 JJJ 是 m×nm \times nm×n,向量 vvv 是 n×1n \times 1n×1。
  • 相乘后的结果 J⋅vJ \cdot vJ⋅v 是一个大小仅为 m×1m \times 1m×1 的向量。
  • 核心优势:JVP 允许我们绕过显式构建 m×nm \times nm×n 巨型矩阵的过程,直接暴露出最终的 m×1m \times 1m×1 向量结果。

直观理解 JVP 的物理意义,在流匹配(Flow Matching)或微积分中:

  • 向量 vvv 可以理解为输入空间中的一个变化方向(或速度)。
  • J⋅vJ \cdot vJ⋅v(即 JVP)则代表:当输入 xxx 沿着 vvv 方向发生微小改变时,输出 yyy 随之产生的变化方向(或速度)。
  • 它本质上就是前向模式自动微分(Forward-mode Automatic Differentiation)。

3. 怎么做到"不完整计算"却能得到结果?

PyTorch (通过 torch.func.jvp) 或 JAX 等框架在底层实现 JVP 时,利用了链式法则的伴随传播。

它不会先算 JJJ 再乘 vvv,而是在计算图前向传播的同时,把 vvv 顺着计算图一步步传下去:

  • 假设操作是 y=f(g(x))y = f(g(x))y=f(g(x))。
  • 框架会先计算 g(x)g(x)g(x) 的 JVP 得到 v′v'v′。
  • 然后立刻将 v′v'v′ 传给 fff,计算 fff 的 JVP。
  • 整个过程中,内存里始终只有和原始数据一样大的向量,而没有任何中间的大矩阵。

在流匹配(Flow Matching)、CNF(连续常微分方程归一化流)等领域,我们虽然理论上遇到了 Jacobian,但实际代码中都会通过数学技巧绕过去:

  1. 如果你需要计算 JJJ 乘以一个向量(如 J⋅v\boldsymbol{J} \cdot \boldsymbol{v}J⋅v):
  • 采用 JVP (Jacobian-Vector Product)。计算量等同于运行一次普通的前向传播,空间复杂度瞬间从 O(m×n)\mathcal{O}(m \times n)O(m×n) 降到 O(m+n)\mathcal{O}(m + n)O(m+n)。
  1. 如果你需要计算一个向量乘以 JJJ(如 uT⋅J\boldsymbol{u}^T \cdot \boldsymbol{J}uT⋅J):
  • 采用 VJP (Vector-Jacobian Product)。这本质上就是 PyTorch 最拿手的 backward()。计算量等同于运行一次普通的反向传播。
  1. 如果你需要计算 Jacobian 的迹(Trace,即散度 ∇⋅f(x)\nabla \cdot f(x)∇⋅f(x)):
  • 比如在计算似然(Likelihood)时,需要算 Tr(J)\text{Tr}(J)Tr(J)。
    • 我们不会硬算,而是借用 Hutchinson 随机估计器:
      Tr(J)=Eϵ∼N(0,I)ϵTJϵ\text{Tr}(J) = \mathbb{E}_{\epsilon \sim \mathcal{N}(0, I)} \\epsilon\^T J \\epsilonTr(J)=Eϵ∼N(0,I)ϵTJϵ
    • 这里 JϵJ\epsilonJϵ 就是一个标准的 JVP(或 ϵTJ\epsilon^T JϵTJ 是 VJP),只需要一次自动微分就能估算出来。

4. 物理意义:我们关心的不是"所有的导数",而是"某个方向的变化"

这是数值计算和深度学习中极其精妙的核心思想。之所以"只需要结果,不需要矩阵本身",是因为我们最终的物理目的、优化目标以及计算机的底层实现,都只需要那个"相乘后的向量"。

雅可比矩阵 JJJ 存储了所有输入分量对所有输出分量的导数。这就好比一本厚厚的地形字典,记录了地图上每一个点在所有方向上的坡度。

但在实际应用中,我们往往不需要知道整本地形字典,我们只关心"如果我往特定方向迈出一步(向量 vvv),我的高度会发生什么变化(结果 J⋅vJ \cdot vJ⋅v)"。

  • JVP (J⋅vJ \cdot vJ⋅v) 的直观意义:输入沿着 vvv 方向移动时,输出移动的方向和速度。
  • VJP (uT⋅Ju^T \cdot JuT⋅J) 的直观意义:在输出的各个分量中,如果我给它们分配不同的权重/扰动(向量 uuu),这会如何反向作用于输入的每一个分量。

举个具体例子(自动驾驶/目标检测):

假设输入 xxx 是汽车的控制参数(油门、方向盘),输出 yyy 是汽车在 3D 空间中的位置和姿态。

  • 完整的雅可比矩阵 JJJ 会告诉你:每一个控制微调对每一个位置坐标的所有可能影响(信息量巨大)。
  • 但现在你只想知道:"如果我猛打方向盘(给一个特定的方向向量 vvv),汽车的姿态会怎么变?"
  • 你只需要计算 J⋅vJ \cdot vJ⋅v,它直接给你一个反映汽车新姿态的变化向量。在这个过程中,你根本不需要知道如果踩油门会发生什么,因此完整计算 JJJ 是极大的浪费。

5. 数学目的:矩阵只是中间桥梁,最终目标是"降维"

在深度学习的算法设计中,雅可比矩阵几乎永远只是一个中间过渡,它最终都会被压缩成一个标量或者低维向量。

  • 在损失函数优化中:我们最终的目的是为了求 Loss(标量)。
    在反向传播时,我们求的是 ∂Loss∂x\frac{\partial Loss}{\partial x}∂x∂Loss。根据链式法则,它等于:
    ∂Loss∂x=∂Loss∂y⋅∂y∂x=uT⋅J\frac{\partial Loss}{\partial x} = \frac{\partial Loss}{\partial y} \cdot \frac{\partial y}{\partial x} = u^T \cdot J∂x∂Loss=∂y∂Loss⋅∂x∂y=uT⋅J
    看到了吗?我们数学上真正的目标是求出最左边的梯度向量,雅可比矩阵 JJJ 只是夹在中间的工具。我们直接用 VJP 算出 uT⋅Ju^T \cdot JuT⋅J 的最终向量即可。
  • 在流匹配(Flow Matching)计算散度时:我们的目的是为了求 Jacobian 的迹(Trace,对角线之和)。
    通过 Hutchinson 随机估计,我们将求迹问题转化为了求 ϵT(Jϵ)\epsilon^T (J \epsilon)ϵT(Jϵ)。
    我们先用 JVP 算出向量 v′=Jϵv' = J \epsilonv′=Jϵ,再让 ϵT\epsilon^TϵT 与之做内积。自始至终,我们的数学目标只是为了得到一个单一的标量数值。为了得到一个标量,去算一个几万亿元素的矩阵,显然是不合理的。

6. 计算效率:动态图"边走边算",绕过显式矩阵

从计算机科学的角度来看,不计算完整矩阵不仅是"不想算",更是因为框架在底层实现 JVP/VJP 时,根本不需要构造这个矩阵。

现代自动微分(如 PyTorch, JAX)利用了计算图的线性相继性:

假设你的网络是由很多层堆叠而成的:x→h1→h2→yx \to h_1 \to h_2 \to yx→h1→h2→y。

那么总的雅可比矩阵是各层雅可比矩阵的乘积:J=J3⋅J2⋅J1J = J_3 \cdot J_2 \cdot J_1J=J3⋅J2⋅J1。

如果要完整计算 JJJ:

  1. 计算机必须把庞大的 J1J_1J1、J2J_2J2、J3J_3J3 全部显式计算并存下来。
  2. 然后把这些巨型矩阵做矩阵乘法(极其消耗显存和算力)。

如果要计算 JVP (J⋅vJ \cdot vJ⋅v):

  1. 计算机从输入端出发,拿着向量 vvv。
  2. 经过第一层时,直接计算 v1=J1⋅vv_1 = J_1 \cdot vv1=J1⋅v(矩阵与向量相乘,结果还是个小向量)。
  3. 经过第二层时,直接计算 v2=J2⋅v1v_2 = J_2 \cdot v_1v2=J2⋅v1。
  4. 经过第三层时,直接计算 v3=J3⋅v2v_3 = J_3 \cdot v_2v3=J3⋅v2。

核心秘密就在这里:在整个前向传播的过程中,计算机从来没有在内存里拼凑过任何一个大矩阵,它只是在一层层地传递和更新一个大小和特征图一模一样的小向量。这种"边走边算"的机制,用极低的代价换取了我们想要的最终结果。


5.总结对照

"知其果,不必知其因。" 在高维世界中,把所有变量之间的勾连关系(完整的 Jacobian)巨细靡遗地打印出来是计算机的灾难。抓住我们关心的特定方向(向量),让它顺着网络流淌,直接拿到高维空间坍缩后的结果,才是深度学习能跑在普通 GPU 上的关键。

这个"向量代替矩阵"的逻辑,在不同的数学场景下有不同的妙用。

  • 你想了解如何用这个原理,在 Flow Matching 中不费吹灰之力就估算出图像生成的似然(Likelihood/散度)吗?
  • 或者说,你在代码中遇到了某些具体的矩阵乘法导致了显存溢出(OOM),想看看怎么用 JVP/VJP 去改造它?
特性 完整雅可比矩阵 (JJJ) 雅可比-向量积 (JVPJVPJVP)
空间复杂度 O(m×n)\mathcal{O}(m \times n)O(m×n) (灾难级) O(m+n)\mathcal{O}(m + n)O(m+n) (极小)
计算逻辑 算出所有输入对所有输出的偏导 只算输出在特定输入方向上的投影
对应工具 torch.autograd.functional.jacobian torch.func.jvp (前向微分)

三、一般代码示例

在 PyTorch 中,计算完整的雅可比矩阵、JVP(雅可比-向量积)以及 VJP(向量-雅可比积)有不同的高效实现方式。

下面为你提供一套完整的 Python 代码示例,对比完整计算(低效)与 JVP/VJP(高效)的写法。

我们需要先安装或确保 PyTorch 版本较新(推荐 PyTorch 2.0+),因为它内置了强大的 torch.func 模块(原名 functorch),专门用于处理这类高级自动微分。

复制代码
import torch
from torch.func import jacrev, jacfwd, jvp, vjp

# 1. 定义一个多维输入、多维输出的复杂函数 f(x)
# 假设输入 x 是 3 维,输出 y 是 2 维: y = f(x)
def f(x):
    y1 = x[0]**2 + x[1] * x[2]
    y2 = torch.sin(x[0]) + x[2]**3
    return torch.stack([y1, y2])

# 初始化输入数据 x (3维)
x = torch.tensor([2.0, 3.0, 4.0], requires_grad=True)
print(f"输入 x: {x}\n")

示例 1:完整计算 Jacobian 矩阵(高维时不推荐)

PyTorch 提供了 jacrev(基于反向传播,适合 m<nm < nm<n)和 jacfwd(基于前向传播,适合 m>nm > nm>n)。

复制代码
# 完整计算 Jacobian (大小为 2 x 3)
jacobian_matrix = jacrev(f)(x)

print("--- 1. 完整的 Jacobian 矩阵 ---")
print(jacobian_matrix)
# 输出结构为:
# [[dy1/dx0, dy1/dx1, dy1/dx2],
#  [dy2/dx0, dy2/dx1, dy2/dx2]]

示例 2:使用 JVP(雅可比-向量积)------ 高效绕过大矩阵

当你只需要计算 J⋅vJ \cdot vJ⋅v,且 vvv 的维度与输入 xxx 相同时使用。整个过程不构建上面的 2×32 \times 32×3 矩阵。

复制代码
# 定义一个与输入 x 维度相同的切向量 v
v = torch.tensor([1.0, 0.5, -1.0])

# 计算 JVP
# jvp 接收的参数依次是:函数、(输入元组,)、(向量元组,)
# 它会同时返回:f(x) 的输出结果,以及 J*v 的结果
f_x_output, jvp_result = jvp(f, (x,), (v,))

print("\n--- 2. JVP (Jacobian-Vector Product) ---")
print(f"f(x) 的输出: {f_x_output}")
print(f"J * v 的结果: {jvp_result}")  # 这是一个 2 维向量

示例 3:使用 VJP(向量-雅可比积)------ 深度学习反向传播的核心

当你只需要计算 uT⋅Ju^T \cdot JuT⋅J,且 uuu 的维度与输出 f(x)f(x)f(x) 相同时使用。

复制代码
# 定义一个与输出 f(x) 维度相同的余切向量 u
u = torch.tensor([2.0, 1.0])

# 计算 VJP
# vjp 返回两样东西:f(x) 的输出结果,以及一个"能计算 VJP 的函数 (vjp_fn)"
f_x_output, vjp_fn = vjp(f, x)

# 将向量 u 传给 vjp_fn,得到 u^T * J
vjp_result = vjp_fn(u)[0]

print("\n--- 3. VJP (Vector-Jacobian Product) ---")
print(f"f(x) 的输出: {f_x_output}")
print(f"u^T * J 的结果: {vjp_result}")  # 这是一个 3 维向量

示例 4:实战应用 ------ 利用 JVP 快速估计 Jacobian 的迹(Trace / 散度)

在 Flow Matching 或 CNF 中,我们需要计算 Tr(J)=∑∂fi∂xi\text{Tr}(J) = \sum \frac{\partial f_i}{\partial x_i}Tr(J)=∑∂xi∂fi。如果维度有几百万,算完整 Jacobian 再求对角线和会直接爆显存。

我们用 Hutchinson 估计器,只需运行一次 JVP 即可无偏估计出 Trace:``

复制代码
# 假设高维场景下,随机采样一个与 x 维度相同的噪声向量 ϵ ~ N(0, I)
epsilon = torch.randn_like(x)

# 只需要算一次 JVP,得到 J * ϵ
_, j_epsilon = jvp(f, (x,), (epsilon,))

# 计算 ϵ^T * J * ϵ
trace_estimate = torch.dot(epsilon, j_epsilon)

print("\n--- 4. 实战:Jacobian 迹的随机估计 (Hutchinson) ---")
print(f"Jacobian 迹的估计值: {trace_estimate.item()}")
print(f"对比真实迹 (对角线之和): {jacobian_matrix[0,0] + jacobian_matrix[1,1]}")

四、在 Flow Matching应用

在 Flow Matching(以及连续归一化流 CNF)中,要计算一幅图像生成的精确似然(Likelihood),数学上要求解连续变量的变化率,这就必须计算速度场(Velocity Field,即你的神经网络 f(x))的散度(Divergence)。

散度的定义是雅可比矩阵的迹(Trace),即对角线元素之和:

Div(f(x))=Tr(J)=∑i=1n∂fi(x)∂xi\text{Div}(f(x)) = \text{Tr}(J) = \sum_{i=1}^{n} \frac{\partial f_i(x)}{\partial x_i}Div(f(x))=Tr(J)=i=1∑n∂xi∂fi(x)

如果图像维度 n 是几十万,算完整的 J 再取对角线会直接导致显存崩溃。而利用 JVP/VJP + Hutchinson 随机估计器,我们可以"不费吹灰之力",只需一次额外的前向或反向传播就能精准估算出它。

以下是具体的演进逻辑和 PyTorch 实战代码:

1. 核心数学转换(Hutchinson 技巧)

Hutchinson 估计器提出了一个精妙的数学恒等式:对于任意方阵 J,如果我们引入一个随机噪声向量 ε,只要这个噪声满足均值为 0、协方差矩阵为单位阵 I(例如标准的正态分布或 Rademacher 分布),那么就有:

Tr(J)=EϵϵTJϵ\text{Tr}(J) = \mathbb{E}_{\epsilon} \left \\epsilon\^T J \\epsilon \\rightTr(J)=EϵϵTJϵ

我们把这个公式拆开看,就会发现它完美的变成了 JVP(或 VJP) 的形状:

ϵTJϵ=ϵT⋅(J⋅ϵ)=ϵT⋅JVP(f,x,ϵ)\epsilon^T J \epsilon = \epsilon^T \cdot (J \cdot \epsilon) = \epsilon^T \cdot \text{JVP}(f, x, \epsilon)ϵTJϵ=ϵT⋅(J⋅ϵ)=ϵT⋅JVP(f,x,ϵ)

这就是不费吹灰之力的秘密:

  1. 随机抽一个和图像尺寸完全一样的噪声 ε。
  2. 不用算 J,直接用 jvp 工具算出向量 Jε(它的大小和图像一模一样)。
  3. 让 ε 和 Jε 做点积(Dot Product),直接得到一个标量数值!

2. PyTorch 实战代码示例

在 Flow Matching 的实际训练或评估中,输入数据通常包含 Batch 维度。为了高效处理 Batch 数据,我们需要结合 PyTorch 的 torch.func.vmap(自动向量化映射)和 jvp。

以下是标准的工业级实现:

复制代码
import torch
from torch.func import jvp, vmap

# 1. 模拟一个 Flow Matching 的速度场网络 (Velocity Field / Drift)
# 假设输入图像特征图大小为 [Channels, H, W],例如 3 * 32 * 32 = 3072 维
class SimpleVelocityField(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(3072, 4096),
            torch.nn.GELU(),
            torch.nn.Linear(4096, 3072)
        )
    def forward(self, x):
        # x 形状为 [3072]
        return self.net(x)

model = SimpleVelocityField().eval()

# 2. 准备 Batch 数据
batch_size = 16
img_dim = 3072
# 模拟网络输入:[Batch, Dim]
x_batch = torch.randn(batch_size, img_dim) 

# 3. 核心:定义单个样本的 JVP 散度计算函数
def compute_single_div(x, eps):
    # jvp 接收:函数、(输入元组,)、(特定方向向量元组,)
    # 返回:(网络输出, J * eps)
    _, j_eps = jvp(model, (x,), (eps,))
    
    # eps^T * J * eps 
    # 因为是单个样本(一维向量),直接做内积(点积)
    return torch.dot(eps, j_eps)

# 4. 关键加速:使用 vmap 批量并行处理整个 Batch 
# 采样与 x_batch 形状完全一致的随机噪声 ϵ ~ N(0, I)
epsilon_batch = torch.randn_like(x_batch)

# vmap 会自动并行处理 Batch 维(第 0 维)
batch_div_fn = vmap(compute_single_div, in_dims=(0, 0))

# 只需要一步,算出整个 Batch 每一个样本的速度场散度!
divergence = batch_div_fn(x_batch, epsilon_batch)

print(f"输入 Batch 形状: {x_batch.shape}")
print(f"估算出的散度 (每个样本一个标量): \n{divergence}")
print(f"散度输出形状: {divergence.shape}") # 形状为 [16],没有产生任何大矩阵!

3. 为什么说它"不费吹灰之力"?

我们可以从两个硬指标来看看这个方案有多优雅:

  • 显存几乎归零:如果图像尺寸是 3072,完整雅可比矩阵需要存储 3072 × 3072 ≈ 940 万个浮点数。而上面这段代码在运行过程中,显存里最大的变量也就是和输入一样大的 epsilon_batch(16 × 3072),显存开销降低了成千上万倍。
  • 计算速度极快:普通的 jvp 底层是通过前向自动微分(Forward-mode AD)实现的。计算一次 jvp 的耗时大约仅仅相当于网络做一次普通前向传播(Forward Pass)的 2 到 3 倍。

4. 了解这个对算似然(Likelihood)有什么用?

在 Flow Matching 中,根据连续连续变量公式(Instantaneous Change of Variables 公式),一幅图像 x₀ 变换到噪声 x₁ 的对数似然(Log-Likelihood)变化,可以通过沿着 ODE 轨迹对散度进行积分得到:

log⁡p1(x1)−log⁡p0(x0)=−∫01Div(f(xt))dt\log p_1(x_1) - \log p_0(x_0) = -\int_{0}^{1} \text{Div}(f(x_t)) dtlogp1(x1)−logp0(x0)=−∫01Div(f(xt))dt

有了上面这个高效的散度估计器,你就可以在使用 ODE 求解器(如 torchdiffeq)对图像进行生成或重建的同时,顺手把散度丢给求解器一起做积分。整个生成结束时,图像的精确似然(比如用于衡量生成质量的 BPD 位每像素指标)就被同步算出来了,完全不会拖慢你的训练和推理速度。

这个 Hutchinson 估计器不仅可以使用正态分布(Gaussian)噪声,在生图领域中大家往往更喜欢使用 Rademacher 噪声(即只有 +1 和 -1 的随机噪声),后者的方差通常更小。