独立成分分析 (Independent Component Analysis, ICA) 是一种用于信号分离和降维的统计方法,常用于盲源分离 (Blind Source Separation, BSS) 问题,例如音频信号分离或脑电信号 (EEG) 处理。
实现 ICA(独立成分分析)
步骤
- 生成混合信号数据:创建多个独立信号并混合它们。
- 中心化 (Centering) & 白化 (Whitening):对数据进行标准化以提高收敛速度。
- 迭代优化解混矩阵:使用非高斯性 (Negentropy) 作为优化目标,应用梯度上升法。
- 获得独立成分:通过训练的解混矩阵恢复源信号。
例子代码:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
# 1. 生成数据
torch.manual_seed(42)
num_samples = 1000
s1 = torch.sin(torch.linspace(0, 8 * torch.pi, num_samples)) # 正弦波
s2 = torch.sign(torch.sin(torch.linspace(0, 8 * torch.pi, num_samples))) # 方波
S = torch.stack([s1, s2]) # (2, num_samples)
# 2. 生成混合信号 X = A @ S
mixing_matrix = torch.tensor([[1.0, 0.5], [0.5, 1.0]], dtype=torch.float32)
X = mixing_matrix @ S # (2, num_samples)
# 3. 数据预处理 (去中心化)
X_mean = X.mean(dim=1, keepdim=True)
X_centered = X - X_mean
# 4. 白化处理 (ZCA 白化)
cov = (X_centered @ X_centered.T) / num_samples
eigvals, eigvecs = torch.linalg.eigh(cov)
eigvals = torch.clamp(eigvals, min=1e-5) # 避免负数
whitening_matrix = eigvecs @ torch.diag(1.0 / torch.sqrt(eigvals)) @ eigvecs.T
X_white = whitening_matrix @ X_centered # 白化后的数据
# 5. 定义 ICA 模型
class ICA(nn.Module):
def __init__(self, n_components):
super().__init__()
self.W = nn.Parameter(torch.eye(n_components)) # 初始化为单位矩阵
def forward(self, X):
return self.W @ X
# 6. 训练 ICA
ica = ICA(n_components=2)
optimizer = optim.Adam([ica.W], lr=0.01)
def neg_entropy(y):
return torch.mean(torch.tanh(y), dim=1)
num_epochs = 1000
for epoch in range(num_epochs):
optimizer.zero_grad()
Y = ica(X_white) # 通过 W 提取信号
loss = -torch.sum(neg_entropy(Y)) # 负熵最大化
loss.backward()
optimizer.step()
# 7. 使用 QR 分解保持 W 近似正交
with torch.no_grad():
ica.W.copy_(torch.linalg.qr(ica.W)[0]) # QR 正交化
# 8. 信号恢复
separated = ica(X_white).detach().cpu().numpy() # 确保 NumPy 兼容性
# 9. 绘图
plt.figure(figsize=(10, 5))
plt.subplot(3, 1, 1)
plt.plot(S.T.detach().cpu().numpy()) # 确保 NumPy 兼容
plt.title("Original Source Signals")
plt.subplot(3, 1, 2)
plt.plot(X.T.detach().cpu().numpy()) # 确保 NumPy 兼容
plt.title("Mixed Signals")
plt.subplot(3, 1, 3)
plt.plot(separated.T) # 直接使用 NumPy 数据
plt.title("Recovered Signals (ICA)")
plt.tight_layout()
plt.show()
代码解析
-
数据生成
- 生成两个独立信号:一个 正弦波 和一个 方波。
- 通过 随机混合矩阵 将它们混合成两个观察信号。
-
数据预处理
- 去中心化 (Centering):减去均值,使数据零均值。
- 白化 (Whitening):对数据进行 PCA 变换,确保协方差矩阵为单位矩阵,提高 ICA 的效果。
-
ICA 训练
- 定义解混矩阵 WWW ,使用 PyTorch 梯度优化。
- 采用 非高斯性(Negentropy)最大化 原则来优化,使用
tanh
近似 Negentropy。 - 梯度更新 W ,并在训练过程中 保持 W 近似正交 以防止数值发散。
-
信号恢复
- 训练完成后,
W
将学习到 解混变换 ,将X
投影到独立信号空间,即可恢复原始信号。
- 训练完成后,