[模式识别-从入门到入土] 拓展-EM算法

[模式识别-从入门到入土] 拓展-EM算法

知乎:https://www.zhihu.com/people/byzh_rc

CSDN:https://blog.csdn.net/qq_54636039

注:本文仅对所述内容做了框架性引导,具体细节可查询其余相关资料or源码

参考文章:各方资料

文章目录

EM算法

EM 算法是处理含隐变量的统计模型 的迭代优化方法,用于求模型参数的最大似然 / 最大后验概率估计

核心痛点:模型中存在未观测的隐变量

核心思路: "先估计隐变量,再更新参数"的交替迭代(先猜后验、再修正 )

-> 逼近最优参数,直到收敛

  1. 随机初始化一组参数
  2. 用当前参数估计隐变量的分布/期望
  3. 用隐变量的估计值更新参数
  4. 重复步骤2-3,直到参数或似然函数稳定

E步 (期望步):

假设当前参数已知,计算隐变量的期望/后验概率(即隐变量取不同值的可能性)

M步 (最大化步):

用 E 步得到的隐变量估计值,通过最大似然估计更新模型参数

完整执行流程(高斯混合模型为例)

  1. 初始化 :定义聚类数K,初始化参数 π k , μ k , Σ k \pi_k, \mu_k, \Sigma_k πk,μk,Σk

  2. E 步 :计算隐变量的后验概率 γ ( i , k ) \gamma(i,k) γ(i,k)

  3. M 步 :用 γ ( i , k ) \gamma(i,k) γ(i,k)更新参数 π k , μ k , Σ k \pi_k, \mu_k, \Sigma_k πk,μk,Σk

  4. 计算对数似然 函数:
    ∑ i = 1 N log ⁡ { ∑ k = 1 K π k N ( x i ∣ μ k , Σ k ) } \sum_{i=1}^N \log\left\{ \sum_{k=1}^K \pi_k N(x_i|\mu_k, \Sigma_k) \right\} i=1∑Nlog{k=1∑KπkN(xi∣μk,Σk)}

  5. 收敛判断:若参数或似然函数变化小于阈值,停止迭代;否则回到 E 步

E步

计算 "数据点 x i x_i xi由第k个高斯分量生成" 的概率 γ ( i , k ) = p ( y = k ∣ x i ) \gamma(i,k) = p(y=k|x_i) γ(i,k)=p(y=k∣xi),公式为:
γ ( i , k ) = π k N ( x i ∣ μ k , Σ k ) ∑ j = 1 K π j N ( x i ∣ μ j , Σ j ) \gamma(i,k) = \frac{\pi_k N(x_i|\mu_k, \Sigma_k)}{\sum_{j=1}^K \pi_j N(x_i|\mu_j, \Sigma_j)} γ(i,k)=∑j=1KπjN(xi∣μj,Σj)πkN(xi∣μk,Σk)
π k \pi_k πk是第k个分量的权重, N ( ⋅ ) N(\cdot) N(⋅)是高斯分布

M步

更新 3 类参数:

  • 分量权重 π k \pi_k πk:
    π k = m k m \pi_k = \frac{m_k}{m} πk=mmk

    m k = ∑ i = 1 m γ ( i , k ) m_k = \sum_{i=1}^m \gamma(i,k) mk=∑i=1mγ(i,k),m是总数据量

  • 均值 μ k \mu_k μk:
    μ k = 1 m k ∑ i = 1 m γ ( i , k ) x i \mu_k = \frac{1}{m_k} \sum_{i=1}^m \gamma(i,k) x_i μk=mk1i=1∑mγ(i,k)xi

  • 协方差 Σ k \Sigma_k Σk:
    Σ k = 1 m k ∑ i = 1 m γ ( i , k ) ( x i − μ k ) ( x i − μ k ) T \Sigma_k = \frac{1}{m_k} \sum_{i=1}^m \gamma(i,k)(x_i - \mu_k)(x_i - \mu_k)^T Σk=mk1i=1∑mγ(i,k)(xi−μk)(xi−μk)T

完整执行流程(两硬币抛投模型为例)

  1. 初始化 :定义硬币数量(2枚,记为硬币1、硬币2),初始化参数 π 1 , π 2 \pi_1, \pi_2 π1,π2(选择两枚硬币的先验概率,满足 π 1 + π 2 = 1 \pi_1+\pi_2=1 π1+π2=1)和 θ 1 , θ 2 \theta_1, \theta_2 θ1,θ2(两枚硬币的正面概率)

  2. E 步 :计算隐变量的后验概率 γ ( i , k ) \gamma(i,k) γ(i,k)(第i轮实验选择第k枚硬币的概率)

  3. M 步 :用 γ ( i , k ) \gamma(i,k) γ(i,k)更新参数 π 1 , π 2 , θ 1 , θ 2 \pi_1, \pi_2, \theta_1, \theta_2 π1,π2,θ1,θ2

  4. 计算对数似然 函数:
    ∑ i = 1 m log ⁡ { ∑ k = 1 2 π k ∏ t = 1 n P ( x i t ∣ k ) } \sum_{i=1}^m \log\left\{ \sum_{k=1}^2 \pi_k \prod_{t=1}^n P(x_{it}|k) \right\} i=1∑mlog{k=1∑2πkt=1∏nP(xit∣k)}

    (其中m为实验总轮数,n为每轮抛投次数, x i t x_{it} xit为第i轮第t次抛投结果, P ( x i t ∣ k ) P(x_{it}|k) P(xit∣k)为第k枚硬币抛出该结果的概率:正面为 θ k \theta_k θk,反面为 1 − θ k 1-\theta_k 1−θk)

  5. 收敛判断 :若参数 π 1 , π 2 , θ 1 , θ 2 \pi_1, \pi_2, \theta_1, \theta_2 π1,π2,θ1,θ2或似然函数变化小于阈值,停止迭代;否则回到 E 步

E步

计算 "第i轮实验选择第k枚硬币" 的概率 γ ( i , k ) = p ( y = k ∣ x i ) \gamma(i,k) = p(y=k|x_i) γ(i,k)=p(y=k∣xi),公式为:
γ ( i , k ) = π k ∏ t = 1 n P ( x i t ∣ k ) ∑ j = 1 2 π j ∏ t = 1 n P ( x i t ∣ j ) \gamma(i,k) = \frac{\pi_k \prod_{t=1}^n P(x_{it}|k)}{\sum_{j=1}^2 \pi_j \prod_{t=1}^n P(x_{it}|j)} γ(i,k)=∑j=12πj∏t=1nP(xit∣j)πk∏t=1nP(xit∣k)

说明: π k \pi_k πk是选择第k枚硬币的先验概率, ∏ t = 1 n P ( x i t ∣ k ) \prod_{t=1}^n P(x_{it}|k) ∏t=1nP(xit∣k)是第k枚硬币产生第i轮抛投结果的联合概率, P ( x i t ∣ k ) P(x_{it}|k) P(xit∣k)为第k枚硬币抛出 x i t x_{it} xit结果的概率(正面为 θ k \theta_k θk,反面为 1 − θ k 1-\theta_k 1−θk)

M步

更新 2 类核心参数:

  • 正面概率 θ k \theta_k θk:

θ k = ∑ i = 1 m γ ( i , k ) ⋅ n i + ∑ i = 1 m γ ( i , k ) ⋅ n \theta_k = \frac{\sum_{i=1}^m \gamma(i,k) \cdot n_{i+}}{\sum_{i=1}^m \gamma(i,k) \cdot n} θk=∑i=1mγ(i,k)⋅n∑i=1mγ(i,k)⋅ni+

注: n i + n_{i+} ni+是第i轮实验的正面次数,n是每轮抛投次数

  • 选择概率 π k \pi_k πk:

π k = m k m \pi_k = \frac{m_k}{m} πk=mmk

注: m k = ∑ i = 1 m γ ( i , k ) m_k = \sum_{i=1}^m \gamma(i,k) mk=∑i=1mγ(i,k)(第k枚硬币被选择的期望轮数),m是实验总轮数

EM 算法的典型例子

"硬币正反面概率估计":

已知:多轮抛硬币的结果(正面 / 反面)

隐变量:每一轮用的是哪枚硬币

模型参数:每枚硬币的正面概率

过程:

  • 先假设硬币的正面概率
  • 估计每轮用哪枚硬币
  • 用 "硬币选择结果" 更新正面概率
  • 重复直到稳定
相关推荐
努力学算法的蒟蒻2 小时前
day41(12.22)——leetcode面试经典150
算法·leetcode·面试
liliangcsdn2 小时前
Python拒绝采样算法优化与微调模拟
人工智能·算法·机器学习
Christo32 小时前
2024《A Rapid Review of Clustering Algorithms》
人工智能·算法·机器学习·数据挖掘
AndrewHZ2 小时前
【图像处理基石】图像梯度:核心算法原理与经典应用场景全解析
图像处理·算法·计算机视觉·cv·算子·边缘提取·图像梯度
让学习成为一种生活方式2 小时前
组蛋白短链酰化修饰--文献精读187
算法
fei_sun2 小时前
数字图像处理
人工智能·算法·计算机视觉
Tisfy2 小时前
LeetCode 960.删列造序 III:动态规划(最长递增子序列)
算法·leetcode·动态规划·字符串·题解·逆向思维
多米Domi0112 小时前
0x3f第十天复习(考研日2)(9.18-12.30,14.00-15.00)
python·算法·leetcode
listhi5202 小时前
支持向量机多分类解决方案
算法·支持向量机·分类