决策树:可解释的 if-else 规则

摘要 :前面讲的算法(线性回归、逻辑回归、SVM)本质都是"数学公式"------输入特征乘以权重再通过函数变换。决策树完全不同:它是一组嵌套的 if-else 规则。这种"提问→回答→再提问"的结构让决策树成为最可解释的机器学习算法------你不仅能知道预测结果,还能完整说出模型做决策的逻辑链条。这篇文章讲清楚决策树的分裂原理、如何防止过拟合、以及它的优缺点。


一、决策树的直观理解

一个生活场景

你想决定"今天要不要出门跑步"。你的决策过程可能是这样的:

复制代码
天气晴朗吗?
├── 否 → 下雨或太冷 → 不出门 🛋️
└── 是 → 温度超过 25°C 吗?
    ├── 是 → 太热了 → 不出门 🛋️
    └── 否 → 空气质量好吗?
        ├── 否 → 雾霾 → 不出门 🛋️
        └── 是 → 出门跑步!🏃

这就是一棵决策树------通过一系列的是/否问题,最终到达一个决策(叶子节点)

在机器学习中

对于分类任务,决策树就是自动从数据中学习出这样一套"提问规则":

复制代码
决策树分类示意:

           petal length ≤ 2.45?
           ╱              ╲
         是                 否
        ╱                    ╲
    setosa            petal width ≤ 1.75?
                      ╱              ╲
                    是                 否
                   ╱                    ╲
              versicolor            virginica

每个内部节点是一个"特征测试",每个叶子节点是一个"分类结果"。从根到叶子的路径就是一条完整的 if-else 规则链。


二、如何构建一棵决策树?

构建决策树的核心问题是:在每一个节点,应该选择哪个特征、在什么阈值上分裂?

答案:选择能让分裂后的"纯度"提升最大的方式。

纯度的度量

"纯度"衡量一个集合中样本的"整齐程度"。如果所有样本都属于同一类,纯度为 1;如果各类各占一半,纯度最低。

分类:基尼系数(Gini Impurity)
复制代码
Gini = 1 - Σ(pᵢ²)

其中 pᵢ 是第 i 类样本在集合中的比例

例子(二分类):
  10 个样本全是 A:    Gini = 1 - 1² = 0          ← 最纯
  5 个 A + 5 个 B:     Gini = 1 - (0.5² + 0.5²) = 0.5  ← 最不纯
  
  Gini 越小 → 越纯 → 越好
分类:信息增益(Information Gain)

基于信息熵(Entropy)的概念:

复制代码
Entropy = -Σ(pᵢ × log₂(pᵢ))

例子(二分类):
  10 个样本全是 A:    Entropy = -1×log₂(1) = 0          ← 最纯
  5 个 A + 5 个 B:     Entropy = -(0.5×log₂0.5 + 0.5×log₂0.5) = 1.0  ← 最不纯

信息增益 = 分裂前的熵 - 分裂后加权平均的熵

复制代码
# 分裂前后纯度的变化
# 父节点:10 个样本 [5A, 5B] → Entropy = 1.0

# 按特征 X 分裂:
#   左子节点:7 个样本 [6A, 1B] → Entropy ≈ 0.59
#   右子节点:3 个样本 [2A, 1B] → Entropy ≈ 0.92

# 信息增益 = 1.0 - (7/10 × 0.59 + 3/10 × 0.92)
#          = 1.0 - 0.689 = 0.311

# 信息增益越大 → 分裂越好
回归:均方误差(MSE)

对于回归任务,分裂目标是最小化子节点的 MSE

复制代码
MSE = Σ(yᵢ - ȳ)² / N

在每个候选分裂点,计算左右子节点的 MSE 加权和
选择让 MSE 最小的分裂方式

分裂过程

复制代码
# 伪代码:构建决策树
def build_tree(data, max_depth=5):
    # 如果所有样本属于同一类 → 停止
    if all_same_class(data):
        return LeafNode(class=当前类别)
    
    # 如果达到最大深度 → 停止
    if depth >= max_depth:
        return LeafNode(class=多数类)
    
    # 遍历所有特征和所有可能的分裂点
    best_gain = -inf
    for feature in features:
        for threshold in possible_thresholds(feature):
            gain = information_gain(data, feature, threshold)
            if gain > best_gain:
                best_gain = gain
                best_feature = feature
                best_threshold = threshold
    
    # 按最佳分裂点分裂
    left_data, right_data = split(data, best_feature, best_threshold)
    
    # 递归构建子树
    left_child = build_tree(left_data, depth+1)
    right_child = build_tree(right_data, depth+1)
    
    return DecisionNode(feature=best_feature, 
                        threshold=best_threshold,
                        left=left_child, 
                        right=right_child)

Gini vs Entropy

指标 计算速度 性能差异 sklearn 默认
基尼系数 快(无 log 计算) 几乎无差异 criterion='gini'
信息熵 略慢(有 log 计算) 几乎无差异 criterion='entropy'

在实际应用中,两者通常表现非常接近。基尼系数是 sklearn 的默认选择。


三、剪枝:防止过拟合的关键

不剪枝的决策树会长成什么样?

如果不加限制,决策树会一直分裂直到每个叶子节点只包含一个样本------训练准确率 100%,但对新数据的泛化能力极差。

复制代码
不剪枝的决策树(深度=10):
  ┌──────────────────────────────────────┐
  │ 根节点                                │
  │   ├── 分裂                           │
  │   │   ├── 分裂                       │
  │   │   │   ├── 分裂                   │
  │   │   │   │   ├── 叶子(1个样本)     │
  │   │   │   │   └── 叶子(1个样本)     │
  │   │   │   └── ...                    │
  │   │   └── ...                        │
  │   └── ...                            │
  └──────────────────────────────────────┘
  → 完美记住每个训练样本 → 严重过拟合 ❌

剪枝方法

预剪枝(Pre-pruning):在构建时提前停止
复制代码
from sklearn.tree import DecisionTreeClassifier

# 通过限制参数来预剪枝
tree = DecisionTreeClassifier(
    max_depth=5,              # 最大深度(最常用)
    min_samples_split=10,     # 内部节点最少样本数
    min_samples_leaf=5,       # 叶子节点最少样本数
    max_leaf_nodes=20,        # 最大叶子节点数
)
参数 作用 值越大 值越小
max_depth 树的最大深度 更简单 更复杂
min_samples_split 分裂所需最少样本数 更早停止 更晚停止
min_samples_leaf 叶子最少样本数 更平滑 更细致
max_leaf_nodes 最多叶子数 更复杂 更简单

经验法则

复制代码
# 默认参数(几乎一定会过拟合)
tree_default = DecisionTreeClassifier()  # ❌

# 合理的剪枝起点
tree_pruned = DecisionTreeClassifier(
    max_depth=5,
    min_samples_leaf=5,
    min_samples_split=10
)  # ✅

# 用交叉验证选参
from sklearn.model_selection import GridSearchCV

param_grid = {
    'max_depth': [3, 5, 7, 10, None],
    'min_samples_leaf': [1, 3, 5, 10],
}
后剪枝(Post-pruning):先长全再剪

成本复杂度剪枝(CCP, Cost-Complexity Pruning):sklearn 提供的后剪枝方法。

核心思想:对树的大小加一个惩罚项

复制代码
剪枝后的损失 = 原始损失 + α × 叶子节点数

α 越大 → 叶子节点越少 → 树越简单

tree = DecisionTreeClassifier()
path = tree.cost_complexity_pruning_path(X_train, y_train)
alphas = path.ccp_alphas  # 可选的 α 值

# 对每个 α 训练一棵树,用交叉验证选最优
best_tree = None
best_score = 0
for alpha in alphas:
    tree = DecisionTreeClassifier(ccp_alpha=alpha)
    tree.fit(X_train, y_train)
    score = tree.score(X_val, y_val)
    if score > best_score:
        best_score = score
        best_tree = tree

四、决策树的完整实战

复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import classification_report

# ===== 1. 加载数据 =====
iris = load_iris()
X, y = iris.data, iris.target
feature_names = iris.feature_names
class_names = iris.target_names

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# ===== 2. 训练(默认参数 vs 剪枝) =====
tree_default = DecisionTreeClassifier(random_state=42)
tree_default.fit(X_train, y_train)

tree_pruned = DecisionTreeClassifier(
    max_depth=3, min_samples_leaf=4, random_state=42
)
tree_pruned.fit(X_train, y_train)

# ===== 3. 对比 =====
for name, model in [("默认(未剪枝)", tree_default), ("剪枝后", tree_pruned)]:
    train_acc = model.score(X_train, y_train)
    test_acc = model.score(X_test, y_test)
    depth = model.get_depth()
    leaves = model.get_n_leaves()
    print(f"{name}:")
    print(f"  深度={depth}, 叶子数={leaves}")
    print(f"  训练准确率={train_acc:.3f}, 测试准确率={test_acc:.3f}\n")

# ===== 4. 可视化决策树 =====
plt.figure(figsize=(16, 8))
plot_tree(
    tree_pruned,
    feature_names=feature_names,
    class_names=class_names.tolist(),
    filled=True,
    rounded=True,
    fontsize=10
)
plt.title("决策树可视化(鸢尾花分类)")
plt.show()

# ===== 5. 特征重要性 =====
importance = pd.DataFrame({
    'feature': feature_names,
    'importance': tree_pruned.feature_importances_
}).sort_values('importance', ascending=False)

print("特征重要性:")
print(importance)

输出示例

复制代码
默认(未剪枝):
  深度=6, 叶子数=9
  训练准确率=1.000, 测试准确率=0.933

剪枝后(max_depth=3, min_samples_leaf=4):
  深度=3, 叶子数=4
  训练准确率=0.958, 测试准确率=0.967 ✅

特征重要性:
          feature  importance
  petal length (cm)    0.546
   petal width (cm)    0.442
  sepal length (cm)    0.012
   sepal width (cm)    0.000

五、决策树的可解释性

提取决策规则

决策树的每一条从根到叶子的路径都是一条可读的规则:

复制代码
from sklearn.tree import export_text

rules = export_text(tree_pruned, feature_names=feature_names)
print(rules)

输出

复制代码
|--- petal length (cm) <= 2.45
|   |--- class: setosa
|--- petal length (cm) > 2.45
|   |--- petal width (cm) <= 1.75
|   |   |--- petal length (cm) <= 4.95
|   |   |   |--- class: versicolor
|   |   |--- petal length (cm) > 4.95
|   |   |   |--- class: virginica
|   |--- petal width (cm) > 1.75
|   |   |--- class: virginica

这就是规则的可读性------你可以把这条规则写进业务文档,给非技术人员看,甚至作为监管合规的依据。

特征重要性

决策树可以输出特征重要性------每个特征在树中被用来分裂时,带来的纯度提升的加权和

复制代码
# 归一化的特征重要性
importance = tree_pruned.feature_importances_

# 重要性之和为 1
# 越重要的特征被用来分裂的次数越多、位置越靠近根节点

特性重要性解读

复制代码
petal length 最重要(0.546):出现在根节点,一次分裂就把 setosa 完全分开
petal width 其次(0.442):在第二层分裂,把 versicolor 和 virginica 分开
sepal 特征几乎不重要(<0.02):在最佳分裂中几乎用不到

六、决策树的优缺点

优点

优点 说明
可解释性最强 可以输出完整的 if-else 规则
不需要特征缩放 决策树基于阈值,不是距离------标准化不影响结果
处理混合特征 数值型+类别型特征可以直接使用
捕捉非线性关系 天然支持非线性决策边界
特征重要性 内置的特征选择机制
对缺失值不敏感 可以处理缺失数据

缺点

缺点 说明
容易过拟合 不剪枝的决策树几乎一定过拟合
对数据微小变化敏感 换几个样本,整棵树可能完全不同
贪婪搜索 每步只看当前最优分裂,不能保证全局最优
偏向多值特征 取值多的特征更容易被选中分裂
决策边界是轴对齐的 只能水平/垂直分裂,不能斜着分裂

决策树 vs 之前算法

复制代码
# 对比:四种分类器的决策边界特点

# 逻辑回归 / 线性 SVM    → 一条直线切分
# RBF 核 SVM              → 任意形状的平滑曲线
# KNN                     → 由数据密度决定的分段边界
# 决策树                  → 轴对齐的矩形切分(水平/垂直线)

七、总结

概念 一句话理解
决策树 一系列 if-else 问题的嵌套------像医生问诊一样逐步诊断
基尼系数/信息熵 衡量数据"纯度"的指标------越纯越好
信息增益 分裂后纯度提升了多少------用来选择最佳分裂特征
剪枝 限制树的生长以防止过拟合------深度、叶子大小、α 参数
特征重要性 每个特征在分裂中的贡献------告诉你哪些特征真正有用

核心三句话

  1. 决策树是最可解释的 ML 算法------你能写出完整的决策规则,向任何人解释
  2. 决策树必须剪枝------不加限制的决策树一定会过拟合
  3. 单棵决策树能力有限------但它是更强大算法(随机森林、XGBoost)的基础构件

下一篇文章:集成学习------把多棵"弱"决策树组合起来,得到比任何单棵树都强的模型。这是经典机器学习在最前沿的延伸。


相关推荐
小糯米6011 小时前
JS 数组
数据结构·算法·排序算法
宝贝儿好2 小时前
【LLM】第一章:知识体系框架概览
人工智能·深度学习·机器学习·自然语言处理
拳里剑气2 小时前
C++算法:链表
c++·算法·链表
凌波粒2 小时前
LeetCode--90.子集II(回溯算法)
数据结构·算法·leetcode
苏州邦恩精密2 小时前
GOM三维扫描在制造中的真实价值:让“修模”从经验动作变成数据动作
人工智能·科技·机器学习·3d·自动化·制造
旖-旎2 小时前
《LeetCode 417 太平洋大西洋水流问题 FloodFill DFS 解法》
c++·算法·深度优先·力扣·floodfill
凌波粒2 小时前
LeetCode--46.全排列(回溯算法)
数据结构·算法·leetcode
2zcode2 小时前
项目文档:基于MATLAB语音信号变声算法设计与实现
算法·matlab·语音识别
指令集梦境2 小时前
图解:单调栈算法模板(Java语言)
java·开发语言·算法