【决策树深度探索(一)】从零搭建:机器学习的“智慧之树”——决策树分类算法!



文章目录


为什么选择决策树?------直观易懂的"决策者"

在众多机器学习算法中,决策树以其独特的优势脱颖而出。它最大的魅力在于其直观性(Interpretability)易于理解性。想想看,我们日常生活中的决策过程,是不是也常常像一棵树一样?"如果今天不下雨,我就去公园;如果下雨,我就在家看书。"这不就是一棵简单的决策树吗?

决策树通过学习数据中的简单决策规则,构建一个像流程图一样的树结构。每个内部节点代表一个特征上的测试,每个分支代表一个测试结果,而每个叶节点则代表一个分类结果。这种结构使得我们能够清晰地追踪模型的决策路径,这在许多需要解释性的场景中(比如医疗诊断、金融风控)显得尤为重要。

决策树的智慧之源:信息增益(Information Gain)

要让决策树变得"智慧",它需要知道在每一步应该根据哪个特征来做判断,才能最有效地将数据进行分类。衡量这种"有效性"的核心指标,就是信息增益(Information Gain)

在解释信息增益之前,我们先要理解一个概念:熵(Entropy)。熵是信息论中用来衡量一个系统"混乱程度"或"不确定性"的指标。简单来说,一个数据集越混乱(即类别越混杂),它的熵就越高;反之,如果一个数据集中的样本都属于同一类别(非常"纯净"),那么它的熵就为零。

比如,一个班级里男同学和女同学各占一半,这个班级的性别构成就是最"不确定"的,熵值最高。如果一个班级全是男同学,那么这个班级的性别构成就是"确定"的,熵值为零。

信息增益 ,顾名思义,就是特征对数据集带来的信息量增加 的程度。当我们选择一个特征对数据集进行划分后,数据集的混乱程度会降低,这种混乱程度的下降,就是我们所获得的信息增益。决策树在每一步生长时,都会选择那个能带来最大信息增益的特征来划分数据,从而确保每次划分都是"最优决策"。

数学表达(不用慌,直观理解更重要!):

假设数据集 D D D 的熵为 H ( D ) H(D) H(D)。

用特征 A A A 划分数据集 D D D 后,根据特征 A A A 的不同取值 v v v,数据集被分成若干个子集 D v D_v Dv。

那么,特征 A A A 带来的信息增益 G a i n ( D , A ) Gain(D, A) Gain(D,A) 定义为:
G a i n ( D , A ) = H ( D ) − ∑ v ∈ V a l u e s ( A ) ∣ D v ∣ ∣ D ∣ H ( D v ) Gain(D, A) = H(D) - \sum_{v \in Values(A)} \frac{|D_v|}{|D|} H(D_v) Gain(D,A)=H(D)−v∈Values(A)∑∣D∣∣Dv∣H(Dv)

其中, H ( D v ) H(D_v) H(Dv) 是子集 D v D_v Dv 的熵, ∣ D v ∣ ∣ D ∣ \frac{|D_v|}{|D|} ∣D∣∣Dv∣ 是子集 D v D_v Dv 在整个数据集中的比例。

我们的目标,就是找到那个 G a i n ( D , A ) Gain(D, A) Gain(D,A) 最大的特征 A A A!

手撕代码:从零搭建你的第一个决策树分类器!

光说不练假把式!现在,让我们一起亲手实现一个基于信息增益的决策树分类器。为了简化,我们先处理离散型特征

我们将使用一个经典的"是否打网球"数据集来演示,它包含"天气"、"温度"、"湿度"、"风力"等特征,以及最终的决策"打网球"。

python 复制代码
import math
import collections

# 示例数据集:Outlook, Temperature, Humidity, Wind, PlayTennis
# 数据格式:[特征1, 特征2, ..., 目标变量]
dataset = [
    ['Sunny', 'Hot', 'High', 'Weak', 'No'],
    ['Sunny', 'Hot', 'High', 'Strong', 'No'],
    ['Overcast', 'Hot', 'High', 'Weak', 'Yes'],
    ['Rain', 'Mild', 'High', 'Weak', 'Yes'],
    ['Rain', 'Cool', 'Normal', 'Weak', 'Yes'],
    ['Rain', 'Cool', 'Normal', 'Strong', 'No'],
    ['Overcast', 'Cool', 'Normal', 'Strong', 'Yes'],
    ['Sunny', 'Mild', 'High', 'Weak', 'No'],
    ['Sunny', 'Cool', 'Normal', 'Weak', 'Yes'],
    ['Rain', 'Mild', 'Normal', 'Weak', 'Yes'],
    ['Sunny', 'Mild', 'Normal', 'Strong', 'Yes'],
    ['Overcast', 'Mild', 'High', 'Strong', 'Yes'],
    ['Overcast', 'Hot', 'Normal', 'Weak', 'Yes'],
    ['Rain', 'Mild', 'High', 'Strong', 'No']
]

# 特征索引 (为了方便理解,用字典映射)
# 0: Outlook, 1: Temperature, 2: Humidity, 3: Wind, 4: PlayTennis (Target)
feature_names = ['Outlook', 'Temperature', 'Humidity', 'Wind']
target_index = 4 # 目标变量在数据集中的索引

### 第一步:计算熵 (Entropy)
def calculate_entropy(data):
    """
    计算数据集的熵。
    参数:
        data: 列表的列表,每个子列表代表一个样本,最后一个元素是目标变量。
    返回:
        熵值 (float)。
    """
    num_samples = len(data)
    if num_samples == 0:
        return 0.0

    label_counts = collections.Counter(sample[target_index] for sample in data)
    entropy = 0.0
    for label in label_counts:
        prob = label_counts[label] / num_samples
        entropy -= prob * math.log2(prob)
    return entropy

### 第二步:计算信息增益 (Information Gain)
def calculate_information_gain(data, feature_idx):
    """
    计算给定特征的信息增益。
    参数:
        data: 数据集。
        feature_idx: 要计算信息增益的特征的索引。
    返回:
        信息增益 (float)。
    """
    initial_entropy = calculate_entropy(data)
    num_samples = len(data)
    
    # 按照特征值分组
    feature_values = collections.defaultdict(list)
    for sample in data:
        feature_values[sample[feature_idx]].append(sample)
    
    weighted_avg_entropy = 0.0
    for value_data in feature_values.values():
        prob = len(value_data) / num_samples
        weighted_avg_entropy += prob * calculate_entropy(value_data)
        
    return initial_entropy - weighted_avg_entropy

### 第三步:找到最佳划分特征
def find_best_split_feature(data, available_features_indices):
    """
    在可用特征中找到信息增益最大的特征。
    参数:
        data: 数据集。
        available_features_indices: 当前可用于划分的特征索引列表。
    返回:
        最佳特征的索引 (int) 或 None (如果没有可用特征)。
    """
    best_gain = -1
    best_feature = None
    
    for feature_idx in available_features_indices:
        gain = calculate_information_gain(data, feature_idx)
        if gain > best_gain:
            best_gain = gain
            best_feature = feature_idx
            
    return best_feature, best_gain

### 第四步:构建决策树 (核心递归函数)
# 树的结构:{特征名: {特征值: 子树}} 或 最终类别
class DecisionTreeNode:
    def __init__(self, feature_name=None, value=None, results=None):
        self.feature_name = feature_name  # 分裂特征的名称
        self.value = value              # 如果是分支节点,表示该分支对应的特征值
        self.results = results          # 如果是叶节点,表示最终的分类结果 (例如: 'Yes', 'No')
        self.children = {}              # 子节点字典 {特征值: DecisionTreeNode}

    def add_child(self, value, node):
        self.children[value] = node

    def __repr__(self):
        if self.results is not None:
            return f"LeafNode({self.results})"
        else:
            return f"FeatureNode({self.feature_name}, children={list(self.children.keys())})"

def build_tree(data, available_features_indices, max_depth=None, current_depth=0):
    """
    递归构建决策树。
    参数:
        data: 当前节点的数据集。
        available_features_indices: 当前可以考虑的特征索引。
        max_depth: 树的最大深度,防止过拟合。
        current_depth: 当前树的深度。
    返回:
        决策树的根节点 (DecisionTreeNode)。
    """
    # 停止条件1: 数据集为空
    if not data:
        return DecisionTreeNode(results="NoData") # 或者返回最常见的类别

    # 停止条件2: 所有样本都属于同一类别 (节点纯净)
    label_counts = collections.Counter(sample[target_index] for sample in data)
    if len(label_counts) == 1:
        return DecisionTreeNode(results=data[0][target_index])

    # 停止条件3: 没有更多特征可以用来划分
    if not available_features_indices or (max_depth is not None and current_depth >= max_depth):
        # 返回数据集中出现次数最多的类别作为叶节点结果
        majority_label = label_counts.most_common(1)[0][0]
        return DecisionTreeNode(results=majority_label)

    # 寻找最佳划分特征
    best_feature_idx, best_gain = find_best_split_feature(data, available_features_indices)

    # 停止条件4: 无法找到有效特征或信息增益太小 (这里简单判断best_feature是否存在)
    if best_feature_idx is None or best_gain <= 0: # 如果信息增益为0或负数,说明无法有效划分
        majority_label = label_counts.most_common(1)[0][0]
        return DecisionTreeNode(results=majority_label)

    root = DecisionTreeNode(feature_name=feature_names[best_feature_idx])

    # 按照最佳特征的不同取值进行分支
    feature_values = collections.defaultdict(list)
    for sample in data:
        feature_values[sample[best_feature_idx]].append(sample)

    # 从可用特征中移除当前已使用的特征
    new_available_features = [idx for idx in available_features_indices if idx != best_feature_idx]

    for value, subset_data in feature_values.items():
        # 递归构建子树
        child_node = build_tree(subset_data, new_available_features, max_depth, current_depth + 1)
        child_node.value = value # 标记这个子节点是通过哪个特征值连接的
        root.add_child(value, child_node)

    return root

### 第五步:进行预测
def predict(tree, sample):
    """
    使用构建好的决策树对单个样本进行预测。
    参数:
        tree: 决策树的根节点 (DecisionTreeNode)。
        sample: 待预测的样本 (列表)。
    返回:
        预测的类别。
    """
    if tree.results is not None:  # 如果是叶节点
        return tree.results
    
    feature_idx = feature_names.index(tree.feature_name)
    feature_value = sample[feature_idx]

    if feature_value in tree.children:
        return predict(tree.children[feature_value], sample)
    else:
        # 如果遇到训练集中没有出现过的特征值,可以返回多数类别或一个默认值
        # 这里为了简化,我们暂时返回None,实际应用中需要更健壮的处理
        print(f"Warning: Unknown feature value '{feature_value}' for feature '{tree.feature_name}'.")
        # 可以回溯到父节点结果,或者返回训练集中最常见的类别
        return None # 实际应用中可能需要更复杂的策略

### 树的可视化 (可选,辅助理解)
def print_tree(node, indent=""):
    if node.results is not None:
        print(f"{indent}--> Predict: {node.results}")
    else:
        for value, child in node.children.items():
            print(f"{indent}If {node.feature_name} is '{value}':")
            print_tree(child, indent + "  ")

# --- 运行你的决策树! ---
print("🚀 开始构建决策树...")
available_features = list(range(len(feature_names))) # 所有特征的索引
my_decision_tree = build_tree(dataset, available_features, max_depth=3) # 限制深度,防止过拟合

print("\n🌳 构建完成的决策树结构:")
print_tree(my_decision_tree)

# 进行预测
print("\n🔮 进行预测:")
test_sample1 = ['Sunny', 'Cool', 'High', 'Strong', '?'] # 新的样本
prediction1 = predict(my_decision_tree, test_sample1)
print(f"Sample: {test_sample1[:-1]} -> Prediction: {prediction1}")

test_sample2 = ['Overcast', 'Hot', 'Normal', 'Weak', '?']
prediction2 = predict(my_decision_tree, test_sample2)
print(f"Sample: {test_sample2[:-1]} -> Prediction: {prediction2}")

test_sample3 = ['Rain', 'Hot', 'High', 'Weak', '?']
prediction3 = predict(my_decision_tree, test_sample3)
print(f"Sample: {test_sample3[:-1]} -> Prediction: {prediction3}")
代码解读:
  1. calculate_entropy(data): 这是计算数据集纯度的基石。它统计目标变量的各类数量,然后根据公式计算熵。

  2. calculate_information_gain(data, feature_idx): 计算某个特征对熵的降低程度。它会根据该特征的不同取值,将数据集分成若干子集,分别计算子集的熵,然后加权求和,用原始熵减去这个加权平均熵,就得到了信息增益。

  3. find_best_split_feature(data, available_features_indices): 遍历所有可用特征,计算它们的信息增益,并返回信息增益最大的那个特征。这正是决策树每一步"做出最佳决策"的地方。

  4. DecisionTreeNode: 定义了树的节点结构,可以是内部节点(带有分裂特征和子节点)或叶节点(带有最终分类结果)。

  5. build_tree(data, available_features_indices, max_depth, current_depth): 这是整个决策树的核心!它是一个递归函数:

    • 首先判断停止条件:数据集是否纯净?是否没有更多特征可用?是否达到了最大深度限制?如果满足,就返回一个叶节点。
    • 如果不满足停止条件,就调用 find_best_split_feature 找到最佳划分特征。
    • 根据最佳特征将数据集划分为多个子集。
    • 对每个子集递归调用 build_tree 函数,构建子树,并将子树连接到当前节点。
  6. predict(tree, sample): 遍历已构建的树,根据样本的特征值,沿着树的分支向下移动,直到达到叶节点,返回叶节点的分类结果。

  7. print_tree(node, indent): 一个简单的辅助函数,帮助我们可视化决策树的结构,让你更直观地理解它的决策路径。

通过这个手撕代码,你不仅理解了决策树的理论,更亲自参与了它的构造过程,对信息增益和递归建树的理解一定会更加深刻!👍

结语与展望

在接下来的"深度探索"系列中,我们将继续升级我们的决策树,探索更多高级特性,例如:

  • 处理连续型特征:如何优雅地划分数值型数据?
  • 剪枝(Pruning):如何防止决策树过拟合,提高泛化能力?
  • 其他划分标准:除了信息增益,还有Gini不纯度、信息增益率等。
  • 决策树的集成学习:随机森林、GBDT、XGBoost等强大模型的基础。

机器学习的道路充满乐趣,保持这份好奇心和动手能力,你一定会走得更远!


相关推荐
程序员-King.1 小时前
day161—动态规划—最长递增子序列(LeetCode-300)
算法·leetcode·深度优先·动态规划·递归
西柚小萌新2 小时前
【计算机视觉CV:目标检测】--3.算法原理(SPPNet、Fast R-CNN、Faster R-CNN)
算法·目标检测·计算机视觉
高频交易dragon2 小时前
Hawkes LOB Market从论文到生产
人工智能·算法·金融
无心水2 小时前
2、Go语言源码文件组织与命令源码文件实战指南
开发语言·人工智能·后端·机器学习·golang·go·gopath
_OP_CHEN2 小时前
【算法基础篇】(五十)扩展中国剩余定理(EXCRT)深度精讲:突破模数互质限制
c++·算法·蓝桥杯·数论·同余方程·扩展欧几里得算法·acm/icpc
福楠2 小时前
C++ STL | set、multiset
c语言·开发语言·数据结构·c++·算法
enfpZZ小狗2 小时前
基于C++的反射机制探索
开发语言·c++·算法
炽烈小老头2 小时前
【每天学习一点算法 2026/01/22】杨辉三角
学习·算法
MicroTech20252 小时前
微算法科技(NASDAQ :MLGO)量子安全区块链:PQ-DPoL与Falcon签名的双重防御体系
科技·算法·安全