不使用递归的决策树生成算法

不使用递归的决策树生成算法

利用队列 queue ,实现层次遍历(广度优先遍历),逐步处理每个节点来建立子树结构。再构建一个辅助队列,将每个节点存储到 nodes_to_process 列表中,以便在树生成完成后可以反向遍历计算每个节点的 leaf_num(叶子节点数量)。对于每个节点,根据特征选择和树的条件构建子节点;如果达到叶节点条件,直接将其标记为叶节点。最后,逆序处理计算每个结点的叶节点数量:通过逆序遍历 nodes_to_process 列表(即从叶节点到根节点),每次更新父节点的 leaf_num 为其所有子节点 leaf_num 的总和。


在构建决策树的过程中,每个节点都会根据特征选择和树的构建条件来决定是否进一步分裂。以下是这个步骤的详细说明:
1 、当前节点的特征选择
对于每个节点 current_node ,需要从剩余的特征集合 A 中选择一个"最优特征" a ∗ ,用于将数据集 D 划分成不同的子集。这个"最优特征"由基尼指数、信息增益或信息增率等来确定,使得划分后的子集在类别上更加纯净。
2 、判断是否满足叶节点条件
在进一步构建子节点之前,检查当前节点是否满足叶节点条件。如果满足以下任一条件,则将 current_node 标记为叶节点,而不再继续分裂:
( 1 ) 单一类别 :如果数据集 D 中的所有样本都属于同一类 C ,则不再需要进一步划分。此时可以将 current_node 标记为叶节点,类别为 C 。
( 2 ) 属性集为空或样本在剩余特征上取值相同 :如果 A 为空(即没有剩余特征可以选择),或数据集 D 中样本在剩余特征上的取值都相同,那么即使进一步分裂也不能提供更多信息。在这种情况下,current_node 也被标记为叶节点,并根据 D 中的样本数最多的类别作为 current_node 的类别。
( 3 ) 达到最大深度 :如果当前节点的深度已经达到了预设的最大深度 MaxDepth ,则停止继续分裂,将 current_node 直接标记为叶节点,并将类别设为当前数据集中样本数最多的类别。
3 、构建子节点
如果不满足叶节点条件,则 current_node 将根据选择的特征 a ∗ 来生成子节点。分情况处理:
( 1 )当前特征为离散值:如果 a ∗ 是一个离散特征,节点会针对 a ∗ 的每个可能的取值创建一个子节点 child_node ,表示 a ∗ 取该值的样本子集。将数据集中所有在 a ∗ 上取值为 a ∗v的样本(记作Da ∗ =a ∗v)分配到 child_node ,并继续构建树。 如果D a ∗ =a ∗v为空,即该子集没有样本,说明该特征值在当前分支下没有样本。此时,将 child_node 标记为叶节点,并将其类别设为当前数据中出现次数最多的类别。如果D a ∗ =a ∗v不为空,则将 child_node 和该子集继续加入到构建队列中。
( 2 )当前特征为连续值:如果 a ∗ 是一个连续特征,则会根据分割点(采用二分法选取)将数据集划分为两个子集。构建两个子节点:一个子节点代表 a ∗ ≥ split_valuea 的样本子集;另一个子节点代表 a ∗ <split_valuea 的样本子集。将两个子节点及其对应的数据集加入到构建队列中,继续后续的树构建。
4 、将子节点添加到树中
每个 child_node 会作为 current_node 的子节点,存储在 current_node.subtree 中。通过这种方式,不断将子节点加入树中,直到所有节点都满足叶节点条件,不再继续分裂为止。
5 、完成子节点分裂后的后续处理
当队列中所有节点都处理完后,逆序遍历已处理的节点列表,计算每个节点的叶节点数。

不使用递归的建树算法的实现思路

创建两个队列,分别为 queue 与 nodes_to_process 。 queue = deque([(root, X, y)]) 用来存储节点和数据,queue 的结构为三元组,分别为根节点、当前节点的 X 值,即去除 a ∗ 属性后剩下的 X 值,以及 y 标签。 nodes_to_process = [] 记录所有节点以便后续计算 leaf_num 。遍历 queue 队列,创建根节点并将其放入队列,并将当前节点存入 nodes_to_process以记录节点。使用 queue 按层次处理每个节点。
每次处理时,首先检查是否达到叶节点条件(如最大深度或单一类别),如果是则标记为叶节点。如果不是叶节点,则选择最佳分割特征,并根据特征类型(离散或连续)生成对应的子节点。
queue 队列处理完毕后,通过 nodes_to_process 逆序遍历,每个节点的 leaf_num 设为其子节点的 leaf_num 总和。

代码实现

    def generate_tree(self, X, y):
        root = Node()
        root.high = 0  # 根节点的高度为0
        queue = deque([(root, X, y)])  # 使用队列来存储节点和数据
        nodes_to_process = []  # 记录所有节点以便后续计算 leaf_num

        while queue:
            current_node, current_X, current_y = queue.popleft()
            nodes_to_process.append(current_node)

            # 叶节点条件:达到最大深度或只有单一类别或没有特征
            if current_node.high >= self.MaxDepth or current_y.nunique() == 1 or current_X.empty:
                current_node.is_leaf = True
                current_node.leaf_class = current_y.mode()[0]
                current_node.leaf_num = 1  # 是叶子节点,叶子数量为 1
                continue

            # 选择最佳划分特征
            best_feature_name, best_impurity = self.choose_best_feature_to_split(current_X, current_y)
            current_node.feature_name = best_feature_name
            current_node.impurity = best_impurity[0]
            current_node.feature_index = self.columns.index(best_feature_name)
            feature_values = current_X[best_feature_name]

            if len(best_impurity) == 1:  # 离散值特征
                current_node.is_continuous = False
                unique_vals = feature_values.unique()
                sub_X = current_X.drop(best_feature_name, axis=1)

                for value in unique_vals:
                    child_node = Node()
                    child_node.high = current_node.high + 1
                    queue.append((child_node, sub_X[feature_values == value], current_y[feature_values == value]))
                    current_node.subtree[value] = child_node

            elif len(best_impurity) == 2:  # 连续值特征
                current_node.is_continuous = True
                current_node.split_value = best_impurity[1]
                up_part = '>= {:.3f}'.format(current_node.split_value)
                down_part = '< {:.3f}'.format(current_node.split_value)

                child_node_up = Node()
                child_node_down = Node()
                child_node_up.high = current_node.high + 1
                child_node_down.high = current_node.high + 1
                queue.append((child_node_up, current_X[feature_values >= current_node.split_value],
                              current_y[feature_values >= current_node.split_value]))
                queue.append((child_node_down, current_X[feature_values < current_node.split_value],
                              current_y[feature_values < current_node.split_value]))

                current_node.subtree[up_part] = child_node_up
                current_node.subtree[down_part] = child_node_down

        # 逆序遍历 nodes_to_process,计算每个节点的 leaf_num
        while nodes_to_process:
            node = nodes_to_process.pop()
            if node.is_leaf:
                node.leaf_num = 1
            else:
                node.leaf_num = sum(child.leaf_num for child in node.subtree.values())

        return root
相关推荐
今天吃饺子1 分钟前
2024年SCI一区最新改进优化算法——四参数自适应生长优化器,MATLAB代码免费获取...
开发语言·算法·matlab
是阿建吖!2 分钟前
【优选算法】二分查找
c++·算法
王燕龙(大卫)6 分钟前
leetcode 数组中第k个最大元素
算法·leetcode
不去幼儿园1 小时前
【MARL】深入理解多智能体近端策略优化(MAPPO)算法与调参
人工智能·python·算法·机器学习·强化学习
Mr_Xuhhh1 小时前
重生之我在学环境变量
linux·运维·服务器·前端·chrome·算法
盼海2 小时前
排序算法(五)--归并排序
数据结构·算法·排序算法
网易独家音乐人Mike Zhou5 小时前
【卡尔曼滤波】数据预测Prediction观测器的理论推导及应用 C语言、Python实现(Kalman Filter)
c语言·python·单片机·物联网·算法·嵌入式·iot
Swift社区9 小时前
LeetCode - #139 单词拆分
算法·leetcode·职场和发展
Kent_J_Truman10 小时前
greater<>() 、less<>()及运算符 < 重载在排序和堆中的使用
算法
IT 青年10 小时前
数据结构 (1)基本概念和术语
数据结构·算法