通俗易懂讲透 EM 算法(期望最大化)

通俗易懂讲透 EM 算法(期望最大化)|本科生/研究生一看就懂

EM 算法是带隐藏变量的概率模型求解神器 ,最经典的落地就是高斯混合模型 GMM 聚类 。它和 K-Means 很像,但更强:能分椭圆形、不同大小、不同密度的簇,还能输出每个点属于每一类的概率。

本文用大白话+生活例子+数学推导+可直接运行代码+对比总结,适合课堂笔记、实验报告。


一、EM 算法是什么?一句话讲明白

EM = E 步猜概率 + M 步更新参数,循环迭代直到收敛。

  • E 步(Expectation) :用当前模型,算每个样本属于每一类的后验概率(软归属)。
  • M 步(Maximization) :用这些概率,重新估计每一类的高斯参数(均值、协方差、权重)。

一句话记忆:
先瞎猜 → 按概率算分布 → 再修正猜测 → 越来越准


二、超通俗例子:黑暗中分辨 3 种水果

你在黑暗里摸到一堆水果:苹果、香蕉、橙子。

不知道谁是谁 ,也不知道每类长啥样,只能靠触感估计。

EM 怎么做:

  1. 随便猜初始参数:比如猜苹果是圆的、香蕉是长的。
  2. E 步:对每个水果,算它有多像苹果/香蕉/橙子。
  3. M 步:根据"像苹果"的所有水果,重新算苹果的平均特征。
  4. 循环:直到特征不再变,你就准确分出了三类。

三、EM 算法核心流程(最标准 4 步)

  1. 初始化
    随机给定 K 个高斯分布的:均值、协方差、混合系数。
  2. E 步(求后验概率)
    对每个样本,计算它属于第 k 个高斯的概率 γ_ik。
  3. M 步(更新高斯参数)
    用 γ_ik 加权更新:均值、协方差、混合系数。
  4. 收敛判断
    似然函数变化很小就停止。

四、核心公式(报告/作业直接用)

1. 后验概率(E 步核心)

γik=πk⋅N(xi∣μk,Σk)∑j=1Kπj⋅N(xi∣μj,Σj)\gamma_{ik} = \frac{\pi_k \cdot \mathcal{N}(x_i | \mu_k, \Sigma_k)}{\sum_{j=1}^K \pi_j \cdot \mathcal{N}(x_i | \mu_j, \Sigma_j)}γik=∑j=1Kπj⋅N(xi∣μj,Σj)πk⋅N(xi∣μk,Σk)

2. 混合系数(M 步)

πk=1N∑i=1Nγik\pi_k = \frac{1}{N}\sum_{i=1}^N \gamma_{ik}πk=N1∑i=1Nγik

3. 均值(M 步)

μk=∑i=1Nγikxi∑i=1Nγik\mu_k = \frac{\sum_{i=1}^N \gamma_{ik} x_i}{\sum_{i=1}^N \gamma_{ik}}μk=∑i=1Nγik∑i=1Nγikxi

4. 协方差(M 步)

Σk=∑i=1Nγik(xi−μk)(xi−μk)T∑i=1Nγik\Sigma_k = \frac{\sum_{i=1}^N \gamma_{ik} (x_i-\mu_k)(x_i-\mu_k)^T}{\sum_{i=1}^N \gamma_{ik}}Σk=∑i=1Nγik∑i=1Nγik(xi−μk)(xi−μk)T


五、EM 为什么比 K-Means 强?

  • K-Means:硬分类、只认球形簇、等方差
  • EM(GMM):软分类、支持椭圆簇、可不同方差

简单说:
K-Means 是 EM 的特例(方差固定、概率0/1)。


六、完整实战代码(可直接复制运行)

包含:手动实现 EM + 可视化迭代 + BIC 选最优簇数 + Sklearn 对比

python 复制代码
# 安装依赖
# pip install numpy matplotlib scikit-learn scipy

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from sklearn.mixture import GaussianMixture
from matplotlib.patches import Ellipse

# ====================== 1. 生成测试数据 ======================
np.random.seed(42)
means = [(-2,-2), (0,3), (4,0)]
covs = [
    [[1.0,0.5],[0.5,1.0]],
    [[0.3,-0.2],[-0.2,0.3]],
    [[0.8,0],[0,0.5]]
]
X = np.vstack([
    np.random.multivariate_normal(m, c, 600)
    for m,c in zip(means, covs)
])

# ====================== 2. 手动实现 EM-GMM ======================
class EMGMM:
    def __init__(self, n_components=3, max_iter=100, tol=1e-4):
        self.K = n_components
        self.max_iter = max_iter
        self.tol = tol

    def _init(self, X):
        n,d = X.shape
        idx = np.random.choice(n, self.K, replace=False)
        self.mu = X[idx]
        self.cov = np.array([np.cov(X.T)]*self.K)
        self.pi = np.ones(self.K)/self.K

    def _e_step(self, X):
        N = len(X)
        gamma = np.zeros((N, self.K))
        for k in range(self.K):
            gamma[:,k] = self.pi[k] * multivariate_normal.pdf(X, self.mu[k], self.cov[k])
        gamma /= gamma.sum(axis=1, keepdims=True)
        return gamma

    def _m_step(self, X, gamma):
        Nk = gamma.sum(axis=0)
        self.pi = Nk / len(X)
        self.mu = (gamma.T @ X) / Nk[:,None]
        for k in range(self.K):
            diff = X - self.mu[k]
            self.cov[k] = (gamma[:,k][:,None]*diff).T @ diff / Nk[k]
            self.cov[k] += 1e-6 * np.eye(X.shape[1])

    def fit(self, X):
        self._init(X)
        old_loglike = -np.inf
        for i in range(self.max_iter):
            gamma = self._e_step(X)
            self._m_step(X, gamma)
            loglike = self._loglike(X)
            if abs(loglike - old_loglike) < self.tol:
                break
            old_loglike = loglike
        return self

    def _loglike(self, X):
        ll = np.zeros(len(X))
        for k in range(self.K):
            ll += self.pi[k] * multivariate_normal.pdf(X, self.mu[k], self.cov[k])
        return np.log(ll).mean()

    def predict(self, X):
        gamma = self._e_step(X)
        return np.argmax(gamma, axis=1)

# ====================== 3. 训练与可视化 ======================
model = EMGMM(n_components=3).fit(X)
labels = model.predict(X)

plt.figure(figsize=(8,6))
plt.scatter(X[:,0], X[:,1], c=labels, s=15, cmap='viridis')
plt.scatter(model.mu[:,0], model.mu[:,1], marker='X', s=300, c='red')
plt.title('EM-GMM 聚类结果')
plt.show()

# ====================== 4. BIC 自动选最优簇数 ======================
ks = range(1,7)
bics = []
for k in ks:
    gmm = GaussianMixture(k, random_state=42)
    gmm.fit(X)
    bics.append(gmm.bic(X))

plt.figure(figsize=(6,4))
plt.plot(ks, bics, 'o-')
plt.title('BIC 曲线 --- 越小越好')
plt.xlabel('簇数 K')
plt.show()

七、EM 算法优缺点(面试/报告必背)

✅ 优点

  1. 软聚类:输出概率,更贴近真实场景。
  2. 支持任意椭圆簇:不局限球形。
  3. 理论扎实:统计学基石,可用于缺失值、HMM、LDA 等。
  4. 结果可解释:有清晰的概率意义。

❌ 缺点

  1. 必须指定 K
  2. 对初值敏感,易局部最优。
  3. 比 K-Means 慢
  4. 对噪声敏感

八、EM vs K-Means vs DBSCAN(超清晰对比)

算法 类型 簇形状 是否要K 对噪声 速度
EM-GMM 软聚类 椭圆/高斯 一般
K-Means 硬聚类 球形
DBSCAN 密度聚类 任意

九、适用场景(什么时候用 EM?)

👉 首选 EM

  • 数据近似高斯混合分布
  • 需要概率归属(风险评分、异常检测、医学诊断)
  • 簇是椭圆形、不等方差
  • 语音识别、图像分割、文本主题模型

👉 不要用

  • 大数据量 → 用 K-Means
  • 噪声极多 → 用 DBSCAN
  • 不知道簇数 → 用 Mean Shift 或层次聚类

十、一句话总结

EM 算法是求解带隐藏变量概率模型的迭代框架,在 GMM 聚类中通过 E 步算概率、M 步更新分布,能处理 K-Means 无法解决的非球形、异方差数据,是机器学习最核心算法之一。

相关推荐
海海不掉头发2 小时前
【AI大模型实战项目】大模型入门实战:两个落地项目保姆级教程12月14日-【项目】基于知识库RAG的物流行业信息问答系统
人工智能·python·深度学习·语言模型·自然语言处理·pycharm·scikit-learn
2301_773553622 小时前
mysql执行SQL查询时结果不一致_检查事务隔离级别设置与幻读
jvm·数据库·python
mpr0xy2 小时前
《AI怎么一步步变聪明的?》系列(六)中国大模型崛起之路:从“追赶者”到“解题人”
人工智能·ai·大语言模型·qwen·deepseek
游了个戏2 小时前
OPC × AI × 快手:小游戏蓝海中的第三极突围
人工智能·游戏
神奇小汤圆2 小时前
Harness Engineering 时代的失败经验
人工智能
ok_hahaha2 小时前
AI从头开始-黑马LongChain-RAG开发3
人工智能
Pentane.2 小时前
【力扣hot100】【Leetcode 15】三数之和|暴力枚举 双指针 算法笔记及打卡(14/100)
数据结构·笔记·算法·leetcode
m0_377618232 小时前
mysql如何解决乱码问题_检查客户端与服务器字符集一致性
jvm·数据库·python
糖炒栗子03262 小时前
让 AI 在大项目中做修改的标准操作模板
人工智能