支持向量机多分类解决方案
SVM多分类MATLAB实现
1. 主函数:SVM多分类训练与测试
matlab
function [svm_model, test_accuracy, confusion_matrix] = svm_multiclass_train_test(X_train, y_train, X_test, y_test, varargin)
% SVM多分类训练与测试
% 输入:
% X_train - 训练特征 (n_samples x n_features)
% y_train - 训练标签 (n_samples x 1)
% X_test - 测试特征
% y_test - 测试标签
% 可选参数:
% 'kernel' - 核函数: 'linear', 'rbf', 'polynomial'
% 'box_constraint' - 惩罚参数C
% 'kernel_scale' - 核尺度
% 'standardize' - 是否标准化数据
% 输出:
% svm_model - 训练好的SVM模型
% test_accuracy - 测试集准确率
% confusion_matrix - 混淆矩阵
% 参数解析
p = inputParser;
addParameter(p, 'kernel', 'linear', @ischar);
addParameter(p, 'box_constraint', 1, @isnumeric);
addParameter(p, 'kernel_scale', 'auto', @(x) ischar(x) || isnumeric(x));
addParameter(p, 'standardize', true, @islogical);
addParameter(p, 'polynomial_order', 3, @isnumeric);
addParameter(p, 'coding', 'onevsone', @ischar);
addParameter(p, 'verbose', true, @islogical);
parse(p, varargin{:});
params = p.Results;
if params.verbose
fprintf('开始SVM多分类训练...\n');
fprintf('训练集大小: %d 样本, %d 特征\n', size(X_train));
fprintf('测试集大小: %d 样本\n', size(X_test, 1));
fprintf('类别数量: %d\n', length(unique(y_train)));
end
% 数据预处理
if params.standardize
[X_train, mu, sigma] = zscore(X_train);
X_test = (X_test - mu) ./ sigma;
if params.verbose
fprintf('数据标准化完成\n');
end
end
% 模板设置
switch params.kernel
case 'linear'
template = templateSVM(...
'KernelFunction', 'linear', ...
'BoxConstraint', params.box_constraint, ...
'Standardize', false, ... % 已经在前面标准化了
'SaveSupportVectors', true);
case 'rbf'
template = templateSVM(...
'KernelFunction', 'rbf', ...
'BoxConstraint', params.box_constraint, ...
'KernelScale', params.kernel_scale, ...
'Standardize', false, ...
'SaveSupportVectors', true);
case 'polynomial'
template = templateSVM(...
'KernelFunction', 'polynomial', ...
'BoxConstraint', params.box_constraint, ...
'KernelScale', params.kernel_scale, ...
'PolynomialOrder', params.polynomial_order, ...
'Standardize', false, ...
'SaveSupportVectors', true);
otherwise
error('不支持的核函数: %s', params.kernel);
end
% 训练多分类SVM模型
tic;
svm_model = fitcecoc(X_train, y_train, ...
'Learners', template, ...
'Coding', params.coding, ...
'Verbose', params.verbose);
training_time = toc;
if params.verbose
fprintf('模型训练完成,耗时: %.2f 秒\n', training_time);
end
% 训练集预测
y_train_pred = predict(svm_model, X_train);
train_accuracy = sum(y_train_pred == y_train) / length(y_train);
% 测试集预测
y_test_pred = predict(svm_model, X_test);
test_accuracy = sum(y_test_pred == y_test) / length(y_test);
% 混淆矩阵
confusion_matrix = confusionmat(y_test, y_test_pred);
if params.verbose
fprintf('训练集准确率: %.2f%%\n', train_accuracy * 100);
fprintf('测试集准确率: %.2f%%\n', test_accuracy * 100);
% 各类别精度
class_report = compute_class_metrics(y_test, y_test_pred);
display_classification_report(class_report);
end
% 保存模型信息
svm_model.Info.TrainingTime = training_time;
svm_model.Info.TrainAccuracy = train_accuracy;
svm_model.Info.TestAccuracy = test_accuracy;
svm_model.Info.Parameters = params;
end
function [X_norm, mu, sigma] = zscore(X)
% 标准化数据
mu = mean(X, 1);
sigma = std(X, 0, 1);
sigma(sigma == 0) = 1; % 避免除零
X_norm = (X - mu) ./ sigma;
end
function class_report = compute_class_metrics(y_true, y_pred)
% 计算各类别评估指标
classes = unique(y_true);
n_classes = length(classes);
class_report = struct();
for i = 1:n_classes
class = classes(i);
% 二值化标签
true_pos = (y_true == class) & (y_pred == class);
false_pos = (y_true ~= class) & (y_pred == class);
false_neg = (y_true == class) & (y_pred ~= class);
true_neg = (y_true ~= class) & (y_pred ~= class);
TP = sum(true_pos);
FP = sum(false_pos);
FN = sum(false_neg);
TN = sum(true_neg);
% 计算指标
precision = TP / (TP + FP + eps);
recall = TP / (TP + FN + eps);
f1_score = 2 * (precision * recall) / (precision + recall + eps);
specificity = TN / (TN + FP + eps);
accuracy = (TP + TN) / (TP + TN + FP + FN);
class_report(i).Class = class;
class_report(i).Precision = precision;
class_report(i).Recall = recall;
class_report(i).F1_Score = f1_score;
class_report(i).Specificity = specificity;
class_report(i).Accuracy = accuracy;
class_report(i).Support = sum(y_true == class);
end
end
function display_classification_report(class_report)
% 显示分类报告
fprintf('\n分类详细报告:\n');
fprintf('%-10s %-10s %-10s %-10s %-10s %-10s %-10s\n', ...
'Class', 'Precision', 'Recall', 'F1-Score', 'Specificity', 'Accuracy', 'Support');
fprintf('%-10s %-10s %-10s %-10s %-10s %-10s %-10s\n', ...
'-----', '---------', '------', '--------', '----------', '--------', '-------');
for i = 1:length(class_report)
cr = class_report(i);
fprintf('%-10d %-10.3f %-10.3f %-10.3f %-10.3f %-10.3f %-10d\n', ...
cr.Class, cr.Precision, cr.Recall, cr.F1_Score, ...
cr.Specificity, cr.Accuracy, cr.Support);
end
end
2. 交叉验证与参数优化
matlab
function [best_model, best_params, cv_results] = svm_hyperparameter_tuning(X, y, varargin)
% SVM超参数调优
% 使用交叉验证寻找最佳参数
p = inputParser;
addParameter(p, 'cv_folds', 5, @isnumeric);
addParameter(p, 'kernel_types', {'linear', 'rbf'}, @iscell);
addParameter(p, 'c_values', 2.^(-5:2:15), @isnumeric);
addParameter(p, 'gamma_values', 2.^(-15:2:3), @isnumeric); % 用于RBF核
addParameter(p, 'verbose', true, @islogical);
parse(p, varargin{:});
params = p.Results;
if params.verbose
fprintf('开始SVM超参数调优...\n');
fprintf('交叉验证折数: %d\n', params.cv_folds);
end
% 准备参数组合
param_combinations = [];
idx = 1;
for k = 1:length(params.kernel_types)
kernel = params.kernel_types{k};
switch kernel
case 'linear'
for c_idx = 1:length(params.c_values)
param_combinations(idx).Kernel = kernel;
param_combinations(idx).BoxConstraint = params.c_values(c_idx);
param_combinations(idx).KernelScale = 1;
idx = idx + 1;
end
case 'rbf'
for c_idx = 1:length(params.c_values)
for g_idx = 1:length(params.gamma_values)
param_combinations(idx).Kernel = kernel;
param_combinations(idx).BoxConstraint = params.c_values(c_idx);
param_combinations(idx).KernelScale = 1/sqrt(2 * params.gamma_values(g_idx));
idx = idx + 1;
end
end
end
end
if params.verbose
fprintf('参数组合总数: %d\n', length(param_combinations));
end
% 交叉验证
cv_results = struct();
best_accuracy = 0;
for i = 1:length(param_combinations)
param = param_combinations(i);
if params.verbose && mod(i, 10) == 0
fprintf('正在评估第 %d/%d 个参数组合...\n', i, length(param_combinations));
end
% 设置模板
template = templateSVM(...
'KernelFunction', param.Kernel, ...
'BoxConstraint', param.BoxConstraint, ...
'KernelScale', param.KernelScale, ...
'Standardize', true);
% 交叉验证
cv_model = fitcecoc(X, y, ...
'Learners', template, ...
'KFold', params.cv_folds, ...
'Verbose', false);
% 计算准确率
cv_accuracy = 1 - kfoldLoss(cv_model, 'LossFun', 'ClassifError');
% 存储结果
cv_results(i).Parameters = param;
cv_results(i).CVAccuracy = cv_accuracy;
cv_results(i).Std = std(cv_accuracy);
% 更新最佳参数
if cv_accuracy > best_accuracy
best_accuracy = cv_accuracy;
best_params = param;
best_cv_model = cv_model;
end
end
% 训练最终模型
best_template = templateSVM(...
'KernelFunction', best_params.Kernel, ...
'BoxConstraint', best_params.BoxConstraint, ...
'KernelScale', best_params.KernelScale, ...
'Standardize', true);
best_model = fitcecoc(X, y, ...
'Learners', best_template, ...
'Verbose', params.verbose);
if params.verbose
fprintf('超参数调优完成!\n');
fprintf('最佳参数: Kernel=%s, C=%.4f, Scale=%.4f\n', ...
best_params.Kernel, best_params.BoxConstraint, best_params.KernelScale);
fprintf('最佳交叉验证准确率: %.2f%%\n', best_accuracy * 100);
end
end
3. 可视化与结果分析
matlab
function visualize_svm_results(svm_model, X_test, y_test, varargin)
% 可视化SVM分类结果
p = inputParser;
addParameter(p, 'feature_names', {}, @iscell);
addParameter(p, 'class_names', {}, @iscell);
addParameter(p, 'pca_visualize', true, @islogical);
addParameter(p, 'plot_confusion', true, @islogical);
addParameter(p, 'plot_roc', true, @islogical);
parse(p, varargin{:});
params = p.Results;
% 预测
y_pred = predict(svm_model, X_test);
% 创建图形窗口
figure('Position', [100, 100, 1400, 900]);
% 1. 混淆矩阵
if params.plot_confusion
subplot(2, 3, 1);
cm = confusionchart(y_test, y_pred);
title('混淆矩阵');
% 计算并显示准确率
accuracy = sum(y_test == y_pred) / length(y_test);
text(0.5, -0.1, sprintf('准确率: %.2f%%', accuracy * 100), ...
'Units', 'normalized', 'HorizontalAlignment', 'center', ...
'FontSize', 12, 'FontWeight', 'bold');
end
% 2. PCA可视化(如果特征维度>2)
if params.pca_visualize && size(X_test, 2) > 2
subplot(2, 3, 2);
% PCA降维
[coeff, score] = pca(X_test);
X_pca = score(:, 1:2);
% 绘制散点图
gscatter(X_pca(:, 1), X_pca(:, 2), y_test);
hold on;
% 标记错误分类的点
misclassified = y_test ~= y_pred;
plot(X_pca(misclassified, 1), X_pca(misclassified, 2), 'kx', ...
'MarkerSize', 10, 'LineWidth', 2, 'DisplayName', '错误分类');
xlabel('第一主成分');
ylabel('第二主成分');
title('PCA可视化 (前两个主成分)');
legend('show');
grid on;
end
% 3. 特征重要性(仅对线性核)
if isequal(svm_model.Info.Parameters.kernel, 'linear') && ~isempty(params.feature_names)
subplot(2, 3, 3);
% 提取权重(需要访问二分类器的权重)
feature_weights = compute_feature_importance(svm_model);
barh(feature_weights);
yticks(1:length(params.feature_names));
yticklabels(params.feature_names);
xlabel('特征权重');
title('特征重要性 (线性SVM)');
grid on;
end
% 4. 准确率随训练样本数量的学习曲线
subplot(2, 3, 4);
plot_learning_curve(svm_model, X_test, y_test);
% 5. 各类别性能指标
subplot(2, 3, 5);
class_report = compute_class_metrics(y_test, y_pred);
metrics = zeros(length(class_report), 3);
for i = 1:length(class_report)
metrics(i, 1) = class_report(i).Precision;
metrics(i, 2) = class_report(i).Recall;
metrics(i, 3) = class_report(i).F1_Score;
end
bar(metrics);
legend('精确率', '召回率', 'F1分数', 'Location', 'best');
xlabel('类别');
ylabel('分数');
title('各类别性能指标');
grid on;
% 6. 模型比较(如果有多个模型)
subplot(2, 3, 6);
if isfield(svm_model.Info, 'ComparisonResults')
plot_model_comparison(svm_model.Info.ComparisonResults);
else
% 显示模型信息
info_str = sprintf('核函数: %s\nC参数: %.2f\n训练时间: %.2fs\n测试准确率: %.2f%%', ...
svm_model.Info.Parameters.kernel, ...
svm_model.Info.Parameters.box_constraint, ...
svm_model.Info.TrainingTime, ...
svm_model.Info.TestAccuracy * 100);
text(0.1, 0.5, info_str, 'FontSize', 12, 'VerticalAlignment', 'middle');
title('模型信息');
axis off;
end
sgtitle('SVM多分类结果分析', 'FontSize', 14, 'FontWeight', 'bold');
end
function weights = compute_feature_importance(svm_model)
% 计算特征重要性(线性SVM)
try
% 获取所有二分类器
binary_learners = svm_model.BinaryLearners;
n_features = size(svm_model.X, 2);
weights = zeros(n_features, 1);
for i = 1:length(binary_learners)
if isprop(binary_learners{i}, 'Beta')
weights = weights + abs(binary_learners{i}.Beta);
end
end
weights = weights / length(binary_learners);
catch
weights = [];
end
end
function plot_learning_curve(svm_model, X_test, y_test)
% 绘制学习曲线
X_train = svm_model.X;
y_train = svm_model.Y;
train_sizes = round(linspace(0.1 * length(y_train), length(y_train), 10));
train_accuracies = zeros(size(train_sizes));
test_accuracies = zeros(size(train_sizes));
for i = 1:length(train_sizes)
n_samples = train_sizes(i);
% 随机选择子集
idx = randperm(length(y_train), n_samples);
X_subset = X_train(idx, :);
y_subset = y_train(idx);
% 训练模型
temp_model = fitcecoc(X_subset, y_subset, ...
'Learners', svm_model.BinaryLearners{1}.Template, ...
'Coding', svm_model.CodingName);
% 计算准确率
train_pred = predict(temp_model, X_subset);
test_pred = predict(temp_model, X_test);
train_accuracies(i) = sum(train_pred == y_subset) / length(y_subset);
test_accuracies(i) = sum(test_pred == y_test) / length(y_test);
end
plot(train_sizes, train_accuracies * 100, 'b-o', 'LineWidth', 2, 'DisplayName', '训练集');
hold on;
plot(train_sizes, test_accuracies * 100, 'r-o', 'LineWidth', 2, 'DisplayName', '测试集');
xlabel('训练样本数量');
ylabel('准确率 (%)');
title('学习曲线');
legend('show');
grid on;
end
4. 完整使用示例
matlab
% SVM多分类完整示例
function run_svm_multiclass_example()
% 运行SVM多分类示例
fprintf('=== SVM多分类完整示例 ===\n\n');
% 1. 生成示例数据(3类)
rng(42); % 设置随机种子确保可重复性
n_samples = 1000;
n_features = 4;
n_classes = 3;
% 生成高斯分布数据
[X, y] = generate_multiclass_data(n_samples, n_features, n_classes);
fprintf('生成数据信息:\n');
fprintf('样本数量: %d\n', n_samples);
fprintf('特征数量: %d\n', n_features);
fprintf('类别数量: %d\n', n_classes);
fprintf('类别分布:\n');
tabulate(y);
% 2. 划分训练集和测试集
cv = cvpartition(y, 'HoldOut', 0.3);
X_train = X(training(cv), :);
y_train = y(training(cv), :);
X_test = X(test(cv), :);
y_test = y(test(cv), :);
fprintf('\n数据划分:\n');
fprintf('训练集: %d 样本\n', size(X_train, 1));
fprintf('测试集: %d 样本\n', size(X_test, 1));
% 3. 超参数调优
fprintf('\n=== 超参数调优 ===\n');
[best_model, best_params, cv_results] = svm_hyperparameter_tuning(...
X_train, y_train, ...
'kernel_types', {'linear', 'rbf'}, ...
'c_values', 2.^(-1:3), ...
'gamma_values', 2.^(-5:0), ...
'cv_folds', 5, ...
'verbose', true);
% 4. 使用最佳参数训练最终模型
fprintf('\n=== 最终模型训练 ===\n');
final_model = svm_multiclass_train_test(...
X_train, y_train, X_test, y_test, ...
'kernel', best_params.Kernel, ...
'box_constraint', best_params.BoxConstraint, ...
'kernel_scale', best_params.KernelScale, ...
'standardize', true, ...
'verbose', true);
% 5. 可视化结果
fprintf('\n=== 结果可视化 ===\n');
feature_names = {'特征1', '特征2', '特征3', '特征4'};
class_names = {'类别A', '类别B', '类别C'};
visualize_svm_results(final_model, X_test, y_test, ...
'feature_names', feature_names, ...
'class_names', class_names, ...
'pca_visualize', true, ...
'plot_confusion', true, ...
'plot_roc', true);
% 6. 模型比较
fprintf('\n=== 不同核函数比较 ===\n');
compare_different_kernels(X_train, y_train, X_test, y_test);
fprintf('\n示例运行完成!\n');
end
function [X, y] = generate_multiclass_data(n_samples, n_features, n_classes)
% 生成多分类示例数据
X = zeros(n_samples, n_features);
y = zeros(n_samples, 1);
samples_per_class = floor(n_samples / n_classes);
for i = 1:n_classes
start_idx = (i-1) * samples_per_class + 1;
end_idx = i * samples_per_class;
% 为每个类别生成不同的均值和协方差
mean_val = (i-1) * 2;
covariance = eye(n_features) * (0.5 + i*0.2);
X_class = mvnrnd(repmat(mean_val, 1, n_features), covariance, samples_per_class);
X(start_idx:end_idx, :) = X_class;
y(start_idx:end_idx) = i;
end
% 随机打乱数据
idx = randperm(n_samples);
X = X(idx, :);
y = y(idx);
end
function compare_different_kernels(X_train, y_train, X_test, y_test)
% 比较不同核函数的性能
kernels = {'linear', 'rbf', 'polynomial'};
results = struct();
fprintf('%-12s %-12s %-12s %-12s\n', 'Kernel', 'Train Acc', 'Test Acc', 'Time(s)');
fprintf('%-12s %-12s %-12s %-12s\n', '------', '---------', '--------', '-------');
for i = 1:length(kernels)
tic;
[model, test_acc, ~] = svm_multiclass_train_test(...
X_train, y_train, X_test, y_test, ...
'kernel', kernels{i}, ...
'box_constraint', 1, ...
'standardize', true, ...
'verbose', false);
time_elapsed = toc;
train_acc = model.Info.TrainAccuracy;
fprintf('%-12s %-12.2f %-12.2f %-12.2f\n', ...
kernels{i}, train_acc*100, test_acc*100, time_elapsed);
results(i).Kernel = kernels{i};
results(i).TrainAccuracy = train_acc;
results(i).TestAccuracy = test_acc;
results(i).Time = time_elapsed;
end
% 可视化比较
figure('Position', [200, 200, 1000, 400]);
subplot(1, 2, 1);
accuracies = [results.TestAccuracy] * 100;
bar(accuracies);
set(gca, 'XTickLabel', {results.Kernel});
ylabel('测试准确率 (%)');
title('不同核函数性能比较');
grid on;
subplot(1, 2, 2);
times = [results.Time];
bar(times);
set(gca, 'XTickLabel', {results.Kernel});
ylabel('训练时间 (秒)');
title('训练时间比较');
grid on;
end
参考代码 支持向量机,用于分类,含训练集与测试集,用SVM进行多分类 www.3dddown.com/csa/64920.html
使用说明
1. 基本使用方法
matlab
% 加载您的数据
% load('your_data.mat'); % X, y
% 简单调用
[model, accuracy, cm] = svm_multiclass_train_test(X_train, y_train, X_test, y_test);
% 自定义参数
[model, accuracy, cm] = svm_multiclass_train_test(X_train, y_train, X_test, y_test, ...
'kernel', 'rbf', ...
'box_constraint', 10, ...
'kernel_scale', 'auto', ...
'standardize', true);
2. 超参数调优
matlab
% 自动寻找最佳参数
[best_model, best_params] = svm_hyperparameter_tuning(X, y, ...
'kernel_types', {'linear', 'rbf'}, ...
'c_values', [0.1, 1, 10, 100], ...
'cv_folds', 5);
3. 关键参数说明
-
kernel: 核函数选择
linear: 线性核,适合线性可分数据rbf: 高斯核,适合非线性数据polynomial: 多项式核
-
box_constraint: 惩罚参数C
- 较小值:允许更多误分类,模型更简单
- 较大值:严格惩罚误分类,可能过拟合
-
kernel_scale: 核尺度
- 较小值:复杂的决策边界
- 较大值:平滑的决策边界