import operator
def majorityCnt(classList):
classCount = {}
for i in classList:
if i not in classCount:
classCount[i] = 0
classCount[i] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
第2关:创建树函数
python复制代码
from ex03_lib import majorityCnt,splitDataSet,chooseBestFeatureToSplit
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet] #获取数据集的所有类别
#### 请补充完整代码 ####
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
#######################
return myTree
第3关:获取叶子节点数目
python复制代码
def getNumLeafs(myTree):
numLeafs = 0
if type(myTree).__name__ == 'dict':
fi = list(myTree.keys())[0]
se = myTree[fi]
for i in se.keys():
if type(se[i]).__name__ == 'dict':
numLeafs += getNumLeafs(se[i])
else:
numLeafs += 1
else:
numLeafs += 1
return numLeafs
第4关:获取树的层数
python复制代码
def getTreeDepth(myTree):
maxDepth = 0
#### 请补充完整代码 ####
fi = list(myTree.keys())[0]
se = myTree[fi]
for i in se.keys():
if type(se[i]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(se[i])
else:
thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
#######################
return maxDepth