【机器学习】DecisionTree - 决策树中的数学原理

前言

决策树:一种树形结构,其中每个内部节点表示一个属性上的判断,每个分支代表一个判断结果的输出,最后每个叶节点代表一种分类结果。

请确保已经了解决策树的基本工作流程再进行阅读。

最典型的决策树算法是Hunt算法,该算法是由Hunt等人提出的最早的决策树算法。现代,Hunt算法是许多决策树算 法的基础,包括ID3、C4.5和CART等。Hunt算法诞生时间较早,且基础理论并非特别完善,此处以应用较广、理论 基础较为完善的ID3算法的基本原理进行讲解。

决策树算法的核心是要解决几个问题:

  • 如何从数据表中找出最合适特征进行分支 ?
  • 什么时候让决策树停止生长,防止过拟合 ?

文章内容若不特别标注则默认以分类树为例来进行说明。

不纯度的衡量指标

为了要将数据表转化为一棵树,决策树需要从所有特征中找到最佳特征用于分支,而衡量某个特征好不好的指标就是不纯度

决策树的每个节点中都包含一组数据,在这组数据中,如果有某一类标签占有较大的比例,我们就说这组数据或节点比较 "纯"。某一类标签占的比例越大,叶子就越纯,不纯度就越低,分枝就越好。如果没有哪一类标签的比例很大,各类标签都相对平均,则说该组数据或节点 "不纯",不纯度高,分枝不好。

分类型决策树在叶子节点上的决策规则是少数服从多数,当叶子节点比较纯时,表示可信度越高,当叶子节点不纯时,结果可信度就相对更低。例如某叶子节点中A标签数据有10个,B标签数据有9个,按照少数服从多数,将该叶子节点的某测试数据归为A标签显然并没有多大把握,分类错误的概率就高了。

分类误差率(Classification error)

不纯度的计算或度量方法一般由误差率衍生而来,误差率的计算非常简单粗暴:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> C l a s s i f i c a t i o n e r r o r ( t ) = 1 − m a x [ p ( i ∣ t ) ] \begin{align} Classification\space error(t) = 1-max[p(i|t)] \end{align} </math>Classification error(t)=1−max[p(i∣t)]

  • 取值范围: <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 ≤ C l a s s i f i c a t i o n e r r o r ( t ) ≤ 0.05 0\le Classification\space error(t) \le 0.05 </math>0≤Classification error(t)≤0.05
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t:表示某组数据
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> p ( i ∣ t ) p(i|t) </math>p(i∣t):表示某组数据( <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t) 下类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 所占的比例

分类误差率越小代表越纯,不纯度越低,反之不纯度越高。

例如一个笼子里有3只小狗,4只小猫,5只小猪,那这组数据的误差率就为 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 − 5 3 + 4 + 5 = 7 12 1-\frac{5}{3+4+5}=\frac {7}{12} </math>1−3+4+55=127。

误差率在计算的时候可以看出,将数量最多的类别以外的其它类别的数据都视为 "错误" 的数据,然后计算这些 "错误" 的数据在所有数据中所占的比例,将这个比例作为计算结果。

信息熵:Entropy

信息熵公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E n t r o p y ( t ) = − ∑ i = 0 c − 1 p ( i ∣ t ) l o g 2 p ( i ∣ t ) \begin{align} Entropy(t)=-\sum^{c-1}_{i=0}p(i|t)log_2p(i|t) \end{align} </math>Entropy(t)=−i=0∑c−1p(i∣t)log2p(i∣t)

  • 取值范围: <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 ≤ E n t r o p y ( t ) ≤ 1 0\le Entropy(t) \le 1 </math>0≤Entropy(t)≤1
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t:表示某组数据
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> p ( i ∣ t ) p(i|t) </math>p(i∣t):表示某组数据( <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t) 下类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 所占的比例
  • c:样本数据中的类别个数

信息熵越小代表越纯,不纯度越低,反之不纯度越高。

信息中信息量的大小跟随机事件的概率有关。越小概率的事情产生的信息量越大,越大概率的事情发生了产生的信息量越小。

  • 某地发生了地震:小概率事情, 信息量大
  • 太阳从东方升起,西方落下:大概率视屏,信息量小

信息量:信息量是对信息的度量,多少信息用信息量来衡量。

因此一个具体事件的信息量应该是随着其发生概率而递减的,且不能为负

假设存在两个不相关的事件 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y,根据常识我们可以很容易理解当两个事件同时发生时产生的信息量等于每个事件各自发生时产生的信息量之和,即: <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( p x , p y ) = f ( p x ) + f ( p y ) f(p_x,p_y) = f(p_x)+f(p_y) </math>f(px,py)=f(px)+f(py)

由于两事件不相关,因此 <math xmlns="http://www.w3.org/1998/Math/MathML"> p x y = p x ∗ p y p_{xy}=p_x*p_y </math>pxy=px∗py,也就是事件 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y 同时发生的概率等于两者发生的概率的积,则有 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( p x y ) = f ( p x ) + f ( p y ) f(p_{xy})=f(p_x)+f(p_y) </math>f(pxy)=f(px)+f(py)

我们要找到满足这样一个关系的函数,很容易就能联想到 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( p x ) f(p_x) </math>f(px) 与 <math xmlns="http://www.w3.org/1998/Math/MathML"> p x p_x </math>px 的对数有关,因此得到信息量公式。fe

  • 信息量公式: <math xmlns="http://www.w3.org/1998/Math/MathML"> h ( x ) = − l o g 2 p ( x ) h(x)=-log_2p(x) </math>h(x)=−log2p(x)
  • 例如: <math xmlns="http://www.w3.org/1998/Math/MathML"> h ( x ) + h ( y ) = − ( l o g 2 p ( x ) + l o g 2 p ( y ) ) = − l o g 2 p ( x ) p ( y ) = − l o g 2 p ( x y ) h(x)+h(y)=-(log_2p(x)+log_2p(y))=-log_2p(x)p(y)=-log_2p(xy) </math>h(x)+h(y)=−(log2p(x)+log2p(y))=−log2p(x)p(y)=−log2p(xy)
    • 即: <math xmlns="http://www.w3.org/1998/Math/MathML"> h ( x y ) = h ( x ) + h ( y ) h(xy)=h(x)+h(y) </math>h(xy)=h(x)+h(y)
  • 为什么要加负号:概率小于等于1,因此 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g 2 p ( x ) ≤ 0 log_2p(x) \le 0 </math>log2p(x)≤0,信息量要求大于等于0
  • 为什么底数是2:我们只需要信息量满足低概率事件 x 对应于高的信息量。那么对数的选择是任意的。我们只是遵循信息论的普遍传统,使用 2 作为对数的底

信息熵(Entropy):信息量度量的是一个具体事件发生了所带来的信息,而熵则是在结果出来之前对可能产生的信息量的期望------考虑该随机变量的所有可能取值,即所有可能发生事件所带来的信息量的期望,这时候再看信息熵公式应该会清楚很多。

此外,信息熵还可以作为一个系统复杂程度的度量,系统越复杂,出现不同情况的种类越多,则信息熵越大,反之越小。对于决策树中的某节点也是如此。

信息增益(Information gain):决策树中的信息增益是指子节点的信息熵减去父节点的信息熵,值一定是正数,信息增益越大,代表使用某特征进行分支的效果越好。

基尼系数:Gini

Gini系数公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> G i n i ( t ) = ∑ i = 0 c − 1 p ( i ∣ t ) [ 1 − p ( i ∣ t ) ] = 1 − ∑ i = 0 c − 1 [ p ( i ∣ t ) ] 2 \begin{align} Gini(t)&=\sum^{c-1}{i=0}p(i|t)[1-p(i|t)]\\ &=1-\sum^{c-1}{i=0}[p(i|t)]^2 \end{align} </math>Gini(t)=i=0∑c−1p(i∣t)[1−p(i∣t)]=1−i=0∑c−1[p(i∣t)]2

  • 取值范围: <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 ≤ G i n i ( t ) ≤ 0.05 0\le Gini(t) \le 0.05 </math>0≤Gini(t)≤0.05
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t:表示某组数据
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> p ( i ∣ t ) p(i|t) </math>p(i∣t):表示某组数据( <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t) 下类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 所占的比例
  • c:样本数据中的类别个数

Gini系数越小代表越纯,不纯度越低,反之不纯度越高。

Gini系数是衡量不纯度或不平等度量的一种指标,常用于衡量贫富差距,与信息熵计算公式相比,只需将信息熵计算公式中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g 2 p ( i ∣ t ) log_2p(i|t) </math>log2p(i∣t) 替换为 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 − p ( i ∣ t ) 1-p(i|t) </math>1−p(i∣t) 即可。

与分类误差率进行比较,分类误差率只考虑除占比最大的类以外的类所占的比例,但由公式可以看出,Gini系数则是考虑全部事件发生的概率求期望,并且相比信息熵的计算更加简单。

误差率、信息熵、Gini系数 比较

分类误差率的局限性:

  • 对样本分布不敏感:分类误差率只考虑被错误分类的样本比例,而忽略了样本在各个类别中的分布情况。这意味着在处理不均衡数据集时,分类误差率可能无法准确反映数据的真实情况。
  • 不具备连续性:分类误差率只关注样本的分类结果,而不考虑类别之间的关系和连续性。这在处理连续型特征或特征之间存在明显顺序关系的情况下,可能导致信息损失和决策不准确。
  • 不支持特征权重:分类误差率对所有特征都是等权重考虑,无法通过给予不同特征不同的权重来增强对某些特征的重要性。这在某些问题中可能限制了决策树的表达能力。
  • 不适用于多类别问题:分类误差率在处理多类别问题时存在问题。它只考虑了被错误分类的样本比例,而忽略了其他类别的信息。在多类别问题中,信息熵和Gini系数通常更常用和更可靠。

信息熵和Gini系数之间差别不大,大多情况都是使用的信息熵或Gini系数来作为不纯度的衡量标准。

ID3算法

特征选择

决策树最终的优化目标是使得叶节点的总不纯度最低,即对应衡量不纯度的指标最低。

ID3使用的不纯度衡量方法是信息熵,最优条件是叶节点的总信息熵最小,因此ID3决策树在决定是否对某节点进行切分的时候,会尽可能选取使得该节点对应的子节点信息熵最小的特征进行切分。换言之,就是要求父节点信息熵和子节点总信息熵之差要最大。对于ID3而言,二者之差就是信息增益,即Information gain。

但这里需要注意,一个父节点下可能有多个子节点,而每个子节点又有自己的信息熵,所以父节点信息熵和子节点信息熵之差,应该是父节点的信息熵 - 所有子节点信息熵的加权平均。其中,权重是使用单个叶子节点上所占的样本量比上父节点上的总样本量。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> I ( c h i l d ) = ∑ j = 1 k N ( v j ) N I ( v j ) I(child)=\sum^k_{j=1}\frac{N(v_j)}{N}I(v_j) </math>I(child)=j=1∑kNN(vj)I(vj)

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> v j v_j </math>vj:子节点
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( v j ) N(v_j) </math>N(vj):该子节点的样本量
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N:总样本量
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> I I </math>I:不纯度,impurity

父节点和子节点的不纯度下降数:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Δ I = I ( p a r e n t ) − I ( c h i l d ) \Delta I = I(parent) - I(child) </math>ΔI=I(parent)−I(child)

这里的 <math xmlns="http://www.w3.org/1998/Math/MathML"> I ( v i ) I(vi) </math>I(vi) 是某节点的不纯度度量,可以是Gini系数、信息熵等方法,由父节点不纯度和子节点不纯度均值的差得到不纯度下降数,若使用信息熵方法,则该指标就是信息增益,不纯度下降数越大说明效果越好。

对于某数据下的所有特征,我们在选择特征产生分支时会对所有特征计算不纯度下降数,然后选择不纯度下降数最大的特征,选择后该特征会被消费掉,不会再次对其进行计算和选择。ID3算法使用的不纯度衡量指标是信息熵,其局部最优化条件也就是信息增益。

当特征较多时,选择所有特征来计算不纯度下降数再进行比较会大大增加计算量,因此使用局部最优的方法,也就是从原本所有特征中随机选取一部分特征,把这些特征作为全部特征来从中进行计算和选择。

总的来说,决策树模型是一个典型的贪心模型,总目标是一个全局最优解,即一整套合理的分类规则使得最终叶节点的纯度最高,但全局最优解在随特征增加而呈现指数级增加的搜索空间内很难高效获取,因此我们退而求其次,考虑采用局部最优来一步步推导结果------只要保证信息增益最大,我们就能得到次最优的模型。当然,局部最优不一定等于全局最优。

ID3 的局限性

ID3局限主要源于局部最优化条件,即信息增益的计算方法,其局限性主要有以下几点:

  • 分支度越高(分类水平越多)的离散变量往往子节点的总信息熵会更小,ID3是按照某一列进行切分,有一些列的分类可能对结果没有足够好的指示。例如一个比较极端的例子:ID3会选取样本的ID作为切分字段,因为每个样本的ID都是不唯一的,这样的话每个分类(子节点)的不纯度都是0,信息增益最大,但这样的分类方式是没有任何效益的
  • 不能直接处理连续型变量,若要使用ID3处理连续型变量,则首先需要对连续变量进行离散化
  • 对缺失值较为敏感,使用ID3之前需要提前对缺失值进行处理
  • 没有剪枝的设置,容易导致过拟合,即在训练集上表现很好,测试集上表现很差

分支度/分类变量水平:例如性别这个特征有两个取值,分别是男、女,产生两个分支,分类变量水平就是2;再例如温度这个特征有三个取值,分别是高、中、低,产生三个分支,分类变量水平就是3。

C4.5算法 / CART算法

修改局部最优化条件

在C4.5中,首先通过引入分支度 (IV:Information Value)的概念,来对信息增益的计算方法进行修正,简而言之,就是在信息增益计算方法的子节点总信息熵的计算方法中添加了 随着分类变量水平的惩罚项 。而分支度的计算公式仍然是基于熵的算法,只是将信息熵计算公式中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ( i ∣ t ) p(i|t) </math>p(i∣t) (某类别样本占总样本数) 改成了 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( v i ) P(v_i) </math>P(vi) (某子节点的总样本数占父节点总样本数的比例),这其实就是我们加权求和时的 "权重"。这样的一个分支度指标,让我们在切分的时候,自动避免那些分类水平太多,信息熵减小过快的特征影响模型,减少过拟合情况。

分支度计算公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> I n f o r m a t i o n V a l u e = − ∑ i = 1 k P ( v i ) l o g 2 P ( v i ) Information\space Value = -\sum^{k}_{i=1}P(v_i)log_2P(v_i) </math>Information Value=−i=1∑kP(vi)log2P(vi)

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i:父节点的第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个子节点
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> v i v_i </math>vi:第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个子节点样本数
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( v i ) P(v_i) </math>P(vi):第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个子节点拥有样本数占父节点总样本数的比例

由前面讲过的信息熵来理解,就是分类变量水平越高,越复杂,对应的信息熵越高,反之越低,因此可以作为其惩罚项,也就是说一个特征中如果标签分类太多,那对应的分支度就会相应增大。

在C4.5中,使用分支度作为惩罚项,也就是将信息增益除以分支度这个指标作为参考标准,该指标被称为Gain Ratio (获利比例/增益率),计算公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> G a i n R a t i o = I n f o r a t i o n G a i n I n f o r m a t i o n V a l u e Gain Ratio = \frac{Inforation\space Gain}{Information\space Value} </math>GainRatio=Information ValueInforation Gain

  • 翻译成中文就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 获利比率 = 信息增益 分支度 获利比率 = \frac{信息增益}{分支度} </math>获利比率=分支度信息增益

这样就解决了ID3算法中使用信息增益作为局部最优化条件带来的问题。

处理连续型变量

在C4.5中,还增加了针对连续变量的处理方法。若输入特征字段是连续型变量,则有下列步骤:

  1. 算法首先会对这一列数进行从小到大的排序
  2. 选取相邻的两个数的中间数作为切分数据集的备选点,若一个连续变量有N个值,则在C4.5的处理过程中将产生 N-1个备选切分点,并且每个切分点都代表着一种二叉树的切分方案

这里需要注意的是,此时针对连续变量的处理并非是将其转化为一个拥有N-1个分类水平的分类变量,而是将其转化成了N-1个二分方案,而在进行下一次的切分过程中,这N-1个方案都要单独带入考虑,其中每一个切分方案和一个离散变量的地位均相同(一个离散变量就是一个单独的多路切分方案);和ID3一样,当选择后,该特征将会被消费掉,不会再次进行计算和选择。

例如某年龄字段数据:11,12,13,14,15,16,17,18,19,20,共10个值

将会对每两个数进行切分,产生9种方案:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> a g e < 11.5 , a g e > 11.5 age < 11.5,age > 11.5 </math>age<11.5,age>11.5
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> a g e < 12.5 , a g e > 12.5 age < 12.5, age > 12.5 </math>age<12.5,age>12.5
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ⋯ \cdots </math>⋯
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> a g e < 19.5 , a g e > 19.5 age < 19.5, age > 19.5 </math>age<19.5,age>19.5

从上述论述能够看出,在对于包含连续变量的数据集进行树模型构建的过程中要消耗更多的运算资源。但与此同时, 我们也会发现,当连续变量的某中间点参与到决策树的二分过程中,往往代表该点对于最终分类结果有较大影响,这 也为我们连续变量的分箱压缩提供了指导性意见。

CART算法

现在被大量使用的是C4.5的改进CART树,CART树本质其实和C4.5区别不大,只不过CART树所有的层都是二叉树, 也就是每层只有两个分枝。

让我们使用CART树来过一遍C4.5的流程,假设年龄是特征,性别是标签:

  1. 首先将年龄从小到大依此进行排列
  2. 然后计算两两相邻的年龄的均值
  3. 按均值所在的点,对连续性变量进行二分(变成数字形式的,第一类是>均值,第二类是<均值), 二分得到的点叫做决策树的 "树桩"。
  4. 计算 n-1个 二分切分方案的获利比例,获利比例最大的切分点,就是切点
  5. 切完之后,计算加权信息熵,计算信息增益,引入分支度,可以计算出相应获利比例了

这个流程和C4.5算法的流程基本一致,但需要注意的是在一些方面存在不同。

改进方案

  • CART算法 对连续型变量的这种切分方法,每次切分之后,并没有消费掉一整个特征,而是只消费掉了一个备选点。而 C4.5 算法每次进行特征选择就会消费掉一整个特征。
  • 实际上,我们可以只关注每次分类的时候,对信息熵的减少贡献最多的那个分类。按我们的例子来说,我们在分类 age 的时候,最关注的是 31 - 40 岁的那一个分类,我们完全可以实现 31 - 40 为一类,其他算一类,然后我们再对"其它"这个类别进行相似的二分类。这就是CART的核心原理,大量地减少了计算的量。

总结

决策树的基本流程其实可以简单概括如下:

  1. 计算全部特征的不纯度指标
  2. 选取不纯度指标最优的特征进行分支
  3. <math xmlns="http://www.w3.org/1998/Math/MathML"> ⋯ \cdots </math>⋯ 直到没有更多的特征可用,或整体的不纯度指标已经最优,或达到设置的某些阈值 (如最大深度),决策树就会停止生长。

对于KNN算法,我们有一个假设:就是每一个特征对于我们的推断的重要性是一样的。这也是KNN最大的缺陷。而决策树天生就认为每个特征对于推断的重要性不是一样的,而CART则是进一步认为,每个特征下的每个分类对于推断的重要性也不是一样的。

Reference

相关推荐
泰迪智能科技012 小时前
高校深度学习视觉应用平台产品介绍
人工智能·深度学习
盛派网络小助手2 小时前
微信 SDK 更新 Sample,NCF 文档和模板更新,更多更新日志,欢迎解锁
开发语言·人工智能·后端·架构·c#
算法小白(真小白)2 小时前
低代码软件搭建自学第二天——构建拖拽功能
python·低代码·pyqt
唐小旭2 小时前
服务器建立-错误:pyenv环境建立后python版本不对
运维·服务器·python
007php0072 小时前
Go语言zero项目部署后启动失败问题分析与解决
java·服务器·网络·python·golang·php·ai编程
Eric.Lee20213 小时前
Paddle OCR 中英文检测识别 - python 实现
人工智能·opencv·计算机视觉·ocr检测
cd_farsight3 小时前
nlp初学者怎么入门?需要学习哪些?
人工智能·自然语言处理
AI明说3 小时前
评估大语言模型在药物基因组学问答任务中的表现:PGxQA
人工智能·语言模型·自然语言处理·数智药师·数智药学
Chinese Red Guest3 小时前
python
开发语言·python·pygame
Focus_Liu3 小时前
NLP-UIE(Universal Information Extraction)
人工智能·自然语言处理