机器学习入门<2>决策树算法

最近在准备考研,突然想整理一下之前学过的机器学习算法。今天先从决策树开始,这个算法就像我们生活中的决策过程一样直观易懂。我搭建了Pycharm+Python3.10环境,下面让我们一起用代码实现一个完整的决策树项目。

一、决策树

1.决策树的含义

决策树通过一系列的判断规则来进行分类或回归预测的算法。

想象一下,你家的智能空调就像一棵会思考的决策树,它需要根据多种环境因素自动选择最适合的运行模式。让我们用这个生动例子来理解决策树算法:

假设你的智能空调有4种工作模式:

  1. 制冷模式 - 快速降温

  2. 制热模式 - 温暖房间

  3. 除湿模式 - 降低湿度

  4. 送风模式 - 仅通风循环

空调需要根据以下环境数据自动决策:

2.决策树的核心概念

  • 根节点:树的起点,包含所有样本

  • 内部节点:决策点,对应一个特征测试

  • 叶节点:决策结果,对应分类标签或回归值

  • 分支:特征的取值范围

3.决策树的核心原理

决策树最核心的问题就是:每次应该用哪个特征来划分数据?

常用标准有三种:

3.1信息增益(ID3)

熵衡量不确定性,熵越大越混乱

python 复制代码
信息增益 = 父节点的熵 - 子节点的加权平均熵

熵 = -Σ p(x) * log₂p(x)
# 熵衡量不确定性,熵越大越混乱

3.2信息增益比(也叫C4.5算法)

解决了信息增益偏好取值多的特征的问题。

python 复制代码
信息增益比 = 信息增益 / 特征本身的熵
# 解决了信息增益偏好取值多的特征的问题

3.3基尼指数(CART算法)

表示随机抽两个样本,类别不一致的概率。基尼指数越小,数据越纯。

python 复制代码
基尼指数 = 1 - Σ p(x)²
# 表示随机抽两个样本,类别不一致的概率
# 基尼指数越小,数据越纯

4.决策树的构建过程

下面是写的一个伪代码,它解释了决策树的构建过程。

python 复制代码
def 构建决策树(数据集):
    if 满足停止条件:  # 比如所有样本属于同一类
        return 叶节点
    
    1. 选择最佳划分特征  # 使用信息增益/基尼指数
    2. 根据特征值划分数据集
    3. 创建当前节点
    4. for 每个划分:
        子树 = 构建决策树(子数据集)
        将子树添加到当前节点
    
    return 当前节点

5.决策树的停止条件

决策树的停止条件主要有以下4种:

  • 节点中样本属于同一类别

  • 没有特征可用

  • 达到预设的最大深度

  • 节点样本数小于阈值

二、鸢尾花分类项目

1.准备环境

在Pycharm终端中输入:

python 复制代码
# 所需库
pip install numpy matplotlib scikit-learn

2.完整代码实现

python 复制代码
# ========== 1. 导入必要的库 ==========
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn import tree

# ========== 2. 设置中文字体 ==========
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'KaiTi', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

# ========== 3. 加载和准备数据 ==========
# 加载鸢尾花数据集
iris = load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.3,
    random_state=42,  # 设置随机种子保证结果可复现
    stratify=y       # 保持类别分布
)

# ========== 4. 创建决策树模型 ==========
clf = DecisionTreeClassifier(
    max_depth=3,            # 最大深度限制,防止过拟合
    criterion='gini',       # 使用基尼指数作为划分标准
    min_samples_split=2,    # 内部节点再划分所需最小样本数
    min_samples_leaf=1,     # 叶节点最少样本数
    random_state=42         # 随机种子
)

# 训练模型
clf.fit(X_train, y_train)

# 评估模型
train_accuracy = clf.score(X_train, y_train)
test_accuracy = clf.score(X_test, y_test)
print("=" * 50)
print(f"训练集准确率: {train_accuracy:.2%}")
print(f"测试集准确率: {test_accuracy:.2%}")
print("=" * 50)

# 特征重要性
feature_importance = dict(zip(iris.feature_names, clf.feature_importances_))
print("特征重要性:")
for feature, importance in feature_importance.items():
    print(f"  {feature}: {importance:.2%}")
print("=" * 50)

# ========== 5. 设置中文标签 ==========
# 中文特征名(对应原英文特征名)
chinese_feature_names = ['花萼长度(cm)', '花萼宽度(cm)', '花瓣长度(cm)', '花瓣宽度(cm)']

# 中文类别名
chinese_class_names = ['山鸢尾', '变色鸢尾', '维吉尼亚鸢尾']

# ========== 6. 可视化决策树(基础中文版) ==========
# 创建图形
plt.figure(figsize=(16, 10))

# 绘制决策树
tree.plot_tree(
    clf,
    feature_names=chinese_feature_names,  # 使用中文特征名
    class_names=chinese_class_names,      # 使用中文类别名
    filled=True,           # 填充颜色
    rounded=True,          # 圆角边框
    fontsize=10,           # 字体大小
    impurity=True,         # 显示不纯度(基尼指数)
    proportion=True,       # 显示样本比例
    precision=2            # 小数位数
)

# 添加标题
plt.title('鸢尾花分类决策树可视化', fontsize=16, fontweight='bold', pad=20)

# 显示图形
plt.tight_layout()
plt.show()

# ========== 7. 增强版可视化(带统计信息) ==========
print("\n正在生成增强版可视化...")

# 创建更大的图形
fig, axes = plt.subplots(1, 2, figsize=(20, 10))

# 左图:决策树
tree.plot_tree(
    clf,
    feature_names=chinese_feature_names,
    class_names=chinese_class_names,
    filled=True,
    rounded=True,
    fontsize=9,
    ax=axes[0]  # 绘制在左图
)
axes[0].set_title('决策树结构', fontsize=14, fontweight='bold')

# 右图:特征重要性柱状图
features = chinese_feature_names
importance = clf.feature_importances_

# 按重要性排序
sorted_idx = np.argsort(importance)
sorted_features = [features[i] for i in sorted_idx]
sorted_importance = importance[sorted_idx]

# 创建水平条形图
bars = axes[1].barh(range(len(sorted_features)), sorted_importance, color='skyblue')
axes[1].set_yticks(range(len(sorted_features)))
axes[1].set_yticklabels(sorted_features, fontsize=11)
axes[1].set_xlabel('重要性', fontsize=12)
axes[1].set_title('特征重要性排名', fontsize=14, fontweight='bold')

# 在条形图右侧添加数值标签
for i, bar in enumerate(bars):
    width = bar.get_width()
    axes[1].text(width + 0.01, bar.get_y() + bar.get_height()/2,
                f'{width:.2%}',
                ha='left', va='center', fontsize=10)

# 添加整体标题
fig.suptitle('鸢尾花决策树分类模型分析', fontsize=18, fontweight='bold', y=0.98)

# 调整布局并显示
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()

# ========== 8. 导出决策规则 ==========
print("\n" + "="*60)
print("决策树分类规则(中文版):")
print("="*60)

from sklearn.tree import export_text

# 导出文本格式的决策规则
tree_rules = export_text(
    clf,
    feature_names=chinese_feature_names,
    decimals=2
)
print(tree_rules)
print("="*60)

# ========== 9. 示例预测 ==========
print("\n示例预测:")
print("-"*40)

# 创建一些测试样本
test_samples = np.array([
    [5.1, 3.5, 1.4, 0.2],  # 样本1: 山鸢尾
    [6.7, 3.0, 5.2, 2.3],  # 样本2: 维吉尼亚鸢尾
    [5.9, 3.0, 4.2, 1.5],  # 样本3: 变色鸢尾
])

print("测试样本特征值:")
for i, sample in enumerate(test_samples, 1):
    print(f"样本{i}: 花萼长度={sample[0]:.1f}cm, 花萼宽度={sample[1]:.1f}cm, "
          f"花瓣长度={sample[2]:.1f}cm, 花瓣宽度={sample[3]:.1f}cm")

# 进行预测
predictions = clf.predict(test_samples)
probabilities = clf.predict_proba(test_samples)

print("\n预测结果:")
for i, (pred, prob) in enumerate(zip(predictions, probabilities), 1):
    print(f"\n样本{i}:")
    print(f"  预测类别: {chinese_class_names[pred]}")
    print(f"  各类别概率:")
    for j, class_name in enumerate(chinese_class_names):
        print(f"    {class_name}: {prob[j]:.2%}")
print("-"*40)

# ========== 10. 保存图像到文件 ==========
print("\n保存图像到文件...")

# 保存决策树图
plt.figure(figsize=(16, 10))
tree.plot_tree(
    clf,
    feature_names=chinese_feature_names,
    class_names=chinese_class_names,
    filled=True,
    rounded=True,
    fontsize=10
)
plt.title('鸢尾花分类决策树可视化', fontsize=16, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig('决策树可视化.png', dpi=300, bbox_inches='tight')
print("决策树图已保存为 '决策树可视化.png'")

# 保存特征重要性图
fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.barh(range(len(sorted_features)), sorted_importance, color='lightgreen')
ax.set_yticks(range(len(sorted_features)))
ax.set_yticklabels(sorted_features, fontsize=11)
ax.set_xlabel('重要性', fontsize=12)
ax.set_title('决策树特征重要性', fontsize=14, fontweight='bold')

for i, bar in enumerate(bars):
    width = bar.get_width()
    ax.text(width + 0.01, bar.get_y() + bar.get_height()/2,
           f'{width:.2%}',
           ha='left', va='center', fontsize=10)

plt.tight_layout()
plt.savefig('特征重要性.png', dpi=300, bbox_inches='tight')
print("特征重要性图已保存为 '特征重要性.png'")

print("\n 所有任务完成!")

3.运行结果解读

运行上面的代码,你会得到:

3.1决策树的可视化的图片

  1. 每个节点显示划分条件

  2. 基尼指数表示纯度

  3. 样本数量和分布

  4. 颜色深浅表示节点纯度

3.2特征分析

  1. 条形图显示各特征的重要性

  2. 可以看到花瓣尺寸是最重要的特征

3.3控制台输出分析

三、总结

决策树作为机器学习中最直观的算法之一,是入门机器学习的绝佳起点。

本篇文章解释了:

  1. 决策树的基本原理和构建过程

  2. 如何用Python实现决策树分类

  3. 如何可视化决策树并解释结果

相关推荐
lxmyzzs40 分钟前
vLLM、SGLang 与 TensorRT-LLM 综合对比分析报告
人工智能·自然语言处理
Blossom.11841 分钟前
基于多智能体协作的AIGC内容风控系统:从单点检测到可解释裁决链
人工智能·python·深度学习·机器学习·设计模式·aigc·transformer
阿杰学AI42 分钟前
AI核心知识30——大语言模型之CoT(简洁且通俗易懂版)
人工智能·语言模型·自然语言处理·aigc·agi·cot·思维链
风途知识百科43 分钟前
光伏板便捷式iv曲线测试仪:怎么给电站号脉?
人工智能
LeeZhao@43 分钟前
【狂飙全模态】狂飙AGI-智能图文理解助手
数据库·人工智能·redis·语言模型·机器人·agi
逻极44 分钟前
从“炼丹”到“炼钢”:我们如何将机器学习推理服务吞吐量提升300%
机器学习·ai·scikit-learn
AI架构师易筋44 分钟前
机器学习中的熵、信息量、交叉熵和 KL 散度:从入门到严谨
人工智能·机器学习
serve the people1 小时前
TensorFlow 模型的 “完整保存与跨环境共享” 方案
人工智能·tensorflow·neo4j
曲幽1 小时前
Flask项目目录结构指南:从单文件到模块化
python·web·model·route·项目结构