LDA算法核心思想
线性判别分析(LDA) 是一种经典的监督降维方法,与PCA不同,LDA利用类别信息寻找能够最大化类间区分度的特征投影方向。
算法优势
- 保留类别判别信息:投影后不同类别样本尽可能分开
- 提升分类性能:降维后的特征更有利于后续分类任务
- 处理多类别问题:可扩展到多个类别的情况
LDA算法原理
数学基础
LDA寻找投影向量 w,使得投影后的数据满足:
- 类间散度最大化
- 类内散度最小化
目标函数:
J(w) = (wᵀS_B w) / (wᵀS_W w)
其中:
S_B:类间散度矩阵S_W:类内散度矩阵
MATLAB实现代码
完整LDA降维函数
matlab
function [Y, W, eigenvalues] = myLDA(X, labels, targetDim)
% MYLDA 线性判别分析降维
% 输入:
% X - 数据矩阵 (n×d),n个样本,d个特征
% labels - 类别标签 (n×1)
% targetDim - 目标维度
% 输出:
% Y - 降维后的数据 (n×targetDim)
% W - 投影矩阵 (d×targetDim)
% eigenvalues - 特征值
[n, d] = size(X);
classes = unique(labels);
k = length(classes);
% 检查目标维度是否合理
maxDim = k - 1;
if targetDim > maxDim
warning('目标维度不能超过类别数-1,自动调整为%d', maxDim);
targetDim = maxDim;
end
% 计算总体均值
totalMean = mean(X, 1);
% 初始化散度矩阵
S_W = zeros(d, d); % 类内散度矩阵
S_B = zeros(d, d); % 类间散度矩阵
% 计算每个类的统计量
for i = 1:k
classIdx = (labels == classes(i));
X_class = X(classIdx, :);
n_class = sum(classIdx);
% 类内散度
classMean = mean(X_class, 1);
S_W = S_W + (X_class - classMean)' * (X_class - classMean);
% 类间散度
meanDiff = (classMean - totalMean)';
S_B = S_B + n_class * (meanDiff * meanDiff');
end
% 解决广义特征值问题:S_B * w = λ * S_W * w
[eigenvectors, eigenvalues] = eig(S_B, S_W);
% 排序特征值(降序)
[eigenvalues, idx] = sort(real(diag(eigenvalues)), 'descend');
W = real(eigenvectors(:, idx(1:targetDim)));
% 数据投影
Y = X * W;
end
可视化与评估函数
matlab
function visualizeLDAresults(X_original, Y_lda, labels, originalDim, targetDim)
% 可视化LDA降维结果
figure('Position', [100, 100, 1200, 400]);
% 原始数据可视化(前2个特征)
subplot(1, 3, 1);
if originalDim >= 2
gscatter(X_original(:,1), X_original(:,2), labels);
title('原始数据(前两个特征)');
xlabel('特征1'); ylabel('特征2');
else
histogram(X_original);
title('原始数据分布');
xlabel('特征值'); ylabel('频数');
end
grid on;
% LDA降维结果可视化
subplot(1, 3, 2);
if targetDim >= 2
gscatter(Y_lda(:,1), Y_lda(:,2), labels);
title('LDA降维结果(前两个判别方向)');
xlabel('LDA方向1'); ylabel('LDA方向2');
else
gscatter(Y_lda, zeros(size(Y_lda)), labels);
title('LDA降维结果(一维)');
xlabel('LDA方向1'); ylabel('');
end
grid on;
% 类别分离度分析
subplot(1, 3, 3);
classSeparation = evaluateClassSeparation(Y_lda, labels);
bar(classSeparation);
title('各类别在LDA空间的分离度');
xlabel('类别'); ylabel('分离度指标');
grid on;
end
function separation = evaluateClassSeparation(Y, labels)
% 评估各类别在降维空间的分离程度
classes = unique(labels);
k = length(classes);
separation = zeros(k, 1);
for i = 1:k
classIdx = (labels == classes(i));
otherIdx = ~classIdx;
% 计算类内距离和类间距离的比值
intraDist = pdist2(Y(classIdx,:), Y(classIdx,:));
interDist = pdist2(Y(classIdx,:), Y(otherIdx,:));
separation(i) = mean(interDist(:)) / mean(intraDist(:));
end
end
实际应用示例
示例1:鸢尾花数据集降维
matlab
% 加载经典鸢尾花数据集
load fisheriris;
X = meas; % 150个样本,4个特征
labels = species; % 3个类别
% 将标签转换为数值
[~, ~, numericLabels] = unique(labels);
% LDA降维到2维
targetDim = 2;
[Y_lda, W, eigenvalues] = myLDA(X, numericLabels, targetDim);
% 可视化结果
visualizeLDAresults(X, Y_lda, numericLabels, size(X,2), targetDim);
% 显示特征值信息
fprintf('前%d个特征值:\n', targetDim);
for i = 1:targetDim
fprintf(' 特征值%d: %.4f (解释度: %.2f%%)\n', ...
i, eigenvalues(i), 100*eigenvalues(i)/sum(eigenvalues));
end
示例2:手写数字识别降维
matlab
% 生成模拟的手写数字数据
rng(42); % 设置随机种子保证可重复性
% 生成3类手写数字的模拟特征(64维)
nSamples = 300;
nFeatures = 64;
nClasses = 3;
X_digits = [];
labels_digits = [];
for i = 1:nClasses
% 每类数字有不同的特征分布
classMean = randn(1, nFeatures) * 2 + i;
classCov = diag(rand(1, nFeatures) + 0.5);
X_class = mvnrnd(classMean, classCov, nSamples);
X_digits = [X_digits; X_class];
labels_digits = [labels_digits; i*ones(nSamples, 1)];
end
% LDA降维
targetDim = 2;
[Y_digits, W_digits, eig_digits] = myLDA(X_digits, labels_digits, targetDim);
% 可视化
figure;
gscatter(Y_digits(:,1), Y_digits(:,2), labels_digits);
title('手写数字数据LDA降维');
xlabel('LDA方向1'); ylabel('LDA方向2');
legend('数字1', '数字2', '数字3');
grid on;
性能评估与比较
LDA vs PCA 对比分析
matlab
function compareLDA_PCA(X, labels)
% 比较LDA和PCA的降维效果
% LDA降维
[Y_lda, ~, eig_lda] = myLDA(X, labels, 2);
% PCA降维
[coeff, score, latent] = pca(X);
Y_pca = score(:,1:2);
% 可视化对比
figure('Position', [100, 100, 1000, 400]);
subplot(1, 2, 1);
gscatter(Y_pca(:,1), Y_pca(:,2), labels);
title('PCA降维结果');
xlabel('主成分1'); ylabel('主成分2');
grid on;
subplot(1, 2, 2);
gscatter(Y_lda(:,1), Y_lda(:,2), labels);
title('LDA降维结果');
xlabel('判别方向1'); ylabel('判别方向2');
grid on;
% 计算类别分离度
separation_pca = evaluateClassSeparation(Y_pca, labels);
separation_lda = evaluateClassSeparation(Y_lda, labels);
fprintf('平均类别分离度:\n');
fprintf(' PCA: %.4f\n', mean(separation_pca));
fprintf(' LDA: %.4f\n', mean(separation_lda));
end
注意事项与最佳实践
1. 数据预处理
matlab
% 建议的数据预处理步骤
X_normalized = zscore(X); % 标准化
% 或
X_scaled = (X - mean(X)) ./ std(X); % 标准化
2. 维度选择策略
matlab
function optimalDim = selectOptimalDimension(eigenvalues, threshold)
% 基于特征值选择最优维度
explained = cumsum(eigenvalues) / sum(eigenvalues);
optimalDim = find(explained >= threshold, 1);
fprintf('建议维度: %d (累计解释度: %.2f%%)\n', ...
optimalDim, 100*explained(optimalDim));
end
3. 异常情况处理
- 小样本问题:当特征维度大于样本数时,S_W可能奇异
- 类别不平衡:考虑加权散度矩阵
- 数值稳定性:加入正则化项
参考代码 利用LDA算法,实现数据降维 www.3dddown.com/csa/79531.html
扩展应用
多分类问题处理
LDA天然支持多分类,但目标维度受限于min(特征数, 类别数-1)
与其他算法结合
matlab
% LDA + SVM 分类管道
ldaModel = @(X, labels) myLDA(X, labels, 2);
svmModel = fitcsvm(Y_lda, labels, 'KernelFunction', 'rbf');
结果解读指南
- 特征值大小:反映对应判别方向的判别能力
- 投影向量:显示原始特征在判别方向中的贡献
- 类别分离度:量化评估降维后类别的可分性