机器学习 - 使用 ID3 算法从原理到实际举例理解决策树

一、什么是决策树

1.基本概念

决策树是一种树形结构,由结点(node)有向边(directed edge) 组成。其中结点分为两类:

  • 内部结点(internal node):表示一个属性(特征)
  • 叶结点(leaf node):表示一个类别

决策树是常用的分类机器学习方法。

2.实际举例说明

以 "相亲对象分类系统" 为例构建简单决策树:

  • 内部结点(长方形):特征 "有无房子""有无上进心"
  • 叶结点(椭圆形):类别 "值得考虑""备胎""Say Goodbye"
  • 分类逻辑:
  • 相亲对象有房子→划分为 "值得认真考虑"

  • 没有房子但有上进心→划分为 "备胎"既没有房子也没有上进心→划分为 "Say Goodbye

实际分类中存在多个特征量,可构建多种决策树,核心问题是如何筛选出最优决策树

二、介绍建立决策树的算法

决策树算法的核心差异在于特征选择指标,常见算法对比如下:

|------|--------|------------------------------------------------|
| 算法 | 特征选择指标 | 核心逻辑 |
| ID3 | 信息增益 | 信息增益越大,特征对降低数据不确定性的能力越强,优先作为上层结点 |
| C4.5 | 信息增益率 | 解决 ID3 对多值特征的偏好问题,通过 "增益率 = 信息增益 / 特征固有值" 平衡选择 |
| CART | 基尼指数 | 基尼指数越小,数据纯度越高,优先选择使基尼指数下降最多的特征 |

本文重点讲解ID3 算法,以下是其核心概念与公式:

1. 某个分类的信息

单个分类的信息表示该分类的不确定性,公式为:

其中,P(x_i) 是选择该分类的概率。

2. 熵(Entropy)

熵是随机变量不确定性的度量,定义为信息的期望值,公式为:

其中,n 是分类的数目;熵值越大,数据不确定性越高。

3. 经验熵(Empirical Entropy)

4. 条件熵(Conditional Entropy)

已知随机变量 X 的条件下,随机变量 Y 的不确定性,公式为:

其中,p_i 是 X=x_i 的概率,H(Y|X=x_i) 是 X=x_i 时 Y 的熵。

5. 信息增益(Information Gain)

样本集 D 的经验熵 H(D) 与特征 A 给定条件下 D 的经验条件熵 H(D|A) 之差,公式为:

关键结论:特征的信息增益值越大,该特征对分类的贡献越强,应优先作为决策树的上层结点。

三、决策树的一般流程

决策树构建分为 6 个步骤,适用于各类决策树算法:

  1. 收集数据:通过爬虫、问卷、数据库查询等方式获取原始数据,无固定方法。
  1. 准备数据 :树构造算法仅支持标称型数据(离散类别数据),需将数值型数据离散化(如将 "年龄 20-30" 划分为 "青年")。
  1. 分析数据:构建树后,通过可视化、误差分析等方式验证树结构是否符合预期。
  1. 训练算法:根据特征选择指标(如 ID3 的信息增益),递归构建决策树的数据结构。
  1. 测试算法:使用测试集计算决策树的错误率,评估模型性能。
  1. 使用算法:将训练好的决策树应用于实际场景(如贷款审批、客户分类),并持续迭代优化。

四、实际举例构建决策树

以 "贷款申请分类" 为例,使用 ID3 算法构建决策树。

1. 数据集准备

贷款申请样本数据表(原始)

|----|----|-----|--------|------|-----------|
| ID | 年龄 | 有工作 | 有自己的房子 | 信贷情况 | 类别(是否给贷款) |
| 1 | 青年 | 否 | 否 | 一般 | 否 |
| 2 | 青年 | 否 | 否 | 好 | 否 |
| 3 | 青年 | 是 | 否 | 好 | 是 |
| 4 | 青年 | 是 | 是 | 一般 | 是 |
| 5 | 青年 | 否 | 否 | 一般 | 否 |
| 6 | 中年 | 否 | 否 | 一般 | 否 |
| 7 | 中年 | 否 | 否 | 好 | 否 |
| 8 | 中年 | 是 | 是 | 好 | 是 |
| 9 | 中年 | 否 | 是 | 非常好 | 是 |
| 10 | 中年 | 否 | 是 | 非常好 | 是 |
| 11 | 老年 | 否 | 是 | 非常好 | 是 |
| 12 | 老年 | 否 | 是 | 好 | 是 |
| 13 | 老年 | 是 | 否 | 好 | 是 |
| 14 | 老年 | 是 | 否 | 非常好 | 是 |
| 15 | 老年 | 否 | 否 | 一般 | 否 |

数据编码(标称化处理)
  • 年龄:0 = 青年,1 = 中年,2 = 老年
  • 有工作:0 = 否,1 = 是
  • 有自己的房子:0 = 否,1 = 是
  • 信贷情况:0 = 一般,1 = 好,2 = 非常好
  • 类别:no = 否,yes = 是
数据集代码定义
python 复制代码
from math import log
def createDataSet():
    dataSet = [
        [0, 0, 0, 0, 'no'],    # 样本1
        [0, 0, 0, 1, 'no'],    # 样本2
        [0, 1, 0, 1, 'yes'],   # 样本3
        [0, 1, 1, 0, 'yes'],   # 样本4
        [0, 0, 0, 0, 'no'],    # 样本5
        [1, 0, 0, 0, 'no'],    # 样本6
        [1, 0, 0, 1, 'no'],    # 样本7
        [1, 1, 1, 1, 'yes'],   # 样本8
        [1, 0, 1, 2, 'yes'],   # 样本9
        [1, 0, 1, 2, 'yes'],   # 样本10
        [2, 0, 1, 2, 'yes'],   # 样本11
        [2, 0, 1, 1, 'yes'],   # 样本12
        [2, 1, 0, 1, 'yes'],   # 样本13
        [2, 1, 0, 2, 'yes'],   # 样本14
        [2, 0, 0, 0, 'no']     # 样本15
    ]
    labels = ['年龄', '有工作', '有自己的房子', '信贷情况']  # 特征标签
    labels1 = ['放贷', '不放贷']  # 分类标签
    return dataSet, labels, labels1  # 返回数据集、特征标签、分类标签

2. 计算经验熵 H (D)

数学计算

样本集 D 共 15 个样本,其中 "放贷(yes)"9 个,"不放贷(no)"6 个,经验熵为:

代码实现
python 复制代码
def calcShannonEnt(dataSet):
    numEntires = len(dataSet)  # 数据集行数(样本数)
    labelCounts = {}  # 存储每个标签的出现次数
    for featVec in dataSet:
        currentLabel = featVec[-1]  # 提取最后一列(分类标签)
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1  # 标签计数
    shannonEnt = 0.0  # 初始化经验熵
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntires  # 标签出现概率
        shannonEnt -= prob * log(prob, 2)  # 计算经验熵
    return shannonEnt

# 测试代码
if __name__ == '__main__':
    dataSet, features, labels1 = createDataSet()
    print("数据集:", dataSet)
    print("经验熵H(D):", calcShannonEnt(dataSet))  # 输出:0.9709505944546686

3. 计算信息增益(选择最优特征)

数学计算(以 "有自己的房子" 为例)

设特征 A_3(有自己的房子),取值为 "是(1)" 和 "否(0)":

  • 子集 D_1(A_3=1):共 9 个样本,均为 "yes",经验熵 H(D_1)=0
  • 子集 D_2(A_3=0):共 6 个样本,"yes" 3 个、"no" 3 个
  • 经验熵
  • 条件熵
  • 信息增益 (注:原文计算结果为 0.420,此处以原文代码输出为准)

其他特征的信息增益计算结果:

  • 年龄(A_1):0.083
  • 有工作(A_2):0.324
  • 信贷情况(A_4):0.363

结论:特征 "有自己的房子(A_3)" 信息增益最大,作为决策树的根节点。

代码实现
python 复制代码
"""
函数:按照给定特征划分数据集
参数:
    dataSet - 待划分数据集
    axis - 特征索引
    value - 特征取值
返回:
    retDataSet - 划分后的子集
"""
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]  # 去掉当前特征列
            reducedFeatVec.extend(featVec[axis+1:])  # 拼接剩余列
            retDataSet.append(reducedFeatVec)
    return retDataSet

"""
函数:选择最优特征
参数:
    dataSet - 数据集
返回:
    bestFeature - 最优特征索引
"""
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  # 特征数量(减去分类列)
    baseEntropy = calcShannonEnt(dataSet)  # 基础经验熵
    bestInfoGain = 0.0  # 最优信息增益
    bestFeature = -1  # 最优特征索引
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]  # 提取第i列特征
        uniqueVals = set(featList)  # 特征的唯一取值
        newEntropy = 0.0  # 条件熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)  # 划分子集
            prob = len(subDataSet) / float(len(dataSet))  # 子集概率
            newEntropy += prob * calcShannonEnt(subDataSet)  # 累加条件熵
        infoGain = baseEntropy - newEntropy  # 计算信息增益
        print(f"第{i}个特征({labels[i]})的增益为:{infoGain:.3f}")
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

# 测试代码
if __name__ == '__main__':
    dataSet, labels, labels1 = createDataSet()
    bestFeature = chooseBestFeatureToSplit(dataSet)
    print(f"最优特征索引值:{bestFeature}(对应特征:{labels[bestFeature]})")
    # 输出:最优特征索引值:2(对应特征:有自己的房子)

4. 生成决策树(递归构建)

核心逻辑
  1. 若样本集所有样本属于同一类别,直接返回该类别(叶节点);
  2. 若无特征可划分或样本特征全相同,返回出现次数最多的类别(叶节点);
  3. 选择最优特征作为当前节点,按特征取值划分子集;
  4. 对每个子集递归执行上述步骤,生成子树。
代码实现
python 复制代码
import operator

"""
函数:统计出现次数最多的类别
参数:
    classList - 类别列表
返回:
    sortedClassCount[0][0] - 最多类别
"""
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    # 按类别次数降序排序
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

"""
函数:创建决策树
参数:
    dataSet - 训练集
    labels - 特征标签
    featLabels - 存储选择的最优特征
返回:
    myTree - 决策树(字典结构)
"""
def createTree(dataSet, labels, featLabels):
    classList = [example[-1] for example in dataSet]  # 提取所有类别
    # 情况1:所有样本类别相同
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 情况2:无特征可划分或特征全相同
    if len(dataSet[0]) == 1 or len(labels) == 0:
        return majorityCnt(classList)
    # 情况3:递归构建树
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 最优特征索引
    bestFeatLabel = labels[bestFeat]  # 最优特征标签
    featLabels.append(bestFeatLabel)
    myTree = {bestFeatLabel: {}}  # 决策树字典
    del(labels[bestFeat])  # 删除已使用的特征标签
    featValues = [example[bestFeat] for example in dataSet]  # 最优特征的所有取值
    uniqueVals = set(featValues)  # 唯一取值
    for value in uniqueVals:
        subLabels = labels[:]  # 复制特征标签(避免递归修改原列表)
        # 递归生成子树
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels, featLabels)
    return myTree

# 测试代码
if __name__ == '__main__':
    dataSet, labels, labels
相关推荐
Q741_1475 小时前
C++ 前缀和 高频笔试考点 实用技巧 牛客 DP34 [模板] 前缀和 题解 每日一题
开发语言·c++·算法·前缀和·牛客网
却道天凉_好个秋5 小时前
计算机视觉(十二):人工智能、机器学习与深度学习
人工智能·深度学习·机器学习·计算机视觉
RTC老炮5 小时前
webrtc弱网-LossBasedBandwidthEstimation类源码分析与算法原理
网络·算法·webrtc
虚拟现实旅人6 小时前
【机器学习】通过tensorflow搭建神经网络进行气温预测
神经网络·机器学习
豆浩宇6 小时前
Conda环境隔离和PyCharm配置,完美同时运行PaddlePaddle和PyTorch
人工智能·pytorch·算法·计算机视觉·pycharm·conda·paddlepaddle
一只鱼^_6 小时前
牛客周赛 Round 108
数据结构·c++·算法·动态规划·图论·广度优先·推荐算法
破烂儿7 小时前
基于机器学习的缓存准入策略研究
人工智能·机器学习·缓存
小刘的AI小站7 小时前
leetcode hot100 二叉搜索树
算法·leetcode
自信的小螺丝钉7 小时前
Leetcode 876. 链表的中间结点 快慢指针
算法·leetcode·链表·指针