【机器学习】回归树

回归树是一种用于数值型目标变量的监督学习算法,通过将特征空间划分为多个区域,并在每个区域内使用简单的预测模型(如区域均值)来进行回归。回归树以"递归划分-计算区域均值"的方式逐层生成树节点,最终形成叶节点预测值。相比于线性回归,回归树更适合处理非线性和复杂数据结构。

回归树的基本原理

在回归树中,每个节点执行以下操作:

  • 选择最优特征及分割点:通过最小化均方误差(Mean Squared Error, MSE)等标准选择最佳分割特征和分割点。
  • 分割数据:根据选择的分割特征将数据划分成两部分,形成左子节点和右子节点。
  • 递归分割:对子节点进行递归分割,直至满足停止条件(如最大深度或最小样本数)。

分割准则

均方误差(MSE)

在回归树中,常用均方误差(MSE)作为分割准则:
MSE = 1 N ∑ i = 1 N ( y i − y ˉ ) 2 \text{MSE} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \bar{y})^2 MSE=N1i=1∑N(yi−yˉ)2

其中,( y_i ) 是样本 ( i ) 的实际值,( \bar{y} ) 是区域内样本的平均值。分割点选择通过最小化分割前后数据的 MSE 来完成。

回归树的构建步骤

  1. 选择最佳分割特征与分割点:遍历每个特征和可能的分割点,计算分割后的MSE,选择使MSE最小的分割特征和点。
  2. 递归分割数据:在左、右子节点递归执行上述过程,形成新的分支节点。
  3. 生成叶节点:一旦满足停止条件,将当前节点的预测值设为该区域中所有样本的均值。

用 Numpy 实现回归树

以下代码展示了如何用 Numpy 实现一个基本的回归树,并通过均方误差来确定分割点。

python 复制代码
import numpy as np
import matplotlib.pyplot as plt

# 计算均方误差(MSE)
def mean_squared_error(y):
    return np.var(y) * len(y)

# 数据集分割
def split_dataset(X, y, feature, threshold):
    left_mask = X[:, feature] <= threshold
    right_mask = ~left_mask
    return X[left_mask], y[left_mask], X[right_mask], y[right_mask]

# 查找最佳分割特征和分割点
def best_split(X, y):
    best_mse = float("inf")
    best_feature, best_threshold = None, None
    for feature in range(X.shape[1]):
        thresholds = np.unique(X[:, feature])
        for threshold in thresholds:
            _, y_left, _, y_right = split_dataset(X, y, feature, threshold)
            if len(y_left) == 0 or len(y_right) == 0:
                continue
            mse_split = mean_squared_error(y_left) + mean_squared_error(y_right)
            if mse_split < best_mse:
                best_mse = mse_split
                best_feature = feature
                best_threshold = threshold
    return best_feature, best_threshold

# 回归树类
class RegressionTree:
    def __init__(self, max_depth=3, min_samples_split=2):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.tree = None

    def fit(self, X, y, depth=0):
        if len(y) < self.min_samples_split or depth >= self.max_depth:
            return np.mean(y)
        
        feature, threshold = best_split(X, y)
        if feature is None:
            return np.mean(y)
        
        left_X, left_y, right_X, right_y = split_dataset(X, y, feature, threshold)
        left_node = self.fit(left_X, left_y, depth + 1)
        right_node = self.fit(right_X, right_y, depth + 1)
        
        self.tree = {"feature": feature, "threshold": threshold, "left": left_node, "right": right_node}
        return self.tree

    def predict_sample(self, x, tree):
        if not isinstance(tree, dict):
            return tree
        if x[tree["feature"]] <= tree["threshold"]:
            return self.predict_sample(x, tree["left"])
        else:
            return self.predict_sample(x, tree["right"])

    def predict(self, X):
        return np.array([self.predict_sample(x, self.tree) for x in X])

# 生成示例数据
np.random.seed(0)
X = np.random.rand(100, 1) * 10  # 特征数据
y = 2 * X.flatten() + np.random.randn(100) * 2  # 标签数据

# 训练回归树
tree = RegressionTree(max_depth=4, min_samples_split=5)
tree.fit(X, y)

# 预测并可视化
X_test = np.linspace(0, 10, 100).reshape(-1, 1)
y_pred = tree.predict(X_test)

plt.scatter(X, y, color="blue", label="训练数据")
plt.plot(X_test, y_pred, color="red", label="回归树预测")
plt.xlabel("特征")
plt.ylabel("目标值")
plt.title("回归树预测示意图")
plt.legend()
plt.show()

在代码中,我们首先通过遍历各个特征和分割点来选择最优分割点,使得均方误差最小。然后在每个节点递归进行分割,直至达到设定的深度或最小样本数。最终通过构建的树结构进行预测。

使用 Sklearn 的回归树

Scikit-Learn 提供了 DecisionTreeRegressor 来实现回归树模型,可以大大简化建模过程。

python 复制代码
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error

# 训练回归树
regressor = DecisionTreeRegressor(max_depth=4, min_samples_split=5)
regressor.fit(X, y)

# 预测
y_pred_sklearn = regressor.predict(X_test)

# 计算均方误差
mse = mean_squared_error(y, regressor.predict(X))
print("均方误差:", mse)

# 可视化
plt.scatter(X, y, color="blue", label="训练数据")
plt.plot(X_test, y_pred_sklearn, color="red", label="Sklearn 回归树预测")
plt.xlabel("特征")
plt.ylabel("目标值")
plt.title("Sklearn 回归树预测示意图")
plt.legend()
plt.show()

总结

本文介绍了回归树的基本概念与实现,包括回归树的分割准则、MSE 计算、最佳分割点选择等细节。通过 Numpy 手动实现了一个简单的回归树模型,并展示了如何在 Scikit-Learn 中快速实现和使用回归树。

相关推荐
蒙奇D索大1 分钟前
【11408学习记录】考研数学攻坚:行列式本质、性质与计算全突破
笔记·学习·线性代数·考研·机器学习·改行学it
Blossom.1182 分钟前
基于机器学习的智能故障预测系统:构建与优化
人工智能·python·深度学习·神经网络·机器学习·分类·tensorflow
DisonTangor18 分钟前
【字节拥抱开源】字节团队开源视频模型 ContentV: 有限算力下的视频生成模型高效训练
人工智能·开源·aigc
春末的南方城市36 分钟前
腾讯开源视频生成工具 HunyuanVideo-Avatar,上传一张图+一段音频,就能让图中的人物、动物甚至虚拟角色“活”过来,开口说话、唱歌、演相声!
人工智能·计算机视觉·自然语言处理·aigc·音视频·视频生成
UQI-LIUWJ38 分钟前
论文笔记:Urban Computing in the Era of Large Language Models
人工智能·语言模型·自然语言处理
张较瘦_39 分钟前
[论文阅读] 人工智能+软件工程 | MemFL:给大模型装上“项目记忆”,让软件故障定位又快又准
论文阅读·人工智能·软件工程
yzx99101342 分钟前
基于 PyTorch 和 OpenCV 的实时表情检测系统
人工智能·pytorch·opencv
ICscholar1 小时前
生成对抗网络(GAN)损失函数解读
人工智能·机器学习·生成对抗网络
我不是小upper1 小时前
L1和L2核心区别 !!--part 2
人工智能·深度学习·算法·机器学习
geneculture1 小时前
融智学本体论体系全景图
人工智能·数学建模·融智学的重要应用·道函数·三类思维坐标