基于weka手工实现ID3决策树

一、决策树ID3算法

相比于logistic回归、BP网络、支持向量机等基于超平面的方法,决策树更像一种算法,里面的数学原理并不是很多,较好理解。

决策树就是一个不断地属性选择、属性划分地过程,直到满足某一情况就停止划分。

  1. 当前样本全部属于同一类别了(信息增益为0);
  2. 已经是空叶子了(没有样本了);
  3. 当前叶子节点所有样本所有属性上取值相同,无法划分了(信息增益为0)。

信息增益如何计算?根据信息熵地变化量,信息熵减少最大地属性就是我们要选择地属性。

信息熵定义:

E n t ( D ) = − ∑ k = 1 ∣ y ∣ p k l o g 2 p k Ent(D)=-\sum_{k=1}^{|y|}p_klog_2p_k Ent(D)=−k=1∑∣y∣pklog2pk

信息增益定义:

G a i n ( D , a ) = E n t ( D ) − ∑ v = 1 v ∣ D v ∣ ∣ D ∣ E n t ( D v ) Gain(D,a)=Ent(D)-\sum_{v=1}^v\frac{|D^v|}{|D|}Ent(D^v) Gain(D,a)=Ent(D)−v=1∑v∣D∣∣Dv∣Ent(Dv)

信息增益越大,则意味着属性a来划分所获得的"纯度提升"越大。

ID3就是以信息增益作为属性选择和划分的标准的。有了决策树生长和停止生长的条件,剩下的其实就是一些编程技巧了,我们就可以进行编码了。

除此之外,决策树还有C4.5等其它实现的算法,包括基尼系数、增益率、剪枝、预剪枝等防止过拟合的方法,但决策树最本质、朴素的思想还是在ID3中体现的最好。

具体可以参考这篇博客:机器学习06:决策树学习.

二、基于weka平台实现ID3决策树

java 复制代码
package weka.classifiers.myf;

import weka.classifiers.Classifier;
import weka.core.*;

/**
 * @author YFMan
 * @Description 自定义的 ID3 分类器
 * @Date 2023/5/25 18:07
 */
public class myId3 extends Classifier {

    // 当前节点 的 后续节点
    private myId3[] m_Successors;

    // 当前节点的划分属性 (如果为空,说明当前节点是叶子节点;否则,说明当前节点是中间节点)
    private Attribute m_Attribute;

    // 当前节点的类别分布 (如果为中间节点,全为 0;为叶子节点,为类别分布)
    private double[] m_Distribution;

    // 当前节点的类别 (如果为中间节点,为 0;为叶子节点,为类别分布)
    // (用于获取类别的索引,对于算法本身没用,但对于可视化 决策树有用)
    private double m_ClassValue;

    // 当前节点的类别属性 (如果为中间节点,为 null;为叶子节点,为类别属性)
    // (用于获取类别的名称,对于算法本身没用,但对于可视化 决策树有用)
    private Attribute m_ClassAttribute;

    /*
     * @Author YFMan
     * @Description 根据训练数据 建立 决策树
     * @Date 2023/5/25 18:43
     * @Param [data]
     * @return void
     **/
    public void buildClassifier(Instances data) throws Exception {
        // 建树
        makeTree(data);
    }

    /*
     * @Author YFMan
     * @Description 根据训练数据 建立 决策树
     * @Date 2023/5/25 18:43
     * @Param [data] 训练数据
     * @return void
     **/
    private void makeTree(Instances data) throws Exception {

        // 如果是空叶子,拒绝建树 (拒判)
        if (data.numInstances() == 0) {
            m_Attribute = null;
            m_ClassValue = Instance.missingValue();
            m_Distribution = new double[data.numClasses()];
            return;
        }

        // 计算 所有属性的 信息增益
        double[] infoGains = new double[data.numAttributes()];
        // 遍历所有属性
        for(int i = 0; i < data.numAttributes(); i++) {
            // 如果是类别属性,跳过
            if (i == data.classIndex()) {
                infoGains[i] = 0;
            } else {
                // 计算信息增益
                infoGains[i] = computeInfoGain(data, data.attribute(i));
            }
        }

        // 选择信息增益最大的属性
        m_Attribute = data.attribute(Utils.maxIndex(infoGains));

        // 如果信息增益为 0,说明当前节点包含的样例都属于同一类别,直接设置为叶子节点
        if (Utils.eq(infoGains[m_Attribute.index()], 0)) {
            // 设置为叶子节点
            m_Attribute = null;
            m_Distribution = new double[data.numClasses()];
            // 遍历所有样例
            for (int i = 0; i < data.numInstances(); i++) {
                // 获取当前样例的类别
                Instance inst = data.instance(i);
                // 统计类别分布
                m_Distribution[(int) inst.classValue()]++;
            }
            // 归一化
            Utils.normalize(m_Distribution);
            // 设置类别
            m_ClassValue = Utils.maxIndex(m_Distribution);
            m_ClassAttribute = data.classAttribute();
        } else { // 否则,递归建树
            // 划分数据集
            Instances[] splitData = splitData(data, m_Attribute);
            // 创建叶子
            m_Successors = new myId3[m_Attribute.numValues()];
            // 叶子再去长叶子,递归调用
            for (int j = 0; j < m_Attribute.numValues(); j++) {
                m_Successors[j] = new myId3();
                m_Successors[j].makeTree(splitData[j]);
            }
        }


    }

    /*
     * @Author YFMan
     * @Description 根据 instance 进行分类
     * @Date 2023/5/25 18:33
     * @Param [instance] 待分类的实例
     * @return double[] 类别分布
     **/
    public double[] distributionForInstance(Instance instance)
            throws NoSupportForMissingValuesException {
        // 如果到达叶子节点,返回类别分布
        if (m_Attribute == null) {
            // 如果 m_Distribution 全为 0(是空叶子),随机返回一个类别分布
            if (Utils.eq(Utils.sum(m_Distribution), 0)) {
                // 在 0~类别数-1 之间随机选择一个类别
                m_Distribution = new double[m_ClassAttribute.numValues()];
                m_Distribution[(int) Math.round(Math.random() * m_ClassAttribute.numValues())] = 1.0;
            }
            return m_Distribution;
        } else {
            // 否则,递归调用
            return m_Successors[(int) instance.value(m_Attribute)].
                    distributionForInstance(instance);
        }
    }

    /*
     * @Author YFMan
     * @Description 计算当前数据集 选择某个属性的 信息增益
     * @Date 2023/5/25 18:29
     * @Param [data, att] 当前数据集,选择的属性
     * @return double 信息增益
     **/
    private double computeInfoGain(Instances data, Attribute att)
            throws Exception {
        // 计算 data 的信息熵
        double infoGain = computeEntropy(data);
        // 计算 data 按照 att 属性进行划分的信息熵
        // 划分数据集
        Instances[] splitData = splitData(data, att);
        // 遍历划分后的数据集
        for (Instances instances : splitData) {
            // 计算概率
            double probability = (double) instances.numInstances() / data.numInstances();
            // 计算信息熵
            infoGain -= probability * computeEntropy(instances);
        }
        // 返回信息增益
        return infoGain;
    }

    /*
     * @Author YFMan
     * @Description 计算信息熵
     * @Date 2023/5/25 18:18
     * @Param [data] 计算的数据集
     * @return double 信息熵
     **/
    private double computeEntropy(Instances data) throws Exception {
        // 计不同类别的数量
        double[] classCounts = new double[data.numClasses()];
        // 遍历数据集
        for(int i=0;i<data.numInstances();i++){
            // 获取类别
            int classIndex = (int) data.instance(i).classValue();
            // 数量加一
            classCounts[classIndex]++;
        }
        // 计算信息熵
        double entropy = 0;
        // 遍历类别
        for (double classCount : classCounts) {
            // 注意:这里是大于 0,因为 log2(0) = -Infinity;
            // 如果是等于 0,那么计算结果就是 NaN,熵就出错了
            if(classCount > 0){
                // 计算概率
                double probability = classCount / data.numInstances();
                // 计算信息熵
                entropy -= probability * Utils.log2(probability);
            }
        }
        // 返回信息熵
        return entropy;
    }

    /*
     * @Author YFMan
     * @Description 根据属性划分数据集
     * @Date 2023/5/25 18:23
     * @Param [data, att] 数据集,属性
     * @return weka.core.Instances[] 划分后的数据集
     **/
    private Instances[] splitData(Instances data, Attribute att) {
        // 定义划分后的数据集
        Instances[] splitData = new Instances[att.numValues()];
        // 遍历划分后的数据集
        for(int i=0;i<splitData.length;i++){
            // 创建数据集 (这里主要是为了初始化 数据集 header)
            // Constructor copying all instances and references to the header
            // information from the given set of instances.
            splitData[i] = new Instances(data,0);
        }
        // 遍历数据集
        for(int i=0;i<data.numInstances();i++){
            // 获取实例
            Instance instance = data.instance(i);
            // 获取实例的属性值
            double value = instance.value(att);
            // 将实例添加到对应的数据集中
            splitData[(int) value].add(instance);
        }
        // 返回划分后的数据集
        return splitData;
    }

    private String toString(int level) {

        StringBuffer text = new StringBuffer();

        if (m_Attribute == null) {
            if (Instance.isMissingValue(m_ClassValue)) {
                text.append(": null");
            } else {
                text.append(": " + m_ClassAttribute.value((int) m_ClassValue));
            }
        } else {
            for (int j = 0; j < m_Attribute.numValues(); j++) {
                text.append("\n");
                for (int i = 0; i < level; i++) {
                    text.append("|  ");
                }
                text.append(m_Attribute.name() + " = " + m_Attribute.value(j));
                text.append(m_Successors[j].toString(level + 1));
            }
        }
        return text.toString();
    }

    public String toString() {

        if ((m_Distribution == null) && (m_Successors == null)) {
            return "Id3: No model built yet.";
        }
        return "Id3\n\n" + toString(0);
    }

    /**
     * Main method.
     *
     * @param args the options for the classifier
     */
    public static void main(String[] args) {
        runClassifier(new myId3(), args);
    }
}
相关推荐
余炜yw42 分钟前
【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感
人工智能·rnn·深度学习
莫叫石榴姐1 小时前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
如若1231 小时前
利用 `OpenCV` 和 `Matplotlib` 库进行图像读取、颜色空间转换、掩膜创建、颜色替换
人工智能·opencv·matplotlib
YRr YRr2 小时前
深度学习:神经网络中的损失函数的使用
人工智能·深度学习·神经网络
ChaseDreamRunner2 小时前
迁移学习理论与应用
人工智能·机器学习·迁移学习
Guofu_Liao2 小时前
大语言模型---梯度的简单介绍;梯度的定义;梯度计算的方法
人工智能·语言模型·矩阵·llama
我爱学Python!2 小时前
大语言模型与图结构的融合: 推荐系统中的新兴范式
人工智能·语言模型·自然语言处理·langchain·llm·大语言模型·推荐系统
果冻人工智能2 小时前
OpenAI 是怎么“压力测试”大型语言模型的?
人工智能·语言模型·压力测试
日出等日落2 小时前
Windows电脑本地部署llamafile并接入Qwen大语言模型远程AI对话实战
人工智能·语言模型·自然语言处理
麦麦大数据2 小时前
Python棉花病虫害图谱系统CNN识别+AI问答知识neo4j vue+flask深度学习神经网络可视化
人工智能·python·深度学习