参考资料:
https://blog.csdn.net/weixin_45666566/article/details/107954454
https://blog.csdn.net/Elenstone/article/details/105328111
代码如下:
python
#-*- coding:utf-8 -*-
import numpy as np
import pandas as pd
import operator
def loadDataSet():
csv = pd.read_csv(filepath_or_buffer=r'D:/PythonData/决策树.csv')
dataSet = np.array(csv)
labels = np.array(csv.columns)[:4]
targets = sorted(np.unique(dataSet[:,-1:].flatten()), reverse=True)
return dataSet, labels, targets
def calcProbabilityEnt(dataSet, targets):
numEntries = len(dataSet) # 数据条数
feaCounts = 0
fea1 = targets[0]
for featVec in dataSet:
if featVec[-1] == fea1:
feaCounts +=1
probabilityEnt = float(feaCounts) / numEntries
return probabilityEnt
def splitDataSet(dataSet, index, value):
retDataSet = []
noRetDataSet = []
for featVec in dataSet:
if featVec[index] == value:
retDataSet.append(np.concatenate((featVec[:index],featVec[index+1:])))
if featVec[index] != value:
noRetDataSet.append(np.concatenate((featVec[:index],featVec[index+1:])))
return retDataSet,noRetDataSet
def chooseBestFeatureToSplit(dataSet, targets):
numFeatures = len(dataSet[0]) - 1
if numFeatures == 1:
return 0
bestGini = 1
bestFeatureIndex = -1
for i in range(numFeatures):
# 每一列中的唯一值集合
uniqueVals = set(example[i] for example in dataSet)
feaGini = 0
for value in uniqueVals:
subDataSet,noSubDataSet = splitDataSet(dataSet=dataSet, index=i,value=value)
prod = len(subDataSet) / float(len(dataSet))
noPord = len(noSubDataSet) / float(len(dataSet))
probabilityEnt = calcProbabilityEnt(subDataSet, targets)
noProbabilityEnt = calcProbabilityEnt(noSubDataSet,targets)
feaGini = round(prod * 2 * probabilityEnt * (1 - probabilityEnt) + (noPord * (2 * noProbabilityEnt * (1 - noProbabilityEnt))),2)
if bestGini > feaGini:
bestGini = feaGini
bestFeatureIndex = i
return bestFeatureIndex
def majorityCnt(classList):
classCount = {}
for vote in classList:
try:
classCount[vote] += 1
except KeyError:
classCount[vote] = 1
sortedClassCount = sorted(iterable=classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels,targets):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList=classList)
bestFeatIndex = chooseBestFeatureToSplit(dataSet=dataSet,targets=targets)
bestFeatLabel = labels[bestFeatIndex]
np.delete(labels,bestFeatIndex)
uniqueVals = set(example[bestFeatIndex] for example in dataSet) # 选出最优特征对应属性的唯一值
myTree = {bestFeatLabel:{}} # 分类结果以字典形式保存
for value in uniqueVals:
subLabels = labels[:] # 深拷贝,拷贝后的值与原值无关(普通复制为浅拷贝,对原值或拷贝后的值的改变互相影响)
subDataSet,noSubDataSet = splitDataSet(dataSet,bestFeatIndex,value)
myTree[bestFeatLabel][value] = createTree(subDataSet,subLabels,targets) # 递归调用创建决策树
return myTree
if __name__=='__main__':
dataSet,labels,targets = loadDataSet()
print(createTree(dataSet,labels,targets))
运行如果如下:
shell
PS D:\PythonWorkSpace> & E:/anaconda3/python.exe d:/PythonWorkSpace/DecisionTreeDemo.py
{'有自己的房子': {'否': {'有工作': {'否': '不同意', '是': '同意'}}, '是': '同意'}}