大数据-202 sklearn 决策树实战:criterion、Graphviz 可视化与剪枝防过拟合

TL;DR

  • 场景:用 sklearn 在 Wine 数据集上训练 DecisionTreeClassifier,并用 Graphviz 导出可视化
  • 结论:criterion 影响分裂度量(gini/entropy/log_loss),random_state 与 splitter 决定稳定性与随机性;剪枝靠 max_depth/min_samples_leaf/ccp_alpha
  • 产出:可复现实验流程 + export_graphviz 参数要点 + 环境版本矩阵 + 常见报错定位表

版本矩阵

组件 已验证说明
scikit-learn 1.8.0 PyPI 声明2025-12-10 发布;Requires: Python>=3.11;DecisionTreeClassifier 的 criterion 支持 gini/entropy/log_loss DecisionTreeClassifier API官方文档;splitter 支持 best/random;criterion 见上 export_graphviz官方文档 导出 DOT;配合 dot 命令可生成 PNG/SVG 等
python-graphviz 0.21 PyPI 声明Requires: Python>=3.9;渲染仍依赖系统 Graphviz(dot)与 PATH
系统 Graphviz(dot) 官方站点需单独安装(Windows/macOS/Linux 方式不同)

使用sklearn实现决策树

参数CRITERION

criterion 这个参数使用来决定不纯度的计算方法,sklearn提供了两种选择:

  • 输入 entropy,使用信息熵(Entropy)
  • 输入 gini,使用基尼系数(Gini Impurity)

在机器学习中,基尼系数和信息熵都是常用的不纯度度量指标,用于决策树等算法的特征选择和节点划分。两者虽然计算方式和敏感性不同,但在实际应用中往往能达到相似的效果。

信息熵的计算公式为:H(X)=-Σp(x)log₂p(x),其中p(x)是某个类别的概率。由于涉及对数运算,其计算复杂度比基尼系数要高。相比之下,基尼系数的计算公式Gini=1-Σp(x)²更加简单直接,不需要对数运算,因此计算速度更快。

具体来说,信息熵对数据不纯度的惩罚更强。例如:

  • 当某个节点中两个类别的样本比例为7:3时:
    • 基尼系数=1-(0.7²+0.3²)=0.42
    • 信息熵=-(0.7log₂0.7+0.3log₂0.3)≈0.88
  • 当比例变为8:2时:
    • 基尼系数=0.32
    • 信息熵≈0.72

这种敏感性差异导致:

  1. 在高维数据场景下(如特征数超过1000),信息熵容易产生过拟合,因为它会倾向于选择更细粒度的划分方式。此时基尼系数表现更稳健。
  2. 在噪声数据较多时(如标签错误率超过10%),基尼系数的抗干扰能力更强。
  3. 当模型欠拟合时(训练集和测试集准确率都低于60%),使用信息熵可能获得更好的效果,因为它能驱动模型学习更精细的决策边界。

实际应用建议:

  • 对于结构化数据(如表格数据),可以优先尝试基尼系数
  • 在计算资源受限的实时系统中,基尼系数是更好的选择
  • 当数据质量较高且维度适中时(如UCI数据集),可以比较两种指标的效果
  • 在集成学习中(如随机森林),两种指标的差异通常会被弱化

需要注意的是,这些规律并非绝对。在实践中,最好的方式是通过交叉验证来比较两种指标在特定数据集上的表现。

初步建模

python 复制代码
# 导入需要的算法库和模块
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
plt.rcParams['font.sans-serif']=['Simhei']
plt.rcParams['axes.unicode_minus']=False

加载数据

python 复制代码
wine = load_wine()
wine.data.shape
wine.target

执行结果如下所示: 如果是 wine 是一张表,应该长这样:

python 复制代码
wine_pd=pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1).head()
wine.feature_names.append("result")
wine_pd.columns=wine.feature_names
wine_pd

执行结果如下图所示: 编写代码查看形状:

python 复制代码
Xtrain, Xtest, Ytrain, ytest = train_test_split(wine.data,wine.target,test_size=0.3,random_state=420)
print(Xtrain.shape)
print(Xtest.shape)

执行结果如下所示:

建立模型

python 复制代码
clf = tree.DecisionTreeClassifier(criterion="gini")
clf = clf.fit(Xtrain, Ytrain)
clf.score(Xtest, ytest) #返回预测的准确度

执行结果如下图所示:

画决策树

我们可以利用 Graphviz 模块导出决策树模型,第一次使用 Graphviz 之前需要进行安装,若是使用从 pip 安装:

shell 复制代码
!pip install graphviz

执行结果如下所示: 对图案进行绘制:

python 复制代码
import matplotlib.pyplot as plt
import graphviz
from sklearn import tree


feature_name = ['酒精','苹果酸','灰','灰的碱性','镁','总酚','类黄酮','非黄烷类酚类','花青素','颜色强度','色调','od280/od315 稀释葡萄酒','脯氨酸']
dot_data = tree.export_graphviz(clf, out_file = None, feature_names= feature_name, class_names=["琴酒","雪莉","贝尔摩德"], filled=True, rounded=True)
graph = graphviz.Source(dot_data)
graph

执行结果如下图所示: 绘制的图片如下所示:

export_graphviz 生成了一个 DOT 格式的决策树:

  • feature_names:每个属性的名字
  • class_names:每个因变量类别的名字
  • label:是否显示不纯度信息的标签,默认为 all 表都显示,可以是root或 none
  • filled:是否给每个节点的主分类绘制不同的颜色,默认为 False
  • out_file:输出的 dot 文件的名字,默认为 None表示不输出文件,可以是自定义名字如"tree.dot"
  • rounded:默认为 True,表示对每个节点的边框加圆角,使用 Helvetica 字体

防止过拟合

在不加限制的情况下,一颗决策树会持续生长直到满足以下任一条件:

  1. 所有叶节点的基尼系数(Gini impurity)或信息熵(information entropy)等不纯度指标达到最优(通常为0)
  2. 没有更多的特征可用于进一步分割数据
  3. 每个叶节点仅包含单一类别的样本

这种不加限制的生长方式往往会导致严重的过拟合问题。具体表现为:

  • 在训练集上可能达到100%的准确率
  • 但在测试集上表现显著下降(如准确率骤降20-30个百分点)

造成这种现象的根本原因在于:

  1. 样本偏差问题:我们收集的训练数据不可能完全代表整体数据分布,总会存在抽样偏差
  2. 噪声敏感问题:当决策树过度生长时:
    • 会捕捉到训练数据中的随机噪声(如5%的错误标注样本)
    • 为这些噪声创建特定的分裂规则(例如"当特征X=3.14159时预测类别A")
    • 导致模型泛化能力下降

典型示例: 在房价预测任务中,一棵过拟合的决策树可能会:

  • 为训练集中某个异常低价样本(如输入错误的$1元房价)创建特殊分支
  • 在实际预测时,遇到类似特征组合的房屋就会错误预测极低价格

解决方法通常包括:

  1. 预剪枝(Pre-pruning):
    • 设置最大深度(如max_depth=5)
    • 设置叶节点最小样本数(如min_samples_leaf=10)
  2. 后剪枝(Post-pruning):
    • 使用代价复杂度剪枝(CCP)
    • 基于验证集性能进行剪枝
  3. 集成方法:
    • 随机森林(通过多棵树投票降低过拟合风险)
    • 梯度提升树(通过迭代优化减少单棵树的影响)
python 复制代码
#我们的树对训练集的拟合程度如何?
score_train = clf.score(Xtrain, Ytrain)
score_train

执行结果如下图所示: 为了让决策树有更好的泛化性,我们要对决策树进行剪枝。剪枝策略对决策树的影响巨大,正确的剪枝策略是优化决策树算法的核心。

random_state

如果我们改动了 random_state,画出来的每一颗树都不一样,它为什么不稳定呢?如果使用其他数据集,它还会不稳定吗? 我们之前提过,无论决策树模型如何进化,在分支上的本质都还是追求某个不纯度相关的指标的优化,而正如我们提到的,不纯度是基于节点计算出来的,也就是说,决策树在建树时,是靠优化节点来追求一棵优化的树,但是最优的节点是能够保证最优的树吗?

集成算法被用来解决这个问题:sklearn 表示,既然一棵树不能保证最优,那就建更多不同的树,然后从中取最好的。怎么样从一组数据集中建不同的树呢?在每次分支的时候,不使用全部特征,而是随机选取一部分特征,从中选取不纯度相关指标最优的作为分支用的节点。 这样,每次生成的树叶不同了。

random_state 用来设置分支中的随机模型的参数,默认是 None,在高维度时随机性会表现更明显,低维度的数据随机性几乎不会显现。输入任意整数,会一直长出同一棵树,让模型稳定下来。

splitter

splitter 也是用来控制决策树中随机选项的,有两种输入值:

  • 输入 best,决策树在分支时虽然随机,但是还是会优先选择更重要的特征进行分支(重要性可以通过属性 feature_importance 查看)
  • 输入 random,决策树在分支时会更加随机,树会因为含有更多的不必要的信息而更深更大,并因为这些不必要信息而降低对训练集的拟合。这也是一种防止过拟合的方式。

当你预测到你的模型会过拟合,用这两个参数来帮助你降低树建成之后的过拟合的可能性,当然,树一旦建成,我们依然是使用剪枝参数来防止拟合的。

python 复制代码
clf = tree.DecisionTreeClassifier(criterion="entropy",random_state=30,splitter="random")
clf = clf.fit(Xtrain, Ytrain)
score = clf.score(Xtest, ytest)
print(score)
plt.rcParams['font.sans-serif']=['Simhei']
plt.rcParams['axes.unicode_minus']=False

代码的执行结果如下所示:

版本矩阵

症状 根因定位 修复
ModuleNotFoundError: No module named 'graphviz' 未安装 python-graphviz 包 python -c "import graphviz" pip install graphviz
ExecutableNotFound: failed to execute 'dot' / 渲染为空 只装了 python 包,未装系统 Graphviz 或 PATH 未配 dot -V 是否可用 安装系统 Graphviz,并把 dot 所在目录加入 PATH
ValueError: Length of feature_names ... feature_names 数量与 X 特征列不一致 对比 Xtrain.shape[1] 与 len(feature_name) 统一特征数;Wine 为 13 维时 feature_names 也应为 13
图里类别名/计数顺序不对 class_names 顺序与内部类别编码不一致 查看 tree 图首节点 value=[...] 顺序 class_names 按类别编码升序传入(如 0/1/2 对应的名称顺序)
NameError: name 'Ytrain' is not defined 变量名大小写不一致/拷贝时漏段 搜索 Ytrain/ytrain 统一命名(训练/测试标签用同一风格)
SyntaxError 出现在 %matplotlib inline 该语法仅适用于 Jupyter 运行环境是 .py / IDE 脚本 脚本环境移除该魔法命令
训练集 1.0、测试集明显下降 树无约束生长导致过拟合 对比 clf.score(Xtrain, Ytrain) vs clf.score(Xtest, ytest) 加 max_depth/min_samples_leaf/min_samples_split,或用 ccp_alpha 做后剪枝
InvalidParameterError: The 'criterion' parameter ... sklearn 版本差异或参数拼写 sklearn.version ;检查 criterion 值 用官方支持值;新版本支持 gini/entropy/log_loss

其他系列

🚀 AI篇持续更新中(长期更新)

AI炼丹日志-29 - 字节跳动 DeerFlow 深度研究框斜体样式架 私有部署 测试上手 架构研究 ,持续打造实用AI工具指南! AI研究-132 Java 生态前沿 2025:Spring、Quarkus、GraalVM、CRaC 与云原生落地

💻 Java篇持续更新中(长期更新)

Java-218 RocketMQ Java API 实战:同步/异步 Producer 与 Pull/Push Consumer MyBatis 已完结,Spring 已完结,Nginx已完结,Tomcat已完结,分布式服务已完结,Dubbo已完结,MySQL已完结,MongoDB已完结,Neo4j已完结,FastDFS 已完结,OSS已完结,GuavaCache已完结,EVCache已完结,RabbitMQ已完结,RocketMQ正在更新... 深入浅出助你打牢基础!

📊 大数据板块已完成多项干货更新(300篇):

包括 Hadoop、Hive、Kafka、Flink、ClickHouse、Elasticsearch 等二十余项核心组件,覆盖离线+实时数仓全栈! 大数据-278 Spark MLib - 基础介绍 机器学习算法 梯度提升树 GBDT案例 详解

相关推荐
豆豆5 分钟前
2026年建设网站的十个步骤
大数据·cms·网站建设·网站制作·低代码平台·建站·网站设计
一直在追8 分钟前
告别 WHERE id=1!大数据工程师的 AI 觉醒:手把手带你拆解向量数据库 (RAG 核心)
大数据·数据库
悟空码字9 分钟前
SpringBoot整合FFmpeg,打造你的专属视频处理工厂
java·spring boot·后端
独自归家的兔11 分钟前
Spring Boot 版本怎么选?2/3/4 深度对比 + 迁移避坑指南(含 Java 8→21 适配要点)
java·spring boot·后端
新芒14 分钟前
海尔智家加速全球体育营销
大数据·人工智能
aiguangyuan16 分钟前
CART算法简介
人工智能·python·机器学习
superman超哥21 分钟前
Rust 移动语义(Move Semantics)的工作原理:零成本所有权转移的深度解析
开发语言·后端·rust·工作原理·深度解析·rust移动语义·move semantics
金融街小单纯22 分钟前
2026年开年准备:“四要素”文件夹命名,有条理的整理
大数据
轻竹办公PPT24 分钟前
用 AI 制作 2026 年工作计划 PPT,需要准备什么
大数据·人工智能·python·powerpoint
jiaozi_zzq27 分钟前
2026中专大数据与会计专业可考证书与进阶指南
大数据