不使用递归的决策树生成算法
利用队列 queue ,实现层次遍历(广度优先遍历),逐步处理每个节点来建立子树结构。再构建一个辅助队列,将每个节点存储到 nodes_to_process 列表中,以便在树生成完成后可以反向遍历计算每个节点的 leaf_num(叶子节点数量)。对于每个节点,根据特征选择和树的条件构建子节点;如果达到叶节点条件,直接将其标记为叶节点。最后,逆序处理计算每个结点的叶节点数量:通过逆序遍历 nodes_to_process 列表(即从叶节点到根节点),每次更新父节点的 leaf_num 为其所有子节点 leaf_num 的总和。
在构建决策树的过程中,每个节点都会根据特征选择和树的构建条件来决定是否进一步分裂。以下是这个步骤的详细说明:
1 、当前节点的特征选择
对于每个节点 current_node ,需要从剩余的特征集合 A 中选择一个"最优特征" a ∗ ,用于将数据集 D 划分成不同的子集。这个"最优特征"由基尼指数、信息增益或信息增率等来确定,使得划分后的子集在类别上更加纯净。
2 、判断是否满足叶节点条件
在进一步构建子节点之前,检查当前节点是否满足叶节点条件。如果满足以下任一条件,则将 current_node 标记为叶节点,而不再继续分裂:
( 1 ) 单一类别 :如果数据集 D 中的所有样本都属于同一类 C ,则不再需要进一步划分。此时可以将 current_node 标记为叶节点,类别为 C 。
( 2 ) 属性集为空或样本在剩余特征上取值相同 :如果 A 为空(即没有剩余特征可以选择),或数据集 D 中样本在剩余特征上的取值都相同,那么即使进一步分裂也不能提供更多信息。在这种情况下,current_node 也被标记为叶节点,并根据 D 中的样本数最多的类别作为 current_node 的类别。
( 3 ) 达到最大深度 :如果当前节点的深度已经达到了预设的最大深度 MaxDepth ,则停止继续分裂,将 current_node 直接标记为叶节点,并将类别设为当前数据集中样本数最多的类别。
3 、构建子节点
如果不满足叶节点条件,则 current_node 将根据选择的特征 a ∗ 来生成子节点。分情况处理:
( 1 )当前特征为离散值:如果 a ∗ 是一个离散特征,节点会针对 a ∗ 的每个可能的取值创建一个子节点 child_node ,表示 a ∗ 取该值的样本子集。将数据集中所有在 a ∗ 上取值为 a ∗v的样本(记作Da ∗ =a ∗v)分配到 child_node ,并继续构建树。 如果D a ∗ =a ∗v为空,即该子集没有样本,说明该特征值在当前分支下没有样本。此时,将 child_node 标记为叶节点,并将其类别设为当前数据中出现次数最多的类别。如果D a ∗ =a ∗v不为空,则将 child_node 和该子集继续加入到构建队列中。
( 2 )当前特征为连续值:如果 a ∗ 是一个连续特征,则会根据分割点(采用二分法选取)将数据集划分为两个子集。构建两个子节点:一个子节点代表 a ∗ ≥ split_valuea 的样本子集;另一个子节点代表 a ∗ <split_valuea 的样本子集。将两个子节点及其对应的数据集加入到构建队列中,继续后续的树构建。
4 、将子节点添加到树中
每个 child_node 会作为 current_node 的子节点,存储在 current_node.subtree 中。通过这种方式,不断将子节点加入树中,直到所有节点都满足叶节点条件,不再继续分裂为止。
5 、完成子节点分裂后的后续处理
当队列中所有节点都处理完后,逆序遍历已处理的节点列表,计算每个节点的叶节点数。
不使用递归的建树算法的实现思路
创建两个队列,分别为 queue 与 nodes_to_process 。 queue = deque([(root, X, y)]) 用来存储节点和数据,queue 的结构为三元组,分别为根节点、当前节点的 X 值,即去除 a ∗ 属性后剩下的 X 值,以及 y 标签。 nodes_to_process = [] 记录所有节点以便后续计算 leaf_num 。遍历 queue 队列,创建根节点并将其放入队列,并将当前节点存入 nodes_to_process以记录节点。使用 queue 按层次处理每个节点。
每次处理时,首先检查是否达到叶节点条件(如最大深度或单一类别),如果是则标记为叶节点。如果不是叶节点,则选择最佳分割特征,并根据特征类型(离散或连续)生成对应的子节点。
queue 队列处理完毕后,通过 nodes_to_process 逆序遍历,每个节点的 leaf_num 设为其子节点的 leaf_num 总和。
代码实现
def generate_tree(self, X, y):
root = Node()
root.high = 0 # 根节点的高度为0
queue = deque([(root, X, y)]) # 使用队列来存储节点和数据
nodes_to_process = [] # 记录所有节点以便后续计算 leaf_num
while queue:
current_node, current_X, current_y = queue.popleft()
nodes_to_process.append(current_node)
# 叶节点条件:达到最大深度或只有单一类别或没有特征
if current_node.high >= self.MaxDepth or current_y.nunique() == 1 or current_X.empty:
current_node.is_leaf = True
current_node.leaf_class = current_y.mode()[0]
current_node.leaf_num = 1 # 是叶子节点,叶子数量为 1
continue
# 选择最佳划分特征
best_feature_name, best_impurity = self.choose_best_feature_to_split(current_X, current_y)
current_node.feature_name = best_feature_name
current_node.impurity = best_impurity[0]
current_node.feature_index = self.columns.index(best_feature_name)
feature_values = current_X[best_feature_name]
if len(best_impurity) == 1: # 离散值特征
current_node.is_continuous = False
unique_vals = feature_values.unique()
sub_X = current_X.drop(best_feature_name, axis=1)
for value in unique_vals:
child_node = Node()
child_node.high = current_node.high + 1
queue.append((child_node, sub_X[feature_values == value], current_y[feature_values == value]))
current_node.subtree[value] = child_node
elif len(best_impurity) == 2: # 连续值特征
current_node.is_continuous = True
current_node.split_value = best_impurity[1]
up_part = '>= {:.3f}'.format(current_node.split_value)
down_part = '< {:.3f}'.format(current_node.split_value)
child_node_up = Node()
child_node_down = Node()
child_node_up.high = current_node.high + 1
child_node_down.high = current_node.high + 1
queue.append((child_node_up, current_X[feature_values >= current_node.split_value],
current_y[feature_values >= current_node.split_value]))
queue.append((child_node_down, current_X[feature_values < current_node.split_value],
current_y[feature_values < current_node.split_value]))
current_node.subtree[up_part] = child_node_up
current_node.subtree[down_part] = child_node_down
# 逆序遍历 nodes_to_process,计算每个节点的 leaf_num
while nodes_to_process:
node = nodes_to_process.pop()
if node.is_leaf:
node.leaf_num = 1
else:
node.leaf_num = sum(child.leaf_num for child in node.subtree.values())
return root