上次讲到信息熵和信息增益,讲了理论知识现在让我们构建决策树
1. 树的基本知识
-
二叉树:一棵树由根节点和子节点组成,子节点又分为叶子节点和非叶子节点,每个非叶子节点都由两个分支组成,分别为该节点的左、右子树。其二叉树图形如下图所示:
根节点:为最顶端的节点,每棵树就只有一个,是其他所有节点的父节点
叶子节点:为最底端的节点,每个叶子节点没有左右分支(子树),因此叶子节点没有子节点。
非叶子节点:为中间部分的节点,每个非叶子节点有左右子树
二叉树数据结构代码表示:
java
public class BTree<T> {
/**
* 根节点
*/
private Node<T> root;
/**
* 节点
*/
class Node<T> {
/**
* 节点数据
*/
private T data;
/**
* 左子树
*/
private Node<T> left;
/**
* 右子树
*/
private Node<T> right;
/**
* 构造方法
*/
public Node(T data) {
this.data = data;
}
}
}
特殊的二叉树:
- 满二叉树:如果每个非叶子节点都有左右子树,那么就叫做满二叉树
换种说法:如果一棵树没有只有左子树或者只有右子树的节点,那么它就是一颗满二叉树。上图就是一个满二叉树- 完全二叉树:一棵深度为k的有n个结点的二叉树,对树中的结点按从上至下、从左到右的顺序进行编号,如果编号为i(1≤i≤n)的结点与满二叉树中编号为i的结点在二叉树中的位置相同,则这棵二叉树称为完全二叉树。
换种说法:如果只有单个子树的子节点都在最底层,并且该节点为唯一一个仅有单个子树的节点,并且该单个子树为左子树,那么该树就是一棵完全二叉树,该树的节点与满二叉树位置一一对应。如下图所示
- 平衡二叉树(BST):涉及到每个节点的值,左子树的节点的值都比当前节点的值小,而右子树的节点的值都比当前节点的值要大,可以通过递归方式构建平衡二叉树。平衡二叉树如下图所示:
- 平衡搜索树(AVL):针对平衡二叉树可能深度过大,增加了查询搜索的成本,因此设计出了平衡搜索树来优化平衡二叉树的深度,并设计出了平衡因子这一参数。
平衡因子 = 左子树的高度 - 右子树的高度。
插入节点时可能会导致树的失衡,这时候可以通过左旋或者右旋进行平衡
- 多叉树:一棵树由根节点和子节点组成,子节点又分为叶子节点和非叶子节点,每个非叶子节点都由多个分支组成,为了能够表示多叉树,我们需要设置几个参数,一个是子节点,一个是右兄弟节点,图形如下所示:
多叉树数据结构代码表示:
java
public class MultiTree<T> {
/**
* 根节点
*/
private Node<T> root;
/**
* 节点
*/
class Node<T> {
/**
* 节点数据
*/
private T data;
/**
* 子节点
*/
private Node<T> child;
/**
* 右边兄弟节点
*/
private Node<T> rightBrother;
/**
* 构造方法
*/
public Node(T data) {
this.data = data;
}
}
}
考虑到子节点有多个分支的情况,构建决策树我们使用多叉树。
2. 构建决策树
信息增益和信息熵计算方式详见:
程序猿成长之路之数据挖掘篇------决策树分类算法(1)------信息熵和信息增益
在知道了计算信息熵和信息增益的办法后,下面我们开始来构建一个决策树
- 首先我们需要设计一下决策树的类
java
/**
* 决策树
* @author zygswo
*
* @param <T>
*/
public class DecisionTree<T> {
public Node<T> getRoot() {
return root;
}
public void setRoot(Node<T> root) {
this.root = root;
}
private Node<T> root;
public DecisionTree() {
super();
this.root = new Node<T>(null);
}
public DecisionTree(Node<T> root) {
super();
this.root = root;
}
/**
* node
*/
static class Node<T>{
@Override
public String toString() {
return "Node [nodeValue=" + nodeValue + ", child=" + child + ", rightBrother=" + rightBrother + "]";
}
public T getNodeValue() {
return nodeValue;
}
public void setNodeValue(T nodeValue) {
this.nodeValue = nodeValue;
}
public Node<T> getChild() {
return child;
}
public void setChild(Node<T> child) {
this.child = child;
}
public Node<T> getRightBrother() {
return rightBrother;
}
public void setRightBrother(Node<T> rightBrother) {
this.rightBrother = rightBrother;
}
private T nodeValue; //node名称
private Node<T> child; //子节点
private Node<T> rightBrother; //右兄弟
public Node() {
super();
}
public Node(T nodeValue) {
super();
this.nodeValue = nodeValue;
}
public void addChild(Node<T> child) {
this.child = child;
}
public void addRightBrother(Node<T> rightBrother) {
this.rightBrother = rightBrother;
}
public boolean hasChild() {
return this.child != null;
}
public boolean hasRightBrother() {
return this.rightBrother != null;
}
}
public void addChildNode(Node<T> child) {
this.root.addChild(child);
}
public void addRightBrother(Node<T> rightBrother) {
this.root.addRightBrother(rightBrother);
}
}
- 获取特征集:
java
[
{"uid":"1","color":"青绿","root":"蜷缩","sound":"浊响","stripe":"清晰","base":"凹陷","touch":"硬滑","result":"1"},
{"uid":"2","color":"乌黑","root":"蜷缩","sound":"沉闷","stripe":"清晰","base":"凹陷","touch":"硬滑","result":"1"},
{"uid":"3","color":"乌黑","root":"蜷缩","sound":"浊响","stripe":"清晰","base":"凹陷","touch":"硬滑","result":"1"},
{"uid":"4","color":"青绿","root":"蜷缩","sound":"沉闷","stripe":"清晰","base":"凹陷","touch":"硬滑","result":"1"},
{"uid":"5","color":"浅白","root":"蜷缩","sound":"浊响","stripe":"清晰","base":"凹陷","touch":"硬滑","result":"1"},
{"uid":"6","color":"青绿","root":"稍蜷","sound":"浊响","stripe":"清晰","base":"稍凹","touch":"软粘","result":"1"},
{"uid":"7","color":"乌黑","root":"稍蜷","sound":"浊响","stripe":"稍糊","base":"稍凹","touch":"软粘","result":"1"},
{"uid":"8","color":"乌黑","root":"稍蜷","sound":"浊响","stripe":"清晰","base":"稍凹","touch":"硬滑","result":"1"},
{"uid":"9","color":"乌黑","root":"稍蜷","sound":"沉闷","stripe":"稍糊","base":"稍凹","touch":"硬滑","result":"0"},
{"uid":"10","color":"青绿","root":"硬挺","sound":"清脆","stripe":"清晰","base":"平坦","touch":"软粘","result":"0"},
{"uid":"11","color":"浅白","root":"硬挺","sound":"清脆","stripe":"模糊","base":"平坦","touch":"硬滑","result":"0"},
{"uid":"12","color":"浅白","root":"蜷缩","sound":"浊响","stripe":"模糊","base":"平坦","touch":"软粘","result":"0"},
{"uid":"13","color":"青绿","root":"稍蜷","sound":"浊响","stripe":"稍糊","base":"凹陷","touch":"硬滑","result":"0"},
{"uid":"14","color":"浅白","root":"稍蜷","sound":"沉闷","stripe":"稍糊","base":"凹陷","touch":"硬滑","result":"0"},
{"uid":"15","color":"乌黑","root":"稍蜷","sound":"浊响","stripe":"清晰","base":"稍凹","touch":"软粘","result":"0"},
{"uid":"16","color":"浅白","root":"蜷缩","sound":"浊响","stripe":"模糊","base":"平坦","touch":"硬滑","result":"0"},
{"uid":"17","color":"青绿","root":"蜷缩","sound":"沉闷","stripe":"稍糊","base":"稍凹","touch":"硬滑","result":"0"}
]
- 构建特征集的样本类:
java
package classificationUtil;
import java.io.Serializable;
/**
* 测试item
* @author zygswo
*
*/
public class TestItem implements Serializable{
@Override
public String toString() {
return "TestItem [uid=" + uid + ", color=" + color + ", root=" + root + ", sound=" + sound + ", stripe="
+ stripe + ", base=" + base + ", touch=" + touch + ", result=" + result + "]";
}
public String getColor() {
return color;
}
public void setColor(String color) {
this.color = color;
}
public String getRoot() {
return root;
}
public void setRoot(String root) {
this.root = root;
}
public String getSound() {
return sound;
}
public void setSound(String sound) {
this.sound = sound;
}
public String getStripe() {
return stripe;
}
public void setStripe(String stripe) {
this.stripe = stripe;
}
public String getBase() {
return base;
}
public void setBase(String base) {
this.base = base;
}
public String getTouch() {
return touch;
}
public void setTouch(String touch) {
this.touch = touch;
}
public String getResult() {
return result;
}
public void setResult(String result) {
this.result = result;
}
public String getUid() {
return uid;
}
public void setUid(String uid) {
this.uid = uid;
}
/**
*
*/
private static final long serialVersionUID = 1L;
/**
* 色泽
*/
private String uid;
/**
* 色泽
*/
@Calc(altName = "色泽")
private String color;
/**
* 根蒂
*/
@Calc(altName = "根蒂")
private String root;
/**
* 敲声
*/
@Calc(altName = "敲声")
private String sound;
/**
* 纹理
*/
@Calc(altName = "纹理")
private String stripe;
/**
* 脐部
*/
@Calc(altName = "脐部")
private String base;
/**
* 触感
*/
@Calc(altName = "触感")
private String touch;
/**
* 是否是好瓜
*/
private String result;
}
*其中calc表示特征,result表示标签
注解calc代码如下:
java
package classificationUtil;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 表示特征的注解,加在属性上用于计算信息熵和信息增益
* @author zygswo
*
*/
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Calc {
String altName();
}
- 构建决策树节点
java
package classificationUtil;
import java.util.List;
public class TreeNode<T> {
@Override
public String toString() {
if (this.isLeaf) {
return this.branchName + "-->" + this.result;
} else {
if (this.branchName == null || "".equals(this.branchName)) {
return this.nodeName;
}
if (this.nodeName == null || "".equals(this.nodeName)) {
return this.branchName + "--无法判断";
}
return this.branchName + "--" + this.nodeName;
}
}
public boolean isLeaf() {
return isLeaf;
}
public void setLeaf(boolean isLeaf) {
this.isLeaf = isLeaf;
}
public String getBranchName() {
return branchName;
}
public void setBranchName(String branchName) {
this.branchName = branchName;
}
public String getResult() {
return result;
}
public void setResult(String result) {
this.result = result;
}
public String getNodeName() {
return nodeName;
}
public void setNodeName(String nodeName) {
this.nodeName = nodeName;
}
public List<T> getItems() {
return items;
}
public void setItems(List<T> items) {
this.items = items;
}
/**
* 是否为叶子节点,如果是就打印信息
*/
private boolean isLeaf;
/**
* 分支名称
*/
private String nodeName;
/**
* 分支名称
*/
private String branchName;
/**
* 该节点的数据子集
*/
private List<T> items;
/**
* 结果
*/
private String result;
}
*这里解释一下:
isLeaf ------ 表示是否为叶子节点,用于在toString方法打印决策树时判断打印内容
nodeName ------ 节点名称,也就是特征的名称,如数据集里的"color(色泽)"
branchName ------ 分支名称,也就是特征的不同特征值名称,如数据集里的"青绿(色泽)"
items ------ 该节点的数据子集
result ------ 分类结果
- 之后我们可以编写构建决策树的代码,思路如下:
- 获取当前数据集,第一次获取全部的数据集,后续迭代获取的是数据子集
- 获取数据集的特征集
- 计算当前数据集在标签下的信息熵
- 计算特征集中不同特征在标签下的信息增益和信息增益率
- 选择最优划分特征,设置为当前treeNode节点,设置节点名称,并进行分类,获取分类数量m
- 设置当前分类的索引为i,当前节点的前一兄弟节点prevNode= null
- 针对第(i <= m)个子数据集,计算信息熵,如果信息熵为0表示该类中只有一类无法继续分类,就设置为叶子节点,并设置当前节点的数据子集、当前节点的分支名称和分类结果;否则设置为当前节点下的分支节点,设置当前节点的数据子集、当前节点的分支名称。如果当前节点的前一兄弟节点为空,就说明是第一个子节点,就设置当前节点为treeNode下的第一个子节点,将前一兄弟节点prevNode设置为当前节点
- 如果为非叶子节点,就重复1-7步直至叶子节点。
- 重复6-7步骤直至i > m
- 结束
整体代码如下所示:
第一次初始化参数:
java
/**
* 运行决策树
* @param cls 实体类
* @param resColName 标签名称
*/
public static <T> DecisionTree<TreeNode<T>> run(
String fileAddr,
Class<T> cls,
String resColName
) {
//1.初始化数据集
File dataFile = new File(fileAddr); //读取文件
BufferedInputStream reader = null;
String itemsStr = "";
List<T> items = new ArrayList<T>(); //数据集
List<String> features = new ArrayList<>();//特征集
try {
if (!dataFile.exists()) {
dataFile.createNewFile();
}
reader = new BufferedInputStream(new FileInputStream(dataFile));
byte[] line = new byte[reader.available()];
reader.read(line);
itemsStr = new String(line);
System.out.println(itemsStr);
//获取数据集
items = JSON.parseArray(itemsStr,cls);
//获取特征集
Field[] fields = cls.getDeclaredFields();
for(Field f:fields) {
String fName = f.getName();
//如果不是特征就过滤
Calc calc = f.getAnnotation(Calc.class);
if (calc == null) {
continue;
}
features.add(fName);
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} finally {
try {
reader.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
//计算标签数量
Map<String,List<T>> resColRes = calcNb(items,resColName);
//计算标签信息熵作为初始信息熵
double initResColXinxiShangRes = calcXinxishang(resColRes);
// items-数据集,features-属性集(特征集),root-决策树根节点,initResColXinxiShangRes-初始信息熵,resColName-标签名称
DecisionTree.Node<TreeNode<T>> root = new DecisionTree.Node<TreeNode<T>>();
run(items,features,root,initResColXinxiShangRes,resColName);
//打印决策树
DecisionTree<TreeNode<T>> decisionTree = new DecisionTree<>(root);
decisionTree.firstSearch(root);
return decisionTree;
}
第二次之后迭代:
java
/**
* 主方法
* @param items 数据集
* @param features 属性集(特征集)
* @param treeNode 决策树节点
* @param resXinxiShang 上一步计算的全局信息熵
* @param labelName 标签名称
*/
private static <T> void run(
List<T> items,
List<String> features,
DecisionTree.Node<TreeNode<T>> treeNode,
double resXinxiShang,
String labelName
){
//如果样本为空或者特征集合为空就返回
if (items.isEmpty() || features.isEmpty()) {
return;
}
//下一个要用的特征
String nextFeature = "";
double maxXinxiZengyiRate= Integer.MIN_VALUE;
for (String feature: features) {
//获取所有剩余的特征,计算信息熵和信息增益,求出最优划分特征
double xinxiZengyi = resXinxiShang;
//获取所有按照该特征(属性)分类的列表
Map<String, List<T>> fRes = calcNb(items, feature);
double xinxiShang = calcXinxishang(fRes);
//计算信息熵
for(Map.Entry<String, List<T>> entry:fRes.entrySet()) {
//获取该特征(属性)中某一类的数据集
List<T> resTmp = entry.getValue();
//在该类中按照标签进行分类
Map<String,List<T>> temp = calcNb(resTmp, labelName);
//计算信息熵
double xinxishangTemp = calcXinxishang(temp);
//计算信息增益
xinxiZengyi -= (resTmp.size() * xinxishangTemp / items.size() * 1.0);
}
//计算信息增益率,获取最大值
System.out.println(feature + "的信息增益为: "+ xinxiZengyi);
double xinxiZengyiRate = xinxiZengyi / xinxiShang * 1.0;
if (xinxiZengyiRate > maxXinxiZengyiRate) {
nextFeature = feature;
maxXinxiZengyiRate = xinxiZengyiRate;
}
}
//得到最优划分特征后,进行决策树的构建
//构建决策树
//判断是否为root根节点,如果为根节点就不会有内容
TreeNode<T> result = treeNode.getNodeValue();
if (result == null) {
result = new TreeNode<T>();
result.setLeaf(false);
result.setNodeName(nextFeature); //设置最优划分特征为节点名称
DecisionTree.Node<TreeNode<T>> node =
new DecisionTree.Node<TreeNode<T>>(result);
treeNode.addChild(node);
} else {
//如果是后生成的节点,就更新节点名称
result.setNodeName(nextFeature);
treeNode.setNodeValue(result);
}
//从剩余的特征里剔除当前特征
// features.remove(nextFeature);
//获取按照最优划分特征分类的列表
//判断上一个决策树节点
DecisionTree.Node<TreeNode<T>> prevNode = null;
Map<String, List<T>> fRes = calcNb(items, nextFeature);
System.out.println("items => " + items.toString());
for (Map.Entry<String, List<T>> entry:fRes.entrySet()) {
//获取该特征(属性)中某一类的数据集
List<T> resTmp = entry.getValue();
//在该类中按照标签进行分类
Map<String,List<T>> temp = calcNb(resTmp, labelName);
//计算信息熵
double xinxishangTemp = calcXinxishang(temp);
//如果信息熵为0,也就是说该类已经划分完毕,就设置叶子节点并返回
if (xinxishangTemp == 0) {
//构建决策树
String res;
//利用反射机制获取分类结果
Class<?> cls = resTmp.get(0).getClass();
Field f;
try {
TreeNode<T> leafResult = new TreeNode<T>();
leafResult.setLeaf(true); //设置为叶子节点
f = cls.getDeclaredField(labelName);
f.setAccessible(true);
res = String.valueOf(f.get(resTmp.get(0)));
leafResult.setResult(res); //设置分类的结果
leafResult.setItems(resTmp); //设置对象数组
leafResult.setBranchName(entry.getKey()); //设置分支名称为当前最优划分特征的值,如色泽可以是青绿
DecisionTree.Node<TreeNode<T>> subNode =
new DecisionTree.Node<TreeNode<T>>(leafResult);
//如果有上一个节点,就连起来
if (prevNode != null) {
prevNode.addRightBrother(subNode);
} else {
treeNode.addChild(subNode);
}
prevNode = subNode;
} catch (NoSuchFieldException | SecurityException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
} else {
//构建决策树
TreeNode<T> nodeResult = new TreeNode<T>();
nodeResult.setLeaf(false); //设置为叶子节点
nodeResult.setBranchName(entry.getKey()); //设置分支名称为当前最优划分特征的值,如色泽可以是青绿
nodeResult.setItems(resTmp); //设置对象数组
DecisionTree.Node<TreeNode<T>> subNode =
new DecisionTree.Node<TreeNode<T>>(nodeResult);
//如果有上一个节点,就连起来
if (prevNode != null) {
prevNode.addRightBrother(subNode);
} else {
treeNode.addChild(subNode);
}
run(entry.getValue(),features,subNode,xinxishangTemp,labelName);
prevNode = subNode;
}
}
}
其他方法:计算信息熵和统计分类的方法:
java
/**
* 计算信息熵
* @param inputDataSet 输入的结果集
* @return 信息熵
*/
private static <T> double calcXinxishang(Map<String, List<T>> inputDataMap) {
double totalNb = 0.0,res = 0.0;
//计算总数
for (Map.Entry<String, List<T>> entry:inputDataMap.entrySet()) {
if (entry.getValue() == null) {
continue;
}
totalNb += entry.getValue().size();
}
//计算信息熵
for (Map.Entry<String, List<T>> entry:inputDataMap.entrySet()) {
if (entry.getValue() == null) {
continue;
}
int currentSize = entry.getValue().size();
double temp = (currentSize / totalNb) * 1.0;
if (res == 0) {
res = -1 * temp * (Math.log(temp) / Math.log(2) * 1.0);
} else {
res += -1 * temp * (Math.log(temp) / Math.log(2) * 1.0);
}
}
return res;
}
/**
* 计算分类统计结果
* @param inputDataSet 输入的结果集
* @param calcColumnName 列名
* @return 统计结果
*/
private static <T> Map<String,List<T>> calcNb(List<T> inputDataSet,String calcColumnName){
Map<String,List<T>> res = new ConcurrentHashMap<String, List<T>>();
if (inputDataSet == null || inputDataSet.isEmpty()) {
return res;
}
Class<?> cls = inputDataSet.get(0).getClass();
Field[] fs = cls.getDeclaredFields();
//
for (Field f:fs) {
f.setAccessible(true);
String name = f.getName();
if (name.equalsIgnoreCase(calcColumnName)) {
for (T inputData:inputDataSet) {
try {
String value = f.get(inputData).toString();
List<T> temp = new ArrayList<>();
if (res.get(value) != null) {
temp = res.get(value);
}
temp.add(inputData);
res.put(value, temp);
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
}
return res;
}
main方法
java
public static void main(String[] args) {
//运行决策树,data_2.txt就为上文的数据集文件,TestItem为特征集的样本类,result为标签名称
run("xxx\\datas_2.txt", TestItem.class,"result");
}
------------------------------------------ ps: 预剪枝、后剪枝有空再弄------------------------------------------------------