24/8/8算法笔记 条件筛选决策树根节点

筛选决策树的根节点是建立决策树过程中的一个重要步骤,主要原因包括:

  1. 减少计算量:选择合适的根节点可以减少树的深度,从而减少模型训练和预测时的计算量。

  2. 提高模型性能:选择最佳分裂点可以最大化模型的性能,通常通过减少误差或提高信息增益来实现。

  3. 防止过拟合:通过选择一个合适的分裂点,可以避免树生长得过于复杂,从而减少过拟合的风险。

  4. 提高泛化能力:一个好的根节点可以帮助模型更好地泛化到新的、未见过的数据上。

  5. 减少训练时间:选择一个好的根节点可以减少构建树所需的时间,特别是在处理大量数据时。

  6. 特征选择:在构建决策树时,选择根节点的过程也涉及到特征选择,即确定哪些特征对于预测目标变量最为重要。

  7. 信息增益:在ID3算法中,选择根节点是基于信息增益最大的原则。信息增益衡量了不纯度的减少,即在分裂前后数据集的不确定性的减少。

  8. 基尼不纯度:在CART算法中,选择根节点是基于最小化基尼不纯度的原则,基尼不纯度是衡量数据集纯度的一个指标。

  9. 模型解释性:一个好的根节点可以提高模型的可解释性,使得模型的决策过程更容易被理解和解释。

  10. 数据分布:选择根节点时考虑数据的分布,可以确保树的分裂更好地反映了数据的内在结构。

在实际应用中,选择根节点通常涉及到计算不同特征的分裂点,并评估每个分裂点的性能指标,如信息增益、基尼不纯度等,以确定哪个特征和哪个分裂点是最佳的。这个过程是构建决策树算法的核心部分。

导入要构建的数据

复制代码
import numpy as np
import pandas as pd
y = np.array(list('NYYYYYNYYN'))
print(y)
X = pd.DataFrame({'日志密度':list('sslmlmmlms'),
                  '好友密度':list('slmmmlsmss'),
                  '真实头像':list('NYYYYNYYYY'),
                  '真实用户':y})
X

将分类数据替换为真实数据

复制代码
X['日志密度'] = X['日志密度'].map({'s': 0, 'm': 1, 'l': 2})
X['好友密度'] = X['好友密度'].map({'s':0,'m':1,'l':2})
X['真实头像'] = X['真实头像'].map({'N':0,'Y':1})
X
复制代码
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

import graphviz 

构建决策树

复制代码
model = DecisionTreeClassifier(criterion='entropy')
model.fit(X.iloc[:,:-1],y)#二维切片

#dot点
dot_data = tree.export_graphviz(model,filled =True,rounded=True,feature_names=X.columns[:-1])

graphviz.Source(dot_data)

特征选择和最佳分裂点的确定。代码的目的是遍历给定的列列表 cols,计算每个特征的可能分裂点,并使用信息熵来评估这些分裂点,选择信息熵最小的分裂点作为最佳分裂点。

复制代码
cols =['日志密度','好友密度','真实头像']
lower_entropy = 3#最小的信息熵
best_split = {}

for col in cols[:1]:
    x=X[col].unique()#返回去重之后的数据
    x.sort()#0,1,2
    print(x)
    #如何根据这一列划分
    for i in range(len(x)-1):#裂分点,裂分值
        split = x[i:i+2].mean()
        
        #裂分的概率分布
        cond = X[col]<=split
        p = cond.value_counts()/cond.size
        print(p)
        
        indexs = p.index
        entroy=0
        for index in indexs:
            user = X[cond ==index]['真实用户']
            p_user = user.value_counts()/user.size
            entroy += (p_user * np.log2(1/p_user)).sum() * p[index]
        if entroy<lower_entropy:
            lower_entropy = entroy
            best_split.clear()
            best_split[col] = split
print('最佳裂分条件:',best_split)
相关推荐
老王熬夜敲代码4 分钟前
解决IP不够用的问题
linux·网络·笔记
polarislove021418 分钟前
8.1 时钟树-嵌入式铁头山羊STM32笔记
笔记·stm32·嵌入式硬件
QT 小鲜肉28 分钟前
【Linux命令大全】001.文件管理之file命令(实操篇)
linux·运维·前端·网络·chrome·笔记
hetao17338372 小时前
2025-12-21~22 hetao1733837的刷题笔记
c++·笔记·算法
一声沧海笑2 小时前
【GEE学习笔记】GEE中如何上传矢量图?
笔记·学习
呱呱巨基2 小时前
Linux 进程控制
linux·c++·笔记·学习
阿恩.7703 小时前
前沿科技计算机国际期刊征稿:电子、AI与网络计算
人工智能·经验分享·笔记·计算机网络·考研·云计算
代码游侠3 小时前
应用——MPlayer 媒体播放器系统代码详解
linux·运维·笔记·学习·算法
悠哉悠哉愿意4 小时前
【EDA学习笔记】电子技术基础知识:基本元件
笔记·嵌入式硬件·学习·eda
不解风水4 小时前
【教程笔记】KalmanFilter
笔记·学习·算法·矩阵·ekf