决策树 (Decision Tree) 学习笔记
核心思想 :通过一系列的规则对数据进行分类或回归。
本质:从根节点开始,利用特征对数据进行递归划分,直到数据"足够纯"或无法再分。
一、 基本概念
1.1 什么是决策树?
**决策树的本质:**通过不断筛选"最好的特征"来切分数据,直到数据分无可分(或者够纯了)为止;
那么什么叫"最好"?就是切分后,数据"纯度"提升最大;后续我们将讲到三种衡量"纯度/不纯度"的指标。
决策树的建立过程:
- 特征选择:选取具有较强分类能力的特征
- 决策树生成:根据选择的特征生成决策树
- 过拟合处理:采用模型剪枝的方法环节过拟合,包括预剪枝和后剪枝
1.2 树的结构
- 根节点 (Root Node):包含样本全集。
- 内部节点 (Internal Node):对应一个特征测试(Condition)。
- 叶节点 (Leaf Node):对应一个决策结果(类别或数值)。
1.3 核心目标:纯度 (Purity)
决策树生成的关键在于:如何选择最优特征进行分裂?
- 目标:每次分裂后,让子节点的"纯度"比父节点更高(即混乱度降低)。
- 不纯度 (Impurity):衡量样本集合混合程度的指标。
二、 核心度量指标 (数学原理)
2.1 信息熵 (Entropy)------ID3算法核心
描述信息的混乱程度。值越大,越混乱;值越小,越纯。
- 公式 :
H(D)=−∑k=1Kpklog2pkH(D) = - \sum_{k=1}^{K} p_k \log_2 p_kH(D)=−k=1∑Kpklog2pk
(其中 pkp_kpk 是第 kkk 类样本的占比) - 极值 :
- 全是一个类别 →H(D)=0\rightarrow H(D) = 0→H(D)=0 (最纯)
- 均匀分布 →H(D)\rightarrow H(D)→H(D) 最大 (最乱)
- 目标:我们要找一个特征,切分后让**信息增益(Information Gain)**最大(也就是熵下降得最快)
2.2 信息增益 (Information Gain) - ID3 算法
分裂前的信息熵减去分裂后的加权平均信息熵。表示"得知该特征后,不确定性减少了多少"。
- 公式 :
Gain(D,A)=H(D)−∑∣Dv∣∣D∣H(Dv)Gain(D, A) = H(D) - \sum \frac{|D_v|}{|D|} H(D_v)Gain(D,A)=H(D)−∑∣D∣∣Dv∣H(Dv) - 缺点 :偏向于选择取值较多的特征(如身份证号、日期),容易过拟合。
2.3 信息增益率 (Gain Ratio) - C4.5 算法
在信息增益的基础上,除以特征本身的"固有值" (Intrinsic Value, IV),作为惩罚项。
- 公式 :
Gain_Ratio(D,A)=Gain(D,A)IV(A)Gain\_Ratio(D, A) = \frac{Gain(D, A)}{IV(A)}Gain_Ratio(D,A)=IV(A)Gain(D,A) - 改进:在信息增益的基础上除以一个"惩罚项"(IV,Intrinsic Value),修正ID3 对多值特征的偏好。
2.4 基尼系数 (Gini Index) - CART 算法
描述从数据集中随机抽取两个样本,其类别不一致的概率。
- 公式 :
Gini(p)=∑k=1Kpk(1−pk)=1−∑k=1Kpk2Gini(p) = \sum_{k=1}^{K} p_k (1 - p_k) = 1 - \sum_{k=1}^{K} p_k^2Gini(p)=k=1∑Kpk(1−pk)=1−k=1∑Kpk2 - 特点 :
- Gini 指数越小,纯度越高。
- 计算速度快 (不需要算 log\loglog 对数),Sklearn 默认使用此指标。
三、 三大经典算法对比
| 特性 | ID3 | C4.5 | CART (主流) |
|---|---|---|---|
| 全称 | Iterative Dichotomiser 3 | Classifier 4.5 | Classification and Regression Tree |
| 支持任务 | 仅分类 | 仅分类 | 分类 & 回归 |
| 划分标准 | 信息熵增益 | 信息增益率 | Gini 系数 (分类) / MSE (回归) |
| 树结构 | 多叉树 | 多叉树 | 二叉树 (Binary Tree) |
| 连续值 | 不支持 | 支持 (二分法) | 支持 (二分法) |
| 缺失值 | 不支持 | 支持 | 支持 |
| 缺点 | 偏好多值特征 | 对数计算慢 | 容易过拟合(需剪枝) |
注意 :CART 是目前最常用的决策树算法,也是 XGBoost、RandomForest 的基石。它强制构建二叉树(即每个节点只有 Yes/No 两个分支)。
四、 关键技术细节
4.1 连续值的处理 (Continuous Features)
- 策略:二分法 (Bi-partition)。
- 步骤 :
- 将特征值从小到大排序。
- 取相邻两数的平均值作为候选切分点。
- 计算每个切分点的收益(Gain/Gini)。
- 选择最优切分点,将数据分为 >t>t>t 和 ≤t\le t≤t 两部分。
4.2 回归树 (Regression Tree)
- 决策树不仅能做分类,也能做回归(预测房价、气温)。
- 预测值:叶子节点内所有样本均值。
- 损失函数:均方误差 (MSE)。即寻找切分点,使得切分后两边数据的 MSE 之和最小。
4.3 剪枝 (Pruning) - 防止过拟合
决策树很容易把训练集"背"下来(过拟合),导致在新数据上表现差。
- 预剪枝 (Pre-Pruning) :在生长过程中提前停止。
- 限制最大深度 (
max_depth) - 限制叶节点最少样本数 (
min_samples_leaf) - 限制分裂所需最少样本数 (
min_samples_split) - 优点 :训练快;缺点:可能导致欠拟合。
- 限制最大深度 (
- 后剪枝 (Post-Pruning) :先长成完全树,再自底向上剪掉对泛化性能无贡献的分支。
- 优点 :保留更多信息,泛化好;缺点:计算代价大。
五、 面试常问
Q1: 简述一下 ID3、C4.5 和 CART 的区别?
它们的区别主要在于划分标准和树的结构:
- ID3 使用信息增益,倾向于选择取值较多的特征(有偏好),且只能处理离散变量,构建的是多叉树。
- C4.5改进了ID3算法,使用信息增益率,解决了特征取值偏好问题,且能通过划分阈值来处理连续变量。
- CART既能做分类,也能做回归。分类时使用Gini系数,回归时使用均方误差(MSE)。最重要的是,GART构建的是二叉树,计算速度比用对数的熵要快,是目前的主流方法。
Q2: 决策树如何处理连续值(比如身高、房价)?
决策树(如 C4.5 和 CART)会采用二分法。
- 先把连续特征从小到大排序。
- 尝试取相邻两个数的中间值作为切分点(Threshold)。
- 计算每一个切分点的收益(Gini 或 增益)。
- 选择收益最大的那个点,把数据切成">x"和"<=x"两部分。
Q3: 决策树的熵和 Gini 系数有什么区别?为什么 CART 用 Gini?
- 从趋势上讲,熵和 Gini 的曲线形状非常接近,效果差别不大。
- 区别在于计算速度:Gini求平方比熵求对数计算速度快,所以在数据量大时,CART倾向于Gini
Q4: 决策树不需要进行特征缩放(归一化/标准化),为什么?
- 因为决策树是基于概率和信息论的 ,它只关心特征的排序和分布,而不关心具体的数值大小。
六、 总结与实战建议
优点
- 可解释性强:生成的规则(If-Then)非常直观,非技术人员也能看懂(白盒模型)。
- 数据预处理简单 :不需要归一化/标准化(对数值大小不敏感,只看排序)。
- 能处理非线性数据。
缺点
- 容易过拟合:不剪枝的树通常效果很差。
- 不稳定性:数据的一点小变化可能导致树结构剧烈变化。
- 局部最优:贪心算法(只顾当前最好),不一定是全局最优解。
Sklearn 代码速查
python
# 导包
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score,confusion_matrix,classification_report
import matplotlib.pyplot as plt
# 数据预处理
data = pd.read_csv('./data/train.csv',sep=',')
# 提取特征
print(data.columns)
print(data.shape)
print(data.isna().sum())
x = data[['Sex', 'Age','Pclass','Fare']]
y = data['Survived']
print(x.head(5))
print(x.shape)
print(y.shape)
# 特征工程
# 缺失值处理
x.loc[:,'Age'] = x['Age'].fillna(x['Age'].mean())
# one-hot 字符转数值
new_x = pd.get_dummies(x,drop_first=True)
print(new_x.shape)
# 数据划分
X_train, X_test, y_train, y_test = train_test_split(new_x,y,test_size=0.2 ,shuffle=True,random_state=33)
# 数据规范化
ss = StandardScaler()
x_train = ss.fit_transform(X_train)
x_test = ss.transform(X_test)
# 模型准备
#model = DecisionTreeClassifier(criterion='entropy') 熵------信息增益
model = DecisionTreeClassifier(criterion='gini')
# 模型训练
model.fit(x_train,y_train)
# 模型预测
pred = model.predict(x_test)
# 模型评估
print(f'准确率:{accuracy_score(y_test, pred)}')
print(f'精确率:{precision_score(y_test, pred)}')
print(f'召回率:{recall_score(y_test, pred)}')
print(f'f1分数:{f1_score(y_test, pred)}')
print(f'混淆矩阵:\n{confusion_matrix(y_test,pred)}')
print(f'分类报告:\n{classification_report(y_test,pred)}')
# 画图
plt.figure(figsize=[80,60])
plot_tree(
model,
filled=True,
)
plt.show()