【漫话机器学习系列】033.决策树回归(Decision Tree Regression)

决策树回归(Decision Tree Regression)

决策树回归是一种基于树状结构进行回归分析的监督学习方法。它将输入空间递归地划分为多个区域,并在每个区域内拟合一个简单的常数值,从而对目标变量进行预测。


决策树回归的原理

  1. 树的构建

    • 决策树以树的形式对数据进行划分。
    • 每次划分选择一个特征及其阈值,将数据集分为两个子集。
    • 目标是找到最佳划分,使得子集内的目标变量尽可能一致(即减少误差)。
  2. 划分准则

    通常采用**均方误差(MSE, Mean Squared Error)**作为划分的评价指标:

    其中, 是真实值,​ 是预测值。

  3. 停止条件

    • 达到最大树深度。
    • 叶节点的样本数少于预设值。
    • 划分后误差改善不足。
  4. 预测

    对于新输入数据,沿着决策树从根节点到叶节点,根据划分规则找到其对应的叶节点,返回叶节点中目标变量的均值作为预测值。


构建过程

  1. 根节点的初始化

    将所有数据视为一个整体,计算均值作为预测值,计算当前数据集的均方误差。

  2. 递归划分

    • 遍历每个特征及其所有可能的划分点,计算划分后的均方误差。
    • 选择能最大程度减少误差的特征及阈值进行划分。
  3. 停止划分

    • 当树的深度达到预设值。
    • 当叶节点的样本数小于预设阈值。
    • 当划分后误差改善不足。

优点

  1. 可解释性强

    决策树的结构直观清晰,易于可视化和理解。

  2. 非线性建模能力

    决策树能有效捕获数据中的非线性关系。

  3. 无需特征缩放

    决策树对特征的数值范围不敏感,不需要标准化或归一化。


缺点

  1. 易过拟合

    决策树在深度较大时可能会过拟合,导致泛化能力差。

  2. 对数据分布敏感

    对于小的样本噪声或异常值,可能会导致不稳定的划分。

  3. 无法捕获连续目标变量的平滑关系

    决策树只能在区域内拟合常数值,难以捕获目标变量的连续变化。


改进方法

  1. 剪枝(Pruning)

    • 预剪枝:设置树的最大深度、叶节点的最小样本数等参数,限制树的规模。
    • 后剪枝:先构建一棵完整的树,然后通过去掉不重要的分支来减少过拟合。
  2. 集成学习

    • 随机森林:构建多棵决策树并取平均值。
    • 梯度提升树(GBDT):通过串联多个决策树逐步减小误差。
    • 极端随机树(Extra Trees):进一步随机化特征和划分点选择,降低过拟合风险。

评价指标

  • 均方误差(MSE)

  • 平均绝对误差(MAE, Mean Absolute Error)

  • 决定系数(R2R^2R2)

    衡量模型对目标变量的解释程度:

    其中 是目标变量的均值。


代码实现

以下是使用 Python 中的 Scikit-learn 实现决策树回归的代码示例:

python 复制代码
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
import numpy as np

# 生成模拟数据
np.random.seed(0)
X = np.sort(np.random.rand(100, 1), axis=0)
y = np.sin(2 * np.pi * X).ravel() + np.random.randn(100) * 0.1

# 数据划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 构建决策树回归模型
model = DecisionTreeRegressor(max_depth=4)
model.fit(X_train, y_train)

# 预测
y_pred = model.predict(X_test)

# 评价模型
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"Mean Squared Error: {mse:.3f}")
print(f"R^2 Score: {r2:.3f}")

# 可视化结果
import matplotlib.pyplot as plt
plt.scatter(X_test, y_test, color="blue", label="True Values")
plt.scatter(X_test, y_pred, color="red", label="Predicted Values")
plt.legend()
plt.title("Decision Tree Regression")
plt.xlabel("X")
plt.ylabel("y")
plt.show()

输出结果

Matlab 复制代码
Mean Squared Error: 0.038
R^2 Score: 0.939

应用场景

  1. 房地产价格预测

    根据特征(面积、位置、房龄等)预测房价。

  2. 市场营销分析

    根据用户行为数据预测用户对产品的需求。

  3. 时间序列分析

    使用历史数据预测未来值。


总结

决策树回归是简单易用的回归模型,特别适合处理非线性和非参数问题。然而,单独使用决策树可能会过拟合或欠拟合,因此需要通过剪枝或集成方法进一步提升模型的鲁棒性和泛化能力。

相关推荐
格林威13 分钟前
传送带上运动模糊图像复原:提升动态成像清晰度的 6 个核心方案,附 OpenCV+Halcon 实战代码!
人工智能·opencv·机器学习·计算机视觉·ai·halcon·工业相机
Aurora-Borealis.1 小时前
Day27 机器学习流水线
人工智能·机器学习
zhangfeng11332 小时前
数据分析 医学分析中线性回归、Cox回归、Logistic回归的定义和区别,原理和公式,适用场景
数据分析·回归·线性回归
黑符石3 小时前
【论文研读】Madgwick 姿态滤波算法报告总结
人工智能·算法·机器学习·imu·惯性动捕·madgwick·姿态滤波
JQLvopkk3 小时前
智能AI“学习功能”在程序开发部分的逻辑
人工智能·机器学习·计算机视觉
jiayong233 小时前
model.onnx 深度分析报告(第2篇)
人工智能·机器学习·向量数据库·向量模型
张祥6422889044 小时前
数理统计基础一
人工智能·机器学习·概率论
悟乙己4 小时前
使用TimeGPT进行时间序列预测案例解析
机器学习·大模型·llm·时间序列·预测
云和数据.ChenGuang4 小时前
人工智能实践之基于CNN的街区餐饮图片识别案例实践
人工智能·深度学习·神经网络·机器学习·cnn
人工智能培训5 小时前
什么是马尔可夫决策过程(MDP)?马尔可夫性的核心含义是什么?
人工智能·深度学习·机器学习·cnn·智能体·马尔可夫决策