机器学习算法系列专栏:决策树算法(初学者)

(一)决策树概念

决策树通过对训练样本的学习,并建立分类规则然后依据分类规则,对新样本数据进行分类预测,属于有监督学习

(二)决策树核心

所有数据从根节点一步一步落到叶子节点

  • 根节点:第一个节点
  • 非叶子节点:中间节点
  • 叶子节点:最终结果节点

(三)常见问题

1.哪个节点作为根节点?哪些节点作为中间节点?哪些节点作为叶子节点 ?

  • 根节点:由算法根据全局最优特征自动选择

  • 中间节点:由算法在满足分裂条件时递归生成

  • 叶子节点:由算法在满足停止条件时自动标记,存储最终预测结果

关键点:决策树的节点角色完全由数据和算法规则决定,无需人工干预

2.节点如何分裂 ?

在当前节点上,选择一个最优特征和一个最优切分点,把样本分到左右(或多路)子节点中,使得子节点的纯度最高(不纯度最低)

3.节点分裂标准的依据 ?

使分裂后子节点的"不纯度下降"最大(或等价地,使子节点纯度提高最多)

不同决策树算法只是用不同的数学指标 来量化"不纯度下降",而具体用哪个指标取决于你选的是 ID3、C4.5 还是 CART

(四)决策树分类标准

1.ID3算法 ― 信息增益 (Information Gain)

衡量标准:

熵值:表示随机变量不确定性的度量,或者说是物体内部的混乱程度

熵值计算公式:

A集合:[1,1,1,1,1,1,1,1,2,2]

B集合:[0,1,2,3,4,5,6,7,8,9]

A集合熵值:-2/10*log2(2/10)-8/10*log2(8/10)= 0.722

B集合熵值:-1/10*log2(1/10)*10= 3.322

显然B的熵值更大,更加混乱

2.C4.5算法 ― 信息增益比 (Gain Ratio)

衡量标准:信息增益率

3.CART决策树― Gini 指数下降 (分类树) 或 MSE 下降 (回归树)

(五)决策树剪枝

(5.1)剪枝原因

防止过拟合

(5.2)剪枝方法

预剪枝和后剪枝

(5.3)预剪枝策略

  1. 限制树的深度
  2. 限制叶子节点的个数以及叶子节点的样本数
  3. 基尼系数

(六)决策树的回归模型

(6.1)回归树概念

解决回归问题的决策树模型即为回归树

(6.2)回归树特点

必须是二叉树

(6.3)回归树实现步骤

(1)计算最优切分点

因为只有一个变量,所以切分变量必然是x

可以考虑如下9个切分点:[1.5 2.5 3.5 4.5 5.5 6.5 7.5 8.5 9.5]

原因:实际上考虑两个变量间任意一个位置为切分点均可

切分点1.5的计算

当s=1.5时,将数据分为两个部分:

第一部分:(1,5.56)

第二部分:(2,5.7)、(3,5.91)、(4,6.4)...(10,9.05)

核心:

1.节点切分依据?

使分裂后左右子节点的"均方误差(MSE)下降"最大

2.如何预测?

预测时,把待预测样本沿树一路分到某个叶子节点,用该叶子节点内训练样本的目标值均值作为输出

(2)计算损失

C1=5.56

C2=1/9(5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05)=7.5

Loss =(5.56-5.56)^2+(5.7-7.5)^2+(5.91-7.5)^2+...+(9.05-7.5)^2=0+15.72=15.72

(3)同理计算其他分割点的损失

容易看出,当s=6.5时,loss=1.93最小,所以第一个划分点s=6.5

(4)对于小于6.5部分

切分点1.5的计算

当s=1.5时,将数据分为两个部分,第一部分:(1,5.56)第二部分:(2,5.7)、(3,5.91)、(4,6.4),(5,6.8)、(6,7.05)

C =5.56

C,=1/5(5.7+5.91+6.4+6.8+7.05)=6.37Loss =0+(5.7-6.37)^2+(5.91-6.37)^2+..+(7.05-6.37)^2=0+13087=13087

(5)因此得到:

容易看出:

  • 当s=3.5时,loss=0.2771最小,所以第一个划分点s=3.5
  • 当s=8.5时,ioss=0.021最小,所以第二个划分点s=8.5

(6)假设只分裂我们计算的这几次:

那么分段函数为:

  • 当x<=3.5时,1/3(5.56+5.7+5.91)=5.72
  • 当3.5<x<=6.5时,1/3(6.4+6.8+7.05)=6.75
  • 当6.5<x<=8.5时,1/2(8.9+8.7)=8.8
  • 当8.5<x时,1/2(9+9.05)=9.025

最终得到分段函数!

(7)对于预测来说:

特征x必然位于其中某个区间内,所以,即可得到回归的结果,比如说:

如果x=11,那么对应的回归值为9.025

  • 当x<=3.5时,1/3(5.56+5.7+5.91)=5.72
  • 当3.5<x<=6.5时,1/3(6.4+6.8+7.05)=6.75
  • 当6.5<x<=8.5时,1/2(8.9+8.7)=8.8
  • 当8.5<x时,1/2(9+9.05)=9.025

(8)决策树的构造:

(六)具体代码实现案例

复制代码
import pandas as pd
from sklearn import tree

data = pd.read_csv("data(1).csv")

x = data.iloc[:, :-1]
y = data.iloc[:, -1]

reg = tree.DecisionTreeRegressor()
reg = reg.fit(x, y)

y_pr = reg.predict(x)
print(y_pr)
score = reg.score(x, y)
print(score)
相关推荐
哈__3 分钟前
PromptPilot搭配Doubao-seed-1.6:定制你需要的AI提示prompt
大数据·人工智能·promptpilot
Godspeed Zhao16 分钟前
自动驾驶中的传感器技术23——Camera(14)
人工智能·机器学习·自动驾驶·isp算法
weixin_4640780739 分钟前
机器学习sklearn:编码、哑变量、二值化和分段
人工智能·机器学习·sklearn
java1234_小锋39 分钟前
【NLP舆情分析】基于python微博舆情分析可视化系统(flask+pandas+echarts) 视频教程 - 词云图-微博评论词云图实现
python·自然语言处理·flask·nlp·nlp舆情分析
CS创新实验室1 小时前
《机器学习数学基础》补充资料:泰勒定理与余项
人工智能·机器学习·概率论·泰勒定理·泰勒展开·余项
codists1 小时前
《AI-Assisted Programming》读后感
python
watersink1 小时前
最小VL视觉语言模型OmniVision-968M
人工智能·语言模型·自然语言处理
是乐谷1 小时前
阿里招AI产品运营
人工智能·程序人生·面试·职场和发展·产品运营·求职招聘
爱欲无极1 小时前
基于Flask的微博话题多标签情感分析系统设计
后端·python·flask
星空下的曙光1 小时前
React 虚拟 DOM Diff 算法详解,Vue、Snabbdom 与 React 算法对比
vue.js·算法·react.js