Numpy 实现基尼指数算法的决策树

基尼系数实现决策树

基尼指数

Gini ⁡ ( D ) = 1 − ∑ k = 1 K ( ∣ C k ∣ ∣ D ∣ ) 2 \operatorname{Gini}(D)=1-\sum_{k=1}^{K}\left(\frac{\left|C_{k}\right|}{|D|}\right)^{2} Gini(D)=1−k=1∑K(∣D∣∣Ck∣)2

特征 A A A条件下集合 D D D的基尼指数:

Gini ⁡ ( D , A ) = ∣ D 1 ∣ ∣ D ∣ Gini ⁡ ( D 1 ) + ∣ D 2 ∣ ∣ D ∣ Gini ⁡ ( D 2 ) \operatorname{Gini}(D, A)=\frac{\left|D_{1}\right|}{|D|} \operatorname{Gini}\left(D_{1}\right)+\frac{\left|D_{2}\right|}{|D|} \operatorname{Gini}\left(D_{2}\right) Gini(D,A)=∣D∣∣D1∣Gini(D1)+∣D∣∣D2∣Gini(D2)

python 复制代码
import numpy as np

def calculate_gini(labels):
    # 计算标签的基尼系数
    _, counts = np.unique(labels, return_counts=True)
    probabilities = counts / len(labels)
    gini = 1 - np.sum(probabilities ** 2)
    return gini

def calculate_gini_index(data, labels, feature_index, threshold):
    # 根据给定的特征和阈值划分数据
    left_mask = data[:, feature_index] <= threshold
    right_mask = data[:, feature_index] > threshold
    left_labels = labels[left_mask]
    right_labels = labels[right_mask]

    # 计算左右子集的基尼系数
    left_gini = calculate_gini(left_labels)
    right_gini = calculate_gini(right_labels)

    # 计算基尼指数
    total_gini = calculate_gini(labels)
    left_weight = len(left_labels) / len(labels)
    right_weight = len(right_labels) / len(labels)
    gini_index = (left_weight * left_gini) + (right_weight * right_gini)
    return gini_index

def find_best_split(data, labels):
    num_features = data.shape[1]
    best_gini_index = float('inf')
    best_feature_index = -1
    best_threshold = None

    for feature_index in range(num_features):
        feature_values = data[:, feature_index]
        unique_values = np.unique(feature_values)

        for threshold in unique_values:
            gini_index = calculate_gini_index(data, labels, feature_index, threshold)
            if gini_index < best_gini_index:
                best_gini_index = gini_index
                best_feature_index = feature_index
                best_threshold = threshold

    return best_feature_index, best_threshold

def create_decision_tree(data, labels):
    # 基本情况:如果所有标签都相同,则返回一个叶节点,其中包含该标签
    if len(np.unique(labels)) == 1:
        return {'label': labels[0]}

    # 找到最佳的划分特征
    best_feature_index, best_threshold = find_best_split(data, labels)

    # 创建一个新的内部节点,其中包含最佳特征和阈值
    node = {
        'feature_index': best_feature_index,
        'threshold': best_threshold,
        'left': None,
        'right': None
    }

    # 根据最佳特征和阈值划分数据
    left_mask = data[:, best_feature_index] <= best_threshold
    right_mask = data[:, best_feature_index] > best_threshold
    left_data = data[left_mask]
    left_labels = labels[left_mask]
    right_data = data[right_mask]
    right_labels = labels[right_mask]

    # 递归创建左右子树
    node['left'] = create_decision_tree(left_data, left_labels)
    node['right'] = create_decision_tree(right_data, right_labels)

    return node

def predict(node, sample):
    if 'label' in node:
        return node['label']

    feature_value = sample[node['feature_index']]
    if feature_value <= node['threshold']:
        return predict(node['left'], sample)
    else:
        return predict(node['right'], sample)

# 示例数据集
data = np.array([
    [1, 2, 0],
    [1, 2, 1],
    [1, 3, 1],
    [2, 3, 1],
    [2, 3, 0],
    [2, 2, 0],
    [1, 1, 0],
    [1, 1, 1],
    [2, 1, 1],
    [1, 3, 0]
])

labels = np.array([0, 1, 1, 1, 0, 0, 0, 1, 1, 1])

# 创建决策树
decision_tree = create_decision_tree(data, labels)

# 测试数据
test_data = np.array([
    [1, 2, 0],
    [2, 1, 1],
    [1, 3, 1],
    [2, 3, 0]
])

# 预测结果
for sample in test_data:
    prediction = predict(decision_tree, sample)
    print(f"样本: {sample}, 预测标签: {prediction}")
相关推荐
砍树+c+v几秒前
3a 感知机训练过程示例(手算拆解,代码实现)
人工智能·算法·机器学习
zy_destiny2 分钟前
【工业场景】用YOLOv26实现4种输电线隐患检测
人工智能·深度学习·算法·yolo·机器学习·计算机视觉·输电线隐患识别
乔江seven6 分钟前
【python轻量级Web框架 Flask 】2 构建稳健 API:集成 MySQL 参数化查询与 DBUtils 连接池
前端·python·mysql·flask·web
智驱力人工智能11 分钟前
货车违规变道检测 高速公路安全治理的工程实践 货车变道检测 高速公路货车违规变道抓拍系统 城市快速路货车压实线识别方案
人工智能·opencv·算法·安全·yolo·目标检测·边缘计算
乾元12 分钟前
实战案例:解析某次真实的“AI vs. AI”攻防演练
运维·人工智能·安全·web安全·机器学习·架构
罗湖老棍子15 分钟前
【例9.18】合并石子(信息学奥赛一本通- P1274)从暴搜到区间 DP:石子合并的四种写法
算法·动态规划·区间dp·区间动态规划
2301_8107301024 分钟前
python第四次作业
数据结构·python·算法
马剑威(威哥爱编程)25 分钟前
Libvio.link爬虫技术解析:搞定反爬机制
爬虫·python
adam_life27 分钟前
区间动态# P1880 [NOI1995] 石子合并】
算法
坠金32 分钟前
递归、递归和回溯的区别
算法