分数匹配算法(Score Matching)的目标函数“散度+模”计算量大,两种解决方法, 分层分数匹配和降噪分数匹配

tr()虽然可以计算, 但是计算成本时非常高的,尤其是在变量 x 很高维或者神经网络层次很深的时候, 通常是无法接受的。针对这个情况,一般有两种解决方法, 分层分数匹配(Sliced score matching)和降噪分数匹配(Denoising score matching)。

好的,这是您提供的文本的 Markdown 版本。


分层分数匹配 (Sliced Score Matching)

分层匹配使用一个随机投影矩阵可以近似计算 tr(∇ₓ s_θ(x)),改变后的目标函数为

(4.1.57)

E p v E p d a t a ( x ) v T ∇ x s θ ( x ) v + 1 2 ∥ s θ ( x ) ∥ 2 2 \mathbb{E}{p_v} \mathbb{E}{p_{data}(x)} \left v\^T \\nabla_x s_{\\theta}(x) v + \\frac{1}{2} \\\| s_{\\theta}(x) \\\|_2\^2 \\right EpvEpdata(x)vT∇xsθ(x)v+21∥sθ(x)∥22

其中 p v p_v pv 是一个简单的随机向量即可,比如多元正态分布。其中 v T ∇ x s θ ( x ) v v^T \nabla_x s_{\theta}(x) v vT∇xsθ(x)v 可以直接利用正向模式的自动微分计算,但是仍然要四倍的计算量。

降噪分数匹配 (Denoising Score Matching)

另一种解决分数匹配的方法是降噪分数匹配 (Denoising Score Matching) ,它是分数匹配算法的一个变种,它可以完全避开 tr(∇ₓ s_θ(x)) 的计算。首先在观测数据 x x x 上添加一些预先设定好的噪声数据,得到了新的数据 x ~ \tilde{x} x~,这相当于构建了一条件概率分布 q σ ( x ~ ∣ x ) q_{\sigma}(\tilde{x}|x) qσ(x~∣x),根据边际化方法,边缘分布 q σ ( x ~ ) q_{\sigma}(\tilde{x}) qσ(x~) 的计算方法为

(4.1.58)

q σ ( x ~ ) ≜ ∫ q σ ( x ~ ∣ x ) p d a t a ( x ) d x q_{\sigma}(\tilde{x}) \triangleq \int q_{\sigma}(\tilde{x}|x) p_{data}(x) dx qσ(x~)≜∫qσ(x~∣x)pdata(x)dx

然后把分数匹配算法应用在这个加噪后的数据分布上,

(4.1.59)

1 2 E q σ ( x ~ ∣ x ) ∥ s θ ( x \~ ) − ∇ x \~ log ⁡ q σ ( x \~ ∣ x ) ∥ 2 2 = 1 2 E q σ ( x ~ ∣ x ) p d a t a ( x ) ∥ s θ ( x \~ ) − ∇ x \~ log ⁡ q σ ( x \~ ∣ x ) ∥ 2 2 \begin{aligned} \frac{1}{2} \mathbb{E}{q{\sigma}(\tilde{x}|x)} \\\| s_{\\theta}(\\tilde{x}) - \\nabla_{\\tilde{x}} \\log q_{\\sigma}(\\tilde{x}\|x) \\\|_2\^2 \\ = \frac{1}{2} \mathbb{E}{q{\sigma}(\tilde{x}|x)p_{data}(x)} \\\| s_{\\theta}(\\tilde{x}) - \\nabla_{\\tilde{x}} \\log q_{\\sigma}(\\tilde{x}\|x) \\\|_2\^2 \end{aligned} 21Eqσ(x~∣x)∥sθ(x\~)−∇x\~logqσ(x\~∣x)∥22=21Eqσ(x~∣x)pdata(x)∥sθ(x\~)−∇x\~logqσ(x\~∣x)∥22

这么做的一个前提是,如果添加的噪声足够小,那么 q σ ( x ~ ) ≈ p d a t a ( x ) q_{\sigma}(\tilde{x}) \approx p_{data}(x) qσ(x~)≈pdata(x) 成立,此时有 ∇ x ~ log ⁡ q σ ( x ~ ) ≈ ∇ x log ⁡ p d a t a ( x ) \nabla_{\tilde{x}} \log q_{\sigma}(\tilde{x}) \approx \nabla_x \log p_{data}(x) ∇x~logqσ(x~)≈∇xlogpdata(x) 成立,这时我们可以用分数匹配算法估计出 q σ ( x ~ ) q_{\sigma}(\tilde{x}) qσ(x~) 的分数 ∇ x ~ log ⁡ q σ ( x ~ ) \nabla_{\tilde{x}} \log q_{\sigma}(\tilde{x}) ∇x~logqσ(x~),并用它近似表示原数据分布 p d a t a ( x ) p_{data}(x) pdata(x) 的分数。