Numpy 实现基尼指数算法的决策树

基尼系数实现决策树

基尼指数

Gini ⁡ ( D ) = 1 − ∑ k = 1 K ( ∣ C k ∣ ∣ D ∣ ) 2 \operatorname{Gini}(D)=1-\sum_{k=1}^{K}\left(\frac{\left|C_{k}\right|}{|D|}\right)^{2} Gini(D)=1−k=1∑K(∣D∣∣Ck∣)2

特征 A A A条件下集合 D D D的基尼指数:

Gini ⁡ ( D , A ) = ∣ D 1 ∣ ∣ D ∣ Gini ⁡ ( D 1 ) + ∣ D 2 ∣ ∣ D ∣ Gini ⁡ ( D 2 ) \operatorname{Gini}(D, A)=\frac{\left|D_{1}\right|}{|D|} \operatorname{Gini}\left(D_{1}\right)+\frac{\left|D_{2}\right|}{|D|} \operatorname{Gini}\left(D_{2}\right) Gini(D,A)=∣D∣∣D1∣Gini(D1)+∣D∣∣D2∣Gini(D2)

python 复制代码
import numpy as np

def calculate_gini(labels):
    # 计算标签的基尼系数
    _, counts = np.unique(labels, return_counts=True)
    probabilities = counts / len(labels)
    gini = 1 - np.sum(probabilities ** 2)
    return gini

def calculate_gini_index(data, labels, feature_index, threshold):
    # 根据给定的特征和阈值划分数据
    left_mask = data[:, feature_index] <= threshold
    right_mask = data[:, feature_index] > threshold
    left_labels = labels[left_mask]
    right_labels = labels[right_mask]

    # 计算左右子集的基尼系数
    left_gini = calculate_gini(left_labels)
    right_gini = calculate_gini(right_labels)

    # 计算基尼指数
    total_gini = calculate_gini(labels)
    left_weight = len(left_labels) / len(labels)
    right_weight = len(right_labels) / len(labels)
    gini_index = (left_weight * left_gini) + (right_weight * right_gini)
    return gini_index

def find_best_split(data, labels):
    num_features = data.shape[1]
    best_gini_index = float('inf')
    best_feature_index = -1
    best_threshold = None

    for feature_index in range(num_features):
        feature_values = data[:, feature_index]
        unique_values = np.unique(feature_values)

        for threshold in unique_values:
            gini_index = calculate_gini_index(data, labels, feature_index, threshold)
            if gini_index < best_gini_index:
                best_gini_index = gini_index
                best_feature_index = feature_index
                best_threshold = threshold

    return best_feature_index, best_threshold

def create_decision_tree(data, labels):
    # 基本情况:如果所有标签都相同,则返回一个叶节点,其中包含该标签
    if len(np.unique(labels)) == 1:
        return {'label': labels[0]}

    # 找到最佳的划分特征
    best_feature_index, best_threshold = find_best_split(data, labels)

    # 创建一个新的内部节点,其中包含最佳特征和阈值
    node = {
        'feature_index': best_feature_index,
        'threshold': best_threshold,
        'left': None,
        'right': None
    }

    # 根据最佳特征和阈值划分数据
    left_mask = data[:, best_feature_index] <= best_threshold
    right_mask = data[:, best_feature_index] > best_threshold
    left_data = data[left_mask]
    left_labels = labels[left_mask]
    right_data = data[right_mask]
    right_labels = labels[right_mask]

    # 递归创建左右子树
    node['left'] = create_decision_tree(left_data, left_labels)
    node['right'] = create_decision_tree(right_data, right_labels)

    return node

def predict(node, sample):
    if 'label' in node:
        return node['label']

    feature_value = sample[node['feature_index']]
    if feature_value <= node['threshold']:
        return predict(node['left'], sample)
    else:
        return predict(node['right'], sample)

# 示例数据集
data = np.array([
    [1, 2, 0],
    [1, 2, 1],
    [1, 3, 1],
    [2, 3, 1],
    [2, 3, 0],
    [2, 2, 0],
    [1, 1, 0],
    [1, 1, 1],
    [2, 1, 1],
    [1, 3, 0]
])

labels = np.array([0, 1, 1, 1, 0, 0, 0, 1, 1, 1])

# 创建决策树
decision_tree = create_decision_tree(data, labels)

# 测试数据
test_data = np.array([
    [1, 2, 0],
    [2, 1, 1],
    [1, 3, 1],
    [2, 3, 0]
])

# 预测结果
for sample in test_data:
    prediction = predict(decision_tree, sample)
    print(f"样本: {sample}, 预测标签: {prediction}")
相关推荐
CoderCodingNo4 小时前
【NOIP】2011真题解析 luogu-P1003 铺地毯 | GESP三、四级以上可练习
算法
iFlyCai4 小时前
C语言中的指针
c语言·数据结构·算法
查古穆5 小时前
栈-有效的括号
java·数据结构·算法
再一次等风来5 小时前
近场声全息(NAH)仿真实现:从阵列实值信号到波数域重建
算法·matlab·信号处理·近场声全息·nah
汀、人工智能5 小时前
16 - 高级特性
数据结构·算法·数据库架构·图论·16 - 高级特性
大熊背5 小时前
利用ISP离线模式进行分块LSC校正的方法
人工智能·算法·机器学习
XWalnut5 小时前
LeetCode刷题 day4
算法·leetcode·职场和发展
绛橘色的日落(。・∀・)ノ5 小时前
Numpy 第五章 数学函数
numpy
蒸汽求职5 小时前
机器人软件工程(Robotics SDE):特斯拉Optimus落地引发的嵌入式C++与感知算法人才抢夺战
大数据·c++·算法·职场和发展·机器人·求职招聘·ai-native
极梦网络无忧5 小时前
OpenClaw 基础使用说明(中文版)
python