决策树:机器学习中的“智慧树”

在机器学习的广阔森林中,决策树Decision Tree)是一棵独特而强大的**"智慧树"**。

它是一种监督学习算法,既可以用于分类任务,也能用于回归任务,通过树形结构模拟人类决策过程。

这篇文章会带你了解决策树,从基础概念开始,一步步讲解如何构建决策树、常用的算法以及它的实际应用。

1. 概述

决策树Decision Tree)作为机器学习中的一种经典的监督学习算法,通过树形结构模拟人类决策过程。

它既可以处理分类问题(如判断邮件是否为垃圾邮件),也能解决回归问题(如预测房价)。

决策树通过一系列的决策规则,将数据划分到不同的类别或预测目标值。

它的结构就像一棵倒立的树,顶部是根节点,代表整个数据集,然后通过一系列的特征判断,不断分叉出分支,最终到达叶节点,每个叶节点代表一个决策结果。

决策树 的最大优点是可解释性强。它的决策过程就像人类的思考过程一样,通过一系列的"如果......那么......"的规则来做出判断。

例如,在一个简单的贷款审批决策树中,它可能会先判断申请人的收入是否高于某个阈值,如果高于,再判断信用记录是否良好,最终决定是否批准贷款。

这种直观的决策过程使得决策树在许多需要可解释性的场景中非常受欢迎,比如医疗诊断、金融风险评估等领域。

2. 基本流程

决策树的基本流程有4步:

  1. 特征选择:这是决策树构建的起点。我们需要从众多特征中选择一个最能区分数据的特征作为当前节点的划分依据。
  2. 创建节点:根据所选特征,创建一个决策节点。这个节点会根据特征的不同取值将数据划分为若干子集。
  3. 递归划分:对于每个子集,重复上述过程,继续选择特征并创建节点,直到满足停止条件。
  4. 剪枝可选):为了避免决策树过于复杂而导致过拟合,通常需要对树进行剪枝。

其中,关于剪枝 的步骤,后续其他文章中再详细介绍,本文主要讨论前3个步骤。

前3个步骤中,最重要的就是第一个步骤特征选择,也就是如何在决策树的每次分叉时划分数据集。

3. 划分数据集常用算法

生成决策树的过程中,划分数据集的常用算法主要有3个。

3.1. 基于信息增益(ID3)

信息增益是基于信息论的概念。它衡量的是使用某个特征进行划分后,数据的不确定性减少了多少。

具体来说,信息增益计算的是划分前数据的熵(表示数据的混乱程度)与划分后各子集熵的加权平均值之差。

如果一个特征的信息增益越大,说明使用这个特征划分后,数据的不确定性减少得越多,这个特征就越有价值。

它的计算公式: \\text{Gain}(D,A) = H(D) - \\sum_{v \\in Values(A)} \\frac{\|D\^v\|}{\|D\|} H(D\^v)

其中 H(D) = -\\sum p_i \\log p_i 为信息熵。

通过scikit-learn库使用它非常简单:

python 复制代码
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 使用信息增益构建决策树
clf = DecisionTreeClassifier(criterion='entropy', random_state=42)
clf.fit(X_train, y_train)

# 预测并计算准确率
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'基于信息增益的决策树准确率:{accuracy}')

## 输出结果:
'''
基于信息增益的决策树准确率:0.9777777777777777
'''

这里参数criterion='entropy'就表示使用ID3算法来划分决策树。

3.2. 基于增益率(C4.5)

增益率是信息增益的一个改进版本,它考虑了特征划分后产生的分支数量对信息增益的影响。

因为如果一个特征有很多取值,那么即使它划分后的熵减少得不多,也可能得到一个较大的信息增益,这可能会导致决策树偏向于选择取值多的特征。

增益率通过将信息增益除以该特征划分产生的信息熵(称为分裂信息),来修正这种偏向。

它的公式是: \\text{Gain_Ratio}(D,A) = \\frac{\\text{Gain}(D,A)}{\\text{IV}(A)}

其中 \\text{IV}(A) = -\\sum_{v} \\frac{\|D\^v\|}{\|D\|} \\log \\frac{\|D\^v\|}{\|D\|}

scikit-learn 中,没有直接实现增益率的选项,但可以通过调整参数来近似实现。

C4.5算法在ID3算法的基础上加入了增益率的概念,同时支持连续值的处理。

python 复制代码
# 使用 C4.5 算法(近似实现)
clf_c45 = DecisionTreeClassifier(criterion='entropy', splitter='best', random_state=42)
clf_c45.fit(X_train, y_train)

# 预测并计算准确率
y_pred_c45 = clf_c45.predict(X_test)
accuracy_c45 = accuracy_score(y_test, y_pred_c45)
print(f'基于增益率的决策树准确率:{accuracy_c45}')

## 输出结果:
'''
基于信息增益的决策树准确率:0.9777777777777777
'''

这里近似的实现方法得到的准确率和ID3算法一样。

3.3. 基于基尼系数(CART)

基尼指数是衡量数据不纯度的另一种方法,它的计算相对简单,基尼指数越小,表示数据的纯度越高。

对于一个特征,我们计算按照该特征划分后各子集的基尼指数加权平均值,然后选择基尼指数降低最多的特征作为划分特征。

基尼指数更倾向于选择那些能够将数据划分得更"纯净"的特征。

它的计算公式是: \\text{Gini}(D) = 1 - \\sum p_i\^2

选择使基尼指数最小的特征进行划分。

实现示例如下:

python 复制代码
# 使用基尼指数构建决策树
clf_cart = DecisionTreeClassifier(criterion='gini', random_state=42)
clf_cart.fit(X_train, y_train)

# 预测并计算准确率
y_pred_cart = clf_cart.predict(X_test)
accuracy_cart = accuracy_score(y_test, y_pred_cart)
print(f'基于基尼指数的决策树准确率:{accuracy_cart}')

## 输出结果:
'''
基于信息增益的决策树准确率:1.0
'''

从结果来看,CART算法的准确率比前两种要高。

4. 不同算法的比较

这三种算法各有优缺点和适用场景,使用时根据实际的数据情况来选择。

算法 优点 缺点 适用场景
信息增益 可解释性强 偏向于选择取值多的特征 特征取值较少的数据集
增益率 抑制过拟合 计算相对复杂 类别较多的分类任务
基尼指数 计算效率高 某些情况下可能会过于敏感,导致过拟合 大规模数据/需要快速训练的场景

5. 总结

决策树是一种强大而直观的机器学习算法,它通过一系列的决策规则来对数据进行分类或回归。

信息增益增益率基尼指数是三种常用的特征选择标准,它们各有优缺点,适用于不同的应用场景。

通过scikit-learn这个强大的机器学习库,我们可以轻松地实现基于这些标准的决策树模型,并应用于实际问题中。

相关推荐
JovaZou13 小时前
n8n 本地部署及实践应用,实现零成本自动化运营 Telegram 频道(保证好使)
运维·人工智能·docker·ai·自然语言处理·自动化·llama
晨航14 小时前
PromptPro|提示词生成和管理专家
人工智能·ai·aigc
weixin_4578858218 小时前
风暴之眼:在AI重构的数字世界重绘职业坐标系
人工智能·搜索引擎·ai·重构
Elastic 中国社区官方博客19 小时前
Elasticsearch:使用稀疏向量提升相关性
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
晨航21 小时前
DeepSeek轻松入门教程——从入门到精通
人工智能·ai·aigc
wang_yb21 小时前
『Plotly实战指南』--雷达图绘制与应用
plotly·databook
才思喷涌的小书虫2 天前
学术分享:基于 ARCADE 数据集评估 Grounding DINO、YOLO 和 DINO 在血管狭窄检测中的效果
人工智能·yolo·目标检测·计算机视觉·ai·语言模型·视觉检测
惊鸿Randy2 天前
AI模型多阶段调用进度追踪系统设计文档
java·spring boot·ai·ai编程
宝桥南山2 天前
Model Context Protocol (MCP) - 尝试创建和测试一下MCP Server
microsoft·ai·微软·c#·.net·.net core