机器学习之监督学习(四)决策树和随机森林
- [0. 文章传送](#0. 文章传送)
- [1. 决策树 Decision Tree](#1. 决策树 Decision Tree)
- [2. 随机森林 Random Forest](#2. 随机森林 Random Forest)
- [3. 决策树 vs 神经网络](#3. 决策树 vs 神经网络)
- [4. 代码实现](#4. 代码实现)
- [5. 案例](#5. 案例)
0. 文章传送
机器学习之监督学习(一)线性回归、多项式回归、算法优化[巨详细笔记]
机器学习之监督学习(二)二元逻辑回归
机器学习之监督学习(三)神经网络基础
机器学习之实战篇------预测二手房房价(线性回归)
机器学习之实战篇------肿瘤良性/恶性分类器(二元逻辑回归)
机器学习之实战篇------MNIST手写数字0~9识别(全连接神经网络模型)
1. 决策树 Decision Tree
案例引入
前面的文章系列已经介绍了几种监督学习算法,包括线性回归、逻辑回归、神经网络,现在介绍另一种截然不同的算法------决策树模型(Decison Tree)。决策树是一种简单、高效、可解释的机器学习模型,可以用于分类问题和回归问题。
下面展现了一个猫分类的案例,输入耳朵形状、脸形状、胡须有无,输出1(猫)/0(非猫),图中包含10个样本。
下面展现了一棵决策树,学过数据结构的我们对树结构(递归结构)并不陌生。最顶端的结点称为根结点(root node),根结点左右分叉形成左子树和右子树,最底部的结点称为叶结点(leaf node)。在这棵决策树中,在根结点中选取耳朵形状作为特征,然后根据耳朵尖状或松散分叉成左右子树,左孩子结点选择脸型作为特征,右孩子结点选择胡须作为特征,再进一步分叉后第三层的叶子结点输出类别(是猫/不是猫)。
在构建了决策树后,对于新的输入,我们便可以预测其类别。例如一只耳尖、脸圆、有须的猫,从根结点出发进行一系列特征选择(左->左),将其分类为猫。
构建过程
如何从训练集中构建出一棵最优决策树?接下来让我们回答几个核心问题,探究决策树的构建过程:
Q1:如何选择节点特征进行分裂?
答:最大化纯度 (purity)
构建决策树第一个核心问题是如何在诸多特征中进行选择,让决策树进一步分裂。
看上图,当选择cat DNA作为特征时,可以发现左边都是猫,右边都是狗,即完全分门别类,左右边纯度都是100%。而在下面的三种特征选取中,我们发现按第一个特征分裂,左边猫占80%,右边非猫占80%;按第二个特征分裂,左边猫占4/7,右边非猫占2/3;按第三个特征分裂,左边猫占3/4,右边非猫占2/3,可以发现两个比例都是按第一个特征分裂最高,因此选取第一个特征最优。
如何表征纯度?引入熵 (entropy)和熵函数 的概念,定义p表示样本中正类(例如是猫)的占比,则熵值为H(p),其中H为熵函数,表达式为:
H ( p ) = − p l o g 2 ( p ) − ( 1 − p ) l o g 2 ( 1 − p ) H(p)=-plog_2(p)-(1-p)log_2(1-p) H(p)=−plog2(p)−(1−p)log2(1−p),
函数图像如下:可以看到函数图像关于p=0.5对称,p(0.5)=1,p(0)=p(1)=1
熵表征的是非纯度 (impurity),可以看到p=0.5时非纯度最高,p趋近0或1时表示样本趋于正/负类,纯度高,因此非纯度趋于0.
节点特征选择 的策略就是选择对应信息增益 (information gain)最大的特征。
何为信息增益?表达式为
I n f o r m a t i o n g a i n = H ( p r o o t ) − ( w l e f t H ( p l e f t ) + w r i g h t H ( p r i g h t ) ) Information~gain=H(p_{root})-(w_{left}H(p_{left})+w_{right}H(p_{right})) Information gain=H(proot)−(wleftH(pleft)+wrightH(pright))
参照下图示例理解信息增益,需要计算根结点熵值、左右结点熵值,如何计算左、右结点熵值乘权重后相加值与根结点熵值相减,计算出信息增益。
下图中,耳朵特征信息增益最大,因此选取其为根结点分裂特征。
由于树模型属于典型的递归结构,在确定了根结点后,剩下按一样的思路递归构造左子树和右子树即可,但仍有一个问题需要解决,何时结束递归?即何时结束决策树分类,完成决策树的构造?
Q2:何时停止节点分裂?
答:节点停止分裂的标准并不单一,以下是常见的几个基本标准:
①当一个结点纯度达到100%时
②当决策树分裂达到最大深度时
③当信息增益小于某个阈值时
④当某个结点样本数小于某个阈值时
Q3:上面案例中的特征都是二元特征,那该如何处理多元特征?
答:采用独热编码。
上面案例中的特征都是二元特征,也就是只有两个取值,这样构建的决策树属于二叉树。当某一个特征有k个有限取值时,可以创建k个二元特征(0/1取值),保证决策树是二叉树。使用独热编码的特征,既适用于决策树模型,也适用于逻辑回归、线性回归、神经网络模型。
Q4:上面案例中的特征都是离散型特征, 如何处理连续取值特征?
答:二分法(Binary Split),这是最常用的方法:
- 对连续特征的所有不同取值进行排序
- 取相邻两个值的中点作为可能的分割点
- 对每个可能的分割点,计算分割后的信息增益
- 选择信息增益最大的点作为该特征的分割点
例如,如果一个特征有值[1, 3, 4, 5, 7],则可能的分割点为1.5, 3.5, 4.5, 6。
Q5:上面案例中决策树用于处理二元分类问题,那如何处理多元分类问题呢?
答:
①有时会采用One-vs-Rest策略,为每个类别训练一个二分类决策树,然后选择置信度最高的类别作为最终预测结果。
②在多分类问题中,我们仍然使用信息增益作为分裂标准,但计算方式略有不同:
对于熵,公式变为: H = − Σ ( p i × l o g 2 ( p i ) ) H = -Σ(p_i \times log_2(p_i)) H=−Σ(pi×log2(pi)),其中pi是第i类的概率。
在多分类问题中,叶节点的类别通常由该节点中样本最多的类别决定。
为了理解多分类问题中决策树的构建过程,下面是某ai大模型生成的案例:
完整的决策树构建过程:
初始数据集:
颜色 形状 类别
红 圆 苹果
黄 长 香蕉
橙 圆 橙子
红 圆 苹果
黄 长 香蕉
红 圆 苹果
橙 圆 橙子
黄 圆 苹果
根节点的熵计算:
苹果: 4/8, 香蕉: 2/8, 橙子: 2/8
H = -(4/8 * log2(4/8) + 2/8 * log2(2/8) + 2/8 * log2(2/8)) ≈ 1.5
第一次分裂(使用"颜色"特征):
红色节点:3苹果
黄色节点:2香蕉,1苹果
橙色节点:2橙子
计算每个子节点的熵:
H_红 = 0
H_黄 ≈ 0.92
H_橙 = 0
信息增益 = 1.5 - (3/8 * 0 + 3/8 * 0.92 + 2/8 * 0) ≈ 1.15
继续分裂黄色节点(使用"形状"特征):
长形节点:2香蕉
圆形节点:1苹果
计算这次分裂的信息增益:
H_before ≈ 0.92
H_after = 0
信息增益 ≈ 0.92
最终的叶节点决定:
红色节点:3苹果 -> 类别为苹果
橙色节点:2橙子 -> 类别为橙子
黄色长形节点:2香蕉 -> 类别为香蕉
黄色圆形节点:1苹果 -> 类别为苹果
决策树图示:
根节点
(颜色)
/ | \
/ | \
/ | \
红色 黄色 橙色
(苹果) (形状) (橙子)
/ \
/ \
长形 圆形
(香蕉) (苹果)
这个决策树的解释:
首先看水果的颜色
如果是红色,分类为苹果
如果是橙色,分类为橙子
如果是黄色,再看形状:
如果是长形,分类为香蕉
如果是圆形,分类为苹果
Q6:上面的案例中决策树都是分类树,那如何处理回归问题呢?
答:构建回归树。
例如不再是对猫进行分类,而是预测动物的重量,那属于回归问题,需要构造回归决策树。在分类树中,节点特征选择的依据是获得最大化纯度或最大信息增益,而回归树不同,分裂特征选择依据是最小化方差。
将熵替换为方差,其他不变,示例如下:
分裂停止后,使用叶结点的平均值作为最终的预测数值。
2. 随机森林 Random Forest
单棵决策树存在一些明显的缺点:容易对训练数据过度拟合,特别是当树的深度较大时;对训练数据的小变化非常敏感。
为了增加算法健壮性,可以构建多棵决策树构建决策森林(tree ensemble),综合多棵决策树预测结果做出最终预测。
如下图,构建了一片由三棵决策树构成的森林,对于这只猫预测结果为猫:非猫=2:1,因此最终预测结果为猫。
随机森林(Random Forest)是一种集成学习方法,通过组合多个决策树来提高模型的性能和稳定性。它是由Leo Breiman和Adele Cutler在2001年提出的。随机森林结合了决策树的简单性和集成学习的优势,能够显著减少单棵决策树的缺点。
随机森林的基本原理
①有放回抽样构建B棵决策树:
从训练数据集中随机抽取多个子集(有放回抽样)。
每个子集用于训练一棵决策树。
②随机特征选择:
在每个节点分裂时,随机选择一部分(k<n)特征(推荐k= n \sqrt{n} n )进行评估。
这有助于减少特征之间的相关性,增加模型的多样性。
③组合决策树:
对于分类问题,最终的预测结果是所有决策树投票的结果。
对于回归问题,最终的预测结果是所有决策树预测值的平均值。
3. 决策树 vs 神经网络
决策树和随机森林:
①擅长处理表格型数据
②不适合处理非表格型数据,例如图像、音频、文本
③较简单的决策树具备可解释性
神经网络:
①适合处理表格型数据和非表格型数据(图像、音频、文本)
②运行效率比决策树模型慢
③支持迁移学习
4. 代码实现
手写版本
sklearn版本
从sklearn.ensemble导入随机森林分类器RandomForestClassifier或随机森林回归器RandomForestRegressor
python
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor
创建随机森林分类器,主要参数包括:
①n_estimators: int, default=100:森林中树的数量。
②criterion: {"gini", "entropy", "log_loss"}, default="gini":衡量分裂质量的功能,支持 gini(基尼不纯度),entropy(信息增益)和 log_loss(对数损失)。
③min_samples_split: int or float, default=2:分裂内部节点所需的最小样本数。如果是整数,则将该数值作为最小样本数。如果是浮点数,则将其视为比例。
④max_depth: int, default=None:决策树的最大深度。如果没有设置,节点将扩展,直到所有叶子都是纯的,或者直到所有叶子包含的样本少于 min_samples_split。
⑤andom_state: int, RandomState instance or None, default=None:随机数生成器的种子。
python
rf_classifier = RandomForestClassifier(n_estimators=100, criterion='entropy',random_state=42)
5. 案例
Iris数据集介绍
Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据样本,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。
实验代码
导入模块
python
import numpy as np
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score,classification_report
from sklearn.model_selection import train_test_split
获取数据集
python
#获取iris数据集
data=load_iris()
X=data.data
y=data.target
print(f'X.shape:{X.shape},y.shape:{y.shape}')
print(f'feature_names:{data.feature_names}')
print(f'target_names:{data.target_names}')
X.shape:(150, 4),y.shape:(150,)
feature_names:['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
target_names:['setosa' 'versicolor' 'virginica']
分割数据集
python
# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
训练随机森林模型并进行预测
python
#创建随机森林分类器
rf_classifier=RandomForestClassifier(n_estimators=100,criterion='entropy',random_state=23)
#拟合数据进行训练
rf_classifier.fit(X_train,y_train)
# 在测试集上进行预测
y_pred = rf_classifier.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")
Accuracy: 1.00
python
# 打印分类报告
print("Classification Report:")
print(classification_report(y_test, y_pred, target_names=data.target_names))
Classification Report:
precision recall f1-score support
setosa 1.00 1.00 1.00 19
versicolor 1.00 1.00 1.00 13
virginica 1.00 1.00 1.00 13
accuracy 1.00 45
macro avg 1.00 1.00 1.00 45
weighted avg 1.00 1.00 1.00 45