24/8/8算法笔记 条件筛选决策树根节点

筛选决策树的根节点是建立决策树过程中的一个重要步骤,主要原因包括:

  1. 减少计算量:选择合适的根节点可以减少树的深度,从而减少模型训练和预测时的计算量。

  2. 提高模型性能:选择最佳分裂点可以最大化模型的性能,通常通过减少误差或提高信息增益来实现。

  3. 防止过拟合:通过选择一个合适的分裂点,可以避免树生长得过于复杂,从而减少过拟合的风险。

  4. 提高泛化能力:一个好的根节点可以帮助模型更好地泛化到新的、未见过的数据上。

  5. 减少训练时间:选择一个好的根节点可以减少构建树所需的时间,特别是在处理大量数据时。

  6. 特征选择:在构建决策树时,选择根节点的过程也涉及到特征选择,即确定哪些特征对于预测目标变量最为重要。

  7. 信息增益:在ID3算法中,选择根节点是基于信息增益最大的原则。信息增益衡量了不纯度的减少,即在分裂前后数据集的不确定性的减少。

  8. 基尼不纯度:在CART算法中,选择根节点是基于最小化基尼不纯度的原则,基尼不纯度是衡量数据集纯度的一个指标。

  9. 模型解释性:一个好的根节点可以提高模型的可解释性,使得模型的决策过程更容易被理解和解释。

  10. 数据分布:选择根节点时考虑数据的分布,可以确保树的分裂更好地反映了数据的内在结构。

在实际应用中,选择根节点通常涉及到计算不同特征的分裂点,并评估每个分裂点的性能指标,如信息增益、基尼不纯度等,以确定哪个特征和哪个分裂点是最佳的。这个过程是构建决策树算法的核心部分。

导入要构建的数据

import numpy as np
import pandas as pd
y = np.array(list('NYYYYYNYYN'))
print(y)
X = pd.DataFrame({'日志密度':list('sslmlmmlms'),
                  '好友密度':list('slmmmlsmss'),
                  '真实头像':list('NYYYYNYYYY'),
                  '真实用户':y})
X

将分类数据替换为真实数据

X['日志密度'] = X['日志密度'].map({'s': 0, 'm': 1, 'l': 2})
X['好友密度'] = X['好友密度'].map({'s':0,'m':1,'l':2})
X['真实头像'] = X['真实头像'].map({'N':0,'Y':1})
X
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

import graphviz 

构建决策树

model = DecisionTreeClassifier(criterion='entropy')
model.fit(X.iloc[:,:-1],y)#二维切片

#dot点
dot_data = tree.export_graphviz(model,filled =True,rounded=True,feature_names=X.columns[:-1])

graphviz.Source(dot_data)

特征选择和最佳分裂点的确定。代码的目的是遍历给定的列列表 cols,计算每个特征的可能分裂点,并使用信息熵来评估这些分裂点,选择信息熵最小的分裂点作为最佳分裂点。

cols =['日志密度','好友密度','真实头像']
lower_entropy = 3#最小的信息熵
best_split = {}

for col in cols[:1]:
    x=X[col].unique()#返回去重之后的数据
    x.sort()#0,1,2
    print(x)
    #如何根据这一列划分
    for i in range(len(x)-1):#裂分点,裂分值
        split = x[i:i+2].mean()
        
        #裂分的概率分布
        cond = X[col]<=split
        p = cond.value_counts()/cond.size
        print(p)
        
        indexs = p.index
        entroy=0
        for index in indexs:
            user = X[cond ==index]['真实用户']
            p_user = user.value_counts()/user.size
            entroy += (p_user * np.log2(1/p_user)).sum() * p[index]
        if entroy<lower_entropy:
            lower_entropy = entroy
            best_split.clear()
            best_split[col] = split
print('最佳裂分条件:',best_split)
相关推荐
金星娃儿2 小时前
MATLAB基础知识笔记——(矩阵的运算)
笔记·matlab·矩阵
B20080116刘实5 小时前
CTF攻防世界小白刷题自学笔记13
开发语言·笔记·web安全·网络安全·php
静止了所有花开6 小时前
SpringMVC学习笔记(二)
笔记·学习
红中马喽9 小时前
JS学习日记(webAPI—DOM)
开发语言·前端·javascript·笔记·vscode·学习
huangkj-henan12 小时前
DA217应用笔记
笔记
Young_2022020212 小时前
学习笔记——KMP
笔记·学习
秀儿还能再秀12 小时前
机器学习——简单线性回归、逻辑回归
笔记·python·学习·机器学习
WCF向光而行12 小时前
Getting accurate time estimates from your tea(从您的团队获得准确的时间估计)
笔记·学习
Li_03040614 小时前
Java第十四天(实训学习整理资料(十三)Java网络编程)
java·网络·笔记·学习·计算机网络
啤酒泡泡_Lyla14 小时前
现代无线通信接收机架构:超外差、零中频与低中频的比较分析
笔记·信息与通信