文章目录
- 知识回顾
- 什么是中心极限定理(CLT)
- [为什么整个 BP 会坍缩成 AMP 的两个向量( x x x 和 r r r)?](#为什么整个 BP 会坍缩成 AMP 的两个向量( x x x 和 r r r)?)
- AMP算法的局限性
知识回顾
在上一节中,我们讲了一个简单的线性高斯模型:
y = A x + w , w ∼ N ( 0 , σ 2 I ) y = Ax + w,\quad w \sim \mathcal N(0,\sigma^2 I) y=Ax+w,w∼N(0,σ2I)
如果使用BP算法,那在因子→变量这一步就会遇到这样的求解复杂度:
m i → j ( x j ) ∝ ∫ exp ( − 1 2 σ 2 ( y i − ∑ k ≠ j A i k x k − A i j x j ) 2 ) ∏ k ≠ j n k → i ( x k ) d x k m_{i\to j}(x_j) \propto \int \exp\left( -\frac{1}{2\sigma^2} \left( y_i - \sum_{k\neq j} A_{ik} x_k - A_{ij} x_j \right)^2 \right) \prod_{k\neq j} n_{k\to i}(x_k)d x_k mi→j(xj)∝∫exp −2σ21 yi−k=j∑Aikxk−Aijxj 2 k=j∏nk→i(xk)dxk
当变量更多的时候,运算量就是指数级的增长,BP算法根本没办法跑起来,因此AMP使用了中心极限定理和高斯近似的方法,在稠密因子图中,让这一步简化为:
x t + 1 = η ( A T r t + x t ) r t = y − A x t + 1 δ r t − 1 ⟨ η ′ ( v t − 1 ) ⟩ x^{t+1} = \eta( A^T r^t + x^t )\\ r^t = y - A x^t + \frac{1}{\delta} r^{t-1} \langle \eta'(v^{t-1}) \rangle xt+1=η(ATrt+xt)rt=y−Axt+δ1rt−1⟨η′(vt−1)⟩
将若干个函数的相乘相加压缩到矩阵的线性加和乘 中,同时将复杂的变量关系"折叠"进 x x x 与 r r r,下面我们具体介绍思路
什么是中心极限定理(CLT)
中心极限定理(Central Limit Theorem)说的是一件非常朴素的事情:当你把很多个随机变量加起来时,无论每个变量原来是什么形状,它们的和都会趋向于高斯分布。 传统BP算法中 因子 → 变量消息是:
r m → n ( x n ) = ∑ x m ∖ n f m ( x m ) ∏ n ′ ≠ n q n ′ → m ( x n ′ ) r_{m\to n}(x_n) = \sum_{x_{m\setminus n}} f_m(x_m) \prod_{n' \ne n} q_{n'\to m}(x_{n'}) rm→n(xn)=xm∖n∑fm(xm)n′=n∏qn′→m(xn′)
看起来特别复杂,有:因子函数 f m f_m fm, 一堆消息函数 q q q, 对所有其他变量求和,表面上非常混乱。
但在 AMP 对应的线性高斯模型里:
y i = ∑ j A i j x j + w i y_i = \sum_j A_{ij} x_j + w_i yi=j∑Aijxj+wi
因子对应 y i y_i yi,变量对应 x j x_j xj,因子函数是(我们上一节推过):
f i ( x ) : = p ( y i ∣ x ) ∝ exp ( − 1 2 σ 2 ( y i − ∑ k A i k x k ) 2 ) f_i(x) := p(y_i\mid x) \propto \exp\left( -\frac{1}{2\sigma^2}\big(y_i - \sum_k A_{ik} x_k \big)^2 \right) fi(x):=p(yi∣x)∝exp(−2σ21(yi−k∑Aikxk)2)
我们把除了 x j x_j xj 的所有项都打包成:
S i = ∑ k ≠ j A i k x k + w i S_i = \sum_{k\ne j} A_{ik} x_k + w_i Si=k=j∑Aikxk+wi
于是 y i = A i j x j + S i y_i = A_{ij} x_j + S_i yi=Aijxj+Si,其中 S i S_i Si 它是:
- 由 N − 1 N-1 N−1 个变量 x k x_k xk 的消息分布产生
- 每个乘上一个随机矩阵元素 A i k A_{ik} Aik
- 再加噪声 w i w_i wi
所以是一个上千个随机变量加权后的和,这正是中心极限定理适用的条件!
我们把因子→变量消息写成:
r i → j ( x j ) ∝ E x k ≠ j [ exp ( − 1 2 σ 2 ( y i − A i j x j − S i ) 2 ) ] r_{i\to j}(x_j) \propto \mathbb{E}{x{k\ne j}} \left[ \exp\left( -\frac{1}{2\sigma^2} (y_i - A_{ij}x_j - S_i)^2 \right) \right] ri→j(xj)∝Exk=j[exp(−2σ21(yi−Aijxj−Si)2)]
但我们刚刚看到: S i = ∑ k ≠ j A i k x k + w i S_i = \sum_{k\ne j} A_{ik} x_k + w_i Si=∑k=jAikxk+wi,在高维极限下 → 自动变成高斯,于是:
S i ≈ N ( μ S i , σ S i 2 ) S_i \approx \mathcal{N}(\mu_{S_i}, \sigma_{S_i}^2) Si≈N(μSi,σSi2)
这里 μ S i = ∑ k ≠ j A i k , E [ x k ] \mu_{S_i} = \sum_{k\neq j} A_{ik},\mathbb{E}[x_k] μSi=∑k=jAik,E[xk], σ S i 2 = ∑ k ≠ j A i k 2 V a r ( x k ) + σ 2 \sigma_{S_i}^2 = \sum_{k\neq j} A_{ik}^2\mathrm{Var}(x_k) + \sigma^2 σSi2=∑k=jAik2Var(xk)+σ2
(不用死记,只是"线性组合高斯 → 均值是系数×均值之和,方差是系数²×方差之和,再加噪声方差")
将CLT的结论推广到因子消息
一旦 S i S_i Si 是高斯,那么整个因子消息 r i → j ( x j ) r_{i→j}(x_j) ri→j(xj) 的形状就确定了!
因为因子函数是:
exp ( − 1 2 σ 2 ( y i − A i j x j − S i ) 2 ) \exp\left(-\frac{1}{2\sigma^2}(y_i - A_{ij}x_j - S_i)^2\right) exp(−2σ21(yi−Aijxj−Si)2)
把 S i S_i Si 换成高斯变量: S i ∼ N ( μ S i , σ S i 2 ) S_i \sim \mathcal{N}(\mu_{S_i}, \sigma_{S_i}^2) Si∼N(μSi,σSi2) => 因子→变量消息是:
r i → j ( x j ) ∝ exp ( − 1 2 τ i j ( x j − x ^ i j ) 2 ) r_{i\to j}(x_j) \propto \exp\left( -\frac{1}{2\tau_{ij}} (x_j - \hat{x}_{ij})^2 \right) ri→j(xj)∝exp(−2τij1(xj−x^ij)2)
也就是一个高斯!,BP 复杂的积分 → 自动简化成"给你一个均值、一个方差",但是这个均值和方差到底怎么来的,如果感兴趣可以看一下下面的公式推导,不然可以直接跳到下一个章节。
均值和方差的推导
接下来我们系统推导一下上面这个公式,用"条件分布"的方式改写 r i → j ( x j ) r_{i\to j}(x_j) ri→j(xj),注意到:
y i = A i j x j + S i y_i = A_{ij}x_j + S_i yi=Aijxj+Si
其中 S i S_i Si是高斯随机变量。在固定 x j x_j xj 的前提下 , y i y_i yi 只是一个随机变量:
y i ∣ x j = A i j x j + S i y_i \mid x_j = A_{ij}x_j + S_i yi∣xj=Aijxj+Si
因为"常数 + 高斯"还是高斯,所以:
y i ∣ x j ∼ N ( A i j x j + μ S i , σ S i 2 ) y_i \mid x_j \sim \mathcal{N}\big( A_{ij}x_j + \mu_{S_i}, \sigma_{S_i}^2 \big) yi∣xj∼N(Aijxj+μSi,σSi2)
这句话非常关键,它其实就是我们要的 因子→变量消息的概率形式:
因子节点 i i i 想告诉变量 x j x_j xj的,就是在目前的干扰 S i S_i Si分布下,给定 x j x_j xj,看到 y i y_i yi 的可能性是多少: p ( y i ∣ x j ) p(y_i \mid x_j) p(yi∣xj)。
所以可以直接写:
r i → j ( x j ) ∝ p ( y i ∣ x j ) ∝ exp ( − 1 2 σ S i 2 ( y i − ( A i j x j + μ S i ) ) 2 ) r_{i\to j}(x_j) \propto p(y_i \mid x_j) \propto \exp\left( -\frac{1}{2\sigma_{S_i}^2} \big( y_i - (A_{ij}x_j + \mu_{S_i}) \big)^2 \right) ri→j(xj)∝p(yi∣xj)∝exp(−2σSi21(yi−(Aijxj+μSi))2)
这一步是从 S 高斯 → y ∣ x j y|x_j y∣xj 高斯 ,完全用的是"线性高斯模型"的性质。接着把这个式子改写成"关于 x j x_j xj的标准高斯形式",现在我们手里的式子是:
r i → j ( x j ) ∝ exp ( − 1 2 σ S i 2 ( y i − A i j x j − μ S i ) 2 ) r_{i\to j}(x_j) \propto \exp\left( -\frac{1}{2\sigma_{S_i}^2} \big( y_i - A_{ij}x_j - \mu_{S_i} \big)^2 \right) ri→j(xj)∝exp(−2σSi21(yi−Aijxj−μSi)2)
我们想把它改写成: exp ( − 1 2 τ i j ( x j − x ^ i j ) 2 ) \exp\left(-\frac{1}{2\tau_{ij}}(x_j - \hat{x}{ij})^2\right) exp(−2τij1(xj−x^ij)2),也就是"关于 x j x_j xj 的高斯",这样就能直接读出: 方差 τ i j \tau{ij} τij, 均值 x ^ i j \hat{x}_{ij} x^ij
第一步:把括号里的项换个顺序
y i − A i j x j − μ S i = − ( A i j x j − ( y i − μ S i ) ) y_i - A_{ij}x_j - \mu_{S_i} = -(A_{ij}x_j - (y_i - \mu_{S_i})) yi−Aijxj−μSi=−(Aijxj−(yi−μSi))
平方以后符号没区别:
( y i − A i j x j − μ S i ) 2 = ( A i j x j − ( y i − μ S i ) ) 2 (y_i - A_{ij}x_j - \mu_{S_i})^2 = (A_{ij}x_j - (y_i - \mu_{S_i}))^2 (yi−Aijxj−μSi)2=(Aijxj−(yi−μSi))2
代回去:
r i → j ( x j ) ∝ exp ( − 1 2 σ S i 2 ( A i j x j − ( y i − μ S i ) ) 2 ) r_{i\to j}(x_j) \propto \exp\left( -\frac{1}{2\sigma_{S_i}^2} (A_{ij}x_j - (y_i - \mu_{S_i}))^2 \right) ri→j(xj)∝exp(−2σSi21(Aijxj−(yi−μSi))2)
第二步:把 A i j A_{ij} Aij 从平方里提出来
( A i j x j − ( y i − μ S i ) ) 2 = A i j 2 ( x j − y i − μ S i A i j ) 2 (A_{ij}x_j - (y_i - \mu_{S_i}))^2 = A_{ij}^2 \left(x_j - \frac{y_i - \mu_{S_i}}{A_{ij}}\right)^2 (Aijxj−(yi−μSi))2=Aij2(xj−Aijyi−μSi)2
于是:
r i → j ( x j ) ∝ exp ( − A i j 2 2 σ S i 2 ( x j − y i − μ S i A i j ) 2 ) r_{i\to j}(x_j) \propto \exp\left( -\frac{A_{ij}^2}{2\sigma_{S_i}^2} \left(x_j - \frac{y_i - \mu_{S_i}}{A_{ij}}\right)^2 \right) ri→j(xj)∝exp(−2σSi2Aij2(xj−Aijyi−μSi)2)
对比标准形式:
exp ( − 1 2 τ i j ( x j − x ^ i j ) 2 ) \exp\left( -\frac{1}{2\tau_{ij}} (x_j - \hat{x}_{ij})^2 \right) exp(−2τij1(xj−x^ij)2)
我们就可以一眼读出 : 方差: τ i j = σ S i 2 A i j 2 \tau_{ij} = \frac{\sigma_{S_i}^2}{A_{ij}^2} τij=Aij2σSi2,均值: x ^ i j = y i − μ ∗ S i A i j \hat{x}{ij} = \frac{y_i - \mu*{S_i}}{A{ij}} x^ij=Aijyi−μ∗Si
为什么整个 BP 会坍缩成 AMP 的两个向量( x x x 和 r r r)?
那通过上面的推导,我们知道每一个因子给变量的消息是:
N ( x ^ i j , τ i j ) \mathcal{N}(\hat{x}{ij}, \tau{ij}) N(x^ij,τij)
而每个变量要把所有因子消息乘起来:
x j t + 1 ∼ ∏ i r i → j ( x j ) x_j^{t+1} \sim \prod_i r_{i\to j}(x_j) xjt+1∼i∏ri→j(xj)
但是"高斯 × 高斯 × 高斯 ×..." 仍然是一个高斯!而它的均值就是:
x j t + 1 = η ( x t + A T r t ) x_j^{t+1} = \eta(x^t + A^T r^t) xjt+1=η(xt+ATrt)
在这里如果直接给出公式,估计很多人又会蒙圈,所以我还是多提一嘴,毕竟原论文中的推导确实太过复杂,所以这里给一个总体的路线并附上原文[1](#1),我们把逻辑线串起来,再说一遍"高斯×高斯×高斯 → η ( x t + A T r t ) η(x^t + A^T r^t) η(xt+ATrt)
-
因子→变量消息 :我们推出来了:
r i → j ( x j ) ∝ N ( x j ; x ^ i j , τ i j ) r_{i\to j}(x_j) \propto \mathcal{N}(x_j;\hat{x}{ij},\tau{ij}) ri→j(xj)∝N(xj;x^ij,τij) -
变量结点把所有因子消息相乘 :
∏ i r i → j ( x j ) ∝ N ( x j ; m j , v j ) \prod_i r_{i\to j}(x_j) \propto \mathcal{N}(x_j; m_j, v_j) i∏ri→j(xj)∝N(xj;mj,vj)其中 m j , v j m_j, v_j mj,vj 可以根据根据"高斯乘法公式"算出,意思是最后我们需要维护的也只有均值和方差。
-
这等价于: x j x_j xj 收到一个 noisy 观测 s j = m j s_j = m_j sj=mj ,噪声方差 v j v_j vj。
-
变量结点再结合先验 p 0 ( x j ) p_0(x_j) p0(xj) ,做一次贝叶斯更新:
x j t + 1 = η ( s j ) x_j^{t+1} = \eta(s_j) xjt+1=η(sj)其中 η \eta η 就是"给定 s j s_j sj和先验,求 x j x_j xj 的后验均值"的函数。
-
把所有变量打包成向量 ,并且把 s j s_j sj利用上一轮的 x ^ i j \hat{x}_{ij} x^ij等表达式整理出来,就得到:
s t = x t + A T r t s^t = x^t + A^T r^t st=xt+ATrt所以 x t + 1 = η ( x t + A T r t ) x^{t+1} = \eta(x^t + A^T r^t) xt+1=η(xt+ATrt)
因此:变量节点不再需要保存所有消息,只要保存最终的均值(一个向量)即可 → x x x
类似地: 每个因子节点不用保存所有 r r r 消息,只需要保存当前解释不了的误差 → 残差向量 r r r
于是 BP 的消息图坍缩成:
x t + 1 = η ( x t + A T r t ) r t = y − A x t + Onsager correction x^{t+1} = \eta(x^t + A^T r^t)\\ r^t = y - A x^t + \text{Onsager correction} xt+1=η(xt+ATrt)rt=y−Axt+Onsager correction
这就是 AMP。
AMP算法的局限性
当然AMP 并不是一个"万能、无脑好用"的算法。其主要存在以下问题:
1. AMP 非常依赖"大维度 + 随机矩阵"这一前提
这是最核心、也是最容易被忽略的限制。AMP 正确、收敛、可预测的数学基础来自:
- N → ∞ N → ∞ N→∞(高维极限)
- 测量矩阵 A 的元素是 i.i.d. 随机(高斯、子高斯等)
- 因子图是"稠密且对称"的
说白了就是: AMP = 为"巨大随机矩阵"量身定做的算法
2. AMP 对噪声模型要求严格------主要是"加性白高斯噪声"
经典 AMP 的推导基于: w ∼ N ( 0 , σ 2 I ) w \sim \mathcal{N}(0, \sigma^2 I) w∼N(0,σ2I),如果噪声不是高斯,而是 Poisson 噪声或其他噪声,标准 AMP 就不对了,扩展算法 GAMP 可以处理部分情形,但也不完美。
3. AMP 容易发散(diverge)
即使在"看上去合法"的场景下,AMP 也经常发散。常见发散原因如下:
- A 的列不是完全独立(实际工程中几乎都会相关)
- A 是稠密但不是 i.i.d.
- 去噪器 η η η 太强等等...
症状:
- r t r^t rt 变得更大
- x t x^t xt 在两轮之间来回震荡
- η ′ η' η′ 太大 → Onsager correction 消不掉反馈
所以在实际运用的过程中,还是需要根据具体的情况选择对应的算法,而不是盲目使用!希望本文能让你对AMP算法有更深的理解,我是不懂代码的杰瑞学长,我们下期再见!
- Bayati M, Montanari A. The dynamics of message passing on dense graphs, with applications to compressed sensing[J]. IEEE Transactions on Information Theory, 2011, 57(2): 764-785. ↩︎