[机器学习]11-基于CART决策树算法的西瓜数据集分类

  • CART决策树算法 :使用Gini指数选择最优划分特征。

    • 递归构建二叉树结构(但代码实际生成多叉树,与标准CART不同)。

    • 终止条件为节点样本全属同一类别(返回该类)或者无剩余特征可用(返回多数类)。

程序代码:

python 复制代码
import math
import json
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import random

file_path = '西瓜数据集.xlsx'
label_list = []
data = pd.read_excel(file_path, skiprows=1,nrows=17)
data_dict = data.to_dict(orient='list')
data_list = list(zip(data['色泽'], data['根蒂'], data['敲声'], data['纹理'], data['脐部'], data['触感'], data['好瓜']))
train_set, test_set = train_test_split(data_list, test_size=2, random_state=random.randint(1, 1000))

print("data_list:",len(data_list),data_list)
print("train_list:",len(train_set),train_set)
print("test_list:",len(test_set),test_set)

keys = []
#keys.append(None)
for index in data.keys():
    keys.append(index)
keys.remove('编号')
keys.remove('好瓜')
print('keys:',keys)

def predict(tree, sample):
    if isinstance(tree, dict):
        feature, subtree = next(iter(tree.items()))
        value = sample[feature]
        if value in subtree:
            return predict(subtree[value], sample)
    else:
        return tree

def evaluate(tree, test_set, label_index):
    correct_predictions = 0
    total_samples = len(test_set)

    for sample in test_set:
        prediction = predict(tree, sample)
        if prediction == sample[label_index]:
            correct_predictions += 1

    accuracy = correct_predictions / total_samples
    return accuracy


def calculate_gini(data, label_index):
    labels = [entry[label_index] for entry in data]
    unique_labels = set(labels)
    total_samples = len(data)
    gini = 1.0

    for label in unique_labels:
        prob = labels.count(label) / total_samples
        gini -= prob ** 2

    return gini


def calculate_feature_gini(data, feature_index, label_index):
    feature_values = set([entry[feature_index] for entry in data])
    gini = 0

    for value in feature_values:
        subset_data = [entry for entry in data if entry[feature_index] == value]
        prob = len(subset_data) / len(data)
        gini += prob * calculate_gini(subset_data, label_index)

    return gini


def choose_best_feature_cart(data, features, label_index):
    best_gini = float('inf')
    best_feature = None

    for feature in features:
        feature_gini = calculate_feature_gini(data, features.index(feature), label_index)

        print(f"Feature: {feature}, Gini: {feature_gini}")

        if feature_gini < best_gini:
            best_gini = feature_gini
            best_feature = feature

    return best_feature


def cart_decision_tree(data, features, label_index):
    if len(set(data[i][label_index] for i in range(len(data)))) == 1:
        return data[0][label_index]

    if len(features) == 0:
        max_class = max(set(entry[label_index] for entry in data), key=lambda x: data.count(x))
        return max_class

    root_feature = choose_best_feature_cart(data, features, label_index)
    if root_feature is None:
        max_class = max(set(entry[label_index] for entry in data), key=lambda x: data.count(x))
        return max_class

    print(root_feature)
    tree = {root_feature: {}}
    for value in set(data[i][features.index(root_feature)] for i in range(len(data))):
        subset_data = [data[i] for i in range(len(data)) if data[i][features.index(root_feature)] == value]
        subset_features = [feat for feat in features if feat != root_feature]
        tree[root_feature][value] = cart_decision_tree(subset_data, subset_features, label_index)

    return tree

# 测试
#'色泽','根蒂','敲声','纹理','脐部','触感'
selected_features = [0,1,2,3,4,5]
label_index = 6
cart_decision_tree = cart_decision_tree(train_set, selected_features, label_index)
print(json.dumps(cart_decision_tree, indent=2, ensure_ascii=False))

test_accuracy_cart = evaluate(cart_decision_tree, test_set, label_index)
print("CART Test Accuracy: {:.2%}".format(test_accuracy_cart))

运行结果:

data_list: 17 [('青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'), ('乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '是'), ('乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'), ('青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '是'), ('浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'), ('青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '是'), ('乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '是'), ('乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '是'), ('乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '否'), ('青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '否'), ('浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '否'), ('浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '否'), ('青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '否'), ('浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '否'), ('乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '否'), ('浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '否'), ('青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '否')]

train_list: 15 [('浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '否'), ('青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '是'), ('青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '否'), ('乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '是'), ('浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '否'), ('浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '否'), ('青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '是'), ('浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'), ('青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'), ('浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '否'), ('青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '否'), ('乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '否'), ('乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'), ('乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '是'), ('乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '否')]

test_list: 2 [('乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '是'), ('青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '否')]

keys: ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感']

Feature: 0, Gini: 0.4266666666666665

Feature: 1, Gini: 0.4063492063492064

Feature: 2, Gini: 0.40888888888888886

Feature: 3, Gini: 0.30000000000000004

Feature: 4, Gini: 0.3377777777777778

Feature: 5, Gini: 0.4933333333333333

3

Feature: 0, Gini: 0.25

Feature: 1, Gini: 0.375

Feature: 2, Gini: 0.25

Feature: 4, Gini: 0.375

Feature: 5, Gini: 0.25

0

Feature: 1, Gini: 0.5

Feature: 2, Gini: 0.5

Feature: 4, Gini: 0.0

Feature: 5, Gini: 0.5

4

Feature: 0, Gini: 0.3541666666666667

Feature: 1, Gini: 0.16666666666666669

Feature: 2, Gini: 0.20833333333333326

Feature: 4, Gini: 0.375

Feature: 5, Gini: 0.16666666666666669

1

Feature: 0, Gini: 0.3333333333333333

Feature: 2, Gini: 0.4444444444444445

Feature: 4, Gini: 0.4444444444444445

Feature: 5, Gini: 0.4444444444444445

0

Feature: 2, Gini: 0.5

Feature: 4, Gini: 0.5

Feature: 5, Gini: 0.5

2

Feature: 4, Gini: 0.5

Feature: 5, Gini: 0.5

4

Feature: 5, Gini: 0.5

5

{

"3": {

"稍糊": {

"0": {

"乌黑": {

"4": {

"浊响": "是",

"沉闷": "否"

}

},

"青绿": "否",

"浅白": "否"

}

},

"模糊": "否",

"清晰": {

"1": {

"硬挺": "否",

"稍蜷": {

"0": {

"乌黑": {

"2": {

"乌黑": {

"4": {

"乌黑": {

"5": {

"乌黑": "是"

}

}

}

}

}

},

"青绿": "是"

}

},

"蜷缩": "是"

}

}

}

}

CART Test Accuracy: 100.00%