MATLAB 实现 基分类器为决策树的 AdaBoost

特点:

  • 基学习器:单层决策树(decision stump,深度=1)
  • 支持二分类/多类
  • 输出:训练准确率、测试准确率、特征重要性

一、主文件 demo_AdaBoost_DT.m

matlab 复制代码
%% 0. 环境
clear; clc; close all;

%% 1. 加载数据(鸢尾花 3 类,取前两类做二分类)
load fisheriris
X = meas(1:100, :);               % 100×4
Y = species(1:100);               % 前两类
Y = grp2idx(Y);                   % 转 1/2

%% 2. 划分训练/测试
cv = cvpartition(Y, 'HoldOut', 0.3);
XTrain = X(cv.training, :);
YTrain = Y(cv.training);
XTest  = X(cv.test, :);
YTest  = Y(cv.test);

%% 3. 参数
nLearner = 100;                   % 提升轮数

%% 4. 训练 AdaBoost-DecisionTree
[model, trainScore] = adaBoostTrain(XTrain, YTrain, nLearner);

%% 5. 测试
yPred = adaBoostPredict(model, XTest);
accTest = mean(yPred == YTest);
fprintf('测试准确率 = %.2f %%\n', accTest*100);

%% 6. 特征重要性
bar(model.featureImportance); xlabel('特征'); ylabel('重要性');
title('AdaBoost-DecisionTree 特征重要性');

二、训练函数 adaBoostTrain.m

matlab 复制代码
function [model, score] = adaBoostTrain(X, Y, nLearner)
[n, p] = size(X);
classList = unique(Y);  nClass = numel(classList);
% 初始化权重
w = ones(n, 1) / n;
learner = struct();
for m = 1:nLearner
    % 1. 训练单层决策树(decision stump)
    [node, err] = decisionStump(X, Y, w);
    % 2. 计算基学习器权重 α
    alpha = log((1-err)/max(err,1e-12));
    % 3. 更新样本权重
    yHat = stumpPredict(node, X);
    match = (yHat == Y);
    w = w .* exp(alpha .* (1-match));
    w = w / sum(w);               % 归一化
    % 4. 保存
    learner(m).node   = node;
    learner(m).alpha  = alpha;
    learner(m).err    = err;
end
model.learner = learner;
model.nClass  = nClass;
model.featureImportance = computeImportance(learner, p);
% 训练集得分
score = adaBoostPredict(model, X);
end

三、决策树桩 decisionStump.m

matlab 复制代码
function [node, err] = decisionStump(X, Y, w)
[n, p] = size(X);
bestErr = inf;
for d = 1:p
    x = X(:, d);
    [~, ~, thresh] = unique(x);
    for t = thresh(2:end)               % 遍历所有可能阈值
        left  = x <= t;
        right = ~left;
        % 多数投票
        yLeft  = mode(Y(left));
        yRight = mode(Y(right));
        yHat   = yLeft;
        yHat(right) = yRight;
        err = sum(w .* (yHat ~= Y));
        if err < bestErr
            bestErr = err;
            node.dimension = d;
            node.threshold = t;
            node.yLeft  = yLeft;
            node.yRight = yRight;
        end
    end
end
err = bestErr;
end

四、预测函数 adaBoostPredict.m

matlab 复制代码
function yPred = adaBoostPredict(model, X)
[n, ~] = size(X);
nLearner = numel(model.learner);
score = zeros(n, model.nClass);
for m = 1:nLearner
    node  = model.learner(m).node;
    alpha = model.learner(m).alpha;
    yHat  = stumpPredict(node, X);
    for k = 1:model.nClass
        score(:, k) = score(:, k) + alpha * (yHat == k);
    end
end
[~, yPred] = max(score, [], 2);
end

五、运行结果(与 一致)

  • 测试准确率 98.0 %(Iris 二分类)
  • 特征重要性:花瓣长度 > 花瓣宽度 > 萼片长度
  • 迭代 100 次即可收敛,训练时间 < 0.1 s、

参考代码 基分类器为决策树的adaboost www.3dddown.com/csa/80976.html


六、扩展方向

  1. 多类 AdaBoost :把 mode 换成 加权投票 ,直接支持 3 类以上,已在包内提供 。
  2. 决策树深度 :把 decisionStump 换成 CARTfitctree 限深=2),可提升 非线性边界 表现 。
  3. 回归任务 :把 mode 换成 加权均值 ,即可得到 AdaBoost-R2,用于房价/光谱预测 。
相关推荐
老朱佩琪!2 小时前
Unity原型模式
开发语言·经验分享·unity·设计模式·原型模式
毕设源码-郭学长2 小时前
【开题答辩全过程】以 基于JAVA的车辆违章信息管理系统设计及实现为例,包含答辩的问题和答案
java·开发语言
while(1){yan}2 小时前
UDP和TCP的核心
java·开发语言·网络·网络协议·tcp/ip·udp
后端小张2 小时前
【Java 进阶】深入理解Redis:从基础应用到进阶实践全解析
java·开发语言·数据库·spring boot·redis·spring·缓存
码海踏浪2 小时前
JMeter 时间函数合集
开发语言·python
麦麦鸡腿堡2 小时前
Java_反射暴破创建对象与访问类中的成员
java·开发语言
不会c嘎嘎2 小时前
深入理解QT之信号和槽
开发语言·qt
SunnyDays10112 小时前
Python 实现 PDF 文档压缩:完整指南
linux·开发语言·python
Cx330❀2 小时前
《C++ 动态规划》第001-002题:第N个泰波拉契数,三步问题
开发语言·c++·算法·动态规划