【机器学习】Lesson 6 - 决策树(分类)

目录

背景

一、适用数据集

[1. 数据集选择](#1. 数据集选择)

[2. 本文数据集介绍](#2. 本文数据集介绍)

二、算法原理

[1. 算法简介](#1. 算法简介)

[2. 决策树分类的应用](#2. 决策树分类的应用)

[3. 模型参数设置](#3. 模型参数设置)

[4. 决策树相关知识补充](#4. 决策树相关知识补充)

[4.1 节点含义](#4.1 节点含义)

[4.2 节点信息解释](#4.2 节点信息解释)

[4.3 每一层的含义](#4.3 每一层的含义)

三、代码

[1. 导入包&数据](#1. 导入包&数据)

[2. 数据预处理](#2. 数据预处理)

[3. 数据集划分](#3. 数据集划分)

[4. 建模评估并绘制决策树](#4. 建模评估并绘制决策树)

决策树图形信息解释


背景

决策树 (decision tree) 是机器学习中一类常见的监督学习算法,可用于分类和回归任务。决策树通过递归地对数据集进行分裂,最终形成一系列规则。这些规则根据输入数据的特征,将其分类到特定类别或预测值。

决策树就是不断根据某属性进行划分的过程(每次决策时都是在上次决策结果的基础之上进行),即 "if⋯⋯elif⋯⋯else⋯⋯" 的决策过程,最终得出一套有效的判断逻辑,便是学到的模型。

在前文《L1练习-鸢尾花数据集处理(分类/聚类)》中,分类部分涉及到决策树的正确率评估,在本文中代码和可视化将以展现决策树算法逻辑为重。随机森林(Random Forest)是基于决策树的一种集成学习算法,将在下一篇中进行介绍和代码演示。

一、适用数据集

决策树适用的数据集类型

  1. 结构化数据:决策树在处理表格型、结构化的数据时表现较好。
  2. 数值型和类别型特征:既能处理数值型特征,也能处理类别型特征,不需要特征标准化。
  3. 缺失值和噪声:对缺失值和噪声有一定的鲁棒性,但较复杂的数据可能导致过拟合。

1. 数据集选择

我们的老朋友 --- sklearn 中自带的数据集 --- 鸢尾花数据集。

2. 本文数据集介绍

鸢尾花数据集是通过四个维度数据统计了三种不同品种鸢尾花的数据集,其标签存储在 iris.target 里,将 150 行数据分别标为 0(Setosa)、1(Versicolor)、2(Virginica),各 50 个,并且升序排列。故,若取前一百条记录时,仅有 0(Setosa)、1(Versicolor)两种计入数据集。

取两个维度对标签进行可视化处理会出现以下图像:

数据量

  • 维度:4个
  • 记录行数:150行

二、算法原理

1. 算法简介

其结构类似于一棵树,由以下部分组成:

  • 根节点(Root Node):表示整个数据集的初始分裂点。
  • 内部节点(Internal Nodes):表示根据特定特征进行的决策。
  • 叶节点(Leaf Nodes):表示分类结果或回归输出。

2. 决策树分类的应用

决策树分类会相对来说更加适合满足以下条件的分析背景:

规则明确的场景:适合需要生成 "如果-那么" 决策规则的任务,如信贷审批、人事决策。

小规模数据:在特征数量较少、数据量较小的情况下表现良好,如学生成绩预测、医疗诊断。

简单特征关系:适用于特征和目标变量关系简单的问题,如客户分类、产品推荐。

无需特征工程:可以直接处理类别型和数值型特征,不需要标准化或归一化,如市场营销、欺诈检测。

可解释性要求高:模型透明易解释,适合法律、医学等需要可视化和逻辑透明的场景。

类别不均衡:通过调整分裂策略,可应对不均衡数据,如客户流失预测、罕见疾病筛查。

多类分类任务:天然支持多类别分类,无需额外调整,如植物识别、文本分类。

3. 模型参数设置

DecisionTreeClassifier() 是 scikit-learn 库中的决策树分类器,它有几个关键参数,用于控制模型的行为和性能。以下是其中一些重要的参数:

  1. criterion:用于划分节点的标准,可以是 'gini'(基尼指数)、'entropy'(信息增益),表示对数据纯度的衡量。默认值是 'gini'

  2. splitter:分割节点的方式,可以选择 'best'(选择最优特征划分)、'random'(随机选取)。 默认值是 'best'

  3. max_depth:树的最大深度,如果为 None 则不限制深度,直到所有叶子都是纯样本,即叶子节点中所有的样本点都属于同一个类别。或者每个叶子节点包含的样本数小于 min_samples_split。默认值是 None。

  4. min_samples_split:分裂一个内部节点所需的最小样本数。如果为整数,则min_samples_split就是最少样本数。如果为浮点数(0到1之间),则每次分裂最少样本数为 ceil(min_samples_split * n_samples)。默认值是 2。

  5. min_samples_leaf:每个叶节点至少需要的最小样本数。如果为整数,则min_samples_split 就是最少样本数。如果为浮点数(0到1之间),则每个叶子节点最少样本数为 ceil(min_samples_leaf * n_samples)。默认值是 1。

  6. max_features:若非 None,限制考虑的特征数。可以选择 'auto'(选择最优数量的特征),'all'(使用所有可用的特征)或一个整数值。如果为整数,每次分裂只考虑 max_features个特征;如果为浮点数(0到1之间),每次切分只考虑 int(max_features * n_features) 个特征。默认值是 None (和 'all' 一样,使用全部特征)

  7. random_state:用于随机化的种子,保证结果的可重复性。默认值是 None。

  8. class_weight:处理类别不平衡的选项,如 'default'、'balanced' 或自定义权重列表。默认值是 None (每个类别的权重都为1)

  9. presort:是否先对数据进行排序再进行划分,对于大样本集可能会提高效率。对于大数据集会减慢总体的训练过程。如果class_weight='balanced',则分类的权重与样本中每个类别出现的频率成反比:n_samples / (n_classes * np.bincount(y))。默认值是 False

4. 决策树相关知识补充

4.1 节点含义

每个节点表示一个特征的划分条件,树从根节点开始,逐步对样本进行分类。

内部节点:包含特征划分的条件,用来将数据集分为更小的子集。

叶节点:表示最终的分类结果,不再划分。

4.2 节点信息解释

在每个节点上通常会显示以下内容:

  1. 特征名称和划分条件

    • 格式:feature_name <= threshold
    • 解释:该节点按照某个特征的值与阈值进行比较。如果满足条件,样本被划分到左子节点;否则进入右子节点。

    例如:petal length <= 2.45 表示花瓣长度小于或等于 2.45 的样本进入左子节点。

  2. Gini系数 (gini):

    • 表示当前节点的不纯度 ,范围在 [0, 0.5]
    • 值越小,节点内样本越纯(即更倾向于某个类别)。

    例如:gini = 0.5 表示节点内样本均匀分布在两个或多个类别中,而 gini = 0.0 表示样本全属于一个类别。

  3. 样本数量 (samples):当前节点包含的样本总数。

  4. 类别分布 (value):

    • 格式:value = [n1, n2, n3],表示当前节点中属于每个类别的样本数量。
    • 例如:value = [50, 0, 0] 表示该节点中有 50 个样本,全部属于第一个类别。
  5. 类别名称 (class):决策结果,即当前节点的预测类别。

4.3 每一层的含义

根节点(第 0 层):包含所有样本,进行第一个特征的划分。通常选择最能区分样本的特征。

第 1 层及之后的层:根据上一层的划分结果,进一步细分数据。每一层表示决策过程中的一步细化,直到叶节点为止。

叶节点(最后一层):不再进行划分,直接输出最终类别。叶节点的 Gini 系数通常为 0,表示样本已完全纯化。

三、代码

1. 导入包&数据

python 复制代码
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counter

from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.tree import export_graphviz
import graphviz
from sklearn import tree
import matplotlib.pyplot as plt

#导入内置的鸢尾花数据
from sklearn.datasets import load_iris
import pandas as pd
import numpy as np


plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号


iris = load_iris()
print(iris)

2. 数据预处理

python 复制代码
iris_d = pd.DataFrame(data=iris.data,columns=['sepal length', 'sepal width', 'petal length', 'petal width'])

iris_d

#查看数据类型信息
iris_d.info()

iris_d["target"] = iris.target

iris_d

3. 数据集划分

python 复制代码
X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

4. 建模评估并绘制决策树

python 复制代码
#训练模型
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train,)

# 评估模型
accuracy = clf.score(X_test, y_test)
print(f"分类准确率: {accuracy:.2f}")

plt.figure(figsize=(6, 4))
plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
plt.title("决策树分类可视化")
plt.show()

决策树图形信息解释

如本决策树中第一分支上的信息

1.petal length (cm) <= 2.35

含义 :该节点根据 petal length (cm)(花瓣长度)是否小于或等于 2.35 进行数据划分。

解释:如果花瓣长度 ≤ 2.35,数据进入左子节点。如果花瓣长度 > 2.35,数据进入右子节点。

2.gini = 0.663

含义:此处的 Gini 系数是当前节点数据的不纯度度量。

解释 :Gini 系数范围是 [0, 1],0 表示节点中的样本完全属于一个类别,1 表示样本均匀分布在所有类别。gini = 0.663 表示当前节点中的样本分布不均匀但不完全纯。

3.samples = 105

含义:此节点包含 105 个样本。

解释:这些样本尚未被完全分类,仍需进一步划分。

4.value = [40, 32, 33]

含义:此节点中样本在每个类别中的分布情况。

解释 :类别 setosa 中有 40 个样本。类别 versicolor 中有 32 个样本。类别 virginica 中有 33 个样本。

5.class = setosa

含义 :此节点预测的类别是 setosa(鸢尾花的一个类别)。

解释 :在当前节点中,setosa 是占比最大的类别(尽管不是唯一类别),因此决策树暂时将该节点的样本归为 setosa

  1. 总结

由以上信息可以得到:

  • 该节点包含 105 个样本,通过 petal length 特征进行划分,Gini 系数表明节点内的样本分布仍有一定的不纯性。
  • 节点中样本分布显示 setosa 样本数量较多,因此暂时将其归类为 setosa
  • 此节点之后将继续分裂,进一步细化分类过程。
相关推荐
正在走向自律2 小时前
深度学习:重塑学校教育的未来
人工智能·深度学习·机器学习
dundunmm3 小时前
论文阅读:Statistical Comparisons of Classifiers over Multiple Data Sets
论文阅读·人工智能·算法·机器学习·评估方法
人类群星闪耀时3 小时前
机器学习在自动化运维中的应用:提升运维效率的新利器
运维·机器学习·自动化
WBingJ3 小时前
李宏毅机器学习深度学习:机器学习任务攻略
人工智能·深度学习·机器学习
**之火3 小时前
(五)机器学习 - 数据分布
人工智能·机器学习
martian6653 小时前
人工智能机器学习基本概念详解
人工智能·机器学习
小雄abc3 小时前
决定系数R2 浅谈三 : 决定系数R2与相关系数r的关系、决定系数R2是否等于相关系数r的平方
经验分享·笔记·深度学习·算法·机器学习·学习方法·论文笔记
知来者逆5 小时前
Layer-Condensed KV——利用跨层注意(CLA)减少 KV 缓存中的内存保持 Transformer 1B 和 3B 参数模型的准确性
人工智能·深度学习·机器学习·transformer
宸码6 小时前
【机器学习】手写数字识别的最优解:CNN+Softmax、Sigmoid与SVM的对比实战
人工智能·python·神经网络·算法·机器学习·支持向量机·cnn
睡觉狂魔er6 小时前
自动驾驶控制与规划——Project 1: 车辆纵向控制
人工智能·机器学习·自动驾驶