基于MATLAB的高斯混合模型(GMM)实现,包含自动生成混合数据、EM算法实现及可视化模块:
一、核心代码
1. 数据生成模块
matlab
function [X, trueParams] = generateGMMData(N, K, dim)
% 生成混合高斯数据
% 输入:
% N: 总样本数
% K: 高斯分量数
% dim: 数据维度
% 输出:
% X: 生成的数据矩阵 (dim x N)
% trueParams: 真实参数结构体
rng(0); % 固定随机种子
trueParams = struct();
% 随机生成参数
trueParams.weights = rand(1,K);
trueParams.weights = trueParams.weights / sum(trueParams.weights);
trueParams.means = 10*randn(dim,K);
trueParams.covs = repmat(eye(dim), [1,1,K]);
for k=1:K
trueParams.covs(:,:,k) = trueParams.covs(:,:,k) * (0.5 + 0.5*rand);
end
% 生成数据
X = zeros(dim,N);
for k=1:K
Nk = round(N*trueParams.weights(k));
X(:,(1:Nk)+(k-1)*Nk) = mvnrnd(trueParams.means(:,k), trueParams.covs(:,:,k), Nk);
end
end
2. GMM-EM算法实现
matlab
function [mu, cov, weight, logLik] = gmm_em(X, K, maxIter, tol)
% 输入:
% X: 数据矩阵 (dim x N)
% K: 高斯分量数
% maxIter: 最大迭代次数
% tol: 收敛阈值
% 输出:
% mu: 均值矩阵 (dim x K)
% cov: 协方差矩阵 (dim x dim x K)
% weight: 权重向量 (1 x K)
% logLik: 对数似然序列
[dim, N] = size(X);
logLik = zeros(maxIter,1);
% 初始化参数
mu = 10*randn(dim,K);
cov = repmat(eye(dim), [1,1,K]);
weight = ones(1,K)/K;
% 预计算对数权重
logWeight = log(weight);
for iter=1:maxIter
% E步:计算后验概率
logPost = zeros(N,K);
for k=1:K
logPost(:,k) = logWeight(k) + log_mvnpdf(X, mu(:,k), cov(:,:,k));
end
logSum = logsumexp(logPost, 2);
post = exp(logPost - logSum);
% M步:更新参数
Nk = sum(post,1);
weight = Nk/N;
for k=1:K
mu(:,k) = (X * post(:,k)) / Nk(k);
diff = X - mu(:,k);
cov(:,:,k) = (diff * diag(post(:,k)) * diff') / Nk(k);
% 正则化协方差矩阵
cov(:,:,k) = (cov(:,:,k) + cov(:,:,k)')/2 + 1e-6*eye(dim);
end
% 计算对数似然
logLik(iter) = sum(logsumexp(logPost,2));
% 检查收敛
if iter > 1 && abs(logLik(iter)-logLik(iter-1)) < tol
logLik = logLik(1:iter);
break;
end
end
end
function y = log_mvnpdf(X, mu, Sigma)
% 计算多元高斯分布对数概率密度
[dim,N] = size(X);
X_centered = X - mu;
[R, p] = chol(Sigma);
if p ~= 0
y = -inf(1,N);
return;
end
y = -0.5 * sum((X_centered/R).^2, 1) - 0.5*dim*log(2*pi) - sum(log(diag(R)));
end
function s = logsumexp(x, dim)
% 稳定计算log-sum-exp
x_max = max(x,[],dim);
s = x_max + log(sum(exp(x - x_max), dim));
end
二、完整示例Demo
1. 数据生成与可视化
matlab
% 生成数据
N = 1000; % 总样本数
K = 3; % 高斯分量数
[X, trueParams] = generateGMMData(N, K, 2);
% 可视化原始数据
figure;
scatter(X(1,:), X(2,:), 10, 'filled');
hold on;
for k=1:K
plot(trueParams.means(1,k), trueParams.means(2,k), 'kx', 'MarkerSize', 15, 'LineWidth', 2);
end
title('原始数据分布');
hold off;
2. 模型训练
matlab
% 设置参数
maxIter = 100;
tol = 1e-5;
% 运行EM算法
[mu, cov, weight, logLik] = gmm_em(X, K, maxIter, tol);
% 显示训练结果
disp('估计参数:');
disp('均值:');
disp(mu');
disp('协方差:');
disp(cov(:,:,1));
disp(['对数似然: ', num2str(logLik(end))]);
3. 结果可视化
matlab
% 绘制等高线
figure;
hold on;
scatter(X(1,:), X(2,:), 10, 'filled');
[x1, x2] = meshgrid(linspace(min(X(1,:)), max(X(1,:)),50), ...
linspace(min(X(2,:)), max(X(2,:)),50));
X_grid = [x1(:)'; x2(:)'];
post = zeros(size(X_grid,2), K);
% 计算后验概率
for k=1:K
post(:,k) = weight(k) * mvnpdf(X_grid', mu(:,k), cov(:,:,k));
end
post = post ./ sum(post,2);
% 绘制概率密度等高线
contour(x1, x2, reshape(post(:,1), size(x1)), [0.01,0.1,0.5,1], 'LineColor', 'r');
contour(x1, x2, reshape(post(:,2), size(x1)), [0.01,0.1,0.5,1], 'LineColor', 'g');
contour(x1, x2, reshape(post(:,3), size(x1)), [0.01,0.1,0.5,1], 'LineColor', 'b');
title('GMM概率密度估计');
hold off;
三、关键优化策略
1. 协方差矩阵正则化
matlab
% 在M步更新协方差时添加正则化
cov(:,:,k) = (diff * diag(post(:,k)) * diff') / Nk(k) + 1e-6*eye(dim);
2. 并行计算加速
matlab
% 使用parfor加速E步计算
parfor k=1:K
logPost(:,k) = logWeight(k) + log_mvnpdf(X, mu(:,k), cov(:,:,k));
end
3. 收敛性增强
matlab
% 添加参数变化监控
mu_prev = mu;
for iter=1:maxIter
% ... M步更新参数 ...
if norm(mu - mu_prev, 'fro') < tol
break;
end
mu_prev = mu;
end
四、性能评估
matlab
% 计算估计误差
trueMu = trueParams.means;
estMu = mu;
trueCov = trueParams.covs;
estCov = cov;
% 均值误差
mu_err = mean(sqrt(sum((trueMu - estMu).^2, 1)));
% 协方差误差
cov_err = mean(vecnorm(trueCov(:) - estCov(:)));
disp(['均值平均误差: ', num2str(mu_err)]);
disp(['协方差平均误差: ', num2str(cov_err)]);
五、扩展应用场景
1. 图像分割
matlab
% 加载图像
img = imread('peppers.png');
img_gray = rgb2gray(img);
img_double = im2double(img_gray);
% 将图像转换为数据矩阵
[X_img, Y_img] = meshgrid(1:size(img_double,2), 1:size(img_double,1));
X_img = X_img(:);
Y_img = Y_img(:);
img_data = [X_img, Y_img]';
% 运行GMM聚类
K = 5;
[mu_img, cov_img, weight_img, ~] = gmm_em(img_data, K, 100, 1e-5);
% 生成分割结果
segmented = reshape(mode(bsxfun(@plus, mu_img', cov_img(:,:,1)), 2), size(img));
imshow(label2rgb(segmented));
2. 语音特征聚类
matlab
% 加载语音MFCC特征
load('mfcc_data.mat'); % 假设包含1000个20维MFCC特征向量
% 运行GMM聚类
K = 10;
[mu, cov, weight, ~] = gmm_em(mfcc_data, K, 200, 1e-5);
% 计算软分类
post = zeros(size(mfcc_data,2), K);
for k=1:K
post(:,k) = weight(k) * mvnpdf(mfcc_data', mu(:,k), cov(:,:,k));
end
post = post ./ sum(post,2);
六、注意事项
-
初始化策略:建议先用K-means初始化均值
matlab[idx, kmeans_mu] = kmeans(X', K); mu = kmeans_mu';
-
维度灾难:高维数据建议使用对角协方差矩阵
matlabcov(:,:,k) = diag(diag(cov(:,:,k))); % 强制对角
-
数值稳定性:计算概率密度时使用对数域
matlablog_prob = log_mvnpdf(X, mu, cov);
七、参考
- Bishop, C. M. (2006). Pattern Recognition and Machine Learning. Springer.
- Murphy, K. P. (2012). Machine Learning: A Probabilistic Perspective. MIT Press.
- 代码 matlab 实现GMM------EM算法 www.youwenfan.com/contentcsi/63293.html
- MathWorks官方网页:
fitgmdist
函数说明