前言
决策树是分类与回归问题中常用的方法之一。其实不仅是机器学习领域,在每天的日常决策中,我们都在使用决策树。流程图实际上就是决策树的可视化表示。
一、基本概念
-
决策树原理
-
通过一系列 逻辑规则(特征分割条件) 构建树形结构,用于分类或回归。
-
每个内部节点表示一个特征判断,分支表示判断结果,叶节点表示最终类别或数值。
-
-
关键术语
-
熵(Entropy) :衡量系统不确定性
\( S = -\sum_{i=1}^N p_i \log_2 p_i \)
\( p_i \) 为第 \( i \) 类样本的比例。
-
信息增益(Information Gain) :分割后熵的减少量
\( IG(Q) = S_0 - \sum_{i=1}^q \frac{N_i}{N} S_i \)
-
基尼系数(Gini Index) :衡量数据不纯度
\( G = 1 - \sum_{k} (p_k)^2 \)
-
-
分割质量指标对比
指标 公式 特点 信息增益(熵) \( IG = S_0 - \sum \frac{N_i}{N} S_i \) 对类别分布敏感 基尼系数 \( G = 1 - \sum p_k^2 \) 计算更高效,与熵效果相似 错分率 \( E = 1 - \max p_k \) 不推荐使用,对概率变化不敏感
二、决策树构建算法
-
核心思想
-
贪婪算法 :每一步选择 信息增益最大(或基尼系数最小)的特征进行分割。
-
递归分裂:重复分割直到满足停止条件(如节点纯度达标、深度限制等)。
-
-
常用算法
算法 特点 ID3 使用信息增益,仅支持分类,无法处理连续特征和缺失值 C4.5 改进 ID3,支持连续特征、缺失值处理,引入信息增益率防止过拟合 CART 使用基尼系数(分类)或均方误差(回归),支持分类和回归,生成二叉树 -
停止条件
-
节点样本数小于阈值
-
节点纯度达到要求(如熵/基尼系数接近 0)
-
树达到预设最大深度
-
三、分类与回归应用
-
分类树
-
目标:预测离散类别标签。
-
质量指标:熵或基尼系数。
-
示例代码(sklearn):
pythonfrom sklearn.tree import DecisionTreeClassifier clf = DecisionTreeClassifier(criterion='gini', max_depth=3) clf.fit(X_train, y_train)
-
-
回归树
-
目标:预测连续数值。
-
质量指标 :均方误差(MSE)
\( D = \frac{1}{\ell} \sum_{i=1}^\ell (y_i - \bar{y})^2 \)
\( \bar{y} \) 为叶节点样本均值。
-
示例代码(sklearn):
pythonfrom sklearn.tree import DecisionTreeRegressor reg = DecisionTreeRegressor(max_depth=3) reg.fit(X_train, y_train)
-
四、过拟合与剪枝
-
过拟合表现
- 决策树过深,叶节点样本过少,模型在训练集上完美拟合但在测试集上效果差。
-
解决方法
-
预剪枝(Pre-pruning):提前限制模型复杂度
-
max_depth
:树的最大深度 -
min_samples_leaf
:叶节点最少样本数 -
max_features
:分割时考虑的最大特征数
-
-
后剪枝(Post-pruning):先构建完整树,再自底向上合并节点(如 CCP 方法)。
-
五、决策树的优缺点
优点 | 缺点 |
---|---|
可解释性强,规则可视化 | 对噪声敏感,易过拟合 |
支持数值和类别特征 | 边界为轴平行超平面,可能不如其他模型灵活 |
训练和预测速度快 | 数据微小变化可能导致树结构剧变(不稳定) |
无需特征标准化 | 无法外推(只能预测训练集特征范围内的值) |
六、实战注意事项
-
参数调优
-
使用交叉验证选择最佳
max_depth
、min_samples_leaf
等参数。 -
示例代码:
pythonfrom sklearn.model_selection import GridSearchCV params = {'max_depth': [3, 5, 7], 'min_samples_leaf': [1, 5, 10]} grid = GridSearchCV(DecisionTreeClassifier(), params, cv=5) grid.fit(X, y)
-
-
可视化决策树
-
使用
sklearn.tree.plot_tree
或第三方库(如 Graphviz)生成树结构图。 -
示例代码:
pythonfrom sklearn.tree import plot_tree plt.figure(figsize=(20, 10)) plot_tree(clf, filled=True, feature_names=X.columns) plt.show()
-
七、应用
-
分类问题:客户信用评估、疾病诊断、垃圾邮件识别。
-
回归问题:房价预测、销量趋势分析。
-
特征重要性分析:通过节点分裂次数或信息增益量评估特征重要性。