【机器学习】机器学习的基本分类-监督学习-决策树-CART(Classification and Regression Tree)

CART(Classification and Regression Tree)

CART(分类与回归树)是一种用于分类和回归任务的决策树算法,提出者为 Breiman 等人。它的核心思想是通过二分法 递归地将数据集划分为子集,从而构建一棵树。CART 算法既可以生成分类树 ,也可以生成回归树


1. CART 的特点

  1. 二叉树结构:CART 始终生成二叉树,每个节点只有两个分支(左子树和右子树)。
  2. 分裂标准不同
    • 对于分类任务,CART 使用**基尼指数(Gini Index)**作为分裂标准。
    • 对于回归任务,CART 使用**最小均方误差(MSE)**作为分裂标准。
  3. 支持剪枝:通过后剪枝减少过拟合。
  4. 处理连续和离散数据:支持连续特征的划分点选择。

2. CART 的基本流程

  1. 输入:训练数据集 D,目标变量类型(分类或回归)。
  2. 递归分裂
    • 按照基尼指数(分类)或均方误差(回归)选择最佳划分点。
    • 对数据集划分为两个子集,递归构造子树。
  3. 停止条件
    • 节点样本数量小于阈值。
    • 划分后不再能显著降低误差。
  4. 剪枝
    • 通过校验集性能优化,剪去不显著的分支。
  5. 输出:最终的二叉决策树。

3. 分类树

(1) 基尼指数

基尼指数(Gini Index)用于衡量一个节点的"纯度",越小表示越纯:

其中:

  • :类别 k 的样本数量。
  • K:类别的总数。

节点分裂的基尼指数计算为:

最佳划分点 是使 最小的特征和对应的划分点。


(2) 示例:分类树

数据集

天气 温度 湿度 风力 是否运动
晴天 30
晴天 32
阴天 28
雨天 24 正常
雨天 20 正常
  1. 计算每个特征的基尼指数

    • 对离散特征(如天气),分别计算不同类别划分后的基尼指数。
    • 对连续特征(如温度),尝试所有划分点,计算每个划分点的基尼指数。
  2. 选择最优特征和划分点

    • 选择基尼指数最小的划分点。
  3. 生成子树

    • 对每个子集递归分裂,直到满足停止条件。

4. 回归树

(1) 分裂标准

对于回归任务,CART 使用**均方误差(MSE)**作为分裂标准:

其中:

  • :第 i 个样本的目标值。
  • :节点中所有样本目标值的均值。

节点分裂的误差计算为:

最佳划分点 是使 最小的特征和对应的划分点。


(2) 示例:回归树

假设我们有如下数据集(目标值为房价):

面积(平方米) 房价(万元)
50 150
60 180
70 210
80 240
90 270
  1. 尝试划分点

    • 例如,划分点为 656565。
    • 左子集:{50,60},右子集:{70, 80, 90}。
  2. 计算误差

    • 左子集的均值:
    • 右子集的均值:
    • 计算分裂后的总均方误差。
  3. 选择最佳划分点

    • 选择误差最小的划分点,继续构造子树。

5. 剪枝

CART 使用后剪枝来防止过拟合:

  1. 生成完全生长的决策树

  2. 计算子树的损失函数

    其中:

    • :第 i 个叶子节点。
    • :叶子节点的数量。
    • α:正则化参数,控制树的复杂度。
  3. 剪去对验证集性能提升不大的分支


6. CART 的优缺点

优点
  1. 生成二叉树,逻辑清晰,易于实现。
  2. 支持分类和回归任务。
  3. 支持连续特征和缺失值处理。
  4. 剪枝机制增强了泛化能力。
缺点
  1. 易受数据噪声影响,可能生成复杂的树。
  2. 对高维数据表现一般,无法处理稀疏特征。
  3. 生成的边界是轴对齐的,可能不适用于复杂分布。

7. 与其他决策树算法的比较

特点 ID3 C4.5 CART
划分标准 信息增益 信息增益比 基尼指数 / MSE
支持连续特征
树结构 多叉树 多叉树 二叉树
剪枝 后剪枝 后剪枝
应用 分类 分类 分类与回归

8. 代码实现

以下是一个简单的 CART 分类树实现:

python 复制代码
import numpy as np

# 计算基尼指数
def gini_index(groups, classes):
    total_samples = sum(len(group) for group in groups)
    gini = 0.0
    for group in groups:
        size = len(group)
        if size == 0:
            continue
        score = 0.0
        for class_val in classes:
            proportion = [row[-1] for row in group].count(class_val) / size
            score += proportion ** 2
        gini += (1 - score) * (size / total_samples)
    return gini

# 划分数据集
def split_data(data, index, value):
    left, right = [], []
    for row in data:
        if row[index] < value:
            left.append(row)
        else:
            right.append(row)
    return left, right

# 示例数据
dataset = [
    [2.771244718, 1.784783929, 0],
    [1.728571309, 1.169761413, 0],
    [3.678319846, 2.81281357, 0],
    [3.961043357, 2.61995032, 0],
    [2.999208922, 2.209014212, 1],
]

# 计算基尼指数
split = split_data(dataset, 0, 2.5)
gini = gini_index(split, [0, 1])
print("基尼指数:", gini)

输出结果

bash 复制代码
基尼指数: 0.30000000000000004

CART 是机器学习中非常经典的算法,同时也是随机森林、梯度提升决策树等模型的基础。

相关推荐
知识分享小能手21 小时前
React学习教程,从入门到精通, React 属性(Props)语法知识点与案例详解(14)
前端·javascript·vue.js·学习·react.js·vue·react
Christo31 天前
TFS-2018《On the convergence of the sparse possibilistic c-means algorithm》
人工智能·算法·机器学习·数据挖掘
非门由也1 天前
《sklearn机器学习——管道和复合估计器》回归中转换目标
机器学习·回归·sklearn
茯苓gao1 天前
STM32G4 速度环开环,电流环闭环 IF模式建模
笔记·stm32·单片机·嵌入式硬件·学习
是誰萆微了承諾1 天前
【golang学习笔记 gin 】1.2 redis 的使用
笔记·学习·golang
DKPT1 天前
Java内存区域与内存溢出
java·开发语言·jvm·笔记·学习
aaaweiaaaaaa1 天前
HTML和CSS学习
前端·css·学习·html
看海天一色听风起雨落1 天前
Python学习之装饰器
开发语言·python·学习
小憩-1 天前
【机器学习】吴恩达机器学习笔记
人工智能·笔记·机器学习
THMAIL1 天前
深度学习从入门到精通 - 生成对抗网络(GAN)实战:创造逼真图像的魔法艺术
人工智能·python·深度学习·神经网络·机器学习·生成对抗网络·cnn