一、决策树ID3算法
相比于logistic回归、BP网络、支持向量机等基于超平面的方法,决策树更像一种算法,里面的数学原理并不是很多,较好理解。
决策树就是一个不断地属性选择、属性划分地过程,直到满足某一情况就停止划分。
- 当前样本全部属于同一类别了(信息增益为0);
- 已经是空叶子了(没有样本了);
- 当前叶子节点所有样本所有属性上取值相同,无法划分了(信息增益为0)。
信息增益如何计算?根据信息熵地变化量,信息熵减少最大地属性就是我们要选择地属性。
信息熵定义:
E n t ( D ) = − ∑ k = 1 ∣ y ∣ p k l o g 2 p k Ent(D)=-\sum_{k=1}^{|y|}p_klog_2p_k Ent(D)=−k=1∑∣y∣pklog2pk
信息增益定义:
G a i n ( D , a ) = E n t ( D ) − ∑ v = 1 v ∣ D v ∣ ∣ D ∣ E n t ( D v ) Gain(D,a)=Ent(D)-\sum_{v=1}^v\frac{|D^v|}{|D|}Ent(D^v) Gain(D,a)=Ent(D)−v=1∑v∣D∣∣Dv∣Ent(Dv)
信息增益越大,则意味着属性a来划分所获得的"纯度提升"越大。
ID3就是以信息增益作为属性选择和划分的标准的。有了决策树生长和停止生长的条件,剩下的其实就是一些编程技巧了,我们就可以进行编码了。
除此之外,决策树还有C4.5等其它实现的算法,包括基尼系数、增益率、剪枝、预剪枝等防止过拟合的方法,但决策树最本质、朴素的思想还是在ID3中体现的最好。
具体可以参考这篇博客:机器学习06:决策树学习.
二、基于weka平台实现ID3决策树
java
package weka.classifiers.myf;
import weka.classifiers.Classifier;
import weka.core.*;
/**
* @author YFMan
* @Description 自定义的 ID3 分类器
* @Date 2023/5/25 18:07
*/
public class myId3 extends Classifier {
// 当前节点 的 后续节点
private myId3[] m_Successors;
// 当前节点的划分属性 (如果为空,说明当前节点是叶子节点;否则,说明当前节点是中间节点)
private Attribute m_Attribute;
// 当前节点的类别分布 (如果为中间节点,全为 0;为叶子节点,为类别分布)
private double[] m_Distribution;
// 当前节点的类别 (如果为中间节点,为 0;为叶子节点,为类别分布)
// (用于获取类别的索引,对于算法本身没用,但对于可视化 决策树有用)
private double m_ClassValue;
// 当前节点的类别属性 (如果为中间节点,为 null;为叶子节点,为类别属性)
// (用于获取类别的名称,对于算法本身没用,但对于可视化 决策树有用)
private Attribute m_ClassAttribute;
/*
* @Author YFMan
* @Description 根据训练数据 建立 决策树
* @Date 2023/5/25 18:43
* @Param [data]
* @return void
**/
public void buildClassifier(Instances data) throws Exception {
// 建树
makeTree(data);
}
/*
* @Author YFMan
* @Description 根据训练数据 建立 决策树
* @Date 2023/5/25 18:43
* @Param [data] 训练数据
* @return void
**/
private void makeTree(Instances data) throws Exception {
// 如果是空叶子,拒绝建树 (拒判)
if (data.numInstances() == 0) {
m_Attribute = null;
m_ClassValue = Instance.missingValue();
m_Distribution = new double[data.numClasses()];
return;
}
// 计算 所有属性的 信息增益
double[] infoGains = new double[data.numAttributes()];
// 遍历所有属性
for(int i = 0; i < data.numAttributes(); i++) {
// 如果是类别属性,跳过
if (i == data.classIndex()) {
infoGains[i] = 0;
} else {
// 计算信息增益
infoGains[i] = computeInfoGain(data, data.attribute(i));
}
}
// 选择信息增益最大的属性
m_Attribute = data.attribute(Utils.maxIndex(infoGains));
// 如果信息增益为 0,说明当前节点包含的样例都属于同一类别,直接设置为叶子节点
if (Utils.eq(infoGains[m_Attribute.index()], 0)) {
// 设置为叶子节点
m_Attribute = null;
m_Distribution = new double[data.numClasses()];
// 遍历所有样例
for (int i = 0; i < data.numInstances(); i++) {
// 获取当前样例的类别
Instance inst = data.instance(i);
// 统计类别分布
m_Distribution[(int) inst.classValue()]++;
}
// 归一化
Utils.normalize(m_Distribution);
// 设置类别
m_ClassValue = Utils.maxIndex(m_Distribution);
m_ClassAttribute = data.classAttribute();
} else { // 否则,递归建树
// 划分数据集
Instances[] splitData = splitData(data, m_Attribute);
// 创建叶子
m_Successors = new myId3[m_Attribute.numValues()];
// 叶子再去长叶子,递归调用
for (int j = 0; j < m_Attribute.numValues(); j++) {
m_Successors[j] = new myId3();
m_Successors[j].makeTree(splitData[j]);
}
}
}
/*
* @Author YFMan
* @Description 根据 instance 进行分类
* @Date 2023/5/25 18:33
* @Param [instance] 待分类的实例
* @return double[] 类别分布
**/
public double[] distributionForInstance(Instance instance)
throws NoSupportForMissingValuesException {
// 如果到达叶子节点,返回类别分布
if (m_Attribute == null) {
// 如果 m_Distribution 全为 0(是空叶子),随机返回一个类别分布
if (Utils.eq(Utils.sum(m_Distribution), 0)) {
// 在 0~类别数-1 之间随机选择一个类别
m_Distribution = new double[m_ClassAttribute.numValues()];
m_Distribution[(int) Math.round(Math.random() * m_ClassAttribute.numValues())] = 1.0;
}
return m_Distribution;
} else {
// 否则,递归调用
return m_Successors[(int) instance.value(m_Attribute)].
distributionForInstance(instance);
}
}
/*
* @Author YFMan
* @Description 计算当前数据集 选择某个属性的 信息增益
* @Date 2023/5/25 18:29
* @Param [data, att] 当前数据集,选择的属性
* @return double 信息增益
**/
private double computeInfoGain(Instances data, Attribute att)
throws Exception {
// 计算 data 的信息熵
double infoGain = computeEntropy(data);
// 计算 data 按照 att 属性进行划分的信息熵
// 划分数据集
Instances[] splitData = splitData(data, att);
// 遍历划分后的数据集
for (Instances instances : splitData) {
// 计算概率
double probability = (double) instances.numInstances() / data.numInstances();
// 计算信息熵
infoGain -= probability * computeEntropy(instances);
}
// 返回信息增益
return infoGain;
}
/*
* @Author YFMan
* @Description 计算信息熵
* @Date 2023/5/25 18:18
* @Param [data] 计算的数据集
* @return double 信息熵
**/
private double computeEntropy(Instances data) throws Exception {
// 计不同类别的数量
double[] classCounts = new double[data.numClasses()];
// 遍历数据集
for(int i=0;i<data.numInstances();i++){
// 获取类别
int classIndex = (int) data.instance(i).classValue();
// 数量加一
classCounts[classIndex]++;
}
// 计算信息熵
double entropy = 0;
// 遍历类别
for (double classCount : classCounts) {
// 注意:这里是大于 0,因为 log2(0) = -Infinity;
// 如果是等于 0,那么计算结果就是 NaN,熵就出错了
if(classCount > 0){
// 计算概率
double probability = classCount / data.numInstances();
// 计算信息熵
entropy -= probability * Utils.log2(probability);
}
}
// 返回信息熵
return entropy;
}
/*
* @Author YFMan
* @Description 根据属性划分数据集
* @Date 2023/5/25 18:23
* @Param [data, att] 数据集,属性
* @return weka.core.Instances[] 划分后的数据集
**/
private Instances[] splitData(Instances data, Attribute att) {
// 定义划分后的数据集
Instances[] splitData = new Instances[att.numValues()];
// 遍历划分后的数据集
for(int i=0;i<splitData.length;i++){
// 创建数据集 (这里主要是为了初始化 数据集 header)
// Constructor copying all instances and references to the header
// information from the given set of instances.
splitData[i] = new Instances(data,0);
}
// 遍历数据集
for(int i=0;i<data.numInstances();i++){
// 获取实例
Instance instance = data.instance(i);
// 获取实例的属性值
double value = instance.value(att);
// 将实例添加到对应的数据集中
splitData[(int) value].add(instance);
}
// 返回划分后的数据集
return splitData;
}
private String toString(int level) {
StringBuffer text = new StringBuffer();
if (m_Attribute == null) {
if (Instance.isMissingValue(m_ClassValue)) {
text.append(": null");
} else {
text.append(": " + m_ClassAttribute.value((int) m_ClassValue));
}
} else {
for (int j = 0; j < m_Attribute.numValues(); j++) {
text.append("\n");
for (int i = 0; i < level; i++) {
text.append("| ");
}
text.append(m_Attribute.name() + " = " + m_Attribute.value(j));
text.append(m_Successors[j].toString(level + 1));
}
}
return text.toString();
}
public String toString() {
if ((m_Distribution == null) && (m_Successors == null)) {
return "Id3: No model built yet.";
}
return "Id3\n\n" + toString(0);
}
/**
* Main method.
*
* @param args the options for the classifier
*/
public static void main(String[] args) {
runClassifier(new myId3(), args);
}
}