决策树经常处理分类问题,近来的调查表明决策树也是经常使用的数据挖掘算法。
决策树的流程图:
长方形代表判断模块(decision block),椭圆形代表中止模块(terminating block),表示已经得出结论,可以中止运行。
从判断模块引出左右箭头称作分支(branch),它可以到底另一个判断模块或者中止模块。
决策树算法能够读取数据集合,构建决策树
决策树的一个重要任务是为了数据中蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,在这些机器根据数据集创建规则时,就是机器学习的过程。
3.1 决策树的构造
- 优点:计算复杂度不高,输出结果易于理解,对中间值确实不敏感,可以处理不相关特征数据。
- 缺点:可能会产生过度匹配问题。
- 适用数据类型:数值型和标称型。
创建分支的伪代码函数createBranch()如下所示:
检测数据集中的每个子项是否属于同一分类:
If so return 类标签;
Else
寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分的子集
## 递归调用createBranch
调用createBranch并增加返回结果到分支节点中
return 分支节点
决策树的一般流程:
- 收集数据:可以使用任何方法。
- 准备数据:树构造只是用于标称型数据,因此数值型数据必须离散化。
- 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
- 训练算法:构造树的数据结构。
- 测试算法:使用经验树计算错误率。
- 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。
3.1.1 信息增益
在划分数据集之前之后信息发生的变换称为信息增益。
获得信息增益最高的特征就是最好的选择。
熵:定义为信息的期望值
如果待分类的事务划分在多个分类之中,在符号 x i x_i xi的信息定义为:
l ( x i ) = − log 2 p ( x i ) l(x_i) = - \log _{2} p(x_i) l(xi)=−log2p(xi)
为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:
H = − ∑ i = 1 n p ( x i ) log 2 p ( x i ) H = - \sum _{i = 1}^{n} p(x_i) \log _{2} p(x_i) H=−i=1∑np(xi)log2p(xi)
计算给定数据集的香农熵
from math import log
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCount = {}
##为所有可能分类创建字典
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCount.keys():
labelCount[currentLabel] = 0
labelCount[currentLabel] += 1
shannonEnt = 0.0
for key in labelCount:
prob = float(labelCount[key]) / numEntries
shannonEnt -= prob * log(prob, 2) ##以2为底数求对数
return shannonEnt
##简单鱼鉴定数据集
def creatDataSet():
dateSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dateSet, labels
if __name__ == '__main__':
dataSet, labels = creatDataSet()
shannonEnt = calcShannonEnt(dataSet)
print(shannonEnt)
3.1.2 划分数据集
按照给定特征划分数据集
def splitDataSet(dataSet, axis, value):
'''
:param dataSet: 待划分的数据集
:param axis: 划分数据集的特征
:param value: 需要返回的特征的值
:return: 结果数据集
'''
retDataSet = [] ##创建一个新的list对象
for featVec in dataSet:
if featVec[axis] == value:
##抽取
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
if __name__ == '__main__':
dataSet, labels = creatDataSet()
retDataSet = splitDataSet(dataSet, 0, 1)
print(retDataSet)
Python列表对象的方法,append()和extend()用于在列表末尾添加元素,但它们的用法和效果有所不同。
- append()方法:
用法:list.append(element)
参数:element,要添加的单个元素。
功能:将指定的元素添加到列表的末尾。
结果:在列表的末尾添加一个新的元素,并扩展列表的长度。
特点:可以添加任意类型的元素,包括可迭代对象(如列表)。当添加可迭代对象时,整个对象作为单个元素添加到列表中。 - extend()方法:
用法:list.extend(iterable)
参数:iterable,一个可迭代对象,如列表、元组、字符串等。
功能:将可迭代对象中的每个元素添加到列表的末尾。
结果:不会创建一个新的列表对象,而是在原列表的末尾追加元素。
特点:只能添加可迭代对象的元素值,而不是整个可迭代对象。当添加的是另一个列表时,extend()会将那个列表中的每个元素逐个添加到原列表中。 - 总结:
append()适用于添加单个元素或可迭代对象,但整个可迭代对象被视为单个元素添加。
extend()适用于添加可迭代对象的每个元素,而不是整个可迭代对象。
这些方法都不返回新列表,而是直接修改原列表。
选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
## 创建唯一的分类标签列表
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
## 计算每种划分方式的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
## 计算最好的信息增益
if (infoGain > bestInfoGain) :
bestInfoGain = infoGain
bestFeature = i
return bestFeature
if __name__ == '__main__':
dataSet, labels = creatDataSet()
feature = chooseBestFeatureToSplit(dataSet)
print(feature)
3.1.3 递归构建决策树
其工作原理如下:
创建原始数据集,然后基于最好的属性值划分数据集合,由于特征值可能多余两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将向下递归到树分支的下一个节点,在这个节点上我们再次划分数据。因此我们可以采用递归的原则处理数据集。
递归结束的条件:
程序遍历完所有划分数据集的性质,或者每个分支下的所有实例都有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类。
import operator
def majorityCnt(classList):
'''
:param classList: 分类名称列表
:return: 出现次数最多的分类标签
'''
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
创建树的函数代码:
def createTree(dataSet, labels):
'''
:param dataSet: 数据集
:param labels: 标签列
:return:
'''
classList = [example[-1] for example in dataSet]
## 类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
## 遍历完所有特征值时返回出现次数最多的
if len(dataSet[0]) == 1 :
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
##得到列表包含的所有属性值
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value]= createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
if __name__ == '__main__':
myDat, labels = creatDataSet()
myTree = createTree(myDat, labels)
print(myTree)
3.2 在Python中使用Matplotlib注解绘制树形图
3.2.1 Matplotlib注解
Matplotlib提供了一个注解工具annotations,非常有用,它可以在数据图形上添加文本注释。注解通常用于解释数据的内容。
使用文本注解绘制树节点
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
## 定义文本框的箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
'''
绘制带箭头的注解
:param nodeTxt:
:param centerPt:
:param parentPt:
:param nodeType:
:return:
'''
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon= False)
plotNode(U'Decision Node', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode(U'Leaf Node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
if __name__ == '__main__':
createPlot();
3.2.2 构造注解树
获得叶节点的数目和树的层数
def getNumLeafs(myTree):
numLeafs = 0;
firstStr = list(myTree)[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
## 测试节点的数据类型是否为字典
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree)[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
if __name__ == '__main__':
listOfTrees = retrieveTree(1)
print(listOfTrees)
myTree = retrieveTree(0)
print(getNumLeafs(myTree))
print(getTreeDepth(myTree))
plotTree函数
def plotTree(myTree, parentPt, nodeTxt):
## 计算树的宽与高
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree)[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
## 标记子节点属性值
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
## 减少y偏移值
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
##更新后的createPlot
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
# createPlot.ax1 = plt.subplot(111, frameon= False)
# plotNode(U'Decision Node', (0.5, 0.1), (0.1, 0.5), decisionNode)
# plotNode(U'Leaf Node', (0.8, 0.1), (0.3, 0.8), leafNode)
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
if __name__ == '__main__':
myTree = retrieveTree(0)
createPlot(myTree)
3.3 测试和存储分类器
3.3.1 测试算法:使用决策树执行分类
使用决策树的分类函数
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree)[0]
secondDict = inputTree[firstStr]
## 将标签字符串转换为索引
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
if __name__ == '__main__':
myDat, labels = creatDataSet()
print(labels)
myTree = treePlotter.retrieveTree(0)
print(myTree)
print(classify(myTree, labels, [1, 0]))
print(classify(myTree, labels, [1, 1]))
3.3.2 使用算法:决策树的存储
使用pickle模块存储决策树:
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename, 'rb')
return pickle.load(fr)
if __name__ == '__main__':
myDat, labels = creatDataSet()
print(labels)
myTree = treePlotter.retrieveTree(0)
storeTree(myTree, './resource/classifierStorage.txt')
classifyTree = grabTree('./resource/classifierStorage.txt')
print(classifyTree)
3.4 示例:使用决策树预测隐形眼镜类型
def use():
fr = open('./resource/lenses.txt', 'r')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabels)
print(lensesTree)
treePlotter.createPlot(lensesTree)
if __name__ == '__main__':
use()
决策树非常好地匹配了实验数据,然而这些匹配选项可能太多了。我们将这种问题称之为过度匹配(overfitting)。为了减少过度匹配问题,我们可以裁剪决策树,去掉一些不必要地叶子节点。如果叶子节点只能增加少许信息,则可以删除该节点,将它并入到其他叶子节点中。