【第二章:机器学习与神经网络概述】04.回归算法理论与实践 -(3)决策树回归模型(Decision Tree Regression)

第二章: 机器学习与神经网络概述

第四部分:回归算法理论与实践

第三节:决策树回归模型

内容:剪枝方法、回归树结构与算法实现。

决策树回归模型是一种非参数的监督学习方法,通过将特征空间划分为多个区域,在每个区域内做常数预测,适合处理非线性回归问题、特征交互明显的数据集。


一、基本原理

决策树回归以CART(Classification and Regression Trees)算法为基础,通过不断划分特征空间,构建一棵回归树:

  • 每个内部节点表示对某一特征的判断;

  • 每个叶节点表示一个预测值(区域内样本均值);

  • 划分依据:最小化划分后区域内的均方误差(MSE)。


二、划分准则与误差计算

对样本集 D,假设以特征 的值 s 作为划分点,将样本划分为:

其目标是最小化总的平方误差:


三、剪枝策略(Pruning)

决策树容易过拟合,需通过剪枝来控制复杂度:

1. 预剪枝(Pre-Pruning)
  • 在构建过程中提前停止划分:

    • 达到最大深度 max_depth

    • 每个节点最小样本数 min_samples_split

    • MSE 减少小于阈值

2. 后剪枝(Post-Pruning)
  • 先生成整棵树,再从底向上剪去"收益小"的分支(如 sklearn 的 ccp_alpha 参数)

  • 剪枝目标:在保留预测能力的前提下降低模型复杂度


四、Python 实现示例(使用 sklearn)

python 复制代码
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt

# 构造数据
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])

# 模型:不剪枝 vs 预剪枝 vs 后剪枝
reg_full = DecisionTreeRegressor()
reg_pruned = DecisionTreeRegressor(max_depth=3)
reg_ccp = DecisionTreeRegressor(ccp_alpha=0.01)

# 训练
reg_full.fit(X, y)
reg_pruned.fit(X, y)
reg_ccp.fit(X, y)

plt.rcParams['font.sans-serif'] = ['SimHei']
# 解决负号'-'显示为方块的问题
plt.rcParams['axes.unicode_minus'] = False

# 可视化
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
plt.figure(figsize=(10, 6))
plt.scatter(X, y, s=20, label="data", color="black")
plt.plot(X_test, reg_full.predict(X_test), label="Full Tree", linewidth=2)
plt.plot(X_test, reg_pruned.predict(X_test), label="Pre-Pruned (depth=3)", linestyle="--")
plt.plot(X_test, reg_ccp.predict(X_test), label="Post-Pruned (ccp_alpha=0.01)", linestyle=":")
plt.legend()
plt.title("回归树剪枝效果对比")
plt.xlabel("X")
plt.ylabel("y")
plt.grid(True)
plt.tight_layout()
plt.show()

五、优缺点分析

优点 缺点
逻辑简单、易理解 容易过拟合,需要剪枝
可处理非线性和多维特征交互 对微小变化敏感,稳定性差
不需标准化或归一化 对样本数量和分布较敏感
可解释性强(树结构明确) 難以推广:小数据表现好,大数据可能需集成优化

六、模型调参建议

参数 作用 建议
max_depth 限制树的最大深度 控制模型复杂度,避免过拟合
min_samples_split 拆分内部节点所需最小样本数 增大可减少模型复杂度
min_samples_leaf 每个叶子节点的最小样本数 增大有助于平滑预测结果
ccp_alpha 后剪枝惩罚系数(复杂度代价剪枝) 自动调节树结构,可结合验证集选择最佳值

七、典型应用场景

  • 房价预测(特征离散明显)

  • 电商销售量预测

  • 时间序列短期预测(可结合滑窗技术)

  • 特征交互复杂但不多的中小数据集建模


补充:

可视化决策树结构
python 复制代码
from sklearn.tree import plot_tree
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt

# 构造数据
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])

# 模型:不剪枝 vs 预剪枝 vs 后剪枝
reg_pruned = DecisionTreeRegressor(max_depth=3)

# 训练
reg_pruned.fit(X, y)

plt.rcParams['font.sans-serif'] = ['SimHei']
# 解决负号'-'显示为方块的问题
plt.rcParams['axes.unicode_minus'] = False


plt.figure(figsize=(12, 6))
plot_tree(reg_pruned, filled=True, feature_names=["X"], rounded=True)
plt.title("回归树结构(max_depth=3)")
plt.show()

回归树结构图(plot_tree)
python 复制代码
from sklearn.tree import DecisionTreeRegressor, plot_tree
import matplotlib.pyplot as plt
import numpy as np

# 构造样本数据
X = np.array([[1], [2], [3], [4], [5], [6], [7], [8]])
y = np.array([5, 4.5, 4, 3.5, 3, 2.5, 2, 1.5])

# 创建并训练模型
tree = DecisionTreeRegressor(max_depth=3, random_state=42)
tree.fit(X, y)

# 可视化决策树结构
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(figsize=(10, 6))
plot_tree(tree, feature_names=["X"], filled=True, rounded=True)
plt.title("回归树结构图 (max_depth=3)")
plt.show()

剪枝前后预测曲线对比图
python 复制代码
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt

# 构造数据
rng = np.random.RandomState(0)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel() + rng.normal(0, 0.1, X.shape[0])

# 不剪枝模型
reg_full = DecisionTreeRegressor()
reg_full.fit(X, y)

# 预剪枝模型(限制最大深度)
reg_pruned = DecisionTreeRegressor(max_depth=3)
reg_pruned.fit(X, y)

# 后剪枝模型(设置复杂度惩罚参数)
reg_ccp = DecisionTreeRegressor(ccp_alpha=0.01)
reg_ccp.fit(X, y)

# 测试数据
X_test = np.linspace(0, 5, 500).reshape(-1, 1)

# 可视化
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(figsize=(10, 6))
plt.scatter(X, y, label="Train Data", color="black", s=20)
plt.plot(X_test, reg_full.predict(X_test), label="Full Tree", color="blue")
plt.plot(X_test, reg_pruned.predict(X_test), label="Pre-Pruned (max_depth=3)", color="green", linestyle="--")
plt.plot(X_test, reg_ccp.predict(X_test), label="Post-Pruned (ccp_alpha=0.01)", color="red", linestyle=":")
plt.title("回归树剪枝对比图")
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
相关推荐
贾全15 分钟前
第十章:HIL-SERL 真实机器人训练实战
人工智能·深度学习·算法·机器学习·机器人
GIS小天30 分钟前
AI+预测3D新模型百十个定位预测+胆码预测+去和尾2025年7月4日第128弹
人工智能·算法·机器学习·彩票
我是小哪吒2.042 分钟前
书籍推荐-《对抗机器学习:攻击面、防御机制与人工智能中的学习理论》
人工智能·深度学习·学习·机器学习·ai·语言模型·大模型
慕婉03071 小时前
深度学习前置知识全面解析:从机器学习到深度学习的进阶之路
人工智能·深度学习·机器学习
蓝婷儿2 小时前
Python 机器学习核心入门与实战进阶 Day 2 - KNN(K-近邻算法)分类实战与调参
python·机器学习·近邻算法
24毕业生从零开始学ai4 小时前
长短期记忆网络(LSTM):让神经网络拥有 “持久记忆力” 的神奇魔法
rnn·神经网络·lstm
中杯可乐多加冰5 小时前
【AI落地应用实战】AIGC赋能职场PPT汇报:从效率工具到辅助优化
人工智能·深度学习·神经网络·aigc·powerpoint·ai赋能
烟锁池塘柳05 小时前
【大模型】解码策略:Greedy Search、Beam Search、Top-k/Top-p、Temperature Sampling等
人工智能·深度学习·机器学习
Blossom.1187 小时前
机器学习在智能供应链中的应用:需求预测与物流优化
人工智能·深度学习·神经网络·机器学习·计算机视觉·机器人·语音识别