【公式推导】AMP算法比BP算法强在哪(二)

文章目录

知识回顾

上一节中,我们讲了一个简单的线性高斯模型:
y = A x + w , w ∼ N ( 0 , σ 2 I ) y = Ax + w,\quad w \sim \mathcal N(0,\sigma^2 I) y=Ax+w,w∼N(0,σ2I)

如果使用BP算法,那在因子→变量这一步就会遇到这样的求解复杂度:

m i → j ( x j ) ∝ ∫ exp ⁡ ( − 1 2 σ 2 ( y i − ∑ k ≠ j A i k x k − A i j x j ) 2 ) ∏ k ≠ j n k → i ( x k ) d x k m_{i\to j}(x_j) \propto \int \exp\left( -\frac{1}{2\sigma^2} \left( y_i - \sum_{k\neq j} A_{ik} x_k - A_{ij} x_j \right)^2 \right) \prod_{k\neq j} n_{k\to i}(x_k)d x_k mi→j(xj)∝∫exp −2σ21 yi−k=j∑Aikxk−Aijxj 2 k=j∏nk→i(xk)dxk

当变量更多的时候,运算量就是指数级的增长,BP算法根本没办法跑起来,因此AMP使用了中心极限定理和高斯近似的方法,在稠密因子图中,让这一步简化为:
x t + 1 = η ( A T r t + x t ) r t = y − A x t + 1 δ r t − 1 ⟨ η ′ ( v t − 1 ) ⟩ x^{t+1} = \eta( A^T r^t + x^t )\\ r^t = y - A x^t + \frac{1}{\delta} r^{t-1} \langle \eta'(v^{t-1}) \rangle xt+1=η(ATrt+xt)rt=y−Axt+δ1rt−1⟨η′(vt−1)⟩

将若干个函数的相乘相加压缩到矩阵的线性加和乘 中,同时将复杂的变量关系"折叠"进 x x x 与 r r r,下面我们具体介绍思路

什么是中心极限定理(CLT)

中心极限定理(Central Limit Theorem)说的是一件非常朴素的事情:当你把很多个随机变量加起来时,无论每个变量原来是什么形状,它们的和都会趋向于高斯分布。 传统BP算法中 因子 → 变量消息是:

r m → n ( x n ) = ∑ x m ∖ n f m ( x m ) ∏ n ′ ≠ n q n ′ → m ( x n ′ ) r_{m\to n}(x_n) = \sum_{x_{m\setminus n}} f_m(x_m) \prod_{n' \ne n} q_{n'\to m}(x_{n'}) rm→n(xn)=xm∖n∑fm(xm)n′=n∏qn′→m(xn′)

看起来特别复杂,有:因子函数 f m f_m fm, 一堆消息函数 q q q, 对所有其他变量求和,表面上非常混乱。

但在 AMP 对应的线性高斯模型里:

y i = ∑ j A i j x j + w i y_i = \sum_j A_{ij} x_j + w_i yi=j∑Aijxj+wi

因子对应 y i y_i yi,变量对应 x j x_j xj,因子函数是(我们上一节推过):

f i ( x ) : = p ( y i ∣ x ) ∝ exp ⁡ ( − 1 2 σ 2 ( y i − ∑ k A i k x k ) 2 ) f_i(x) := p(y_i\mid x) \propto \exp\left( -\frac{1}{2\sigma^2}\big(y_i - \sum_k A_{ik} x_k \big)^2 \right) fi(x):=p(yi∣x)∝exp(−2σ21(yi−k∑Aikxk)2)

我们把除了 x j x_j xj 的所有项都打包成:

S i = ∑ k ≠ j A i k x k + w i S_i = \sum_{k\ne j} A_{ik} x_k + w_i Si=k=j∑Aikxk+wi

于是 y i = A i j x j + S i y_i = A_{ij} x_j + S_i yi=Aijxj+Si,其中 S i S_i Si 它是:

  • 由 N − 1 N-1 N−1 个变量 x k x_k xk 的消息分布产生
  • 每个乘上一个随机矩阵元素 A i k A_{ik} Aik
  • 再加噪声 w i w_i wi

所以是一个上千个随机变量加权后的和,这正是中心极限定理适用的条件!

我们把因子→变量消息写成:

r i → j ( x j ) ∝ E x k ≠ j [ exp ⁡ ( − 1 2 σ 2 ( y i − A i j x j − S i ) 2 ) ] r_{i\to j}(x_j) \propto \mathbb{E}{x{k\ne j}} \left[ \exp\left( -\frac{1}{2\sigma^2} (y_i - A_{ij}x_j - S_i)^2 \right) \right] ri→j(xj)∝Exk=j[exp(−2σ21(yi−Aijxj−Si)2)]

但我们刚刚看到: S i = ∑ k ≠ j A i k x k + w i S_i = \sum_{k\ne j} A_{ik} x_k + w_i Si=∑k=jAikxk+wi,在高维极限下 → 自动变成高斯,于是:

S i ≈ N ( μ S i , σ S i 2 ) S_i \approx \mathcal{N}(\mu_{S_i}, \sigma_{S_i}^2) Si≈N(μSi,σSi2)

这里 μ S i = ∑ k ≠ j A i k , E [ x k ] \mu_{S_i} = \sum_{k\neq j} A_{ik},\mathbb{E}[x_k] μSi=∑k=jAik,E[xk], σ S i 2 = ∑ k ≠ j A i k 2 V a r ( x k ) + σ 2 \sigma_{S_i}^2 = \sum_{k\neq j} A_{ik}^2\mathrm{Var}(x_k) + \sigma^2 σSi2=∑k=jAik2Var(xk)+σ2

(不用死记,只是"线性组合高斯 → 均值是系数×均值之和,方差是系数²×方差之和,再加噪声方差")

将CLT的结论推广到因子消息

一旦 S i S_i Si 是高斯,那么整个因子消息 r i → j ( x j ) r_{i→j}(x_j) ri→j(xj) 的形状就确定了!

因为因子函数是:

exp ⁡ ( − 1 2 σ 2 ( y i − A i j x j − S i ) 2 ) \exp\left(-\frac{1}{2\sigma^2}(y_i - A_{ij}x_j - S_i)^2\right) exp(−2σ21(yi−Aijxj−Si)2)

把 S i S_i Si 换成高斯变量: S i ∼ N ( μ S i , σ S i 2 ) S_i \sim \mathcal{N}(\mu_{S_i}, \sigma_{S_i}^2) Si∼N(μSi,σSi2) => 因子→变量消息是:

r i → j ( x j ) ∝ exp ⁡ ( − 1 2 τ i j ( x j − x ^ i j ) 2 ) r_{i\to j}(x_j) \propto \exp\left( -\frac{1}{2\tau_{ij}} (x_j - \hat{x}_{ij})^2 \right) ri→j(xj)∝exp(−2τij1(xj−x^ij)2)

也就是一个高斯!,BP 复杂的积分 → 自动简化成"给你一个均值、一个方差",但是这个均值和方差到底怎么来的,如果感兴趣可以看一下下面的公式推导,不然可以直接跳到下一个章节。

均值和方差的推导

接下来我们系统推导一下上面这个公式,用"条件分布"的方式改写 r i → j ( x j ) r_{i\to j}(x_j) ri→j(xj),注意到:
y i = A i j x j + S i y_i = A_{ij}x_j + S_i yi=Aijxj+Si

其中 S i S_i Si是高斯随机变量。在固定 x j x_j xj 的前提下 , y i y_i yi 只是一个随机变量:

y i ∣ x j = A i j x j + S i y_i \mid x_j = A_{ij}x_j + S_i yi∣xj=Aijxj+Si

因为"常数 + 高斯"还是高斯,所以:

y i ∣ x j ∼ N ( A i j x j + μ S i , σ S i 2 ) y_i \mid x_j \sim \mathcal{N}\big( A_{ij}x_j + \mu_{S_i}, \sigma_{S_i}^2 \big) yi∣xj∼N(Aijxj+μSi,σSi2)

这句话非常关键,它其实就是我们要的 因子→变量消息的概率形式

因子节点 i i i 想告诉变量 x j x_j xj的,就是在目前的干扰 S i S_i Si分布下,给定 x j x_j xj,看到 y i y_i yi 的可能性是多少: p ( y i ∣ x j ) p(y_i \mid x_j) p(yi∣xj)。

所以可以直接写:

r i → j ( x j ) ∝ p ( y i ∣ x j ) ∝ exp ⁡ ( − 1 2 σ S i 2 ( y i − ( A i j x j + μ S i ) ) 2 ) r_{i\to j}(x_j) \propto p(y_i \mid x_j) \propto \exp\left( -\frac{1}{2\sigma_{S_i}^2} \big( y_i - (A_{ij}x_j + \mu_{S_i}) \big)^2 \right) ri→j(xj)∝p(yi∣xj)∝exp(−2σSi21(yi−(Aijxj+μSi))2)

这一步是从 S 高斯 → y ∣ x j y|x_j y∣xj 高斯 ,完全用的是"线性高斯模型"的性质。接着把这个式子改写成"关于 x j x_j xj的标准高斯形式",现在我们手里的式子是:

r i → j ( x j ) ∝ exp ⁡ ( − 1 2 σ S i 2 ( y i − A i j x j − μ S i ) 2 ) r_{i\to j}(x_j) \propto \exp\left( -\frac{1}{2\sigma_{S_i}^2} \big( y_i - A_{ij}x_j - \mu_{S_i} \big)^2 \right) ri→j(xj)∝exp(−2σSi21(yi−Aijxj−μSi)2)

我们想把它改写成: exp ⁡ ( − 1 2 τ i j ( x j − x ^ i j ) 2 ) \exp\left(-\frac{1}{2\tau_{ij}}(x_j - \hat{x}{ij})^2\right) exp(−2τij1(xj−x^ij)2),也就是"关于 x j x_j xj 的高斯",这样就能直接读出: 方差 τ i j \tau{ij} τij, 均值 x ^ i j \hat{x}_{ij} x^ij

第一步:把括号里的项换个顺序

y i − A i j x j − μ S i = − ( A i j x j − ( y i − μ S i ) ) y_i - A_{ij}x_j - \mu_{S_i} = -(A_{ij}x_j - (y_i - \mu_{S_i})) yi−Aijxj−μSi=−(Aijxj−(yi−μSi))

平方以后符号没区别:

( y i − A i j x j − μ S i ) 2 = ( A i j x j − ( y i − μ S i ) ) 2 (y_i - A_{ij}x_j - \mu_{S_i})^2 = (A_{ij}x_j - (y_i - \mu_{S_i}))^2 (yi−Aijxj−μSi)2=(Aijxj−(yi−μSi))2

代回去:

r i → j ( x j ) ∝ exp ⁡ ( − 1 2 σ S i 2 ( A i j x j − ( y i − μ S i ) ) 2 ) r_{i\to j}(x_j) \propto \exp\left( -\frac{1}{2\sigma_{S_i}^2} (A_{ij}x_j - (y_i - \mu_{S_i}))^2 \right) ri→j(xj)∝exp(−2σSi21(Aijxj−(yi−μSi))2)

第二步:把 A i j A_{ij} Aij 从平方里提出来

( A i j x j − ( y i − μ S i ) ) 2 = A i j 2 ( x j − y i − μ S i A i j ) 2 (A_{ij}x_j - (y_i - \mu_{S_i}))^2 = A_{ij}^2 \left(x_j - \frac{y_i - \mu_{S_i}}{A_{ij}}\right)^2 (Aijxj−(yi−μSi))2=Aij2(xj−Aijyi−μSi)2

于是:

r i → j ( x j ) ∝ exp ⁡ ( − A i j 2 2 σ S i 2 ( x j − y i − μ S i A i j ) 2 ) r_{i\to j}(x_j) \propto \exp\left( -\frac{A_{ij}^2}{2\sigma_{S_i}^2} \left(x_j - \frac{y_i - \mu_{S_i}}{A_{ij}}\right)^2 \right) ri→j(xj)∝exp(−2σSi2Aij2(xj−Aijyi−μSi)2)

对比标准形式:

exp ⁡ ( − 1 2 τ i j ( x j − x ^ i j ) 2 ) \exp\left( -\frac{1}{2\tau_{ij}} (x_j - \hat{x}_{ij})^2 \right) exp(−2τij1(xj−x^ij)2)

我们就可以一眼读出 : 方差: τ i j = σ S i 2 A i j 2 \tau_{ij} = \frac{\sigma_{S_i}^2}{A_{ij}^2} τij=Aij2σSi2,均值: x ^ i j = y i − μ ∗ S i A i j \hat{x}{ij} = \frac{y_i - \mu*{S_i}}{A{ij}} x^ij=Aijyi−μ∗Si


为什么整个 BP 会坍缩成 AMP 的两个向量( x x x 和 r r r)?

那通过上面的推导,我们知道每一个因子给变量的消息是:

N ( x ^ i j , τ i j ) \mathcal{N}(\hat{x}{ij}, \tau{ij}) N(x^ij,τij)

而每个变量要把所有因子消息乘起来:

x j t + 1 ∼ ∏ i r i → j ( x j ) x_j^{t+1} \sim \prod_i r_{i\to j}(x_j) xjt+1∼i∏ri→j(xj)

但是"高斯 × 高斯 × 高斯 ×..." 仍然是一个高斯!而它的均值就是:

x j t + 1 = η ( x t + A T r t ) x_j^{t+1} = \eta(x^t + A^T r^t) xjt+1=η(xt+ATrt)

在这里如果直接给出公式,估计很多人又会蒙圈,所以我还是多提一嘴,毕竟原论文中的推导确实太过复杂,所以这里给一个总体的路线并附上原文[1](#1)我们把逻辑线串起来,再说一遍"高斯×高斯×高斯 → η ( x t + A T r t ) η(x^t + A^T r^t) η(xt+ATrt)

  1. 因子→变量消息 :我们推出来了:
    r i → j ( x j ) ∝ N ( x j ; x ^ i j , τ i j ) r_{i\to j}(x_j) \propto \mathcal{N}(x_j;\hat{x}{ij},\tau{ij}) ri→j(xj)∝N(xj;x^ij,τij)

  2. 变量结点把所有因子消息相乘
    ∏ i r i → j ( x j ) ∝ N ( x j ; m j , v j ) \prod_i r_{i\to j}(x_j) \propto \mathcal{N}(x_j; m_j, v_j) i∏ri→j(xj)∝N(xj;mj,vj)

    其中 m j , v j m_j, v_j mj,vj 可以根据根据"高斯乘法公式"算出,意思是最后我们需要维护的也只有均值和方差。

  3. 这等价于: x j x_j xj 收到一个 noisy 观测 s j = m j s_j = m_j sj=mj ,噪声方差 v j v_j vj。

  4. 变量结点再结合先验 p 0 ( x j ) p_0(x_j) p0(xj) ,做一次贝叶斯更新:
    x j t + 1 = η ( s j ) x_j^{t+1} = \eta(s_j) xjt+1=η(sj)

    其中 η \eta η 就是"给定 s j s_j sj和先验,求 x j x_j xj 的后验均值"的函数。

  5. 把所有变量打包成向量 ,并且把 s j s_j sj利用上一轮的 x ^ i j \hat{x}_{ij} x^ij等表达式整理出来,就得到:
    s t = x t + A T r t s^t = x^t + A^T r^t st=xt+ATrt

    所以 x t + 1 = η ( x t + A T r t ) x^{t+1} = \eta(x^t + A^T r^t) xt+1=η(xt+ATrt)

因此:变量节点不再需要保存所有消息,只要保存最终的均值(一个向量)即可 → x x x

类似地: 每个因子节点不用保存所有 r r r 消息,只需要保存当前解释不了的误差 → 残差向量 r r r

于是 BP 的消息图坍缩成:

x t + 1 = η ( x t + A T r t ) r t = y − A x t + Onsager correction x^{t+1} = \eta(x^t + A^T r^t)\\ r^t = y - A x^t + \text{Onsager correction} xt+1=η(xt+ATrt)rt=y−Axt+Onsager correction

这就是 AMP。

AMP算法的局限性

当然AMP 并不是一个"万能、无脑好用"的算法。其主要存在以下问题:

1. AMP 非常依赖"大维度 + 随机矩阵"这一前提

这是最核心、也是最容易被忽略的限制。AMP 正确、收敛、可预测的数学基础来自:

  • N → ∞ N → ∞ N→∞(高维极限)
  • 测量矩阵 A 的元素是 i.i.d. 随机(高斯、子高斯等)
  • 因子图是"稠密且对称"的

说白了就是: AMP = 为"巨大随机矩阵"量身定做的算法

2. AMP 对噪声模型要求严格------主要是"加性白高斯噪声"

经典 AMP 的推导基于: w ∼ N ( 0 , σ 2 I ) w \sim \mathcal{N}(0, \sigma^2 I) w∼N(0,σ2I),如果噪声不是高斯,而是 Poisson 噪声或其他噪声,标准 AMP 就不对了,扩展算法 GAMP 可以处理部分情形,但也不完美。

3. AMP 容易发散(diverge)

即使在"看上去合法"的场景下,AMP 也经常发散。常见发散原因如下:

  • A 的列不是完全独立(实际工程中几乎都会相关)
  • A 是稠密但不是 i.i.d.
  • 去噪器 η η η 太强等等...

症状:

  • r t r^t rt 变得更大
  • x t x^t xt 在两轮之间来回震荡
  • η ′ η' η′ 太大 → Onsager correction 消不掉反馈

所以在实际运用的过程中,还是需要根据具体的情况选择对应的算法,而不是盲目使用!希望本文能让你对AMP算法有更深的理解,我是不懂代码的杰瑞学长,我们下期再见!


  1. Bayati M, Montanari A. The dynamics of message passing on dense graphs, with applications to compressed sensing[J]. IEEE Transactions on Information Theory, 2011, 57(2): 764-785. ↩︎
相关推荐
无垠的广袤1 小时前
【工业树莓派 CM0 NANO 单板计算机】小智语音聊天
人工智能·python·嵌入式硬件·语言模型·树莓派·智能体·小智
野蛮人6号1 小时前
力扣热题100道之45跳跃游戏2
算法·leetcode·游戏
唐僧洗头爱飘柔95271 小时前
【区块链技术(05)】区块链核心技术:哈希算法再区块链中的应用;区块哈希与默克尔树;公开密钥算法、编码和解码算法(BASE58、BASE64)
算法·区块链·哈希算法·base64·默克尔树·区块哈希·公私钥算法
BlackPercy1 小时前
[Matplotlib] 动态视频生成
python·matplotlib
B站计算机毕业设计之家1 小时前
大数据:基于python唯品会商品数据可视化分析系统 Flask框架 requests爬虫 Echarts可视化 数据清洗 大数据技术(源码+文档)✅
大数据·爬虫·python·信息可视化·spark·flask·唯品会
27669582921 小时前
闪购商家端 mtgsig
java·python·c#·node·c·mtgsig·mtgsig1.2
不能只会打代码1 小时前
力扣--3578. 统计极差最大为 K 的分割方式数(Java实现,代码注释及题目分析讲解)
算法·leetcode·动态规划·滑动窗口
AndrewHZ1 小时前
【Python与生活】Python文本分析:解码朱自清散文的语言密码
python·beautifulsoup·jieba·语言学·文本分析·文学分析·朱自清
小尧嵌入式2 小时前
QT软件开发知识流程及秒表计时器开发
开发语言·c++·qt·算法