【机器学习】决策树分类#1基于Scikit-Learn的简单实现

主要参考学习资料:

《机器学习算法的数学解析与Python实现》莫凡 著

前置知识:概率论与数理统计-Python

博主采用了由浅入深、循序渐进的学习过程,因此本系列不会将内容一次性吃透,而是先整体后细节,后续再探索每个模型更深入的数学知识和具体实现。

目录

数学模型

分类方法

决策树是一种机器学习算法的框架,采用由if-else嵌套形成的树形结构进行分类。即每一次对样本的某个特征是否满足某一条件进行判断后分成两类,再将分出来的每一类继续按照新的判断条件分类下去。

条件选择

纯度

有了基本的树形结构,我们还需要知道每个节点应填入的判别条件,这正是决策树分类算法进行学习的内容。

首先可以确定判别条件来源于特征,不同的特征维度构成一个集合,称为特征维度集。数据样本的特征维度都可能与最终的类别存在某种关联关系,而决策树的判别条件正是从中产生的。

其次,对于选择哪些特征维度作为判别条件、选择哪里作为一个特征维度的分界线,我们需要一个标准来比较不同判别条件的优劣,故引入纯度的概念。

若一个集合中属于同一类别的样本占比越高,这个集合的纯度就越高。以二元分类问题为例,一个集合只包含一个类别(不包含另一个类别)时纯度达到最高,两个类别各占一半时纯度最低。而分类的目的就是要达到提纯的效果,若子集的纯度比原来的集合的纯度要高,说明分类起到了正面的作用。

纯度度量

为了用数学模型给出纯度的值,同时让其符合损失函数"数值越小越好"的性质,不同的决策树分类算法给出了不同的纯度度量方法,这也是决策树分类算法的核心问题。

信息熵

是热力学中表示系统无序程度的概念,而信息熵则用于描述信息源各种可能事件的不确定性,即情况越乱,信息熵越大,与纯度的概念有异曲同工之处。

信息熵的数学表达式如下:

H ( X ) = − ∑ k = 1 N p k log ⁡ 2 ( p k ) H(X)=-\displaystyle\sum^N_{k=1}p_k\log_2(p_k) H(X)=−k=1∑Npklog2(pk)

X X X表示进行信息熵计算的集合, p p p为概率,在纯度度量中即为不同类别的占比。

将二元分类问题的两种极端情况代入该公式结果如下:

只有一种类别占比100%:

H ( X ) = − ( 1 × log ⁡ 2 ( 1 ) + 0 ) = 0 H(X)=-(1\times\log_2(1)+0)=0 H(X)=−(1×log2(1)+0)=0

两种类别各占50%:

H ( X ) = − ( 0.5 × log ⁡ 2 ( 0.5 ) + 0.5 × log ⁡ 2 ( 0.5 ) ) = 1 H(X)=-(0.5\times\log_2(0.5)+0.5\times\log_2(0.5))=1 H(X)=−(0.5×log2(0.5)+0.5×log2(0.5))=1

信息熵符合纯度和损失函数的要求。

信息增益

为了用信息熵进一步衡量提纯效果,ID3算法给出了信息增益

G ( D , a ) = H ( D ) − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ H ( D v ) G(D,a)=H(D)-\displaystyle\sum^V_{v=1}\frac{|D^v|}{|D|}H(D^v) G(D,a)=H(D)−v=1∑V∣D∣∣Dv∣H(Dv)

G ( D , a ) G(D,a) G(D,a)表示集合 D D D选择特征属性 a a a划分子集时的信息增益, H ( D ) H(D) H(D)即分类前的信息熵。减数中, V V V表示 D D D按特征维度 a a a划分成了几个子集,上标 v v v选取其中的某一个子集,符号"| |"求的是对应集合中元素的个数,则 ∣ D v ∣ ∣ D ∣ \displaystyle\frac{|D^v|}{|D|} ∣D∣∣Dv∣表示某个子集的元素个数占总元素个数的比例,占比越大,该子集信息熵的权重越高。用原集合的信息熵减去划分所得各子集信息熵的加权和就是按特征维度 a a a进行划分的信息增益。

信息增益比

信息增益的缺陷在于只要子集被切分得越细,极端情况下每个子集只有一个元素,也能让纯度相对提升,但这使决策树的设计过于冗杂。C4.5算法提出信息增益比以应对该缺陷:

G r = G ( D , a ) I V ( a ) G_r=\displaystyle\frac{G(D,a)}{IV(a)} Gr=IV(a)G(D,a)

其中 I V ( a ) = − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ log ⁡ 2 ∣ D v ∣ ∣ D ∣ IV(a)=\displaystyle-\sum^V_{v=1}\frac{|D^v|}{|D|}\log_2\frac{|D^v|}{|D|} IV(a)=−v=1∑V∣D∣∣Dv∣log2∣D∣∣Dv∣

特征维度划分出的子集越多,该固有值越大,消除了其不利影响。

基尼指数

CART算法在决策条件的选择上采用了和信息熵原理相似的基尼指数

G i n i ( D ) = 1 − ∑ k = 1 N p k 2 \mathrm{Gini}(D)=1-\displaystyle\sum^N_{k=1}p_k^2 Gini(D)=1−k=1∑Npk2

D D D表示进行基尼指数计算的集合,相比信息熵,基尼指数采取了占比 p p p直接相乘的算法,计算更为简单。

用基尼指数衡量提纯效果的方法也与信息增益一致,将公式中的信息熵替换为基尼指数即可。

停止条件

决策树的第二个问题是停止分裂问题,即不能让判别条件的分支永无止境地生成下去。停止条件有以下三种:

①数据集已经完成了分类,即当前集合的样本都属于同一类;

②所有的特征维度都已被采用。若分类没有完成,决策树会以占比最大的类别作为当前节点的归属类别;

③不同特征维度的提纯效果完全一样,无法做出选择。

剪枝

训练集数据可能出现假性关联问题,即一些特征维度与样本分类实际上并不存在关联关系,而决策树容易将所有的特征维度都考虑进去,产生过于细枝末节的判别分支,进而出现过拟合的问题,于是剪枝算法应运而生。

剪枝分为预剪枝和后剪枝。预剪枝在分支划分前先判断是否需要剪枝,如果需要则停止该分支的划分,即扼杀在萌芽状态;后剪枝在各个判别分支已经形成后再进行剪枝判断。

譬如经过一个冗余的特征分类后,分出的两个类别的纯度还不如分类前的纯度高,则判断为进行剪枝。

代码实现

python 复制代码
#导入决策树分类算法 
from sklearn.tree import DecisionTreeClassifier  
from sklearn.datasets import load_iris  
import numpy as np  
import matplotlib.pyplot as plt  
iris = load_iris()  
X = iris.data[:, :2]
y = iris.target  
clf = DecisionTreeClassifier().fit(X, y)  
def plot_decision_boundary(X, y, model):  
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1  
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1  
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01), np.arange(y_min, y_max, 0.01))  
    Z = model.predict(np.c_[xx.ravel(), yy.ravel()])  
    Z = Z.reshape(xx.shape)  
    plt.contourf(xx, yy, Z, alpha=0.8, cmap=plt.cm.coolwarm)  
    plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', marker='o', cmap=plt.cm.coolwarm)  
    plt.xlabel('Feature 1')  
    plt.ylabel('Feature 2')  
    plt.title('Decision Tree Decision Boundary')  
    plt.show()  
plot_decision_boundary(X, y, clf)

运行结果:

图中可以清晰地看出决策树分类算法较高的拟合程度和判别条件产生的分界线。

算法特点

优点:逻辑清晰,树形结构容易可视化,能够比较直观地观察分类过程。

缺点:容易过拟合,特征维度存在关联关系时也会对预测结果产生明显影响。

应用领域:商业决策、管理决策等。

相关推荐
阿正的梦工坊3 小时前
变分扩散模型 ELBO 重构推导详解
人工智能·深度学习·算法·机器学习
扫地僧9854 小时前
基于提示驱动的潜在领域泛化的医学图像分类方法(Python实现代码和数据分析)
人工智能·分类·数据挖掘
GIS小天5 小时前
AI预测体彩排3新模型百十个定位预测+胆码预测+杀和尾+杀和值2025年3月4日第9弹
人工智能·算法·机器学习·彩票
Dann Hiroaki5 小时前
文献分享: ConstBERT固定数目向量编码文档
数据库·机器学习·自然语言处理·nlp
Shockang6 小时前
机器学习数学通关指南
人工智能·数学·机器学习
CS创新实验室6 小时前
《机器学习数学基础》补充资料:描述性统计
人工智能·机器学习·机器学习数学基础
@心都6 小时前
机器学习数学基础:40.结构方程模型(SEM)中卡方值与卡方自由度比
人工智能·算法·机器学习
进阶的小蜉蝣6 小时前
[machine learning] MACS、MACs、FLOPS、FLOPs
人工智能·机器学习
@心都6 小时前
机器学习数学基础:39.样本和隐含和残差协方差矩阵
算法·机器学习·矩阵