不使用递归的决策树生成算法

不使用递归的决策树生成算法

利用队列 queue ,实现层次遍历(广度优先遍历),逐步处理每个节点来建立子树结构。再构建一个辅助队列,将每个节点存储到 nodes_to_process 列表中,以便在树生成完成后可以反向遍历计算每个节点的 leaf_num(叶子节点数量)。对于每个节点,根据特征选择和树的条件构建子节点;如果达到叶节点条件,直接将其标记为叶节点。最后,逆序处理计算每个结点的叶节点数量:通过逆序遍历 nodes_to_process 列表(即从叶节点到根节点),每次更新父节点的 leaf_num 为其所有子节点 leaf_num 的总和。


在构建决策树的过程中,每个节点都会根据特征选择和树的构建条件来决定是否进一步分裂。以下是这个步骤的详细说明:
1 、当前节点的特征选择
对于每个节点 current_node ,需要从剩余的特征集合 A 中选择一个"最优特征" a ∗ ,用于将数据集 D 划分成不同的子集。这个"最优特征"由基尼指数、信息增益或信息增率等来确定,使得划分后的子集在类别上更加纯净。
2 、判断是否满足叶节点条件
在进一步构建子节点之前,检查当前节点是否满足叶节点条件。如果满足以下任一条件,则将 current_node 标记为叶节点,而不再继续分裂:
( 1 ) 单一类别 :如果数据集 D 中的所有样本都属于同一类 C ,则不再需要进一步划分。此时可以将 current_node 标记为叶节点,类别为 C 。
( 2 ) 属性集为空或样本在剩余特征上取值相同 :如果 A 为空(即没有剩余特征可以选择),或数据集 D 中样本在剩余特征上的取值都相同,那么即使进一步分裂也不能提供更多信息。在这种情况下,current_node 也被标记为叶节点,并根据 D 中的样本数最多的类别作为 current_node 的类别。
( 3 ) 达到最大深度 :如果当前节点的深度已经达到了预设的最大深度 MaxDepth ,则停止继续分裂,将 current_node 直接标记为叶节点,并将类别设为当前数据集中样本数最多的类别。
3 、构建子节点
如果不满足叶节点条件,则 current_node 将根据选择的特征 a ∗ 来生成子节点。分情况处理:
( 1 )当前特征为离散值:如果 a ∗ 是一个离散特征,节点会针对 a ∗ 的每个可能的取值创建一个子节点 child_node ,表示 a ∗ 取该值的样本子集。将数据集中所有在 a ∗ 上取值为 a ∗v的样本(记作Da ∗ =a ∗v)分配到 child_node ,并继续构建树。 如果D a ∗ =a ∗v为空,即该子集没有样本,说明该特征值在当前分支下没有样本。此时,将 child_node 标记为叶节点,并将其类别设为当前数据中出现次数最多的类别。如果D a ∗ =a ∗v不为空,则将 child_node 和该子集继续加入到构建队列中。
( 2 )当前特征为连续值:如果 a ∗ 是一个连续特征,则会根据分割点(采用二分法选取)将数据集划分为两个子集。构建两个子节点:一个子节点代表 a ∗ ≥ split_valuea 的样本子集;另一个子节点代表 a ∗ <split_valuea 的样本子集。将两个子节点及其对应的数据集加入到构建队列中,继续后续的树构建。
4 、将子节点添加到树中
每个 child_node 会作为 current_node 的子节点,存储在 current_node.subtree 中。通过这种方式,不断将子节点加入树中,直到所有节点都满足叶节点条件,不再继续分裂为止。
5 、完成子节点分裂后的后续处理
当队列中所有节点都处理完后,逆序遍历已处理的节点列表,计算每个节点的叶节点数。

不使用递归的建树算法的实现思路

创建两个队列,分别为 queue 与 nodes_to_process 。 queue = deque([(root, X, y)]) 用来存储节点和数据,queue 的结构为三元组,分别为根节点、当前节点的 X 值,即去除 a ∗ 属性后剩下的 X 值,以及 y 标签。 nodes_to_process = [] 记录所有节点以便后续计算 leaf_num 。遍历 queue 队列,创建根节点并将其放入队列,并将当前节点存入 nodes_to_process以记录节点。使用 queue 按层次处理每个节点。
每次处理时,首先检查是否达到叶节点条件(如最大深度或单一类别),如果是则标记为叶节点。如果不是叶节点,则选择最佳分割特征,并根据特征类型(离散或连续)生成对应的子节点。
queue 队列处理完毕后,通过 nodes_to_process 逆序遍历,每个节点的 leaf_num 设为其子节点的 leaf_num 总和。

代码实现

    def generate_tree(self, X, y):
        root = Node()
        root.high = 0  # 根节点的高度为0
        queue = deque([(root, X, y)])  # 使用队列来存储节点和数据
        nodes_to_process = []  # 记录所有节点以便后续计算 leaf_num

        while queue:
            current_node, current_X, current_y = queue.popleft()
            nodes_to_process.append(current_node)

            # 叶节点条件:达到最大深度或只有单一类别或没有特征
            if current_node.high >= self.MaxDepth or current_y.nunique() == 1 or current_X.empty:
                current_node.is_leaf = True
                current_node.leaf_class = current_y.mode()[0]
                current_node.leaf_num = 1  # 是叶子节点,叶子数量为 1
                continue

            # 选择最佳划分特征
            best_feature_name, best_impurity = self.choose_best_feature_to_split(current_X, current_y)
            current_node.feature_name = best_feature_name
            current_node.impurity = best_impurity[0]
            current_node.feature_index = self.columns.index(best_feature_name)
            feature_values = current_X[best_feature_name]

            if len(best_impurity) == 1:  # 离散值特征
                current_node.is_continuous = False
                unique_vals = feature_values.unique()
                sub_X = current_X.drop(best_feature_name, axis=1)

                for value in unique_vals:
                    child_node = Node()
                    child_node.high = current_node.high + 1
                    queue.append((child_node, sub_X[feature_values == value], current_y[feature_values == value]))
                    current_node.subtree[value] = child_node

            elif len(best_impurity) == 2:  # 连续值特征
                current_node.is_continuous = True
                current_node.split_value = best_impurity[1]
                up_part = '>= {:.3f}'.format(current_node.split_value)
                down_part = '< {:.3f}'.format(current_node.split_value)

                child_node_up = Node()
                child_node_down = Node()
                child_node_up.high = current_node.high + 1
                child_node_down.high = current_node.high + 1
                queue.append((child_node_up, current_X[feature_values >= current_node.split_value],
                              current_y[feature_values >= current_node.split_value]))
                queue.append((child_node_down, current_X[feature_values < current_node.split_value],
                              current_y[feature_values < current_node.split_value]))

                current_node.subtree[up_part] = child_node_up
                current_node.subtree[down_part] = child_node_down

        # 逆序遍历 nodes_to_process,计算每个节点的 leaf_num
        while nodes_to_process:
            node = nodes_to_process.pop()
            if node.is_leaf:
                node.leaf_num = 1
            else:
                node.leaf_num = sum(child.leaf_num for child in node.subtree.values())

        return root
相关推荐
爱吃生蚝的于勒1 小时前
C语言内存函数
c语言·开发语言·数据结构·c++·学习·算法
ChoSeitaku6 小时前
链表循环及差集相关算法题|判断循环双链表是否对称|两循环单链表合并成循环链表|使双向循环链表有序|单循环链表改双向循环链表|两链表的差集(C)
c语言·算法·链表
我爱工作&工作love我6 小时前
1435:【例题3】曲线 一本通 代替三分
c++·算法
白-胖-子7 小时前
【蓝桥等考C++真题】蓝桥杯等级考试C++组第13级L13真题原题(含答案)-统计数字
开发语言·c++·算法·蓝桥杯·等考·13级
workflower7 小时前
数据结构练习题和答案
数据结构·算法·链表·线性回归
好睡凯7 小时前
c++写一个死锁并且自己解锁
开发语言·c++·算法
Sunyanhui17 小时前
力扣 二叉树的直径-543
算法·leetcode·职场和发展
一个不喜欢and不会代码的码农7 小时前
力扣105:从先序和中序序列构造二叉树
数据结构·算法·leetcode
前端郭德纲7 小时前
浏览器是加载ES6模块的?
javascript·算法