决策树的调参,比较直观的代码显示

复制代码
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_digits, load_wine
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.metrics import accuracy_score
import numpy as np
import matplotlib.pyplot as plt

# 第一步,获取数据集
# 获取数据集的特征值x 和 目标属性y
dataset = load_digits()

x = dataset.data

y = dataset.target

# 第二步,确定是哪种划分标准好 criterion
# 这一步是找出基尼划分标准好,还是熵划分标准好
DT = DecisionTreeClassifier(random_state=66)
score = cross_val_score(DT, x, y, cv=10).mean()
print("使用基尼划分标准获取到的值为%0.4f"%score)
# 使用基尼划分标准获取到的值为0.8252

DT = DecisionTreeClassifier(criterion="entropy", random_state=66)
score = cross_val_score(DT, x, y, cv=10).mean()
print("使用熵划分标准获取到的值为%0.4f"%score)
# 使用熵划分标准获取到的值为0.8036

# 第三步,查找出模型深度的大概范围 max_depth
# 这一步是大范围先找出模型的最大深度值大概多少最好
# 在大范围内画出max_depth这个参数变化曲线  深度从10开始,每隔10验证一次,获取模型的精确率
ScoreAll = []
for i in range(10, 100, 10):
    DT = DecisionTreeClassifier(max_depth=i, random_state=66)
    score = cross_val_score(DT, dataset.data, dataset.target, cv=10).mean()
    ScoreAll.append([i, score])  # 将模型的深度和精确率保存到数组里面
ScoreAll = np.array(ScoreAll)

max_score = np.where(ScoreAll == np.max(ScoreAll[:, 1]))[0][0]  # 找出ScoreAll里面得分最高的深度
print("大概最优深度参数对应的最高分为:", ScoreAll[max_score])
# 大概最优深度参数对应的最高分为: [20.          0.82524209]
# print(ScoreAll[0])
plt.figure(figsize=[20, 5])
plt.plot(ScoreAll[:, 0], ScoreAll[:, 1])
plt.show()

# 第四步,查找出模型深度的具体参数 max_depth
# 从大范围获取到具体的大概值之后,从如图可以获知是10-25之间
# 从大范围找到最大深度的大概值之后,在这个范围开始找具体的深度值。可以直接复用上面的方法,但是这里为了比较直观,我多写一遍
ScoreAll1 = []
for i in range(10, 30):
    DT = DecisionTreeClassifier(max_depth=i, random_state=66)
    score = cross_val_score(DT, dataset.data, dataset.target, cv=10).mean()
    ScoreAll1.append([i, score])  # 将模型的深度和精确率保存到数组里面
ScoreAll1 = np.array(ScoreAll1)

max_score = np.where(ScoreAll1 == np.max(ScoreAll1[:, 1]))[0][0]  # 找出ScoreAll里面得分最高的深度
print("具体最优深度参数对应的最高分为:", ScoreAll1[max_score])
# 具体最优深度参数对应的最高分为: [13.          0.82691186]
# print(ScoreAll[0])
plt.figure(figsize=[20, 5])
plt.plot(ScoreAll1[:, 0], ScoreAll1[:, 1])
plt.show()

# 第五步, 查找出分割内部节点所需的最小样本数 min_samples_split
# 使用得到划分标准和深度值 去查找分割内部节点所需的最小样本数  目前先从范围(2,30)里面去找
ScoreAll2 = []
for i in range(2, 30):
    DT = DecisionTreeClassifier(max_depth=13, random_state=66, min_samples_split=i)
    score = cross_val_score(DT, dataset.data, dataset.target, cv=10).mean()
    ScoreAll2.append([i, score])  # 将模型分割内部节点所需的最小样本数和精确率保存到数组里面
ScoreAll2 = np.array(ScoreAll2)

max_score = np.where(ScoreAll2 == np.max(ScoreAll2[:, 1]))[0][0]  # 找出ScoreAll里面得分最高的深度
print("具体分割内部节点所需的最小样本数为:", ScoreAll2[max_score])
# 具体分割内部节点所需的最小样本数为: [5.         0.82692117]
# print(ScoreAll[0])
plt.figure(figsize=[20, 5])
plt.plot(ScoreAll2[:, 0], ScoreAll2[:, 1])
plt.show()

# 第六步 查找出叶子节点上的最小样本数 min_samples_leaf
ScoreAll3 = []
for i in range(1, 30):
    DT = DecisionTreeClassifier(min_samples_leaf=i, max_depth=13, random_state=66, min_samples_split=5)
    score = cross_val_score(DT, dataset.data, dataset.target, cv=10).mean()
    ScoreAll3.append([i, score])  # 将模型叶子节点上的最小样本数和精确率保存到数组里面
ScoreAll3 = np.array(ScoreAll3)

max_score = np.where(ScoreAll3 == np.max(ScoreAll3[:, 1]))[0][0]  # 找出ScoreAll里面得分最高的深度
print("具体叶子节点上的最小样本数为:", ScoreAll3[max_score])
# 具体叶子节点上的最小样本数为: [2.         0.83417132]
# print(ScoreAll[0])
plt.figure(figsize=[20, 5])
plt.plot(ScoreAll3[:, 0], ScoreAll3[:, 1])
plt.show()

# 利用网格搜索 在小范围联调max_depth,min_samples_split,min_samples_leaf三个参数
# 从上面获取到的值max_depth=13,min_samples_split=5,min_samples_leaf=2  根据这几个值做个小范围参数设置
# max_depth,min_samples_split,min_sample_leaf 一块调整
param_grid = {
    'max_depth': np.arange(10, 15),
    'min_samples_leaf': np.arange(1, 8),
    'min_samples_split': np.arange(2, 8)
}

rfc = DecisionTreeClassifier(random_state=66)
GS = GridSearchCV(rfc, param_grid, cv=10)
GS.fit(dataset.data, dataset.target)
# 最后拿到最佳的参数值
print(GS.best_params_)
print(GS.best_score_)

以上是使用我自己能理解的角度去敲了一次代码,比较直观的看到现象。归根到底还是参考了别人想法

参考了:机器学习超详细实践攻略(9):决策树算法使用及小白都能看懂的调参指南 - 知乎 (zhihu.com)

相关推荐
喵叔哟13 分钟前
02-YOLO-v8-v9-v10工程差异对比
人工智能·yolo·机器学习
团子的二进制世界27 分钟前
G1垃圾收集器是如何工作的?
java·jvm·算法
白日做梦Q31 分钟前
Anchor-free检测器全解析:CenterNet vs FCOS
python·深度学习·神经网络·目标检测·机器学习
吃杠碰小鸡31 分钟前
高中数学-数列-导数证明
前端·数学·算法
故事不长丨31 分钟前
C#线程同步:lock、Monitor、Mutex原理+用法+实战全解析
开发语言·算法·c#
long31632 分钟前
Aho-Corasick 模式搜索算法
java·数据结构·spring boot·后端·算法·排序算法
近津薪荼32 分钟前
dfs专题4——二叉树的深搜(验证二叉搜索树)
c++·学习·算法·深度优先
熊文豪41 分钟前
探索CANN ops-nn:高性能哈希算子技术解读
算法·哈希算法·cann
熊猫_豆豆1 小时前
YOLOP车道检测
人工智能·python·算法
艾莉丝努力练剑1 小时前
【Linux:文件】Ext系列文件系统(初阶)
大数据·linux·运维·服务器·c++·人工智能·算法