机器学习入门(三)——决策树(Decision Tree)

决策树 (Decision Tree) 学习笔记

核心思想 :通过一系列的规则对数据进行分类或回归。
本质:从根节点开始,利用特征对数据进行递归划分,直到数据"足够纯"或无法再分。


一、 基本概念

1.1 什么是决策树?

**决策树的本质:**通过不断筛选"最好的特征"来切分数据,直到数据分无可分(或者够纯了)为止;

那么什么叫"最好"?就是切分后,数据"纯度"提升最大;后续我们将讲到三种衡量"纯度/不纯度"的指标。
决策树的建立过程:

  • 特征选择:选取具有较强分类能力的特征
  • 决策树生成:根据选择的特征生成决策树
  • 过拟合处理:采用模型剪枝的方法环节过拟合,包括预剪枝和后剪枝

1.2 树的结构

  • 根节点 (Root Node):包含样本全集。
  • 内部节点 (Internal Node):对应一个特征测试(Condition)。
  • 叶节点 (Leaf Node):对应一个决策结果(类别或数值)。

1.3 核心目标:纯度 (Purity)

决策树生成的关键在于:如何选择最优特征进行分裂?

  • 目标:每次分裂后,让子节点的"纯度"比父节点更高(即混乱度降低)。
  • 不纯度 (Impurity):衡量样本集合混合程度的指标。

二、 核心度量指标 (数学原理)

2.1 信息熵 (Entropy)------ID3算法核心

描述信息的混乱程度。值越大,越混乱;值越小,越纯。

  • 公式
    H(D)=−∑k=1Kpklog⁡2pkH(D) = - \sum_{k=1}^{K} p_k \log_2 p_kH(D)=−k=1∑Kpklog2pk
    (其中 pkp_kpk 是第 kkk 类样本的占比)
  • 极值
    • 全是一个类别 →H(D)=0\rightarrow H(D) = 0→H(D)=0 (最纯)
    • 均匀分布 →H(D)\rightarrow H(D)→H(D) 最大 (最乱)
  • 目标:我们要找一个特征,切分后让**信息增益(Information Gain)**最大(也就是熵下降得最快)

2.2 信息增益 (Information Gain) - ID3 算法

分裂前的信息熵减去分裂后的加权平均信息熵。表示"得知该特征后,不确定性减少了多少"。

  • 公式
    Gain(D,A)=H(D)−∑∣Dv∣∣D∣H(Dv)Gain(D, A) = H(D) - \sum \frac{|D_v|}{|D|} H(D_v)Gain(D,A)=H(D)−∑∣D∣∣Dv∣H(Dv)
  • 缺点 :偏向于选择取值较多的特征(如身份证号、日期),容易过拟合。

2.3 信息增益率 (Gain Ratio) - C4.5 算法

在信息增益的基础上,除以特征本身的"固有值" (Intrinsic Value, IV),作为惩罚项。

  • 公式
    Gain_Ratio(D,A)=Gain(D,A)IV(A)Gain\_Ratio(D, A) = \frac{Gain(D, A)}{IV(A)}Gain_Ratio(D,A)=IV(A)Gain(D,A)
  • 改进:在信息增益的基础上除以一个"惩罚项"(IV,Intrinsic Value),修正ID3 对多值特征的偏好。

2.4 基尼系数 (Gini Index) - CART 算法

描述从数据集中随机抽取两个样本,其类别不一致的概率。

  • 公式
    Gini(p)=∑k=1Kpk(1−pk)=1−∑k=1Kpk2Gini(p) = \sum_{k=1}^{K} p_k (1 - p_k) = 1 - \sum_{k=1}^{K} p_k^2Gini(p)=k=1∑Kpk(1−pk)=1−k=1∑Kpk2
  • 特点
    • Gini 指数越小,纯度越高。
    • 计算速度快 (不需要算 log⁡\loglog 对数),Sklearn 默认使用此指标。

三、 三大经典算法对比

特性 ID3 C4.5 CART (主流)
全称 Iterative Dichotomiser 3 Classifier 4.5 Classification and Regression Tree
支持任务 仅分类 仅分类 分类 & 回归
划分标准 信息熵增益 信息增益率 Gini 系数 (分类) / MSE (回归)
树结构 多叉树 多叉树 二叉树 (Binary Tree)
连续值 不支持 支持 (二分法) 支持 (二分法)
缺失值 不支持 支持 支持
缺点 偏好多值特征 对数计算慢 容易过拟合(需剪枝)

注意CART 是目前最常用的决策树算法,也是 XGBoost、RandomForest 的基石。它强制构建二叉树(即每个节点只有 Yes/No 两个分支)。


四、 关键技术细节

4.1 连续值的处理 (Continuous Features)

  • 策略:二分法 (Bi-partition)。
  • 步骤
    1. 将特征值从小到大排序。
    2. 取相邻两数的平均值作为候选切分点。
    3. 计算每个切分点的收益(Gain/Gini)。
    4. 选择最优切分点,将数据分为 >t>t>t 和 ≤t\le t≤t 两部分。

4.2 回归树 (Regression Tree)

  • 决策树不仅能做分类,也能做回归(预测房价、气温)。
  • 预测值:叶子节点内所有样本均值。
  • 损失函数:均方误差 (MSE)。即寻找切分点,使得切分后两边数据的 MSE 之和最小。

4.3 剪枝 (Pruning) - 防止过拟合

决策树很容易把训练集"背"下来(过拟合),导致在新数据上表现差。

  • 预剪枝 (Pre-Pruning) :在生长过程中提前停止。
    • 限制最大深度 (max_depth)
    • 限制叶节点最少样本数 (min_samples_leaf)
    • 限制分裂所需最少样本数 (min_samples_split)
    • 优点 :训练快;缺点:可能导致欠拟合。
  • 后剪枝 (Post-Pruning) :先长成完全树,再自底向上剪掉对泛化性能无贡献的分支。
    • 优点 :保留更多信息,泛化好;缺点:计算代价大。

五、 面试常问

Q1: 简述一下 ID3、C4.5 和 CART 的区别?

它们的区别主要在于划分标准和树的结构:

  1. ID3 使用信息增益,倾向于选择取值较多的特征(有偏好),且只能处理离散变量,构建的是多叉树。
  2. C4.5改进了ID3算法,使用信息增益率,解决了特征取值偏好问题,且能通过划分阈值来处理连续变量。
  3. CART既能做分类,也能做回归。分类时使用Gini系数,回归时使用均方误差(MSE)。最重要的是,GART构建的是二叉树,计算速度比用对数的熵要快,是目前的主流方法。

Q2: 决策树如何处理连续值(比如身高、房价)?

决策树(如 C4.5 和 CART)会采用二分法。

  1. 先把连续特征从小到大排序。
  2. 尝试取相邻两个数的中间值作为切分点(Threshold)。
  3. 计算每一个切分点的收益(Gini 或 增益)。
  4. 选择收益最大的那个点,把数据切成">x"和"<=x"两部分。

Q3: 决策树的熵和 Gini 系数有什么区别?为什么 CART 用 Gini?

  • 从趋势上讲,熵和 Gini 的曲线形状非常接近,效果差别不大。
  • 区别在于计算速度:Gini求平方比熵求对数计算速度快,所以在数据量大时,CART倾向于Gini

Q4: 决策树不需要进行特征缩放(归一化/标准化),为什么?

  • 因为决策树是基于概率和信息论的 ,它只关心特征的排序和分布,而不关心具体的数值大小。

六、 总结与实战建议

优点

  1. 可解释性强:生成的规则(If-Then)非常直观,非技术人员也能看懂(白盒模型)。
  2. 数据预处理简单不需要归一化/标准化(对数值大小不敏感,只看排序)。
  3. 能处理非线性数据。

缺点

  1. 容易过拟合:不剪枝的树通常效果很差。
  2. 不稳定性:数据的一点小变化可能导致树结构剧烈变化。
  3. 局部最优:贪心算法(只顾当前最好),不一定是全局最优解。

Sklearn 代码速查

python 复制代码
# 导包
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score,confusion_matrix,classification_report
import matplotlib.pyplot as plt
# 数据预处理
data = pd.read_csv('./data/train.csv',sep=',')
# 提取特征
print(data.columns)
print(data.shape)
print(data.isna().sum())
x = data[['Sex', 'Age','Pclass','Fare']]
y = data['Survived']
print(x.head(5))
print(x.shape)
print(y.shape)
# 特征工程
# 缺失值处理
x.loc[:,'Age'] =  x['Age'].fillna(x['Age'].mean())
# one-hot  字符转数值
new_x = pd.get_dummies(x,drop_first=True)
print(new_x.shape)
# 数据划分
X_train, X_test, y_train, y_test = train_test_split(new_x,y,test_size=0.2 ,shuffle=True,random_state=33)
# 数据规范化
ss = StandardScaler()
x_train = ss.fit_transform(X_train)
x_test = ss.transform(X_test)
# 模型准备
#model = DecisionTreeClassifier(criterion='entropy') 熵------信息增益
model = DecisionTreeClassifier(criterion='gini')
# 模型训练
model.fit(x_train,y_train)
# 模型预测
pred = model.predict(x_test)
# 模型评估
print(f'准确率:{accuracy_score(y_test, pred)}')
print(f'精确率:{precision_score(y_test, pred)}')
print(f'召回率:{recall_score(y_test, pred)}')
print(f'f1分数:{f1_score(y_test, pred)}')
print(f'混淆矩阵:\n{confusion_matrix(y_test,pred)}')
print(f'分类报告:\n{classification_report(y_test,pred)}')

# 画图
plt.figure(figsize=[80,60])
plot_tree(
    model,
    filled=True,
)
plt.show()
相关推荐
GAOJ_K2 小时前
滚珠花键的无预压、间隙调整与过盈配合“场景适配型”
人工智能·科技·机器人·自动化·制造
ai_xiaogui2 小时前
【开源探索】Panelai:重新定义AI服务器管理面板,助力团队私有化算力部署与模型运维
人工智能·开源·私有化部署·docker容器化·panelai·ai服务器管理面板·comfyui集群管理
源于花海2 小时前
迁移学习的前沿知识(AI与人类经验结合、传递式、终身、在线、强化、可解释性等)
人工智能·机器学习·迁移学习·迁移学习前沿
机 _ 长2 小时前
YOLO26 改进 | 基于特征蒸馏 | 知识蒸馏 (Response & Feature-based Distillation)
python·深度学习·机器学习
king of code porter2 小时前
百宝箱企业版搭建智能体应用-平台概述
人工智能·大模型·智能体
愚公搬代码2 小时前
【愚公系列】《AI短视频创作一本通》004-AI短视频的准备工作(创作AI短视频的基本流程)
人工智能·音视频
物联网软硬件开发-轨物科技2 小时前
【轨物洞见】告别“被动维修”!预测性运维如何重塑老旧电站的资产价值?
运维·人工智能
电商API_180079052472 小时前
第三方淘宝商品详情 API 全维度调用指南:从技术对接到生产落地
java·大数据·前端·数据库·人工智能·网络爬虫
梁辰兴3 小时前
百亿美元赌注变数,AI军备竞赛迎来转折点?
人工智能·ai·大模型·openai·英伟达·梁辰兴·ai军备竞赛