24/8/8算法笔记 条件筛选决策树根节点

筛选决策树的根节点是建立决策树过程中的一个重要步骤,主要原因包括:

  1. 减少计算量:选择合适的根节点可以减少树的深度,从而减少模型训练和预测时的计算量。

  2. 提高模型性能:选择最佳分裂点可以最大化模型的性能,通常通过减少误差或提高信息增益来实现。

  3. 防止过拟合:通过选择一个合适的分裂点,可以避免树生长得过于复杂,从而减少过拟合的风险。

  4. 提高泛化能力:一个好的根节点可以帮助模型更好地泛化到新的、未见过的数据上。

  5. 减少训练时间:选择一个好的根节点可以减少构建树所需的时间,特别是在处理大量数据时。

  6. 特征选择:在构建决策树时,选择根节点的过程也涉及到特征选择,即确定哪些特征对于预测目标变量最为重要。

  7. 信息增益:在ID3算法中,选择根节点是基于信息增益最大的原则。信息增益衡量了不纯度的减少,即在分裂前后数据集的不确定性的减少。

  8. 基尼不纯度:在CART算法中,选择根节点是基于最小化基尼不纯度的原则,基尼不纯度是衡量数据集纯度的一个指标。

  9. 模型解释性:一个好的根节点可以提高模型的可解释性,使得模型的决策过程更容易被理解和解释。

  10. 数据分布:选择根节点时考虑数据的分布,可以确保树的分裂更好地反映了数据的内在结构。

在实际应用中,选择根节点通常涉及到计算不同特征的分裂点,并评估每个分裂点的性能指标,如信息增益、基尼不纯度等,以确定哪个特征和哪个分裂点是最佳的。这个过程是构建决策树算法的核心部分。

导入要构建的数据

复制代码
import numpy as np
import pandas as pd
y = np.array(list('NYYYYYNYYN'))
print(y)
X = pd.DataFrame({'日志密度':list('sslmlmmlms'),
                  '好友密度':list('slmmmlsmss'),
                  '真实头像':list('NYYYYNYYYY'),
                  '真实用户':y})
X

将分类数据替换为真实数据

复制代码
X['日志密度'] = X['日志密度'].map({'s': 0, 'm': 1, 'l': 2})
X['好友密度'] = X['好友密度'].map({'s':0,'m':1,'l':2})
X['真实头像'] = X['真实头像'].map({'N':0,'Y':1})
X
复制代码
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

import graphviz 

构建决策树

复制代码
model = DecisionTreeClassifier(criterion='entropy')
model.fit(X.iloc[:,:-1],y)#二维切片

#dot点
dot_data = tree.export_graphviz(model,filled =True,rounded=True,feature_names=X.columns[:-1])

graphviz.Source(dot_data)

特征选择和最佳分裂点的确定。代码的目的是遍历给定的列列表 cols,计算每个特征的可能分裂点,并使用信息熵来评估这些分裂点,选择信息熵最小的分裂点作为最佳分裂点。

复制代码
cols =['日志密度','好友密度','真实头像']
lower_entropy = 3#最小的信息熵
best_split = {}

for col in cols[:1]:
    x=X[col].unique()#返回去重之后的数据
    x.sort()#0,1,2
    print(x)
    #如何根据这一列划分
    for i in range(len(x)-1):#裂分点,裂分值
        split = x[i:i+2].mean()
        
        #裂分的概率分布
        cond = X[col]<=split
        p = cond.value_counts()/cond.size
        print(p)
        
        indexs = p.index
        entroy=0
        for index in indexs:
            user = X[cond ==index]['真实用户']
            p_user = user.value_counts()/user.size
            entroy += (p_user * np.log2(1/p_user)).sum() * p[index]
        if entroy<lower_entropy:
            lower_entropy = entroy
            best_split.clear()
            best_split[col] = split
print('最佳裂分条件:',best_split)
相关推荐
跃龙客34 分钟前
atomic笔记
笔记·算法
中屹指纹浏览器2 小时前
2026指纹浏览器环境隔离技术:进程、网络、存储三维深度隔离架构
经验分享·笔记
Smoothcloud润云2 小时前
Google DeepMind 学习系列笔记(3):Design And Train Neural Networks
数据库·人工智能·笔记·深度学习·学习·数据分析·googlecloud
【数据删除】3483 小时前
计算机复试学习笔记 Day26【补】
笔记·学习
clear sky .3 小时前
[bootloader]使用笔记
笔记
myloveasuka3 小时前
寻址方式笔记
汇编·笔记·计算机组成原理
johnny2333 小时前
编辑器和笔记软件汇总:Typst、Reminds、Memos、Editor、MDX Notes、Jotty
笔记·编辑器
Helibo443 小时前
数论中的整除
笔记·学习
ding_zhikai4 小时前
【Web应用开发笔记】Django笔记3-2:部署我的简陋网页
笔记·后端·python·django
山岚的运维笔记4 小时前
SQL Server笔记 -- 第86章:查询存储
笔记·python·sql·microsoft·sqlserver·flask