【漫话机器学习系列】033.决策树回归(Decision Tree Regression)

决策树回归(Decision Tree Regression)

决策树回归是一种基于树状结构进行回归分析的监督学习方法。它将输入空间递归地划分为多个区域,并在每个区域内拟合一个简单的常数值,从而对目标变量进行预测。


决策树回归的原理

  1. 树的构建

    • 决策树以树的形式对数据进行划分。
    • 每次划分选择一个特征及其阈值,将数据集分为两个子集。
    • 目标是找到最佳划分,使得子集内的目标变量尽可能一致(即减少误差)。
  2. 划分准则

    通常采用**均方误差(MSE, Mean Squared Error)**作为划分的评价指标:

    其中, 是真实值,​ 是预测值。

  3. 停止条件

    • 达到最大树深度。
    • 叶节点的样本数少于预设值。
    • 划分后误差改善不足。
  4. 预测

    对于新输入数据,沿着决策树从根节点到叶节点,根据划分规则找到其对应的叶节点,返回叶节点中目标变量的均值作为预测值。


构建过程

  1. 根节点的初始化

    将所有数据视为一个整体,计算均值作为预测值,计算当前数据集的均方误差。

  2. 递归划分

    • 遍历每个特征及其所有可能的划分点,计算划分后的均方误差。
    • 选择能最大程度减少误差的特征及阈值进行划分。
  3. 停止划分

    • 当树的深度达到预设值。
    • 当叶节点的样本数小于预设阈值。
    • 当划分后误差改善不足。

优点

  1. 可解释性强

    决策树的结构直观清晰,易于可视化和理解。

  2. 非线性建模能力

    决策树能有效捕获数据中的非线性关系。

  3. 无需特征缩放

    决策树对特征的数值范围不敏感,不需要标准化或归一化。


缺点

  1. 易过拟合

    决策树在深度较大时可能会过拟合,导致泛化能力差。

  2. 对数据分布敏感

    对于小的样本噪声或异常值,可能会导致不稳定的划分。

  3. 无法捕获连续目标变量的平滑关系

    决策树只能在区域内拟合常数值,难以捕获目标变量的连续变化。


改进方法

  1. 剪枝(Pruning)

    • 预剪枝:设置树的最大深度、叶节点的最小样本数等参数,限制树的规模。
    • 后剪枝:先构建一棵完整的树,然后通过去掉不重要的分支来减少过拟合。
  2. 集成学习

    • 随机森林:构建多棵决策树并取平均值。
    • 梯度提升树(GBDT):通过串联多个决策树逐步减小误差。
    • 极端随机树(Extra Trees):进一步随机化特征和划分点选择,降低过拟合风险。

评价指标

  • 均方误差(MSE)

  • 平均绝对误差(MAE, Mean Absolute Error)

  • 决定系数(R2R^2R2)

    衡量模型对目标变量的解释程度:

    其中 是目标变量的均值。


代码实现

以下是使用 Python 中的 Scikit-learn 实现决策树回归的代码示例:

python 复制代码
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
import numpy as np

# 生成模拟数据
np.random.seed(0)
X = np.sort(np.random.rand(100, 1), axis=0)
y = np.sin(2 * np.pi * X).ravel() + np.random.randn(100) * 0.1

# 数据划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 构建决策树回归模型
model = DecisionTreeRegressor(max_depth=4)
model.fit(X_train, y_train)

# 预测
y_pred = model.predict(X_test)

# 评价模型
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"Mean Squared Error: {mse:.3f}")
print(f"R^2 Score: {r2:.3f}")

# 可视化结果
import matplotlib.pyplot as plt
plt.scatter(X_test, y_test, color="blue", label="True Values")
plt.scatter(X_test, y_pred, color="red", label="Predicted Values")
plt.legend()
plt.title("Decision Tree Regression")
plt.xlabel("X")
plt.ylabel("y")
plt.show()

输出结果

Matlab 复制代码
Mean Squared Error: 0.038
R^2 Score: 0.939

应用场景

  1. 房地产价格预测

    根据特征(面积、位置、房龄等)预测房价。

  2. 市场营销分析

    根据用户行为数据预测用户对产品的需求。

  3. 时间序列分析

    使用历史数据预测未来值。


总结

决策树回归是简单易用的回归模型,特别适合处理非线性和非参数问题。然而,单独使用决策树可能会过拟合或欠拟合,因此需要通过剪枝或集成方法进一步提升模型的鲁棒性和泛化能力。

相关推荐
Fishel-3 分钟前
线性回归api再介绍
算法·回归·线性回归
勤劳的进取家12 分钟前
协方差矩阵
线性代数·算法·机器学习·矩阵
nuise_2 小时前
李宏毅机器学习课程笔记02 | 机器学习任务攻略General Guide
人工智能·笔记·机器学习
Kai HVZ3 小时前
《机器学习》——逻辑回归基本介绍
人工智能·机器学习
我爱一根柴哈3 小时前
人工智能机器学习从入门到高级学习资料(含资料下载地址)
人工智能·深度学习·机器学习
B站计算机毕业设计超人4 小时前
计算机毕业设计Python+Spark中药推荐系统 中药识别系统 中药数据分析 中药大数据 中药可视化 中药爬虫 中药大数据 大数据毕业设计 大
大数据·python·深度学习·机器学习·课程设计·数据可视化·推荐算法
金书世界5 小时前
自动驾驶ADAS算法--测试工程环境搭建
人工智能·机器学习·自动驾驶
笔写落去5 小时前
统计学习方法(第二版) 第五章
人工智能·深度学习·机器学习
魔理沙偷走了BUG7 小时前
【简博士统计学习方法】第1章:1. 统计学习的定义与分类
机器学习·统计学习方法
dundunmm7 小时前
【数据挖掘】深度高斯过程
python·深度学习·机器学习·数据挖掘·高斯过程·深度高斯过程