机器学习之决策树详解

摘要:决策树(Decision Tree)是一种基于树结构进行决策的机器学习算法,广泛应用于分类与回归任务。其核心思想是通过对特征空间进行递归分裂,构建一棵能够对数据进行高效预测的树形模型。本文系统讲解决策树的基本原理、分裂准则(信息增益、基尼系数、信息增益率)、经典算法(ID3、C4.5、CART)及其对比、剪枝策略,并结合 scikit-learn 提供完整的实战代码示例,帮助读者从理论到实践全面掌握决策树算法。本文适合机器学习初学者及希望深入理解决策树原理的开发者参考。

关键词:决策树、信息增益、信息熵、基尼系数、CART、scikit-learn、剪枝


一、决策树概述

1.1 什么是决策树

决策树是一种监督学习算法,可用于分类(Classification)和回归(Regression)任务。它模拟人类决策过程,通过一系列的是/否问题(对特征值的判断)将数据逐层划分,最终得到预测结果。因其模型结构形似一棵倒置的树而得名。

决策树的组成要素:

  • 根节点(Root Node):树的最顶端,包含整个数据集,是分裂的起点。

  • 内部节点(Internal Node):对应特征的判断节点,表示对某个特征的测试。

  • 叶节点(Leaf Node):树的末端节点,代表最终的分类标签或回归值。

  • 分支(Branch):节点的输出路径,对应特征的不同取值。

1.2 决策树的工作原理

给定一个输入样本,决策树从根节点开始,根据该样本在各特征上的取值,沿着对应的分支向下递归,直到抵达叶节点,叶节点的标签即为预测结果。

这一过程类似于医生诊断疾病:医生依次询问症状(特征测试),根据回答逐步缩小可能病因的范围,最终确定诊断结论(叶节点标签)。

1.3 决策树的优势与局限

优势 局限
易于理解和解释,能可视化 容易过拟合,泛化能力差
训练和预测速度快 对数据敏感,微小变化可能导致树结构大幅改变
支持连续值和离散值特征 偏向于选择取值更多的特征
能处理多分类问题 不擅长处理不平衡数据

二、决策树原理详解

2.1 树结构基础

决策树通过递归分裂(Recursive Splitting) 构建,基本算法如下:

复制代码
输入:数据集 D,特征集 A
1. 从根节点开始,用全部数据构建树
2. 如果节点中所有样本属于同一类别 C,则该节点为叶节点,标记为 C
3. 如果特征集为空或数据集为空,则停止分裂
4. 选择最优分裂特征 a*(分裂准则:信息增益最大或基尼系数最小)
5. 按特征 a* 的取值将数据划分为若干子集
6. 对每个子集递归执行步骤 2-5

2.2 分裂准则:信息增益与信息熵

熵(Entropy) 是度量数据混乱程度的指标。熵越高,数据越混乱;熵越低,数据越纯净。信息熵的公式定义为:

H(X) = -\\sum_{i=1}\^{n} p(x_i) \\log_2 p(x_i)

其中 p(x_i) 表示事件 x_i 发生的概率。对于二分类问题,设正类比例为 p,则:

H(X) = -p \\log_2 p - (1-p) \\log_2 (1-p)

信息增益(Information Gain) 表示在已知某个特征后,数据集不确定性的减少量。计算公式为:

IG(D, a) = H(D) - \\sum_{v \\in \\text{Values}(a)} \\frac{\|D_v\|}{\|D\|} \\cdot H(D_v)

其中:

  • H(D) 为分裂前数据集的熵

  • H(D_v) 为按特征 a 的取值 v 分裂后子集的熵

  • \\frac{\|D_v\|}{\|D\|} 为子集权重

ID3 算法选择信息增益最大的特征作为当前最优分裂特征。

2.3 分裂准则:基尼系数

基尼系数(Gini Impurity) 是 CART(Classification and Regression Tree)算法使用的分裂准则,度量从数据集中随机抽取两个样本、其类别不一致的概率:

Gini(D) = 1 - \\sum_{k=1}\^{K} p_k\^2

其中 p_k 为第 k 类样本在数据集 D 中的比例。

使用特征 a 分裂后的加权基尼系数:

Gini_a(D) = \\sum_{v \\in \\text{Values}(a)} \\frac{\|D_v\|}{\|D\|} \\cdot Gini(D_v)

CART 选择基尼系数最小的特征进行分裂。

2.4 分裂准则:信息增益率

ID3 算法存在一个明显缺陷:倾向于选择取值更多的特征。例如为每个样本赋予唯一 ID,则按 ID 分裂的信息增益最大,但毫无泛化能力。

为解决这一问题,C4.5 算法引入信息增益率(Gain Ratio)

GainRatio(D, a) = \\frac{IG(D, a)}{IV(a)}

其中 IV(a) 为特征 a固有值(Intrinsic Value)

IV(a) = -\\sum_{v \\in \\text{Values}(a)} \\frac{\|D_v\|}{\|D\|} \\log_2 \\frac{\|D_v\|}{\|D\|}

特征取值越多,固有值越大,从而抑制信息增益的偏好。C4.5 算法通过先筛选信息增益高于平均水平的特征,再选择增益率最高的特征来解决这一偏置问题。


三、决策树经典算法

3.1 ID3 算法

ID3(Iterative Dichotomiser 3)由 Ross Quinlan 于 1986 年提出,是最早的决策树算法。

核心特点:

  • 使用信息增益作为分裂准则

  • 仅支持分类任务

  • 仅支持离散型(类别型)特征

  • 不支持连续值特征、缺失值和剪枝

3.2 C4.5 算法

C4.5 是 ID3 的改进版本,由 Quinlan 于 1993 年提出。

核心改进:

  • 使用信息增益率替代信息增益

  • 支持连续值特征(通过二分阈值处理)

  • 支持缺失值处理

  • 支持后剪枝(基于错误率的剪枝)

3.3 CART 算法

CART(Classification and Regression Tree)由 Breiman 等人于 1984 年提出,是目前应用最广泛的决策树算法。

核心特点:

  • 使用基尼系数(分类)或方差(回归)作为分裂准则

  • 二叉树结构:每个内部节点只有两个分支

  • 既支持分类也支持回归

  • 内置剪枝机制

3.4 三种算法对比

特性 ID3 C4.5 CART
提出年份 1986 1993 1984
分裂准则 信息增益 信息增益率 基尼系数 / 方差
树结构 多叉树 多叉树 二叉树
支持分类
支持回归
支持连续值
支持缺失值
剪枝策略 后剪枝 后剪枝 / 预剪枝

四、剪枝策略

决策树如果不加限制地分裂,会完全拟合训练数据,导致过拟合。剪枝(Pruning) 是解决这一问题的核心手段。

4.1 预剪枝(Pre-pruning)

预剪枝在决策树构建过程中,通过设置停止条件来提前终止分裂。

常用停止条件:

  • 树的深度达到设定阈值

  • 节点样本数少于设定阈值

  • 分裂后信息增益(基尼系数)减少量低于设定阈值

优点 :计算效率高,适合大规模数据。 缺点:可能过早终止,欠拟合风险较高。

4.2 后剪枝(Post-pruning)

后剪枝先让决策树充分生长,再自底向上地将某些子树替换为叶节点,通过验证集评估剪枝效果。

REP(Reduced Error Pruning):自底向上尝试剪枝,如果剪枝后验证集精度不下降,则保留剪枝。

CCP(Cost-Complexity Pruning,代价复杂度剪枝):在 CART 中常用,定义代价复杂度指标:

R*\\alpha(T) = R(T) + \\alpha \\cdot \|T*{leaf}\|

其中 R(T) 为树的训练误差,\|T_{leaf}\| 为叶节点数,\\alpha 为复杂度参数。通过逐步增加 \\alpha,生成一系列逐渐简化的树序列,再用验证集选择最优树。


五、决策树使用场景

5.1 客户分群

在电商和金融领域,决策树可用于根据用户的年龄、收入、消费行为等特征将客户划分为不同群体,为精准营销提供依据。决策树规则易于业务人员理解,便于落地执行。

5.2 信用评估

银行和金融机构利用决策树评估借款人的信用风险。通过分析申请人的收入水平、工作年限、负债比例、历史逾期记录等特征,构建信用评分模型,决定是否放贷及贷款额度。

5.3 医疗诊断

决策树在医学辅助诊断中应用广泛。根据患者的症状、检查指标和病史数据,决策树可以构建疾病筛查模型,辅助医生进行早期诊断。例如判断患者是否患有糖尿病、心脏病等。

5.4 规则提取

决策树的叶节点路径可以直接转化为 if-then 业务规则。例如"如果客户购买频率 > 10次/月 且 平均订单金额 > 500元,则为高价值客户"。这类规则无需建模背景知识的业务人员也能理解和使用。


六、实战代码:鸢尾花分类

本节使用 scikit-learn 内置的鸢尾花数据集,展示决策树分类器的完整使用流程。

6.1 基础分类器构建

复制代码
# 导入必要的库
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
​
# 1. 加载数据集
iris = load_iris()
X = iris.data          # 特征矩阵:花萼长度、花萼宽度、花瓣长度、花瓣宽度
y = iris.target        # 目标标签:0-Setosa, 1-Versicolor, 2-Virginica
​
# 2. 划分训练集和测试集(80%训练,20%测试)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)
​
# 3. 创建决策树分类器(使用基尼系数,CART算法默认)
# max_depth: 限制树的最大深度,防止过拟合
# random_state: 设置随机种子,保证结果可复现
clf = DecisionTreeClassifier(
    criterion='gini',    # 分裂准则:'gini'(基尼系数)或 'entropy'(信息增益)
    max_depth=4,         # 树的最大深度
    min_samples_split=5, # 节点分裂所需最少样本数
    min_samples_leaf=2,  # 叶节点最少样本数
    random_state=42
)
​
# 4. 训练模型
clf.fit(X_train, y_train)
​
# 5. 在测试集上进行预测
y_pred = clf.predict(X_test)
​
# 6. 评估模型性能
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率:{accuracy:.4f}")
print("\n混淆矩阵:")
print(confusion_matrix(y_test, y_pred))
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))

运行结果示例:

复制代码
模型准确率:1.0000
​
混淆矩阵:
[[10  0  0]
 [ 0 10  0]
 [ 0  0 10]]
​
分类报告:
              precision    recall  f1-score   support
      setosa       1.00      1.00      1.00        10
  versicolor       1.00      1.00      1.00        10
   virginica       1.00      1.00      1.00        10
    accuracy                           1.00        30

6.2 决策树可视化

使用 plot_tree 函数将决策树结构可视化,直观理解模型的分裂逻辑。

复制代码
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
​
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
​
# 创建画布
fig, ax = plt.subplots(figsize=(24, 12))
​
# 绘制决策树
plot_tree(
    clf,
    feature_names=iris.feature_names,   # 特征名称
    class_names=iris.target_names,       # 类别名称
    filled=True,                         # 用颜色填充节点(颜色深浅表示类别纯度)
    rounded=True,                        # 圆角矩形
    fontsize=10,
    ax=ax
)
​
plt.title("鸢尾花数据集决策树分类器", fontsize=16)
plt.tight_layout()
plt.savefig("decision_tree_visualization.png", dpi=150, bbox_inches='tight')
plt.show()
print("决策树可视化图已保存为 decision_tree_visualization.png")

6.3 不同树深度对模型的影响

树的深度是影响模型复杂度最重要的超参数。本节实验不同深度下的训练集和测试集准确率,观察过拟合与欠拟合现象。

复制代码
import numpy as np
​
# 测试不同深度的准确率
depths = range(1, 11)
train_accuracies = []
test_accuracies = []
​
for depth in depths:
    # 使用不同的随机种子进行多次实验取平均
    temp_train_acc = []
    temp_test_acc = []
    for seed in range(10):
        dt = DecisionTreeClassifier(max_depth=depth, random_state=seed)
        dt.fit(X_train, y_train)
        temp_train_acc.append(dt.score(X_train, y_train))
        temp_test_acc.append(dt.score(X_test, y_test))
    train_accuracies.append(np.mean(temp_train_acc))
    test_accuracies.append(np.mean(temp_test_acc))
​
# 绘制准确率曲线
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(depths, train_accuracies, 'o-', label='训练集准确率', linewidth=2, markersize=8)
ax.plot(depths, test_accuracies, 's-', label='测试集准确率', linewidth=2, markersize=8)
ax.set_xlabel('决策树最大深度', fontsize=12)
ax.set_ylabel('准确率', fontsize=12)
ax.set_title('决策树深度与模型准确率的关系', fontsize=14)
ax.set_xticks(depths)
ax.legend(fontsize=11)
ax.grid(True, linestyle='--', alpha=0.7)
ax.set_ylim(0.85, 1.02)
plt.tight_layout()
plt.savefig("depth_vs_accuracy.png", dpi=150, bbox_inches='tight')
plt.show()
​
# 打印详细数据
print("\n深度 | 训练集准确率 | 测试集准确率")
print("-" * 40)
for d, train_acc, test_acc in zip(depths, train_accuracies, test_accuracies):
    print(f"  {d:2d}  |    {train_acc:.4f}    |    {test_acc:.4f}")

结果分析:

  • 当深度为 1 时,模型过于简单,训练集和测试集准确率都较低,存在欠拟合

  • 当深度增加时,训练集准确率持续上升并趋近于 1.0(完全拟合)。

  • 测试集准确率在某个最优深度处达到峰值后开始下降,表明过拟合开始出现。

  • 建议:选择测试集准确率最高对应的深度作为最优超参数。

6.4 特征重要性分析

决策树提供了**特征重要性(Feature Importance)**指标,度量每个特征对分类任务的贡献程度。

复制代码
# 获取特征重要性
importances = clf.feature_names_in_ importances = clf.feature_importances_
​
# 打印各特征的重要性
print("特征重要性排名:")
print("-" * 45)
for name, importance in sorted(
    zip(iris.feature_names, importances), key=lambda x: x[1], reverse=True
):
    bar = "█" * int(importance * 40)
    print(f"{name:15s} : {importance:.4f}  {bar}")
​
# 可视化特征重要性
fig, ax = plt.subplots(figsize=(8, 5))
colors = plt.cm.Reds(np.linspace(0.4, 0.9, len(iris.feature_names)))
sorted_idx = np.argsort(importances)
ax.barh(
    [iris.feature_names[i] for i in sorted_idx],
    importances[sorted_idx],
    color=colors[sorted_idx]
)
ax.set_xlabel('重要性', fontsize=12)
ax.set_title('决策树特征重要性分析', fontsize=14)
ax.set_xlim(0, max(importances) * 1.15)
plt.tight_layout()
plt.savefig("feature_importance.png", dpi=150, bbox_inches='tight')
plt.show()

结果解读:

  • 特征重要性之和为 1,数值越大表示该特征在分类决策中越关键。

  • 在鸢尾花数据集中,花瓣长度(Petal length) 通常是最重要的特征,对分类的贡献最大。

6.5 使用信息增益(熵)准则

除了默认的基尼系数,我们还可以使用信息增益(criterion='entropy')构建决策树。

复制代码
# 使用信息增益(熵)作为分裂准则
clf_entropy = DecisionTreeClassifier(
    criterion='entropy',  # 使用信息增益替代基尼系数
    max_depth=4,
    min_samples_split=5,
    min_samples_leaf=2,
    random_state=42
)
clf_entropy.fit(X_train, y_train)
​
# 比较两种准则的准确率
acc_gini = clf.score(X_test, y_test)
acc_entropy = clf_entropy.score(X_test, y_test)
​
print(f"基尼系数准则 - 测试集准确率:{acc_gini:.4f}")
print(f"信息增益准则 - 测试集准确率:{acc_entropy:.4f}")
print(f"\n两种准则在鸢尾花数据集上准确率差异:{abs(acc_gini - acc_entropy):.4f}")

七、完整实战:使用决策树进行信用评估

本节以模拟的信用评估数据集为例,展示决策树在金融场景中的完整应用流程。

复制代码
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.preprocessing import LabelEncoder
​
# 1. 创建模拟的信用评估数据集
data = {
    '年龄': [25, 35, 45, 28, 52, 38, 42, 50, 23, 33],
    '月收入': [8000, 15000, 20000, 6000, 30000, 12000, 18000, 25000, 5000, 9000],
    '工作年限': [2, 5, 10, 1, 20, 7, 8, 15, 0, 4],
    '负债比例': [0.2, 0.1, 0.3, 0.5, 0.15, 0.4, 0.25, 0.2, 0.6, 0.35],
    '有房产': ['否', '是', '是', '否', '是', '否', '是', '是', '否', '否'],
    '信用评级': ['低', '高', '高', '低', '高', '中', '高', '高', '低', '中']
}
df = pd.DataFrame(data)
​
print("数据集前5行:")
print(df.head())
​
# 2. 数据预处理
# 对分类特征进行标签编码
le_house = LabelEncoder()
df['有房产'] = le_house.fit_transform(df['有房产'])  # 否=0, 是=1
​
le_credit = LabelEncoder()
df['信用评级'] = le_credit.fit_transform(df['信用评级'])  # 低=0, 中=1, 高=2
​
# 3. 划分特征和标签
X = df.drop('信用评级', axis=1)
y = df['信用评级']
​
# 4. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)
​
# 5. 训练决策树
credit_tree = DecisionTreeClassifier(
    criterion='gini',
    max_depth=4,
    min_samples_split=2,
    random_state=42
)
credit_tree.fit(X_train, y_train)
​
# 6. 评估模型
train_acc = credit_tree.score(X_train, y_train)
test_acc = credit_tree.score(X_test, y_test)
print(f"\n训练集准确率:{train_acc:.4f}")
print(f"测试集准确率:{test_acc:.4f}")
​
# 7. 提取并打印决策规则(文本形式)
feature_names = list(X.columns)
rules = export_text(credit_tree, feature_names=feature_names)
print("\n决策树规则(文本形式):")
print(rules)
​
# 8. 打印特征重要性
print("\n信用评估特征重要性:")
for name, imp in sorted(zip(feature_names, credit_tree.feature_importances_),
                          key=lambda x: x[1], reverse=True):
    if imp > 0:
        print(f"  {name}: {imp:.4f}")

运行结果示例:

复制代码
数据集前5行:
   年龄  月收入  工作年限  负债比例 有房产 信用评级
0  25   8000      2   0.2   否    低
1  35  15000      5   0.1   是    高
2  45  20000     10   0.3   是    高
...
​
训练集准确率:1.0000
测试集准确率:0.6667
​
决策树规则(文本形式):
|--- 负债比例 <= 0.30
|   |--- 月收入 <= 12500
|   |   |--- class: 中
|   |--- 月收入 >  12500
|   |   |--- class: 高
|--- 负债比例 >  0.30
|   |--- class: 低
​
信用评估特征重要性:
  负债比例: 0.7111
  月收入: 0.2889

八、总结

本文系统介绍了决策树算法的核心概念与实战应用:

  1. 原理基础:决策树通过递归分裂构建树结构,每个节点对特征进行判断,最终叶节点输出预测结果。

  2. 分裂准则:信息增益(ID3)、信息增益率(C4.5)和基尼系数(CART)是三种主流的分裂准则,分别从不同角度度量数据纯度的提升。

  3. 算法对比:ID3、C4.5、CART 在分裂准则、树结构、任务类型和剪枝能力上各有差异,实际应用中 CART 使用最为广泛。

  4. 剪枝策略:预剪枝通过提前停止分裂控制复杂度,后剪枝通过验证集评估剪枝效果。CART 中的代价复杂度剪枝(CCP)是经典的后剪枝方法。

  5. 实战要点 :使用 scikit-learn 构建决策树时,max_depthmin_samples_splitmin_samples_leaf 是防止过拟合的关键超参数。特征重要性分析有助于理解模型决策依据。

决策树不仅是高效的机器学习模型,更是理解更复杂集成算法(如随机森林、梯度提升树)的重要基础。希望本文能帮助读者建立对决策树的完整认知,并在实际项目中灵活运用。


参考库版本(本文代码适用):

复制代码
scikit-learn >= 1.0
matplotlib >= 3.5
pandas >= 1.3
numpy >= 1.20
相关推荐
GitFun2 小时前
7.4 万 Star 的开源记忆系统,让 AI 编程助手不再“失忆
人工智能
数据门徒2 小时前
神经网络原理 第七章:委员会机器
人工智能·神经网络·机器学习
HyperAI超神经2 小时前
Token使用量降低30%,以「阿凡达」为灵感的异构智能体框架Eywa,高效结合语言模型与领域专用基础模型
人工智能·语言模型
xiaoxiaoxiaolll2 小时前
《Nature Communications》论文解读:皮秒级单光子偏振测量如何绘制多模光纤中的模态动态图谱
网络·人工智能
Inhand陈工2 小时前
城投公司地面与停车场监控改造实战:映翰通IR302 + GRE隧道实现RFID与视频数据远程汇聚
网络·人工智能·物联网·网络安全·智能路由器·信息与通信
速易达网络2 小时前
YOLO26为AI而生
人工智能·机器学习
扬帆破浪2 小时前
免费开源AI软件.桌面单机版,可移动的AI知识库,察元 AI桌面版:本地离线知识库的folder-sync 第一次把文件夹挂成知识库
人工智能·知识图谱
夜影风2 小时前
给AI装上记忆系统:AI记忆机制与上下文管理实战
人工智能·langchain·ai记忆系统
深度学习lover2 小时前
<数据集>yolo食物分类检测<目标检测>
人工智能·深度学习·yolo·目标检测·计算机视觉·食物分类识别