通俗易懂讲透 EM 算法(期望最大化)|本科生/研究生一看就懂
EM 算法是带隐藏变量的概率模型求解神器 ,最经典的落地就是高斯混合模型 GMM 聚类 。它和 K-Means 很像,但更强:能分椭圆形、不同大小、不同密度的簇,还能输出每个点属于每一类的概率。
本文用大白话+生活例子+数学推导+可直接运行代码+对比总结,适合课堂笔记、实验报告。
一、EM 算法是什么?一句话讲明白
EM = E 步猜概率 + M 步更新参数,循环迭代直到收敛。
- E 步(Expectation) :用当前模型,算每个样本属于每一类的后验概率(软归属)。
- M 步(Maximization) :用这些概率,重新估计每一类的高斯参数(均值、协方差、权重)。
一句话记忆:
先瞎猜 → 按概率算分布 → 再修正猜测 → 越来越准。
二、超通俗例子:黑暗中分辨 3 种水果
你在黑暗里摸到一堆水果:苹果、香蕉、橙子。
你不知道谁是谁 ,也不知道每类长啥样,只能靠触感估计。
EM 怎么做:
- 随便猜初始参数:比如猜苹果是圆的、香蕉是长的。
- E 步:对每个水果,算它有多像苹果/香蕉/橙子。
- M 步:根据"像苹果"的所有水果,重新算苹果的平均特征。
- 循环:直到特征不再变,你就准确分出了三类。
三、EM 算法核心流程(最标准 4 步)
- 初始化
随机给定 K 个高斯分布的:均值、协方差、混合系数。 - E 步(求后验概率)
对每个样本,计算它属于第 k 个高斯的概率 γ_ik。 - M 步(更新高斯参数)
用 γ_ik 加权更新:均值、协方差、混合系数。 - 收敛判断
似然函数变化很小就停止。
四、核心公式(报告/作业直接用)
1. 后验概率(E 步核心)
γik=πk⋅N(xi∣μk,Σk)∑j=1Kπj⋅N(xi∣μj,Σj)\gamma_{ik} = \frac{\pi_k \cdot \mathcal{N}(x_i | \mu_k, \Sigma_k)}{\sum_{j=1}^K \pi_j \cdot \mathcal{N}(x_i | \mu_j, \Sigma_j)}γik=∑j=1Kπj⋅N(xi∣μj,Σj)πk⋅N(xi∣μk,Σk)
2. 混合系数(M 步)
πk=1N∑i=1Nγik\pi_k = \frac{1}{N}\sum_{i=1}^N \gamma_{ik}πk=N1∑i=1Nγik
3. 均值(M 步)
μk=∑i=1Nγikxi∑i=1Nγik\mu_k = \frac{\sum_{i=1}^N \gamma_{ik} x_i}{\sum_{i=1}^N \gamma_{ik}}μk=∑i=1Nγik∑i=1Nγikxi
4. 协方差(M 步)
Σk=∑i=1Nγik(xi−μk)(xi−μk)T∑i=1Nγik\Sigma_k = \frac{\sum_{i=1}^N \gamma_{ik} (x_i-\mu_k)(x_i-\mu_k)^T}{\sum_{i=1}^N \gamma_{ik}}Σk=∑i=1Nγik∑i=1Nγik(xi−μk)(xi−μk)T
五、EM 为什么比 K-Means 强?
- K-Means:硬分类、只认球形簇、等方差。
- EM(GMM):软分类、支持椭圆簇、可不同方差。
简单说:
K-Means 是 EM 的特例(方差固定、概率0/1)。
六、完整实战代码(可直接复制运行)
包含:手动实现 EM + 可视化迭代 + BIC 选最优簇数 + Sklearn 对比
python
# 安装依赖
# pip install numpy matplotlib scikit-learn scipy
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from sklearn.mixture import GaussianMixture
from matplotlib.patches import Ellipse
# ====================== 1. 生成测试数据 ======================
np.random.seed(42)
means = [(-2,-2), (0,3), (4,0)]
covs = [
[[1.0,0.5],[0.5,1.0]],
[[0.3,-0.2],[-0.2,0.3]],
[[0.8,0],[0,0.5]]
]
X = np.vstack([
np.random.multivariate_normal(m, c, 600)
for m,c in zip(means, covs)
])
# ====================== 2. 手动实现 EM-GMM ======================
class EMGMM:
def __init__(self, n_components=3, max_iter=100, tol=1e-4):
self.K = n_components
self.max_iter = max_iter
self.tol = tol
def _init(self, X):
n,d = X.shape
idx = np.random.choice(n, self.K, replace=False)
self.mu = X[idx]
self.cov = np.array([np.cov(X.T)]*self.K)
self.pi = np.ones(self.K)/self.K
def _e_step(self, X):
N = len(X)
gamma = np.zeros((N, self.K))
for k in range(self.K):
gamma[:,k] = self.pi[k] * multivariate_normal.pdf(X, self.mu[k], self.cov[k])
gamma /= gamma.sum(axis=1, keepdims=True)
return gamma
def _m_step(self, X, gamma):
Nk = gamma.sum(axis=0)
self.pi = Nk / len(X)
self.mu = (gamma.T @ X) / Nk[:,None]
for k in range(self.K):
diff = X - self.mu[k]
self.cov[k] = (gamma[:,k][:,None]*diff).T @ diff / Nk[k]
self.cov[k] += 1e-6 * np.eye(X.shape[1])
def fit(self, X):
self._init(X)
old_loglike = -np.inf
for i in range(self.max_iter):
gamma = self._e_step(X)
self._m_step(X, gamma)
loglike = self._loglike(X)
if abs(loglike - old_loglike) < self.tol:
break
old_loglike = loglike
return self
def _loglike(self, X):
ll = np.zeros(len(X))
for k in range(self.K):
ll += self.pi[k] * multivariate_normal.pdf(X, self.mu[k], self.cov[k])
return np.log(ll).mean()
def predict(self, X):
gamma = self._e_step(X)
return np.argmax(gamma, axis=1)
# ====================== 3. 训练与可视化 ======================
model = EMGMM(n_components=3).fit(X)
labels = model.predict(X)
plt.figure(figsize=(8,6))
plt.scatter(X[:,0], X[:,1], c=labels, s=15, cmap='viridis')
plt.scatter(model.mu[:,0], model.mu[:,1], marker='X', s=300, c='red')
plt.title('EM-GMM 聚类结果')
plt.show()
# ====================== 4. BIC 自动选最优簇数 ======================
ks = range(1,7)
bics = []
for k in ks:
gmm = GaussianMixture(k, random_state=42)
gmm.fit(X)
bics.append(gmm.bic(X))
plt.figure(figsize=(6,4))
plt.plot(ks, bics, 'o-')
plt.title('BIC 曲线 --- 越小越好')
plt.xlabel('簇数 K')
plt.show()
七、EM 算法优缺点(面试/报告必背)
✅ 优点
- 软聚类:输出概率,更贴近真实场景。
- 支持任意椭圆簇:不局限球形。
- 理论扎实:统计学基石,可用于缺失值、HMM、LDA 等。
- 结果可解释:有清晰的概率意义。
❌ 缺点
- 必须指定 K。
- 对初值敏感,易局部最优。
- 比 K-Means 慢。
- 对噪声敏感。
八、EM vs K-Means vs DBSCAN(超清晰对比)
| 算法 | 类型 | 簇形状 | 是否要K | 对噪声 | 速度 |
|---|---|---|---|---|---|
| EM-GMM | 软聚类 | 椭圆/高斯 | 是 | 一般 | 中 |
| K-Means | 硬聚类 | 球形 | 是 | 差 | 快 |
| DBSCAN | 密度聚类 | 任意 | 否 | 强 | 中 |
九、适用场景(什么时候用 EM?)
👉 首选 EM
- 数据近似高斯混合分布
- 需要概率归属(风险评分、异常检测、医学诊断)
- 簇是椭圆形、不等方差
- 语音识别、图像分割、文本主题模型
👉 不要用
- 大数据量 → 用 K-Means
- 噪声极多 → 用 DBSCAN
- 不知道簇数 → 用 Mean Shift 或层次聚类
十、一句话总结
EM 算法是求解带隐藏变量概率模型的迭代框架,在 GMM 聚类中通过 E 步算概率、M 步更新分布,能处理 K-Means 无法解决的非球形、异方差数据,是机器学习最核心算法之一。