ID3算法 决策树学习 Python实现

算法流程

输入:约束决策树生长参数(最大深度,节点最小样本数,可选),训练集(特征值离散或连续,标签离散)。

输出:决策树。

过程:每次选择信息增益最大的属性决策分类,直到当前节点样本均为同一类,或者信息增益过小。

信息增益

设样本需分为 K K K 类,当前节点待分类样本中每类样本的个数分别为 n 1 , n 2 , ... , n K n_1, n_2, ..., n_K n1,n2,...,nK,则该节点信息熵为
I ( n 1 , n 2 , ... , n K ) = − ∑ i = 1 K n i ∑ j = 1 K n j log ⁡ 2 n i ∑ j = 1 K n j I(n_1, n_2, ..., n_K) = -\sum_{i=1}^K \frac{n_i}{\sum_{j=1}^K n_j} \log_2 \frac{n_i}{\sum_{j=1}^K n_j} I(n1,n2,...,nK)=−i=1∑K∑j=1Knjnilog2∑j=1Knjni

设属性 A A A 共 v v v 种取值,当前节点样本按属性 A A A 决策分类为 v v v 个子节点,第 i i i 个子节点待分类样本中每类样本的个数分别为 n i 1 , n i 2 , ... , n i K n_{i1}, n_{i2}, ..., n_{iK} ni1,ni2,...,niK,则父节点按属性 A A A 决策分类的类信息熵为
E ( A ) = ∑ i = 1 v ∑ j = 1 K n i j ∑ j = 1 K n j I ( n i 1 , n i 2 , ... , n i K ) E(A) = \sum_{i=1}^v \frac{\sum_{j=1}^K n_{ij}}{\sum_{j=1}^K n_j} I(n_{i1}, n_{i2}, ..., n_{iK}) E(A)=i=1∑v∑j=1Knj∑j=1KnijI(ni1,ni2,...,niK)

由此计算当前节点在属性 A 上的信息增益为
G a i n ( A ) = I ( n 1 , n 2 , ... , n K ) − E ( A ) Gain(A) = I(n_1, n_2, ..., n_K) - E(A) Gain(A)=I(n1,n2,...,nK)−E(A)

决策树学习过程中可能出现的问题与解决方法

不相关属性(irrelevant attribute),属性与类分布相独立。此情况下信息增益过小,可以终止决策,将当前节点标签设为最高频类。

不充足属性(inadequate attribute),不同类的样本有完全相同特征。此情况下信息增益为 0 0 0,可以终止决策,将当前节点标签设为最高频类。

未知属性值(unknown value),数据集中某些属性值不确定。可以通过预处理剔除含有未知属性值的样本或属性。

过拟合(overfitting),决策树泛化能力不足。可以约束决策树生长参数。

空分支(empty branch),学习过程中某节点样本数为 0 0 0, v ≥ 3 v≥3 v≥3 才会发生。可以将当前节点标签设为父节点的最高频类。

参考代码如下(仅能处理离散属性值状态):

python 复制代码
import numpy as np
class ID3:
    def __init__(self, max_depth = 0, min_samples_split = 0):
        self.max_depth, self.min_samples_split = max_depth, min_samples_split
    def __EI(self, *n):
        n = np.array([i for i in n if i > 0])
        if n.shape[0] <= 1:
            return 0
        p = n / np.sum(n)
        return -np.dot(p, np.log2(p))
    def __Gain(self, A: np.ndarray):
        return self.__EI(*np.sum(A, axis = 0)) - np.average(np.frompyfunc(self.__EI, A.shape[1], 1)(*A.T), weights = np.sum(A, axis = 1))
    def fit(self, X: np.ndarray, y):
        self.DX, (self.Dy, yn) = [np.unique(X[:, i]) for i in range(X.shape[1])], np.unique(y, return_inverse = True)
        self.Dy: np.ndarray
        self.value = []
        def fitcur(n, h, p = 0):
            self.value.append(np.bincount(yn[n], minlength = self.Dy.shape[0]))
            r: np.ndarray = np.unique(y[n])
            if r.shape[0] == 0: # Empty Branch
                return p
            elif r.shape[0] == 1:
                return yn[n[0]]
            elif self.max_depth > 0 and h >= self.max_depth or n.shape[0] <= self.min_samples_split: # Overfitting
                return np.argmax(np.bincount(yn[n]))
            else:
                P = [[n[np.where(X[n, i] == j)[0]] for j in self.DX[i]] for i in range(X.shape[1])]
                G = [self.__Gain(A) for A in [np.array([[np.where(y[i] == j)[0].shape[0] for j in self.Dy] for i in p]) for p in P]]
                m = np.argmax(G)
                if(G[m] < 1e-9): # Inadequate attribute
                    return np.argmax(np.bincount(yn[n]))
                return (m,) + tuple(fitcur(i, h + 1, np.argmax(np.bincount(yn[n]))) for i in P[m])
        self.tree = fitcur(np.arange(X.shape[0]), 0)
    def predict(self, X):
        def precur(n, x):
            return precur(n[1 + np.where(self.DX[n[0]] == x[n[0]])[0][0]], x) if isinstance(n, tuple) else self.Dy[n]
        return np.array([precur(self.tree, x) for x in X])
    def visualize(self, header):
        i = iter(self.value)
        def visval():
            v = next(i)
            print(' (entropy = {}, samples = {}, value = {})'.format(self.__EI(*v), np.sum(v), v), end = '')
        def viscur(n, h, c):
            for i in h[:-1]:
                print('%c   ' % ('│' if i else ' '), end = '')
            if len(h) > 0:
                print('%c── ' % ('├' if h[-1] else '└'), end = '')
                print('[%s] ' % c, end = '')
            if isinstance(n, tuple):
                print(header[n[0]], end = '')
                visval()
                print()
                for i in range(len(n) - 1):
                    viscur(n[i + 1], h + [i < len(n) - 2], str(self.DX[n[0]][i]))
            else:
                print(self.Dy[n], end = '')
                visval()
                print()
        viscur(self.tree, [], '')

连续属性值的离散化

对于某个连续属性,取训练集中所有属性值的相邻两点中点生成界点集,按每个界点将当前节点样本分为 2 2 2 类,算出界点集中最大信息增益的界点。

在上文代码的基础上加以改动,得到能处理连续属性值状态的代码如下:

python 复制代码
import numpy as np
class ID3:
    def __init__(self, max_depth = 0, min_samples_split = 0):
        self.max_depth, self.min_samples_split = max_depth, min_samples_split
    def __EI(self, *n):
        n = np.array([i for i in n if i > 0])
        if n.shape[0] <= 1:
            return 0
        p = n / np.sum(n)
        return -np.dot(p, np.log2(p))
    def __Gain(self, A: np.ndarray):
        return self.__EI(*np.sum(A, axis = 0)) - np.average([self.__EI(*a) for a in A], weights = np.sum(A, axis = 1))
    def fit(self, X: np.ndarray, y):
        self.c = np.array([(np.all([isinstance(j, (int, float)) for j in i])) for i in X.T])
        self.DX, (self.Dy, yn) = [np.unique(X[:, i]) if not self.c[i] else None for i in range(X.shape[1])], np.unique(y, return_inverse = True)
        self.Dy: np.ndarray
        self.value = []
        def Part(n, a):
            if self.c[a]:
                u = np.sort(np.unique(X[n, a]))
                if(u.shape[0] < 2):
                    return None
                v = np.array([(u[i - 1] + u[i]) / 2 for i in range(1, u.shape[0])])
                P = [[n[np.where(X[n, a] < i)[0]], n[np.where(X[n, a] >= i)[0]]] for i in v]
                m = np.argmax([self.__Gain([[np.where(y[i] == j)[0].shape[0] for j in self.Dy] for i in p]) for p in P])
                return v[m], P[m]
            else:
                return None, [n[np.where(X[n, a] == i)[0]] for i in self.DX[a]]
        def fitcur(n: np.ndarray, h, p = 0):
            self.value.append(np.bincount(yn[n], minlength = self.Dy.shape[0]))
            r: np.ndarray = np.unique(y[n])
            if r.shape[0] == 0: # Empty Branch
                return p
            elif r.shape[0] == 1:
                return yn[n[0]]
            elif self.max_depth > 0 and h >= self.max_depth or n.shape[0] <= self.min_samples_split: # Overfitting
                return np.argmax(np.bincount(yn[n]))
            else:
                P = [Part(n, i) for i in range(X.shape[1])]
                G = [self.__Gain([[np.where(y[i] == j)[0].shape[0] for j in self.Dy] for i in p[1]]) if p != None else 0 for p in P]
                m = np.argmax(G)
                if(G[m] < 1e-9): # Inadequate attribute
                    return np.argmax(np.bincount(yn[n]))
                return ((m, P[m][0]) if self.c[m] else (m,)) + tuple(fitcur(i, h + 1, np.argmax(np.bincount(yn[n]))) for i in P[m][1])
        self.tree = fitcur(np.arange(X.shape[0]), 0)
    def predict(self, X):
        def precur(n, x):
            return precur(n[(2 if x[n[0]] < n[1] else 3) if self.c[n[0]] else (1 + np.where(self.DX[n[0]] == x[n[0]])[0][0])], x) if isinstance(n, tuple) else self.Dy[n]
        return np.array([precur(self.tree, x) for x in X])
    def visualize(self, header):
        i = iter(self.value)
        def visval():
            v = next(i)
            print(' (entropy = {}, samples = {}, value = {})'.format(self.__EI(*v), np.sum(v), v), end = '')
        def viscur(n, h, c):
            for i in h[:-1]:
                print('%c   ' % ('│' if i else ' '), end = '')
            if len(h) > 0:
                print('%c── ' % ('├' if h[-1] else '└'), end = '')
                print('[%s] ' % c, end = '')
            if isinstance(n, tuple):
                print(header[n[0]], end = '')
                visval()
                print()
                if self.c[n[0]]:
                    for i in range(2):
                        viscur(n[2 + i], h + [i < 1], ('< ', '>= ')[i] + str(n[1]))
                else:
                    for i in range(len(n) - 1):
                        viscur(n[1 + i], h + [i < len(n) - 2], str(self.DX[n[0]][i]))
            else:
                print(self.Dy[n], end = '')
                visval()
                print()
        viscur(self.tree, [], '')

实验测试

实验使用数据集如下:
Play tennis 数据集(来源:kaggle):离散属性
Mushroom classification 数据集(来源:kaggle):离散属性
Carsdata 数据集(来源:kaggle):连续属性

Iris 数据集(来源:sklearn.datasets):连续属性

其中 play_tennis.csv 内容如下:

day outlook temp humidity wind play
D1 Sunny Hot High Weak No
D2 Sunny Hot High Strong No
D3 Overcast Hot High Weak Yes
D4 Rain Mild High Weak Yes
D5 Rain Cool Normal Weak Yes
D6 Rain Cool Normal Strong No
D7 Overcast Cool Normal Strong Yes
D8 Sunny Mild High Weak No
D9 Sunny Cool Normal Weak Yes
D10 Rain Mild Normal Weak Yes
D11 Sunny Mild Normal Strong Yes
D12 Overcast Mild High Strong Yes
D13 Overcast Hot Normal Weak Yes
D14 Rain Mild High Strong No

Play tennis 数据集上的测试

默认属性二分类测试,代码如下:

python 复制代码
import pandas as pd
class Datasets:
    def __init__(self, fn):
        self.df = pd.read_csv('Datasets\\%s' % fn).map(lambda x: x.strip() if isinstance(x, str) else x)
        self.df.rename(columns = lambda x: x.strip(), inplace = True)
    def getData(self, DX, Dy, drop = False):
        dfn = self.df.loc[~self.df.eq('').any(axis = 1)].apply(pd.to_numeric, errors = 'ignore') if drop else self.df
        return dfn[DX].to_numpy(dtype = np.object_), dfn[Dy].to_numpy(dtype = np.object_)

# play_tennis.csv
a = ['outlook', 'temp', 'humidity', 'wind']
X, y = Datasets('play_tennis.csv').getData(a, 'play')
dt11 = ID3()
dt11.fit(X, y)
dt11.visualize(a)
print()

结果如下:

复制代码
outlook (entropy = 0.9402859586706311, samples = 14, value = [5 9])
├── [Overcast] Yes (entropy = 0, samples = 4, value = [0 4])
├── [Rain] wind (entropy = 0.9709505944546686, samples = 5, value = [2 3])
│   ├── [Strong] No (entropy = 0, samples = 2, value = [2 0])
│   └── [Weak] Yes (entropy = 0, samples = 3, value = [0 3])
└── [Sunny] humidity (entropy = 0.9709505944546686, samples = 5, value = [3 2])
    ├── [High] No (entropy = 0, samples = 3, value = [3 0])
    └── [Normal] Yes (entropy = 0, samples = 2, value = [0 2])

不充足属性测试

更换属性三分类,不充足属性测试,代码如下:

python 复制代码
# play_tennis.csv for inadequate attribute test and class > 2
a = ['temp', 'humidity', 'wind', 'play']
X, y = Datasets('play_tennis.csv').getData(a, 'outlook')
dt12 = ID3(10)
dt12.fit(X, y)
dt12.visualize(a)
print(dt12.predict([['Cool', 'Normal', 'Weak', 'Yes']]))
print()

结果如下:

复制代码
play (entropy = 1.5774062828523454, samples = 14, value = [4 5 5])
├── [No] temp (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│   ├── [Cool] Rain (entropy = 0, samples = 1, value = [0 1 0])
│   ├── [Hot] Sunny (entropy = 0, samples = 2, value = [0 0 2])
│   └── [Mild] wind (entropy = 1.0, samples = 2, value = [0 1 1])
│       ├── [Strong] Rain (entropy = 0, samples = 1, value = [0 1 0])
│       └── [Weak] Sunny (entropy = 0, samples = 1, value = [0 0 1])
└── [Yes] temp (entropy = 1.5304930567574826, samples = 9, value = [4 3 2])
    ├── [Cool] wind (entropy = 1.584962500721156, samples = 3, value = [1 1 1])
    │   ├── [Strong] Overcast (entropy = 0, samples = 1, value = [1 0 0])
    │   └── [Weak] Rain (entropy = 1.0, samples = 2, value = [0 1 1])
    ├── [Hot] Overcast (entropy = 0, samples = 2, value = [2 0 0])
    └── [Mild] wind (entropy = 1.5, samples = 4, value = [1 2 1])
        ├── [Strong] humidity (entropy = 1.0, samples = 2, value = [1 0 1])
        │   ├── [High] Overcast (entropy = 0, samples = 1, value = [1 0 0])
        │   └── [Normal] Sunny (entropy = 0, samples = 1, value = [0 0 1])
        └── [Weak] Rain (entropy = 0, samples = 2, value = [0 2 0])
['Rain']

空分支测试

默认属性二分类,修改部分数据,空分支测试,代码如下:

python 复制代码
# play_tennis.csv modified to generate empty branch
a = ['outlook', 'temp', 'humidity', 'wind']
X, y = Datasets('play_tennis.csv').getData(a, 'play')
X[2, 2], X[13, 2] = 'Low', 'Low'
dt13 = ID3()
dt13.fit(X, y)
dt13.visualize(a)
print(dt13.predict([['Sunny', 'Hot', 'Low', 'Weak']]))
print()

结果如下:

复制代码
outlook (entropy = 0.9402859586706311, samples = 14, value = [5 9])
├── [Overcast] Yes (entropy = 0, samples = 4, value = [0 4])
├── [Rain] wind (entropy = 0.9709505944546686, samples = 5, value = [2 3])
│   ├── [Strong] No (entropy = 0, samples = 2, value = [2 0])
│   └── [Weak] Yes (entropy = 0, samples = 3, value = [0 3])
└── [Sunny] humidity (entropy = 0.9709505944546686, samples = 5, value = [3 2])
    ├── [High] No (entropy = 0, samples = 3, value = [3 0])
    ├── [Low] No (entropy = 0, samples = 0, value = [0 0])
    └── [Normal] Yes (entropy = 0, samples = 2, value = [0 2])
['No']

Mushroom classification 数据集上的测试

默认属性二分类,忽略有未知值的属性,划分训练集和测试集,代码如下:

python 复制代码
# mushrooms.csv ignoring attribute 'stalk-root' with unknown value
a = ['cap-shape', 'cap-surface', 'cap-color', 'bruises', 'odor', 'gill-attachment', 'gill-spacing', 'gill-size',
        'gill-color', 'stalk-shape', 'stalk-surface-above-ring', 'stalk-surface-below-ring', 'stalk-color-above-ring', 'stalk-color-below-ring', 'veil-type',
        'veil-color', 'ring-number', 'ring-type', 'spore-print-color', 'population', 'habitat']
X, y = Datasets('mushrooms.csv').getData(a, 'class')
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt21 = ID3()
dt21.fit(X_train, y_train)
dt21.visualize(a)
print()
y_pred = dt21.predict(X_test)
print(classification_report(y_test, y_pred))

结果如下:

复制代码
odor (entropy = 0.9990161113058208, samples = 6093, value = [3159 2934])
├── [a] e (entropy = 0, samples = 298, value = [298   0])
├── [c] p (entropy = 0, samples = 138, value = [  0 138])
├── [f] p (entropy = 0, samples = 1636, value = [   0 1636])
├── [l] e (entropy = 0, samples = 297, value = [297   0])
├── [m] p (entropy = 0, samples = 26, value = [ 0 26])
├── [n] spore-print-color (entropy = 0.19751069442516636, samples = 2645, value = [2564   81])
│   ├── [b] e (entropy = 0, samples = 32, value = [32  0])
│   ├── [h] e (entropy = 0, samples = 35, value = [35  0])
│   ├── [k] e (entropy = 0, samples = 974, value = [974   0])
│   ├── [n] e (entropy = 0, samples = 1013, value = [1013    0])
│   ├── [o] e (entropy = 0, samples = 33, value = [33  0])
│   ├── [r] p (entropy = 0, samples = 50, value = [ 0 50])
│   ├── [u] e (entropy = 0, samples = 0, value = [0 0])
│   ├── [w] habitat (entropy = 0.34905151737109524, samples = 473, value = [442  31])
│   │   ├── [d] gill-size (entropy = 0.7062740891876007, samples = 26, value = [ 5 21])
│   │   │   ├── [b] e (entropy = 0, samples = 5, value = [5 0])
│   │   │   └── [n] p (entropy = 0, samples = 21, value = [ 0 21])
│   │   ├── [g] e (entropy = 0, samples = 222, value = [222   0])
│   │   ├── [l] cap-color (entropy = 0.7553754125614287, samples = 46, value = [36 10])
│   │   │   ├── [b] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [c] e (entropy = 0, samples = 20, value = [20  0])
│   │   │   ├── [e] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [g] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [n] e (entropy = 0, samples = 16, value = [16  0])
│   │   │   ├── [p] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [r] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [u] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [w] p (entropy = 0, samples = 7, value = [0 7])
│   │   │   └── [y] p (entropy = 0, samples = 3, value = [0 3])
│   │   ├── [m] e (entropy = 0, samples = 0, value = [0 0])
│   │   ├── [p] e (entropy = 0, samples = 35, value = [35  0])
│   │   ├── [u] e (entropy = 0, samples = 0, value = [0 0])
│   │   └── [w] e (entropy = 0, samples = 144, value = [144   0])
│   └── [y] e (entropy = 0, samples = 35, value = [35  0])
├── [p] p (entropy = 0, samples = 187, value = [  0 187])
├── [s] p (entropy = 0, samples = 433, value = [  0 433])
└── [y] p (entropy = 0, samples = 433, value = [  0 433])

              precision    recall  f1-score   support

           e       1.00      1.00      1.00      1049
           p       1.00      1.00      1.00       982

    accuracy                           1.00      2031
   macro avg       1.00      1.00      1.00      2031
weighted avg       1.00      1.00      1.00      2031

Carsdata 数据集上的测试

默认属性三分类,忽略有未知值的样本,划分训练集和测试集,约束决策树生长最大深度为 5,节点最小样本数为 3,代码如下:

python 复制代码
# cars.csv ignoring samples with unknown value
a = ['mpg', 'cylinders', 'cubicinches', 'hp', 'weightlbs', 'time-to-60', 'year']
X, y = Datasets('cars.csv').getData(a, 'brand', True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt31 = ID3(5, 3)
dt31.fit(X_train, y_train)
dt31.visualize(a)
print()
y_pred = dt31.predict(X_test)
print(classification_report(y_test, y_pred))

结果如下:

复制代码
cubicinches (entropy = 1.3101461692119258, samples = 192, value = [ 37  33 122])
├── [< 191.0] year (entropy = 1.5833913647120852, samples = 105, value = [37 33 35])
│   ├── [< 1981.5] cubicinches (entropy = 1.5558899087683136, samples = 87, value = [37 27 23])
│   │   ├── [< 121.5] cubicinches (entropy = 1.4119058166561587, samples = 62, value = [31 23  8])
│   │   │   ├── [< 114.0] cubicinches (entropy = 1.4844331941390079, samples = 47, value = [18 21  8])
│   │   │   │   ├── [< 87.0] Japan. (entropy = 0.5435644431995964, samples = 8, value = [1 7 0])
│   │   │   │   └── [>= 87.0] Europe. (entropy = 1.521560239117063, samples = 39, value = [17 14  8])
│   │   │   └── [>= 114.0] weightlbs (entropy = 0.5665095065529053, samples = 15, value = [13  2  0])
│   │   │       ├── [< 2571.0] Europe. (entropy = 0.9709505944546686, samples = 5, value = [3 2 0])
│   │   │       └── [>= 2571.0] Europe. (entropy = 0, samples = 10, value = [10  0  0])
│   │   └── [>= 121.5] weightlbs (entropy = 1.3593308322365363, samples = 25, value = [ 6  4 15])
│   │       ├── [< 3076.5] hp (entropy = 0.9917601481809735, samples = 20, value = [ 1  4 15])
│   │       │   ├── [< 92.5] US. (entropy = 0, samples = 11, value = [ 0  0 11])
│   │       │   └── [>= 92.5] Japan. (entropy = 1.3921472236645345, samples = 9, value = [1 4 4])
│   │       └── [>= 3076.5] Europe. (entropy = 0, samples = 5, value = [5 0 0])
│   └── [>= 1981.5] mpg (entropy = 0.9182958340544896, samples = 18, value = [ 0  6 12])
│       ├── [< 31.3] US. (entropy = 0, samples = 9, value = [0 0 9])
│       └── [>= 31.3] mpg (entropy = 0.9182958340544896, samples = 9, value = [0 6 3])
│           ├── [< 33.2] Japan. (entropy = 0, samples = 4, value = [0 4 0])
│           └── [>= 33.2] time-to-60 (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│               ├── [< 16.5] US. (entropy = 0, samples = 2, value = [0 0 2])
│               └── [>= 16.5] Japan. (entropy = 0.9182958340544896, samples = 3, value = [0 2 1])
└── [>= 191.0] US. (entropy = 0, samples = 87, value = [ 0  0 87])

              precision    recall  f1-score   support

     Europe.       0.50      0.80      0.62        10
      Japan.       0.83      0.56      0.67        18
         US.       0.94      0.94      0.94        36

    accuracy                           0.81        64
   macro avg       0.76      0.77      0.74        64
weighted avg       0.84      0.81      0.81        64

离散属性和连续属性混合分类测试

根据上文决策树节点划分结果对其中某个属性进行预离散化,相同方式划分训练集和测试集,约束决策树生长参数不变,离散属性和连续属性混合分类测试,代码如下:

python 复制代码
# cars.csv with attribute 'cubicinches' discretized 
def find(a, n):
    def findcur(r):
        return findcur(r + 1) if r < len(a) and a[r] < n else r
    return findcur(0)
X[:, 2] = np.array([('a', 'b', 'c')[find([121.5, 191.0], i)] for i in X[:, 2]])
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt32 = ID3(5, 3)
dt32.fit(X_train, y_train)
dt32.visualize(a)
y_pred = dt32.predict(X_test)
print(classification_report(y_test, y_pred))

结果如下:

复制代码
cubicinches (entropy = 1.3101461692119258, samples = 192, value = [ 37  33 122])
├── [a] year (entropy = 1.5231103605784926, samples = 74, value = [31 28 15])
│   ├── [< 1981.5] weightlbs (entropy = 1.4119058166561587, samples = 62, value = [31 23  8])
│   │   ├── [< 2571.0] weightlbs (entropy = 1.4729350396193688, samples = 50, value = [20 22  8])
│   │   │   ├── [< 2271.5] mpg (entropy = 1.5038892873131435, samples = 42, value = [19 15  8])
│   │   │   │   ├── [< 30.25] Europe. (entropy = 1.2640886121123147, samples = 23, value = [15  5  3])
│   │   │   │   └── [>= 30.25] Japan. (entropy = 1.4674579648482995, samples = 19, value = [ 4 10  5])
│   │   │   └── [>= 2271.5] mpg (entropy = 0.5435644431995964, samples = 8, value = [1 7 0])
│   │   │       ├── [< 37.9] Japan. (entropy = 0, samples = 7, value = [0 7 0])
│   │   │       └── [>= 37.9] Europe. (entropy = 0, samples = 1, value = [1 0 0])
│   │   └── [>= 2571.0] cylinders (entropy = 0.41381685030363374, samples = 12, value = [11  1  0])
│   │       ├── [< 3.5] Japan. (entropy = 0, samples = 1, value = [0 1 0])
│   │       └── [>= 3.5] Europe. (entropy = 0, samples = 11, value = [11  0  0])
│   └── [>= 1981.5] mpg (entropy = 0.9798687566511528, samples = 12, value = [0 5 7])
│       ├── [< 31.3] US. (entropy = 0, samples = 4, value = [0 0 4])
│       └── [>= 31.3] mpg (entropy = 0.954434002924965, samples = 8, value = [0 5 3])
│           ├── [< 33.2] Japan. (entropy = 0, samples = 3, value = [0 3 0])
│           └── [>= 33.2] time-to-60 (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│               ├── [< 16.5] US. (entropy = 0, samples = 2, value = [0 0 2])
│               └── [>= 16.5] Japan. (entropy = 0.9182958340544896, samples = 3, value = [0 2 1])
├── [b] weightlbs (entropy = 1.2910357498542626, samples = 31, value = [ 6  5 20])
│   ├── [< 3076.5] hp (entropy = 0.9293550115186283, samples = 26, value = [ 1  5 20])
│   │   ├── [< 93.5] US. (entropy = 0, samples = 16, value = [ 0  0 16])
│   │   └── [>= 93.5] time-to-60 (entropy = 1.360964047443681, samples = 10, value = [1 5 4])
│   │       ├── [< 15.5] cylinders (entropy = 0.954434002924965, samples = 8, value = [0 5 3])
│   │       │   ├── [< 5.0] Japan. (entropy = 0, samples = 3, value = [0 3 0])
│   │       │   └── [>= 5.0] US. (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│   │       └── [>= 15.5] Europe. (entropy = 1.0, samples = 2, value = [1 0 1])
│   └── [>= 3076.5] Europe. (entropy = 0, samples = 5, value = [5 0 0])
└── [c] US. (entropy = 0, samples = 87, value = [ 0  0 87])
              precision    recall  f1-score   support

     Europe.       0.57      0.40      0.47        10
      Japan.       0.65      0.72      0.68        18
         US.       0.92      0.94      0.93        36

    accuracy                           0.80        64
   macro avg       0.71      0.69      0.70        64
weighted avg       0.79      0.80      0.79        64

Iris 数据集上的测试

默认属性三分类,划分训练集和测试集,限制决策树生长最大深度为 3,代码如下:

python 复制代码
# iris dataset
from sklearn.datasets import load_iris
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt41 = ID3(3)
dt41.fit(X_train, y_train)
dt41.visualize(iris['feature_names'])
print()
y_pred = dt41.predict(X_test)
print(classification_report(y_test, y_pred))

结果如下:

复制代码
petal length (cm) (entropy = 1.5807197138422102, samples = 112, value = [34 41 37])
├── [< 2.45] 0 (entropy = 0, samples = 34, value = [34  0  0])
└── [>= 2.45] petal width (cm) (entropy = 0.9981021327390103, samples = 78, value = [ 0 41 37])
    ├── [< 1.75] petal length (cm) (entropy = 0.4394969869215134, samples = 44, value = [ 0 40  4])
    │   ├── [< 4.95] 1 (entropy = 0.17203694935311378, samples = 39, value = [ 0 38  1])
    │   └── [>= 4.95] 2 (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
    └── [>= 1.75] petal length (cm) (entropy = 0.19143325481419343, samples = 34, value = [ 0  1 33])
        ├── [< 4.85] 2 (entropy = 0.9182958340544896, samples = 3, value = [0 1 2])
        └── [>= 4.85] 2 (entropy = 0, samples = 31, value = [ 0  0 31])

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00         9
           2       1.00      1.00      1.00        13

    accuracy                           1.00        38
   macro avg       1.00      1.00      1.00        38
weighted avg       1.00      1.00      1.00        38
相关推荐
无敌最俊朗@3 小时前
力扣hot100-206反转链表
算法·leetcode·链表
Kuo-Teng3 小时前
LeetCode 279: Perfect Squares
java·数据结构·算法·leetcode·职场和发展
王哈哈^_^3 小时前
YOLO11实例分割训练任务——从构建数据集到训练的完整教程
人工智能·深度学习·算法·yolo·目标检测·机器学习·计算机视觉
檐下翻书1734 小时前
从入门到精通:流程图制作学习路径规划
论文阅读·人工智能·学习·算法·流程图·论文笔记
CoderYanger4 小时前
B.双指针——3194. 最小元素和最大元素的最小平均值
java·开发语言·数据结构·算法·leetcode·职场和发展·1024程序员节
小曹要微笑5 小时前
STM32各系列时钟树详解
c语言·stm32·单片机·嵌入式硬件·算法
2501_941147426 小时前
人工智能与大数据:推动未来智能制造的双引擎
决策树
前进的李工6 小时前
LeetCode hot100:094 二叉树的中序遍历:从递归到迭代的完整指南
python·算法·leetcode·链表·二叉树
麦麦大数据7 小时前
F049 知识图谱双算法推荐在线学习系统vue+flask+neo4j之BS架构开题论文全源码
学习·算法·知识图谱·推荐算法·开题报告·学习系统·计算机毕业设计展示
兩尛8 小时前
215. 数组中的第K个最大元素
数据结构·算法·排序算法