支持向量机多分类解决方案

支持向量机多分类解决方案

SVM多分类MATLAB实现

1. 主函数:SVM多分类训练与测试

matlab 复制代码
function [svm_model, test_accuracy, confusion_matrix] = svm_multiclass_train_test(X_train, y_train, X_test, y_test, varargin)
% SVM多分类训练与测试
% 输入:
%   X_train - 训练特征 (n_samples x n_features)
%   y_train - 训练标签 (n_samples x 1)
%   X_test - 测试特征
%   y_test - 测试标签
% 可选参数:
%   'kernel' - 核函数: 'linear', 'rbf', 'polynomial'
%   'box_constraint' - 惩罚参数C
%   'kernel_scale' - 核尺度
%   'standardize' - 是否标准化数据
% 输出:
%   svm_model - 训练好的SVM模型
%   test_accuracy - 测试集准确率
%   confusion_matrix - 混淆矩阵

    % 参数解析
    p = inputParser;
    addParameter(p, 'kernel', 'linear', @ischar);
    addParameter(p, 'box_constraint', 1, @isnumeric);
    addParameter(p, 'kernel_scale', 'auto', @(x) ischar(x) || isnumeric(x));
    addParameter(p, 'standardize', true, @islogical);
    addParameter(p, 'polynomial_order', 3, @isnumeric);
    addParameter(p, 'coding', 'onevsone', @ischar);
    addParameter(p, 'verbose', true, @islogical);
    
    parse(p, varargin{:});
    params = p.Results;
    
    if params.verbose
        fprintf('开始SVM多分类训练...\n');
        fprintf('训练集大小: %d 样本, %d 特征\n', size(X_train));
        fprintf('测试集大小: %d 样本\n', size(X_test, 1));
        fprintf('类别数量: %d\n', length(unique(y_train)));
    end
    
    % 数据预处理
    if params.standardize
        [X_train, mu, sigma] = zscore(X_train);
        X_test = (X_test - mu) ./ sigma;
        if params.verbose
            fprintf('数据标准化完成\n');
        end
    end
    
    % 模板设置
    switch params.kernel
        case 'linear'
            template = templateSVM(...
                'KernelFunction', 'linear', ...
                'BoxConstraint', params.box_constraint, ...
                'Standardize', false, ... % 已经在前面标准化了
                'SaveSupportVectors', true);
            
        case 'rbf'
            template = templateSVM(...
                'KernelFunction', 'rbf', ...
                'BoxConstraint', params.box_constraint, ...
                'KernelScale', params.kernel_scale, ...
                'Standardize', false, ...
                'SaveSupportVectors', true);
            
        case 'polynomial'
            template = templateSVM(...
                'KernelFunction', 'polynomial', ...
                'BoxConstraint', params.box_constraint, ...
                'KernelScale', params.kernel_scale, ...
                'PolynomialOrder', params.polynomial_order, ...
                'Standardize', false, ...
                'SaveSupportVectors', true);
            
        otherwise
            error('不支持的核函数: %s', params.kernel);
    end
    
    % 训练多分类SVM模型
    tic;
    svm_model = fitcecoc(X_train, y_train, ...
        'Learners', template, ...
        'Coding', params.coding, ...
        'Verbose', params.verbose);
    training_time = toc;
    
    if params.verbose
        fprintf('模型训练完成,耗时: %.2f 秒\n', training_time);
    end
    
    % 训练集预测
    y_train_pred = predict(svm_model, X_train);
    train_accuracy = sum(y_train_pred == y_train) / length(y_train);
    
    % 测试集预测
    y_test_pred = predict(svm_model, X_test);
    test_accuracy = sum(y_test_pred == y_test) / length(y_test);
    
    % 混淆矩阵
    confusion_matrix = confusionmat(y_test, y_test_pred);
    
    if params.verbose
        fprintf('训练集准确率: %.2f%%\n', train_accuracy * 100);
        fprintf('测试集准确率: %.2f%%\n', test_accuracy * 100);
        
        % 各类别精度
        class_report = compute_class_metrics(y_test, y_test_pred);
        display_classification_report(class_report);
    end
    
    % 保存模型信息
    svm_model.Info.TrainingTime = training_time;
    svm_model.Info.TrainAccuracy = train_accuracy;
    svm_model.Info.TestAccuracy = test_accuracy;
    svm_model.Info.Parameters = params;
end

function [X_norm, mu, sigma] = zscore(X)
% 标准化数据
    mu = mean(X, 1);
    sigma = std(X, 0, 1);
    sigma(sigma == 0) = 1; % 避免除零
    X_norm = (X - mu) ./ sigma;
end

function class_report = compute_class_metrics(y_true, y_pred)
% 计算各类别评估指标
    classes = unique(y_true);
    n_classes = length(classes);
    
    class_report = struct();
    
    for i = 1:n_classes
        class = classes(i);
        
        % 二值化标签
        true_pos = (y_true == class) & (y_pred == class);
        false_pos = (y_true ~= class) & (y_pred == class);
        false_neg = (y_true == class) & (y_pred ~= class);
        true_neg = (y_true ~= class) & (y_pred ~= class);
        
        TP = sum(true_pos);
        FP = sum(false_pos);
        FN = sum(false_neg);
        TN = sum(true_neg);
        
        % 计算指标
        precision = TP / (TP + FP + eps);
        recall = TP / (TP + FN + eps);
        f1_score = 2 * (precision * recall) / (precision + recall + eps);
        specificity = TN / (TN + FP + eps);
        accuracy = (TP + TN) / (TP + TN + FP + FN);
        
        class_report(i).Class = class;
        class_report(i).Precision = precision;
        class_report(i).Recall = recall;
        class_report(i).F1_Score = f1_score;
        class_report(i).Specificity = specificity;
        class_report(i).Accuracy = accuracy;
        class_report(i).Support = sum(y_true == class);
    end
end

function display_classification_report(class_report)
% 显示分类报告
    fprintf('\n分类详细报告:\n');
    fprintf('%-10s %-10s %-10s %-10s %-10s %-10s %-10s\n', ...
        'Class', 'Precision', 'Recall', 'F1-Score', 'Specificity', 'Accuracy', 'Support');
    fprintf('%-10s %-10s %-10s %-10s %-10s %-10s %-10s\n', ...
        '-----', '---------', '------', '--------', '----------', '--------', '-------');
    
    for i = 1:length(class_report)
        cr = class_report(i);
        fprintf('%-10d %-10.3f %-10.3f %-10.3f %-10.3f %-10.3f %-10d\n', ...
            cr.Class, cr.Precision, cr.Recall, cr.F1_Score, ...
            cr.Specificity, cr.Accuracy, cr.Support);
    end
end

2. 交叉验证与参数优化

matlab 复制代码
function [best_model, best_params, cv_results] = svm_hyperparameter_tuning(X, y, varargin)
% SVM超参数调优
% 使用交叉验证寻找最佳参数

    p = inputParser;
    addParameter(p, 'cv_folds', 5, @isnumeric);
    addParameter(p, 'kernel_types', {'linear', 'rbf'}, @iscell);
    addParameter(p, 'c_values', 2.^(-5:2:15), @isnumeric);
    addParameter(p, 'gamma_values', 2.^(-15:2:3), @isnumeric); % 用于RBF核
    addParameter(p, 'verbose', true, @islogical);
    
    parse(p, varargin{:});
    params = p.Results;
    
    if params.verbose
        fprintf('开始SVM超参数调优...\n');
        fprintf('交叉验证折数: %d\n', params.cv_folds);
    end
    
    % 准备参数组合
    param_combinations = [];
    idx = 1;
    
    for k = 1:length(params.kernel_types)
        kernel = params.kernel_types{k};
        
        switch kernel
            case 'linear'
                for c_idx = 1:length(params.c_values)
                    param_combinations(idx).Kernel = kernel;
                    param_combinations(idx).BoxConstraint = params.c_values(c_idx);
                    param_combinations(idx).KernelScale = 1;
                    idx = idx + 1;
                end
                
            case 'rbf'
                for c_idx = 1:length(params.c_values)
                    for g_idx = 1:length(params.gamma_values)
                        param_combinations(idx).Kernel = kernel;
                        param_combinations(idx).BoxConstraint = params.c_values(c_idx);
                        param_combinations(idx).KernelScale = 1/sqrt(2 * params.gamma_values(g_idx));
                        idx = idx + 1;
                    end
                end
        end
    end
    
    if params.verbose
        fprintf('参数组合总数: %d\n', length(param_combinations));
    end
    
    % 交叉验证
    cv_results = struct();
    best_accuracy = 0;
    
    for i = 1:length(param_combinations)
        param = param_combinations(i);
        
        if params.verbose && mod(i, 10) == 0
            fprintf('正在评估第 %d/%d 个参数组合...\n', i, length(param_combinations));
        end
        
        % 设置模板
        template = templateSVM(...
            'KernelFunction', param.Kernel, ...
            'BoxConstraint', param.BoxConstraint, ...
            'KernelScale', param.KernelScale, ...
            'Standardize', true);
        
        % 交叉验证
        cv_model = fitcecoc(X, y, ...
            'Learners', template, ...
            'KFold', params.cv_folds, ...
            'Verbose', false);
        
        % 计算准确率
        cv_accuracy = 1 - kfoldLoss(cv_model, 'LossFun', 'ClassifError');
        
        % 存储结果
        cv_results(i).Parameters = param;
        cv_results(i).CVAccuracy = cv_accuracy;
        cv_results(i).Std = std(cv_accuracy);
        
        % 更新最佳参数
        if cv_accuracy > best_accuracy
            best_accuracy = cv_accuracy;
            best_params = param;
            best_cv_model = cv_model;
        end
    end
    
    % 训练最终模型
    best_template = templateSVM(...
        'KernelFunction', best_params.Kernel, ...
        'BoxConstraint', best_params.BoxConstraint, ...
        'KernelScale', best_params.KernelScale, ...
        'Standardize', true);
    
    best_model = fitcecoc(X, y, ...
        'Learners', best_template, ...
        'Verbose', params.verbose);
    
    if params.verbose
        fprintf('超参数调优完成!\n');
        fprintf('最佳参数: Kernel=%s, C=%.4f, Scale=%.4f\n', ...
            best_params.Kernel, best_params.BoxConstraint, best_params.KernelScale);
        fprintf('最佳交叉验证准确率: %.2f%%\n', best_accuracy * 100);
    end
end

3. 可视化与结果分析

matlab 复制代码
function visualize_svm_results(svm_model, X_test, y_test, varargin)
% 可视化SVM分类结果

    p = inputParser;
    addParameter(p, 'feature_names', {}, @iscell);
    addParameter(p, 'class_names', {}, @iscell);
    addParameter(p, 'pca_visualize', true, @islogical);
    addParameter(p, 'plot_confusion', true, @islogical);
    addParameter(p, 'plot_roc', true, @islogical);
    
    parse(p, varargin{:});
    params = p.Results;
    
    % 预测
    y_pred = predict(svm_model, X_test);
    
    % 创建图形窗口
    figure('Position', [100, 100, 1400, 900]);
    
    % 1. 混淆矩阵
    if params.plot_confusion
        subplot(2, 3, 1);
        cm = confusionchart(y_test, y_pred);
        title('混淆矩阵');
        
        % 计算并显示准确率
        accuracy = sum(y_test == y_pred) / length(y_test);
        text(0.5, -0.1, sprintf('准确率: %.2f%%', accuracy * 100), ...
            'Units', 'normalized', 'HorizontalAlignment', 'center', ...
            'FontSize', 12, 'FontWeight', 'bold');
    end
    
    % 2. PCA可视化(如果特征维度>2)
    if params.pca_visualize && size(X_test, 2) > 2
        subplot(2, 3, 2);
        
        % PCA降维
        [coeff, score] = pca(X_test);
        X_pca = score(:, 1:2);
        
        % 绘制散点图
        gscatter(X_pca(:, 1), X_pca(:, 2), y_test);
        hold on;
        
        % 标记错误分类的点
        misclassified = y_test ~= y_pred;
        plot(X_pca(misclassified, 1), X_pca(misclassified, 2), 'kx', ...
            'MarkerSize', 10, 'LineWidth', 2, 'DisplayName', '错误分类');
        
        xlabel('第一主成分');
        ylabel('第二主成分');
        title('PCA可视化 (前两个主成分)');
        legend('show');
        grid on;
    end
    
    % 3. 特征重要性(仅对线性核)
    if isequal(svm_model.Info.Parameters.kernel, 'linear') && ~isempty(params.feature_names)
        subplot(2, 3, 3);
        
        % 提取权重(需要访问二分类器的权重)
        feature_weights = compute_feature_importance(svm_model);
        
        barh(feature_weights);
        yticks(1:length(params.feature_names));
        yticklabels(params.feature_names);
        xlabel('特征权重');
        title('特征重要性 (线性SVM)');
        grid on;
    end
    
    % 4. 准确率随训练样本数量的学习曲线
    subplot(2, 3, 4);
    plot_learning_curve(svm_model, X_test, y_test);
    
    % 5. 各类别性能指标
    subplot(2, 3, 5);
    class_report = compute_class_metrics(y_test, y_pred);
    
    metrics = zeros(length(class_report), 3);
    for i = 1:length(class_report)
        metrics(i, 1) = class_report(i).Precision;
        metrics(i, 2) = class_report(i).Recall;
        metrics(i, 3) = class_report(i).F1_Score;
    end
    
    bar(metrics);
    legend('精确率', '召回率', 'F1分数', 'Location', 'best');
    xlabel('类别');
    ylabel('分数');
    title('各类别性能指标');
    grid on;
    
    % 6. 模型比较(如果有多个模型)
    subplot(2, 3, 6);
    if isfield(svm_model.Info, 'ComparisonResults')
        plot_model_comparison(svm_model.Info.ComparisonResults);
    else
        % 显示模型信息
        info_str = sprintf('核函数: %s\nC参数: %.2f\n训练时间: %.2fs\n测试准确率: %.2f%%', ...
            svm_model.Info.Parameters.kernel, ...
            svm_model.Info.Parameters.box_constraint, ...
            svm_model.Info.TrainingTime, ...
            svm_model.Info.TestAccuracy * 100);
        
        text(0.1, 0.5, info_str, 'FontSize', 12, 'VerticalAlignment', 'middle');
        title('模型信息');
        axis off;
    end
    
    sgtitle('SVM多分类结果分析', 'FontSize', 14, 'FontWeight', 'bold');
end

function weights = compute_feature_importance(svm_model)
% 计算特征重要性(线性SVM)
    try
        % 获取所有二分类器
        binary_learners = svm_model.BinaryLearners;
        n_features = size(svm_model.X, 2);
        weights = zeros(n_features, 1);
        
        for i = 1:length(binary_learners)
            if isprop(binary_learners{i}, 'Beta')
                weights = weights + abs(binary_learners{i}.Beta);
            end
        end
        weights = weights / length(binary_learners);
    catch
        weights = [];
    end
end

function plot_learning_curve(svm_model, X_test, y_test)
% 绘制学习曲线
    X_train = svm_model.X;
    y_train = svm_model.Y;
    
    train_sizes = round(linspace(0.1 * length(y_train), length(y_train), 10));
    train_accuracies = zeros(size(train_sizes));
    test_accuracies = zeros(size(train_sizes));
    
    for i = 1:length(train_sizes)
        n_samples = train_sizes(i);
        
        % 随机选择子集
        idx = randperm(length(y_train), n_samples);
        X_subset = X_train(idx, :);
        y_subset = y_train(idx);
        
        % 训练模型
        temp_model = fitcecoc(X_subset, y_subset, ...
            'Learners', svm_model.BinaryLearners{1}.Template, ...
            'Coding', svm_model.CodingName);
        
        % 计算准确率
        train_pred = predict(temp_model, X_subset);
        test_pred = predict(temp_model, X_test);
        
        train_accuracies(i) = sum(train_pred == y_subset) / length(y_subset);
        test_accuracies(i) = sum(test_pred == y_test) / length(y_test);
    end
    
    plot(train_sizes, train_accuracies * 100, 'b-o', 'LineWidth', 2, 'DisplayName', '训练集');
    hold on;
    plot(train_sizes, test_accuracies * 100, 'r-o', 'LineWidth', 2, 'DisplayName', '测试集');
    xlabel('训练样本数量');
    ylabel('准确率 (%)');
    title('学习曲线');
    legend('show');
    grid on;
end

4. 完整使用示例

matlab 复制代码
% SVM多分类完整示例
function run_svm_multiclass_example()
% 运行SVM多分类示例

    fprintf('=== SVM多分类完整示例 ===\n\n');
    
    % 1. 生成示例数据(3类)
    rng(42); % 设置随机种子确保可重复性
    n_samples = 1000;
    n_features = 4;
    n_classes = 3;
    
    % 生成高斯分布数据
    [X, y] = generate_multiclass_data(n_samples, n_features, n_classes);
    
    fprintf('生成数据信息:\n');
    fprintf('样本数量: %d\n', n_samples);
    fprintf('特征数量: %d\n', n_features);
    fprintf('类别数量: %d\n', n_classes);
    fprintf('类别分布:\n');
    tabulate(y);
    
    % 2. 划分训练集和测试集
    cv = cvpartition(y, 'HoldOut', 0.3);
    X_train = X(training(cv), :);
    y_train = y(training(cv), :);
    X_test = X(test(cv), :);
    y_test = y(test(cv), :);
    
    fprintf('\n数据划分:\n');
    fprintf('训练集: %d 样本\n', size(X_train, 1));
    fprintf('测试集: %d 样本\n', size(X_test, 1));
    
    % 3. 超参数调优
    fprintf('\n=== 超参数调优 ===\n');
    [best_model, best_params, cv_results] = svm_hyperparameter_tuning(...
        X_train, y_train, ...
        'kernel_types', {'linear', 'rbf'}, ...
        'c_values', 2.^(-1:3), ...
        'gamma_values', 2.^(-5:0), ...
        'cv_folds', 5, ...
        'verbose', true);
    
    % 4. 使用最佳参数训练最终模型
    fprintf('\n=== 最终模型训练 ===\n');
    final_model = svm_multiclass_train_test(...
        X_train, y_train, X_test, y_test, ...
        'kernel', best_params.Kernel, ...
        'box_constraint', best_params.BoxConstraint, ...
        'kernel_scale', best_params.KernelScale, ...
        'standardize', true, ...
        'verbose', true);
    
    % 5. 可视化结果
    fprintf('\n=== 结果可视化 ===\n');
    feature_names = {'特征1', '特征2', '特征3', '特征4'};
    class_names = {'类别A', '类别B', '类别C'};
    
    visualize_svm_results(final_model, X_test, y_test, ...
        'feature_names', feature_names, ...
        'class_names', class_names, ...
        'pca_visualize', true, ...
        'plot_confusion', true, ...
        'plot_roc', true);
    
    % 6. 模型比较
    fprintf('\n=== 不同核函数比较 ===\n');
    compare_different_kernels(X_train, y_train, X_test, y_test);
    
    fprintf('\n示例运行完成!\n');
end

function [X, y] = generate_multiclass_data(n_samples, n_features, n_classes)
% 生成多分类示例数据
    X = zeros(n_samples, n_features);
    y = zeros(n_samples, 1);
    
    samples_per_class = floor(n_samples / n_classes);
    
    for i = 1:n_classes
        start_idx = (i-1) * samples_per_class + 1;
        end_idx = i * samples_per_class;
        
        % 为每个类别生成不同的均值和协方差
        mean_val = (i-1) * 2;
        covariance = eye(n_features) * (0.5 + i*0.2);
        
        X_class = mvnrnd(repmat(mean_val, 1, n_features), covariance, samples_per_class);
        X(start_idx:end_idx, :) = X_class;
        y(start_idx:end_idx) = i;
    end
    
    % 随机打乱数据
    idx = randperm(n_samples);
    X = X(idx, :);
    y = y(idx);
end

function compare_different_kernels(X_train, y_train, X_test, y_test)
% 比较不同核函数的性能
    kernels = {'linear', 'rbf', 'polynomial'};
    results = struct();
    
    fprintf('%-12s %-12s %-12s %-12s\n', 'Kernel', 'Train Acc', 'Test Acc', 'Time(s)');
    fprintf('%-12s %-12s %-12s %-12s\n', '------', '---------', '--------', '-------');
    
    for i = 1:length(kernels)
        tic;
        [model, test_acc, ~] = svm_multiclass_train_test(...
            X_train, y_train, X_test, y_test, ...
            'kernel', kernels{i}, ...
            'box_constraint', 1, ...
            'standardize', true, ...
            'verbose', false);
        time_elapsed = toc;
        
        train_acc = model.Info.TrainAccuracy;
        
        fprintf('%-12s %-12.2f %-12.2f %-12.2f\n', ...
            kernels{i}, train_acc*100, test_acc*100, time_elapsed);
        
        results(i).Kernel = kernels{i};
        results(i).TrainAccuracy = train_acc;
        results(i).TestAccuracy = test_acc;
        results(i).Time = time_elapsed;
    end
    
    % 可视化比较
    figure('Position', [200, 200, 1000, 400]);
    
    subplot(1, 2, 1);
    accuracies = [results.TestAccuracy] * 100;
    bar(accuracies);
    set(gca, 'XTickLabel', {results.Kernel});
    ylabel('测试准确率 (%)');
    title('不同核函数性能比较');
    grid on;
    
    subplot(1, 2, 2);
    times = [results.Time];
    bar(times);
    set(gca, 'XTickLabel', {results.Kernel});
    ylabel('训练时间 (秒)');
    title('训练时间比较');
    grid on;
end

参考代码 支持向量机,用于分类,含训练集与测试集,用SVM进行多分类 www.3dddown.com/csa/64920.html

使用说明

1. 基本使用方法

matlab 复制代码
% 加载您的数据
% load('your_data.mat'); % X, y

% 简单调用
[model, accuracy, cm] = svm_multiclass_train_test(X_train, y_train, X_test, y_test);

% 自定义参数
[model, accuracy, cm] = svm_multiclass_train_test(X_train, y_train, X_test, y_test, ...
    'kernel', 'rbf', ...
    'box_constraint', 10, ...
    'kernel_scale', 'auto', ...
    'standardize', true);

2. 超参数调优

matlab 复制代码
% 自动寻找最佳参数
[best_model, best_params] = svm_hyperparameter_tuning(X, y, ...
    'kernel_types', {'linear', 'rbf'}, ...
    'c_values', [0.1, 1, 10, 100], ...
    'cv_folds', 5);

3. 关键参数说明

  • kernel: 核函数选择

    • linear: 线性核,适合线性可分数据
    • rbf: 高斯核,适合非线性数据
    • polynomial: 多项式核
  • box_constraint: 惩罚参数C

    • 较小值:允许更多误分类,模型更简单
    • 较大值:严格惩罚误分类,可能过拟合
  • kernel_scale: 核尺度

    • 较小值:复杂的决策边界
    • 较大值:平滑的决策边界
相关推荐
十三画者2 小时前
【文献分享】vConTACT3机器学习能够实现可扩展且系统的病毒分类体系的构建
人工智能·算法·机器学习·数据挖掘·数据分析
wfeqhfxz25887822 小时前
基于YOLOv10n的热带海洋蝴蝶鱼物种识别与分类系统_P3456数据集训练_1
yolo·分类·数据挖掘
爱看科技2 小时前
微美全息(WIMI.US)突破性精简经典-量子混合神经网络模型助力图像智能分类
人工智能·神经网络·分类
serve the people2 小时前
TensorFlow 2.0 手写数字分类教程之SparseCategoricalCrossentropy 核心原理(三)
人工智能·分类·tensorflow
Boll09660002 小时前
开关柜设备状态识别与分类_YOLO11_C3k2_RetBlock实现
人工智能·分类·数据挖掘
serve the people2 小时前
TensorFlow 2.0 手写数字分类教程之SparseCategoricalCrossentropy 核心原理(四)
人工智能·分类·tensorflow
TrueFurina(互关互赞)2 小时前
7-4 区间水仙花数 Python程序设计-MJU实验四(编程入门•多代码实现•测试均通过)
数据结构·算法·飞书·创业创新·学习方法·远程工作·改行学it
Amnesia0_02 小时前
Map和Set
算法
受伤的僵尸2 小时前
算法类复习(1)-非自注意力机制(图像处理中的注意力)
人工智能·算法