机器学习——决策树节点生成算法

机器学习------决策树节点生成算法

决策树是一种常用的机器学习模型,它能够根据数据特征的不同进行分类或回归。决策树的关键在于节点的生成算法,不同的生成算法会影响决策树的结构和性能。本篇博客将介绍三种常用的决策树节点生成算法:ID3算法、C4.5算法和CART算法,包括详细的理论介绍、算法公式和Python实现,并对三种算法进行对比与总结。

1. ID3算法(Iterative Dichotomiser 3)

ID3算法是一种基于信息增益的决策树节点生成算法,由Ross Quinlan在1986年提出。它通过选择信息增益最大的特征来进行节点划分。

算法步骤:

  1. 若数据集属于同一类别,则将当前节点标记为叶节点,类别为该类别。
  2. 若特征集为空,则将当前节点标记为叶节点,类别为数据集中出现次数最多的类别。
  3. 计算每个特征的信息增益,选择信息增益最大的特征作为当前节点的划分特征。
  4. 根据选定的特征进行节点划分,生成子节点,并递归地对子节点进行以上步骤。

算法公式:

信息增益的计算公式为:

Gain ( D , A ) = H ( D ) − H ( D ∣ A ) \text{Gain}(D, A) = H(D) - H(D|A) Gain(D,A)=H(D)−H(D∣A)

其中, D D D是数据集, A A A是特征, H ( D ) H(D) H(D)是数据集 D D D的熵, H ( D ∣ A ) H(D|A) H(D∣A)是在已知特征 A A A的条件下,数据集 D D D的条件熵。

2. C4.5算法

C4.5算法是ID3算法的改进版本,由Ross Quinlan在1993年提出。相比于ID3算法,C4.5算法解决了ID3算法不能处理连续特征、样本缺失值和过拟合问题。

算法步骤:

  1. 若数据集属于同一类别或特征集为空,则将当前节点标记为叶节点,类别为该类别或数据集中出现次数最多的类别。
  2. 计算每个特征的信息增益率,选择信息增益率最大的特征作为当前节点的划分特征。
  3. 根据选定的特征进行节点划分,生成子节点,并递归地对子节点进行以上步骤。

连续特征离散化:

C4.5算法通过二分法对连续特征进行离散化,选择最佳的划分点作为该特征的分裂点。

信息增益率:

信息增益率的计算公式为:

Gain_ratio ( D , A ) = Gain ( D , A ) H A ( D ) \text{Gain\_ratio}(D, A) = \frac{\text{Gain}(D, A)}{H_A(D)} Gain_ratio(D,A)=HA(D)Gain(D,A)

其中, H A ( D ) H_A(D) HA(D)是特征 A A A的熵。

概率权重方法处理样本缺失值:

C4.5算法采用概率权重方法处理样本缺失值,将缺失值样本按照各类别出现的概率进行赋权。

通过剪枝解决过拟合问题:

C4.5算法采用后剪枝(post-pruning)的方法来解决过拟合问题,通过剪枝操作来减小树的复杂度。

3. CART算法(Classification And Regression Tree)

CART算法既可以用于分类问题,也可以用于回归问题,是一种十分灵活的决策树生成算法,由Breiman等人在1984年提出。

算法步骤:

  1. 若数据集属于同一类别或特征集为空,则将当前节点标记为叶节点,类别为该类别或数据集中出现次数最多的类别;
  2. 针对每个特征,计算每个可能的分裂点,选择使得划分后的数据集基尼指数最小的特征和切分点;
  3. 根据选定的特征和切分点进行节点划分,生成子节点,并递归地对子节点进行以上步骤。

基尼指数:

基尼指数是一种衡量数据集不纯度的指标,定义为:

G i n i ( D ) = 1 − ∑ k = 1 ∣ Y ∣ ( p k ) 2 Gini(D) = 1 - \sum_{k=1}^{|\mathcal{Y}|} (p_k)^2 Gini(D)=1−k=1∑∣Y∣(pk)2

其中, ∣ Y ∣ |\mathcal{Y}| ∣Y∣是类别的个数, p k p_k pk是数据集 D D D中属于类别 k k k的样本的比例。

Python实现:

接下来,让用Python实现以上三种算法,并在相同数据集上进行比较。将使用DecisionTreeClassifier类来实现ID3算法和C4.5算法,以及DecisionTreeRegressor类来实现CART算法。

python 复制代码
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, mean_squared_error
import matplotlib.pyplot as plt

# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 构建ID3决策树模型
id3_clf = DecisionTreeClassifier(criterion='entropy', random_state=42)
id3_clf.fit(X_train, y_train)

# 构建C4.5决策树模型
c45_clf = DecisionTreeClassifier(criterion='entropy', random_state=42)
c45_clf.fit(X_train, y_train)

# 构建CART决策树模型
cart_clf = DecisionTreeRegressor(random_state=42)
cart_clf.fit(X_train, y_train)

# 绘制ID3决策树可视化图形
plt.figure(figsize=(12, 8))
plot_tree(id3_clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names)
plt.title("ID3 Decision Tree")
plt.show()

# 绘制C4.5决策树可视化图形
plt.figure(figsize=(12, 8))
plot_tree(c45_clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names)
plt.title("C4.5 Decision Tree")
plt.show()

# 绘制CART决策树可视化图形
plt.figure(figsize=(12, 8))
plot_tree(cart_clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names)
plt.title("CART Decision Tree")
plt.show()



通过以上代码,分别构建了ID3、C4.5和CART三种决策树模型,并在相同的数据集上进行了训练和评估。最后,绘制了ID3算法生成的决策树的可视化图形。

总结

通过以上代码实现和对比,可以得出以下结论:

  1. ID3算法:ID3算法是一种基于信息增益的决策树节点生成算法,它简单易懂,但不能处理连续特征和样本缺失值,且对于类别较多的特征容易产生过拟合。

  2. C4.5算法:C4.5算法是ID3算法的改进版,它解决了ID3算法的不足,可以处理连续特征和样本缺失值,同时引入了信息增益率准则和剪枝操作来降低过拟合风险。

  3. CART算法:CART算法既可以用于分类问题,也可以用于回归问题,具有更广泛的适用性。CART算法采用基尼指数来选择特征和切分点,生成的树更加简洁,但它也容易过拟合。

通过对比三种算法在相同数据集上的性能表现,可以选择适合具体问题的决策树生成算法。如果需要处理连续特征和样本缺失值,且希望避免过拟合问题,可以选择C4.5算法;如果需要同时处理分类和回归问题,并且对树的简洁性要求较高,可以选择CART算法。

总的来说,决策树是一种简单且有效的机器学习模型,它易于理解和解释,适用于各种类型的数据。在实际应用中,需要根据具体问题的特点和数据集的情况选择合适的决策树生成算法,并通过调参和剪枝等方法来优化模型性能。

相关推荐
林开落L8 分钟前
前缀和算法习题篇(上)
c++·算法·leetcode
远望清一色9 分钟前
基于MATLAB边缘检测博文
开发语言·算法·matlab
tyler_download11 分钟前
手撸 chatgpt 大模型:简述 LLM 的架构,算法和训练流程
算法·chatgpt
封步宇AIGC17 分钟前
量化交易系统开发-实时行情自动化交易-3.4.1.2.A股交易数据
人工智能·python·机器学习·数据挖掘
m0_5236742118 分钟前
技术前沿:从强化学习到Prompt Engineering,业务流程管理的创新之路
人工智能·深度学习·目标检测·机器学习·语言模型·自然语言处理·数据挖掘
SoraLuna31 分钟前
「Mac玩转仓颉内测版7」入门篇7 - Cangjie控制结构(下)
算法·macos·动态规划·cangjie
我狠狠地刷刷刷刷刷34 分钟前
中文分词模拟器
开发语言·python·算法
鸽鸽程序猿34 分钟前
【算法】【优选算法】前缀和(上)
java·算法·前缀和
九圣残炎40 分钟前
【从零开始的LeetCode-算法】2559. 统计范围内的元音字符串数
java·算法·leetcode
YSRM1 小时前
Experimental Analysis of Dedicated GPU in Virtual Framework using vGPU 论文分析
算法·gpu算力·vgpu·pci直通