摘要 :前面讲的算法(线性回归、逻辑回归、SVM)本质都是"数学公式"------输入特征乘以权重再通过函数变换。决策树完全不同:它是一组嵌套的 if-else 规则。这种"提问→回答→再提问"的结构让决策树成为最可解释的机器学习算法------你不仅能知道预测结果,还能完整说出模型做决策的逻辑链条。这篇文章讲清楚决策树的分裂原理、如何防止过拟合、以及它的优缺点。
一、决策树的直观理解
一个生活场景
你想决定"今天要不要出门跑步"。你的决策过程可能是这样的:
天气晴朗吗?
├── 否 → 下雨或太冷 → 不出门 🛋️
└── 是 → 温度超过 25°C 吗?
├── 是 → 太热了 → 不出门 🛋️
└── 否 → 空气质量好吗?
├── 否 → 雾霾 → 不出门 🛋️
└── 是 → 出门跑步!🏃
这就是一棵决策树------通过一系列的是/否问题,最终到达一个决策(叶子节点)。
在机器学习中
对于分类任务,决策树就是自动从数据中学习出这样一套"提问规则":
决策树分类示意:
petal length ≤ 2.45?
╱ ╲
是 否
╱ ╲
setosa petal width ≤ 1.75?
╱ ╲
是 否
╱ ╲
versicolor virginica
每个内部节点是一个"特征测试",每个叶子节点是一个"分类结果"。从根到叶子的路径就是一条完整的 if-else 规则链。
二、如何构建一棵决策树?
构建决策树的核心问题是:在每一个节点,应该选择哪个特征、在什么阈值上分裂?
答案:选择能让分裂后的"纯度"提升最大的方式。
纯度的度量
"纯度"衡量一个集合中样本的"整齐程度"。如果所有样本都属于同一类,纯度为 1;如果各类各占一半,纯度最低。
分类:基尼系数(Gini Impurity)
Gini = 1 - Σ(pᵢ²)
其中 pᵢ 是第 i 类样本在集合中的比例
例子(二分类):
10 个样本全是 A: Gini = 1 - 1² = 0 ← 最纯
5 个 A + 5 个 B: Gini = 1 - (0.5² + 0.5²) = 0.5 ← 最不纯
Gini 越小 → 越纯 → 越好
分类:信息增益(Information Gain)
基于信息熵(Entropy)的概念:
Entropy = -Σ(pᵢ × log₂(pᵢ))
例子(二分类):
10 个样本全是 A: Entropy = -1×log₂(1) = 0 ← 最纯
5 个 A + 5 个 B: Entropy = -(0.5×log₂0.5 + 0.5×log₂0.5) = 1.0 ← 最不纯
信息增益 = 分裂前的熵 - 分裂后加权平均的熵
# 分裂前后纯度的变化
# 父节点:10 个样本 [5A, 5B] → Entropy = 1.0
# 按特征 X 分裂:
# 左子节点:7 个样本 [6A, 1B] → Entropy ≈ 0.59
# 右子节点:3 个样本 [2A, 1B] → Entropy ≈ 0.92
# 信息增益 = 1.0 - (7/10 × 0.59 + 3/10 × 0.92)
# = 1.0 - 0.689 = 0.311
# 信息增益越大 → 分裂越好
回归:均方误差(MSE)
对于回归任务,分裂目标是最小化子节点的 MSE:
MSE = Σ(yᵢ - ȳ)² / N
在每个候选分裂点,计算左右子节点的 MSE 加权和
选择让 MSE 最小的分裂方式
分裂过程
# 伪代码:构建决策树
def build_tree(data, max_depth=5):
# 如果所有样本属于同一类 → 停止
if all_same_class(data):
return LeafNode(class=当前类别)
# 如果达到最大深度 → 停止
if depth >= max_depth:
return LeafNode(class=多数类)
# 遍历所有特征和所有可能的分裂点
best_gain = -inf
for feature in features:
for threshold in possible_thresholds(feature):
gain = information_gain(data, feature, threshold)
if gain > best_gain:
best_gain = gain
best_feature = feature
best_threshold = threshold
# 按最佳分裂点分裂
left_data, right_data = split(data, best_feature, best_threshold)
# 递归构建子树
left_child = build_tree(left_data, depth+1)
right_child = build_tree(right_data, depth+1)
return DecisionNode(feature=best_feature,
threshold=best_threshold,
left=left_child,
right=right_child)
Gini vs Entropy
| 指标 | 计算速度 | 性能差异 | sklearn 默认 |
|---|---|---|---|
| 基尼系数 | 快(无 log 计算) | 几乎无差异 | ✅ criterion='gini' |
| 信息熵 | 略慢(有 log 计算) | 几乎无差异 | criterion='entropy' |
在实际应用中,两者通常表现非常接近。基尼系数是 sklearn 的默认选择。
三、剪枝:防止过拟合的关键
不剪枝的决策树会长成什么样?
如果不加限制,决策树会一直分裂直到每个叶子节点只包含一个样本------训练准确率 100%,但对新数据的泛化能力极差。
不剪枝的决策树(深度=10):
┌──────────────────────────────────────┐
│ 根节点 │
│ ├── 分裂 │
│ │ ├── 分裂 │
│ │ │ ├── 分裂 │
│ │ │ │ ├── 叶子(1个样本) │
│ │ │ │ └── 叶子(1个样本) │
│ │ │ └── ... │
│ │ └── ... │
│ └── ... │
└──────────────────────────────────────┘
→ 完美记住每个训练样本 → 严重过拟合 ❌
剪枝方法
预剪枝(Pre-pruning):在构建时提前停止
from sklearn.tree import DecisionTreeClassifier
# 通过限制参数来预剪枝
tree = DecisionTreeClassifier(
max_depth=5, # 最大深度(最常用)
min_samples_split=10, # 内部节点最少样本数
min_samples_leaf=5, # 叶子节点最少样本数
max_leaf_nodes=20, # 最大叶子节点数
)
| 参数 | 作用 | 值越大 | 值越小 |
|---|---|---|---|
max_depth |
树的最大深度 | 更简单 | 更复杂 |
min_samples_split |
分裂所需最少样本数 | 更早停止 | 更晚停止 |
min_samples_leaf |
叶子最少样本数 | 更平滑 | 更细致 |
max_leaf_nodes |
最多叶子数 | 更复杂 | 更简单 |
经验法则:
# 默认参数(几乎一定会过拟合)
tree_default = DecisionTreeClassifier() # ❌
# 合理的剪枝起点
tree_pruned = DecisionTreeClassifier(
max_depth=5,
min_samples_leaf=5,
min_samples_split=10
) # ✅
# 用交叉验证选参
from sklearn.model_selection import GridSearchCV
param_grid = {
'max_depth': [3, 5, 7, 10, None],
'min_samples_leaf': [1, 3, 5, 10],
}
后剪枝(Post-pruning):先长全再剪
成本复杂度剪枝(CCP, Cost-Complexity Pruning):sklearn 提供的后剪枝方法。
核心思想:对树的大小加一个惩罚项
剪枝后的损失 = 原始损失 + α × 叶子节点数
α 越大 → 叶子节点越少 → 树越简单
tree = DecisionTreeClassifier()
path = tree.cost_complexity_pruning_path(X_train, y_train)
alphas = path.ccp_alphas # 可选的 α 值
# 对每个 α 训练一棵树,用交叉验证选最优
best_tree = None
best_score = 0
for alpha in alphas:
tree = DecisionTreeClassifier(ccp_alpha=alpha)
tree.fit(X_train, y_train)
score = tree.score(X_val, y_val)
if score > best_score:
best_score = score
best_tree = tree
四、决策树的完整实战
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import classification_report
# ===== 1. 加载数据 =====
iris = load_iris()
X, y = iris.data, iris.target
feature_names = iris.feature_names
class_names = iris.target_names
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# ===== 2. 训练(默认参数 vs 剪枝) =====
tree_default = DecisionTreeClassifier(random_state=42)
tree_default.fit(X_train, y_train)
tree_pruned = DecisionTreeClassifier(
max_depth=3, min_samples_leaf=4, random_state=42
)
tree_pruned.fit(X_train, y_train)
# ===== 3. 对比 =====
for name, model in [("默认(未剪枝)", tree_default), ("剪枝后", tree_pruned)]:
train_acc = model.score(X_train, y_train)
test_acc = model.score(X_test, y_test)
depth = model.get_depth()
leaves = model.get_n_leaves()
print(f"{name}:")
print(f" 深度={depth}, 叶子数={leaves}")
print(f" 训练准确率={train_acc:.3f}, 测试准确率={test_acc:.3f}\n")
# ===== 4. 可视化决策树 =====
plt.figure(figsize=(16, 8))
plot_tree(
tree_pruned,
feature_names=feature_names,
class_names=class_names.tolist(),
filled=True,
rounded=True,
fontsize=10
)
plt.title("决策树可视化(鸢尾花分类)")
plt.show()
# ===== 5. 特征重要性 =====
importance = pd.DataFrame({
'feature': feature_names,
'importance': tree_pruned.feature_importances_
}).sort_values('importance', ascending=False)
print("特征重要性:")
print(importance)
输出示例:
默认(未剪枝):
深度=6, 叶子数=9
训练准确率=1.000, 测试准确率=0.933
剪枝后(max_depth=3, min_samples_leaf=4):
深度=3, 叶子数=4
训练准确率=0.958, 测试准确率=0.967 ✅
特征重要性:
feature importance
petal length (cm) 0.546
petal width (cm) 0.442
sepal length (cm) 0.012
sepal width (cm) 0.000
五、决策树的可解释性
提取决策规则
决策树的每一条从根到叶子的路径都是一条可读的规则:
from sklearn.tree import export_text
rules = export_text(tree_pruned, feature_names=feature_names)
print(rules)
输出:
|--- petal length (cm) <= 2.45
| |--- class: setosa
|--- petal length (cm) > 2.45
| |--- petal width (cm) <= 1.75
| | |--- petal length (cm) <= 4.95
| | | |--- class: versicolor
| | |--- petal length (cm) > 4.95
| | | |--- class: virginica
| |--- petal width (cm) > 1.75
| | |--- class: virginica
这就是规则的可读性------你可以把这条规则写进业务文档,给非技术人员看,甚至作为监管合规的依据。
特征重要性
决策树可以输出特征重要性------每个特征在树中被用来分裂时,带来的纯度提升的加权和。
# 归一化的特征重要性
importance = tree_pruned.feature_importances_
# 重要性之和为 1
# 越重要的特征被用来分裂的次数越多、位置越靠近根节点
特性重要性解读:
petal length 最重要(0.546):出现在根节点,一次分裂就把 setosa 完全分开
petal width 其次(0.442):在第二层分裂,把 versicolor 和 virginica 分开
sepal 特征几乎不重要(<0.02):在最佳分裂中几乎用不到
六、决策树的优缺点
优点
| 优点 | 说明 |
|---|---|
| ✅ 可解释性最强 | 可以输出完整的 if-else 规则 |
| ✅ 不需要特征缩放 | 决策树基于阈值,不是距离------标准化不影响结果 |
| ✅ 处理混合特征 | 数值型+类别型特征可以直接使用 |
| ✅ 捕捉非线性关系 | 天然支持非线性决策边界 |
| ✅ 特征重要性 | 内置的特征选择机制 |
| ✅ 对缺失值不敏感 | 可以处理缺失数据 |
缺点
| 缺点 | 说明 |
|---|---|
| ❌ 容易过拟合 | 不剪枝的决策树几乎一定过拟合 |
| ❌ 对数据微小变化敏感 | 换几个样本,整棵树可能完全不同 |
| ❌ 贪婪搜索 | 每步只看当前最优分裂,不能保证全局最优 |
| ❌ 偏向多值特征 | 取值多的特征更容易被选中分裂 |
| ❌ 决策边界是轴对齐的 | 只能水平/垂直分裂,不能斜着分裂 |
决策树 vs 之前算法
# 对比:四种分类器的决策边界特点
# 逻辑回归 / 线性 SVM → 一条直线切分
# RBF 核 SVM → 任意形状的平滑曲线
# KNN → 由数据密度决定的分段边界
# 决策树 → 轴对齐的矩形切分(水平/垂直线)
七、总结
| 概念 | 一句话理解 |
|---|---|
| 决策树 | 一系列 if-else 问题的嵌套------像医生问诊一样逐步诊断 |
| 基尼系数/信息熵 | 衡量数据"纯度"的指标------越纯越好 |
| 信息增益 | 分裂后纯度提升了多少------用来选择最佳分裂特征 |
| 剪枝 | 限制树的生长以防止过拟合------深度、叶子大小、α 参数 |
| 特征重要性 | 每个特征在分裂中的贡献------告诉你哪些特征真正有用 |
核心三句话:
- 决策树是最可解释的 ML 算法------你能写出完整的决策规则,向任何人解释
- 决策树必须剪枝------不加限制的决策树一定会过拟合
- 单棵决策树能力有限------但它是更强大算法(随机森林、XGBoost)的基础构件
下一篇文章:集成学习------把多棵"弱"决策树组合起来,得到比任何单棵树都强的模型。这是经典机器学习在最前沿的延伸。