《机器学习》——决策树

文章目录

决策树

决策树简介

  • 决策树是一种基于树结构(包括根节点、内部节点、叶节点)进行决策的机器学习算法。它通过将数据集逐步划分成不同的子集,每个节点代表一个特征或者属性上的测试,分支代表测试输出,叶节点代表类别或者值。

决策树结构

  • 根节点
    • 决策树从根节点开始,根节点包含整个训练数据集。例如,在一个判断水果是苹果还是橙子的决策树中,根节点可能包含所有待分类的水果样本。
  • 特征选择与划分
    • 然后在每个节点(从根节点开始)选择一个最佳的特征来划分数据集。这个最佳特征的选择通常基于一些指标,如信息增益、信息增益比或基尼指数。
    • 以判断水果为例,假设我们有颜色、形状、大小等特征。如果颜色这个特征对于区分苹果和橙子最有效(比如苹果大多是红色,橙子大多是橙色),那么根节点就根据颜色来划分数据集。红色的水果样本划分到一个分支,橙色的划分到另一个分支。
  • 内部节点与分支
    • 内部节点继续按照选定的特征进行划分。比如在红色水果分支中,我们可以继续用形状这个特征来划分,圆形的可能更倾向于苹果,椭圆形的可能需要进一步判断。
  • 叶子节点
    • 当一个节点中的样本都属于同一类别或者满足某个终止条件(如节点中的样本数量小于某个阈值、信息增益小于某个阈值等)时,该节点就成为叶节点。叶节点代表最终的分类结果。在水果分类的例子中,叶节点可能就是 "苹果" 或者 "橙子"。

决策树的主要算法

1、ID3算法

  • 它使用信息增益来选择特征。信息增益衡量了一个特征对于分类结果的确定性增加程度。例如,在一个邮件分类问题(分为垃圾邮件和非垃圾邮件)中,若某个特征(如邮件主题中是否包含 "促销" 字样)能够使分类的不确定性大大降低,那么这个特征就具有较高的信息增益。
  • 缺点是它倾向于选择具有较多取值的特征,容易导致过拟合。

2、C4.5算法

  • 是 ID3 算法的改进,它使用信息增益比来选择特征。信息增益比在信息增益的基础上考虑了特征自身的熵,克服了 ID3 算法偏向选择取值较多特征的问题。
  • 信息增益率=信息增益/自身熵值

CART决策树

  • 既可以用于分类也可以用于回归任务。在分类任务中,它使用基尼指数来选择特征。基尼指数衡量了从数据集中随机抽取两个样本,其类别标记不一致的概率。在构建决策树时,选择基尼指数最小的特征作为划分特征,使得划分后的子数据集尽可能纯净(即同一类别样本占比较高)。

决策树剪枝

  • 决策树剪枝是一种防止决策树过拟合的技术。过拟合是指模型在训练数据上表现很好,但在测试数据或新的数据上表现不佳的情况。通过剪枝,可以简化决策树的结构,提高模型的泛化能力。
  • 预剪枝和后剪枝对比来说,预剪枝使得决策树的构建过程相对简单,因为它可以避免生成一些过于复杂的子树,从而减少计算资源的消耗和训练时间,而后剪枝通常能够得到比预剪枝更好的结果,因为它是在完整的决策树基础上进行优化,能够更好地权衡模型的复杂度和准确性。

预剪枝

  • 预剪枝是在决策树构建过程中,在节点划分之前先进行估计,如果当前节点划分不能带来决策树泛化性能的提升,就停止划分并将当前节点标记为叶节点。
  • 预剪枝策略
    • 1、限制树的深度
    • 2、限制叶子节点的个数以及叶子节点的样本数
    • 3、基尼系数

后剪枝

  • 后剪枝是在决策树构建完成后,对已生成的决策树进行修剪。它从树的叶节点开始,逐步向上判断每个节点对应的子树是否可以被剪枝。
  • 一种常见的后剪枝方法是代价复杂度剪枝(Cost - Complexity Pruning)。它定义了一个代价函数,这个代价函数考虑了树的错误率和复杂度。通过调整一个复杂度参数,来找到一个最优的剪枝后的树。

决策树模型

参数

核心参数
  • criterion
    • 用于衡量分裂质量的函数,可选值为 'mse'(均方误差,默认)、'friedman_mse' 或 'mae'(平均绝对误差)。
    • 'mse':通常用于回归树,计算的是均方误差,即预测值与真实值差的平方的平均值。它倾向于将数据集划分为使得各个子数据集的均方误差最小的部分。
    • 'friedman_mse':由 Friedman 提出的改进的均方误差,在一些情况下可以提供更好的分裂效果。
    • 'mae':计算平均绝对误差,是预测值与真实值差的绝对值的平均值,在某些场景下可能对异常值更鲁棒。
  • splitter
    • 确定每个节点的分裂策略,可选值为 'best'(默认)或 'random'。
    • 'best':选择最优的分裂策略,会遍历所有可能的特征和分裂点,选择最佳的组合,这样可以找到最能降低不纯度的分裂,但计算开销较大。
    • 'random':在随机的特征中选择最佳分裂点,可能会更快,并且在某些情况下可以防止过拟合,因为它不会像 'best' 那样严格寻找最优分裂。
  • max_depth
    • 树的最大深度。如果不设置(默认为 None),则树会尽可能地扩展,直到所有叶子节点都是纯的或者满足其他停止条件。
    • 例如,设置 max_depth = 3 意味着决策树最多有 3 层(不包括根节点),这样可以防止树过深导致的过拟合,适用于防止模型过于复杂,尤其在样本量不大时。
  • min_samples_split
    • 内部节点(非叶节点)分裂所需的最小样本数。默认为 2。
    • 例如,设置 min_samples_split = 10 表示如果一个节点包含的样本数小于 10,则不会对该节点进行分裂,这有助于防止过拟合,尤其是在数据量较小或噪声较大时,避免因为少量样本而产生不稳定的分裂。
  • min_samples_leaf
    • 叶节点所需的最小样本数。默认为 1。
    • 例如,设置 min_samples_leaf = 5 表示每个叶节点至少需要有 5 个样本,这样可以避免叶节点样本过少而导致的过拟合,使叶节点更具代表性。
其他参数
  • min_weight_fraction_leaf
    • 叶节点所需样本权重的最小比例,默认为 0。
    • 可以根据样本的权重(如果有)来控制叶节点的大小,例如,在某些样本的权重比其他样本更重要的情况下,确保叶节点包含足够的重要样本权重。
  • max_features
    • 寻找最佳分裂时考虑的特征数量,可取值为整数、浮点数、'auto'、'sqrt'、'log2' 或 None(默认)。
    • 'auto':等价于 'sqrt',表示考虑 sqrt(n_features) 个特征,其中 n_features 是特征的数量。
    • 'sqrt':考虑特征数量的平方根个特征。
    • 'log2':考虑 log2(n_features) 个特征。
    • 整数:考虑指定数量的特征。
    • 浮点数:考虑占总特征数一定比例的特征,例如,max_features = 0.5 表示考虑一半的特征。
    • 此参数可用于加快训练速度和防止过拟合,通过限制每次分裂时考虑的特征数量,减少计算量和避免树对训练数据的过度拟合。
停止条件相关参数
  • max_leaf_nodes
    • 最大叶节点数量。如果设置,决策树会尝试构建具有不超过该数量叶节点的树,通过控制叶节点数量来防止过拟合。
    • 例如,设置 max_leaf_nodes = 10 会使决策树的最终叶节点数量不超过 10 个。

决策树实例

  • 对银行数据进行决策树分类
  • 共六百条样本数据,其中流失状态是标签,其余为特征。

实例步骤

  • 导入数据
  • 处理数据
  • 划分数据
  • 创建模型
  • 训练模型
  • 测试模型
  • 画图查看
导入数据
python 复制代码
datas = pd.read_excel("电信客户流失数据.xlsx")
处理数据
python 复制代码
# 选取除最后一列外的所有列作为特征数据
data = datas.iloc[:, :-1]
# 选取最后一列作为目标数据
target = datas.iloc[:, -1]
划分数据
python 复制代码
# 导入数据分割模块
from sklearn.model_selection import train_test_split

# 将数据划分为训练集和测试集,测试集占 20%,随机种子为 42
data_train, data_test, target_train, target_test = \
	train_test_split(data, target, test_size=0.2, random_state=42)
创建模型
python 复制代码
# 导入决策树分类器
from sklearn import tree
# 创建决策树分类器,使用基尼系数作为划分标准,最大深度为 10,随机种子为 0,最大叶子节点数为 10
dtr = tree.DecisionTreeClassifier(criterion='gini', max_depth=10, random_state=0, max_leaf_nodes=10)
训练模型
python 复制代码
# 使用训练集训练决策树
dtr.fit(data_train, target_train)
测试模型
python 复制代码
def cm_plot(y, yp):
    # 导入混淆矩阵和 matplotlib 库
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt

    # 计算混淆矩阵
    cm = confusion_matrix(y, yp)
    # 以蓝色调绘制混淆矩阵
    plt.matshow(cm, cmap=plt.cm.Blues)
    # 添加颜色条
    plt.colorbar()
    for x in range(len(cm)):
        for y in range(len(cm)):
            # 在矩阵的每个单元格添加标注,显示对应的数值
            plt.annotate(cm[x, y], xy=(y, x), horizontalalignment='center',
                         verticalalignment='center')
    # 设置 y 轴标签
    plt.ylabel('True label')
    # 设置 x 轴标签
    plt.xlabel('Predicted label')
    return plt


# 对训练集进行预测
train_predicted = dtr.predict(data_train)


# 导入指标计算模块
from sklearn import metrics
# 打印训练集的分类报告
print(metrics.classification_report(target_train, train_predicted))


# 调用 cm_plot 函数绘制训练集的混淆矩阵,并显示
cm_plot(target_train, train_predicted).show()


# 对测试集进行预测
test_predicted = dtr.predict(data_test)


# 打印测试集的分类报告
print(metrics.classification_report(target_test, test_predicted))


# 调用 cm_plot 函数绘制测试集的混淆矩阵,并显示
cm_plot(target_test, test_predicted).show()


# 计算测试集上的准确率得分
dtr.score(data_test, target_test)

两次的分类报告和混淆矩阵。


画图查看
python 复制代码
# 导入 matplotlib 库
import matplotlib.pyplot as plt
# 导入绘制决策树的模块
from sklearn.tree import plot_tree


# 创建子图
fig, ax = plt.subplots(figsize=(8, 8))
# 绘制决策树,填充颜色
plot_tree(dtr, filled=True, ax=ax)
# 显示决策树
plt.show()


# 对测试集进行概率预测
y_pred_proba = dtr.predict_proba(data_test)
# 取预测为类别 1 的概率
a = y_pred_proba[:, 1]
# 计算 AUC 值
auc_result = metrics.roc_auc_score(target_test, a)


# 导入 matplotlib 库和 ROC 曲线绘制模块
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve


# 计算 ROC 曲线的相关指标
fpr, tpr, thresholds = roc_curve(target_test, a)
# 创建新的图形
plt.figure()
# 绘制 ROC 曲线
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve(area=%0.2f)' % auc_result)
# 绘制对角线
plt.plot([0, 1], [0, 1], color='navy', lw=2, ls='--')
# 设置 x 轴范围
plt.xlim([0.0, 1.0])
# 这里原代码可能是想设置 y 轴范围,修正为 plt.ylim([0.0, 1.05])
plt.ylim([0.0, 1.05])
# 设置 x 轴标签
plt.xlabel('false positive rate')
# 设置 y 轴标签
plt.ylabel('true positive rate')
# 设置图形标题
plt.title('receiver operating characteristic')
# 显示图例
plt.legend()
# 显示图形
plt.show()

决策树与ROC曲线


完整代码

python 复制代码
import pandas as pd


def cm_plot(y,yp):
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt

    cm = confusion_matrix(y,yp)
    plt.matshow(cm,cmap=plt.cm.Blues)
    plt.colorbar()
    for x in range(len(cm)):
        for y in range(len(cm)):
            plt.annotate(cm[x,y],xy=(y,x),horizontalalignment='center',
                         verticalalignment='center')
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
    return plt

datas = pd.read_excel("电信客户流失数据.xlsx")
data = datas.iloc[:,:-1]
target = datas.iloc[:,-1]



from sklearn.model_selection import train_test_split

data_train,data_test,target_train,target_test =\
    train_test_split(data,target,test_size=0.2,random_state=42)


from sklearn import tree
dtr = tree.DecisionTreeClassifier(criterion='gini',max_depth=10,random_state=0,max_leaf_nodes=10)
dtr.fit(data_train,target_train)

train_predicted = dtr.predict(data_train)

from sklearn import metrics
print(metrics.classification_report(target_train,train_predicted))

cm_plot(target_train,train_predicted).show()

test_predicted = dtr.predict(data_test)

print(metrics.classification_report(target_test,test_predicted))

cm_plot(target_test,test_predicted).show()
dtr.score(data_test, target_test)

import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

fig,ax = plt.subplots(figsize=(8,8))
plot_tree(dtr,filled= True,ax=ax)
plt.show()


y_pred_proba = dtr.predict_proba(data_test)
a = y_pred_proba[:,1]
auc_result= metrics.roc_auc_score(target_test,a)

import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve

fpr,tpr,thresholds =roc_curve(target_test,a)
plt.figure()
plt.plot(fpr,tpr,color='darkorange',lw=2,label='ROC curve(area=%0.2f)'%auc_result)
plt.plot([0,1],[0,1],color='navy',lw=2,ls='--')
plt.xlim([0.0,1.0])
plt.xlim([0.0,1.05])
plt.xlabel('false positive rate')
plt.ylabel('true positive rate')
plt.title('receiver operating characteristic')
plt.legend()
plt.show()
相关推荐
科技与数码20 分钟前
倍思氮化镓充电器分享:Super GaN伸缩线快充35W
人工智能·神经网络·生成对抗网络
HUIBUR科技2 小时前
量子计算遇上人工智能:突破算力瓶颈的关键?
人工智能·量子计算
CES_Asia2 小时前
CES Asia 2025聚焦量子与空间技术
人工智能·科技·数码相机·金融·量子计算·智能手表
程序猿阿伟2 小时前
《量子比特:解锁人工智能并行计算加速的密钥》
人工智能·量子计算
来瓶霸王防脱发3 小时前
【C#深度学习之路】如何使用C#实现Yolo5/8/11全尺寸模型的训练和推理
深度学习·yolo·机器学习·c#
music&movie3 小时前
代码填空任务---自编码器模型
python·深度学习·机器学习
盖丽男3 小时前
机器学习的组成
人工智能·机器学习
风一样的树懒4 小时前
Python使用pip安装Caused by SSLError:certificate verify failed
人工智能·python
9命怪猫4 小时前
AI大模型-提示工程学习笔记5-零提示
人工智能·笔记·学习·ai·提示工程
cnbestec5 小时前
GelSight Mini视触觉传感器凝胶触头升级:增加40%耐用性,拓展机器人与触觉AI 应用边界
人工智能·机器人