前言
本篇文章介绍的是Kaggle月赛《Binary Prediction of Poisonous Mushrooms》,即《有毒蘑菇的二分类预测》。与之前练习赛一样,这声比赛也同样适合初学者,但与之前不同的是,本次比赛的数据集有大量的缺失值,如何处理这些缺失值,直接影响比赛的成绩。因此,本期用两篇文章用不同的方法来处理这些,至于用什么模型,模型的参数将不是本期的重点。第一篇使用ColumnTransformer
和Pipeline
技术,来提升处理数据的能力。
题目说明
毒蘑菇预测是机器学习经典分类问题,类似于鸢尾花数据集,也是可以通过 sklearn 库直接加载的。
方法一:
python
from sklearn.datasets import fetch_openml
mushroom_data = fetch_openml(name='mushroom', version=1)
方法二:
从UCI机器学习数据仓库 版本1 、版本2 直接下载本地,
本次数据集是基于上述原始数据,通过合成的方法生成的,数据特征与版本1一致。
以下是数据集的21个特征:
Class 分类 : 是否有毒
Cap Diameter 菌盖直径:菌盖最宽处的测量值。它有助于识别蘑菇的大小,范围从几毫米到几厘米。
Cap Shape 菌盖形:菌盖的整体形状,如圆锥形、钟形、扁平形或波浪形。这一特征有助于区分不同的物种。
Cap Surface 菌盖表面:菌盖表面的纹理和外观。它可以是光滑的、有鳞的、粘稠的或起皱的,为蘑菇的身份提供了线索。
Cap Color 菌盖颜色:菌盖的颜色变化很大,可能会随着蘑菇的成熟而变化。颜色可能是识别物种的关键因素。
Does Bruise or Bleed 瘀伤或出血:指蘑菇在瘀伤时是否会变色,或者是否会释放出有色液体。这种反应对于识别很重要。
Gill Attachment 菌褶附着:菌褶是如何附着在茎上的。它们可以是自由的(未连接)、连接的(连接到茎上)或下降的(沿茎向下延伸)。
Gill Spacing 菌褶间距:菌褶之间的距离。菌褶可以是拥挤的、有间隔的,或者它们的间距是中间的。
Gill Color 菌褶颜色:菌褶的颜色,有助于区分物种,可能会随着年龄而变化。
Stem Height 茎高:茎从地面到盖子连接处的长度。茎高的变化有助于识别。
Stem Width 茎宽:茎的直径。它可以是窄的、中等的或宽的,并且因物种而异。
Stem Root 茎根:茎的基部,可以是肿胀的、球状的或锥形的。观察茎基部可以帮助识别某些蘑菇。
Stem Surface 茎表面:茎表面的纹理和外观。它可以是光滑的、纤维状的、鳞状的或粗糙的。
Stem Color 茎的颜色:茎的颜色,可能是均匀的,也可能沿其长度变化。
Veil Type 面纱类型:指蘑菇上存在的面纱类型,如部分面纱(覆盖菌褶,通常形成一个环)或通用面纱(在早期阶段包裹整个蘑菇)。
Veil Color 面纱颜色:面纱的颜色。它可能是识别蘑菇的关键特征,特别是在区分外观相似的物种时。
Has Ring 有环:表示蘑菇茎周围是否有环(也称为环),这是部分面纱的残余。
Ring Type 环类型:描述存在的环的类型,如单环、双环或扩口或悬挂的环。
Spore Print Color 孢子打印颜色:将盖子放在白色纸张上获得的孢子打印颜色。这是一个关键的识别特征。
Habitat 栖息地:蘑菇生长的环境,如林地、草原或城市地区。栖息地可以帮助缩小可能的物种范围。
Season 季节:一年中蘑菇出现的时间。不同种类的水果在不同的季节,这有助于识别。
特征 | 值或格式 |
---|---|
Class | edible=e, poisonous=p |
Cap Diameter | float number in cm |
Cap Shape | bell=b, conical=c, convex=x, flat=f, sunken=s, knobbed=k |
Cap Surface | fibrous=f, grooves=g, scaly=y, smooth=s |
Cap Color | brown=n, buff=b, cinnamon=c, gray=g, green=r, pink=p, purple=u, red=e, white=w, yellow=y |
Does Bruise or Blee | bruises=t, no=f |
Gill Attachment | attached=a, descending=d, free=f, notched=n |
Gill Spacing | close=c, crowded=w, distant=d |
Gill Color | black=k, brown=n, buff=b, chocolate=h, gray=g, green=r, orange=o, pink=p, purple=u, red=e, white=w, yellow=y |
Stem Height | float number in cm |
Stem Width | float number in mm |
Stem Root | bulbous=b, swollen=s, club=c, cup=u, equal=e, rhizomorphs=z, rooted=r |
Stem Surface | see cap-surface + none=f |
Stem Color | see cap-color + none=f |
Veil Type | partial=p, universal=u |
Veil Color | see cap-color + none=f |
Has Ring | ring=t, none=f |
Ring Type | cobwebby=c, evanescent=e, flaring=r, grooved=g, large=l, pendant=p, sheathing=s, zone=z, scaly=y, movable=m, none=f, unknown=? |
Spore Print Color | black=k, brown=n, buff=b, chocolate=h, green=r, orange=o, purple=u, white=w, yellow=y |
Habitat | grasses=g, leaves=l, meadows=m, paths=p, heaths=h, urban=u, waste=w, woods=d |
Season | spring=s, summer=u, autumn=a, winter=w |
为了更好理解各特征的含义,以下给出蘑菇的结构图。
目标
预测 是否可食用或有毒 edible=e, poisonous=p
。
加载库
python
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.metrics import accuracy_score, matthews_corrcoef
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
加载数据
python
train = pd.read_csv('/kaggle/input/playground-series-s4e8/train.csv')
test = pd.read_csv('/kaggle/input/playground-series-s4e8/test.csv')
submission = pd.read_csv('/kaggle/input/playground-series-s4e8/sample_submission.csv')
python
train.head()
id | class | cap-diameter | cap-shape | cap-surface | cap-color | does-bruise-or-bleed | gill-attachment | gill-spacing | gill-color | ... | stem-root | stem-surface | stem-color | veil-type | veil-color | has-ring | ring-type | spore-print-color | habitat | season | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | e | 8.80 | f | s | u | f | a | c | w | ... | NaN | NaN | w | NaN | NaN | f | f | NaN | d | a |
1 | 1 | p | 4.51 | x | h | o | f | a | c | n | ... | NaN | y | o | NaN | NaN | t | z | NaN | d | w |
2 | 2 | e | 6.94 | f | s | b | f | x | c | w | ... | NaN | s | n | NaN | NaN | f | f | NaN | l | w |
3 | 3 | e | 3.88 | f | y | g | f | s | NaN | g | ... | NaN | NaN | w | NaN | NaN | f | f | NaN | d | u |
4 | 4 | e | 5.85 | x | l | w | f | d | NaN | w | ... | NaN | NaN | w | NaN | NaN | f | f | NaN | g | a |
探索数据
python
train.shape
(3116945, 22)
python
train.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3116945 entries, 0 to 3116944
Data columns (total 22 columns):
Column Dtype
0 id int64
1 class object
2 cap-diameter float64
3 cap-shape object
4 cap-surface object
5 cap-color object
6 does-bruise-or-bleed object
7 gill-attachment object
8 gill-spacing object
9 gill-color object
10 stem-height float64
11 stem-width float64
12 stem-root object
13 stem-surface object
14 stem-color object
15 veil-type object
16 veil-color object
17 has-ring object
18 ring-type object
19 spore-print-color object
20 habitat object
21 season object
dtypes: float64(3), int64(1), object(18)
memory usage: 523.2+ MB
python
train.isnull().sum()
id 0
class 0
cap-diameter 4
cap-shape 40
cap-surface 671023
cap-color 12
does-bruise-or-bleed 8
gill-attachment 523936
gill-spacing 1258435
gill-color 57
stem-height 0
stem-width 0
stem-root 2757023
stem-surface 1980861
stem-color 38
veil-type 2957493
veil-color 2740947
has-ring 24
ring-type 128880
spore-print-color 2849682
habitat 45
season 0
dtype: int64
python
test.head()
id | cap-diameter | cap-shape | cap-surface | cap-color | does-bruise-or-bleed | gill-attachment | gill-spacing | gill-color | ... | stem-root | stem-surface | stem-color | veil-type | veil-color | has-ring | ring-type | spore-print-color | habitat | season | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 3116945 | 8.64 | x | NaN | n | t | NaN | NaN | w | 11.13 | ... | b | NaN | w | u | w | t | g | NaN | d |
1 | 3116946 | 6.90 | o | t | o | f | NaN | c | y | 1.27 | ... | NaN | NaN | n | NaN | NaN | f | f | NaN | d |
2 | 3116947 | 2.00 | b | g | n | f | NaN | c | n | 6.18 | ... | NaN | NaN | n | NaN | NaN | f | f | NaN | d |
3 | 3116948 | 3.47 | x | t | n | f | s | c | n | 4.98 | ... | NaN | NaN | w | NaN | n | t | z | NaN | d |
4 | 3116949 | 6.17 | x | h | y | f | p | NaN | y | 6.73 | ... | NaN | NaN | y | NaN | y | t | NaN | NaN | d |
python
test.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2077964 entries, 0 to 2077963
Data columns (total 21 columns):
Column Dtype
0 id int64
1 cap-diameter float64
2 cap-shape object
3 cap-surface object
4 cap-color object
5 does-bruise-or-bleed object
6 gill-attachment object
7 gill-spacing object
8 gill-color object
9 stem-height float64
10 stem-width float64
11 stem-root object
12 stem-surface object
13 stem-color object
14 veil-type object
15 veil-color object
16 has-ring object
17 ring-type object
18 spore-print-color object
19 habitat object
20 season object
dtypes: float64(3), int64(1), object(17)
memory usage: 332.9+ MB
python
test.isnull().sum()
id 0
cap-diameter 7
cap-shape 31
cap-surface 446904
cap-color 13
does-bruise-or-bleed 10
gill-attachment 349821
gill-spacing 839595
gill-color 49
stem-height 1
stem-width 0
stem-root 1838012
stem-surface 1321488
stem-color 21
veil-type 1971545
veil-color 1826124
has-ring 19
ring-type 86195
spore-print-color 1899617
habitat 25
season 0
dtype: int64
python
for column in train.columns:
print(column)
print(train[column].unique())
id
[ 0 1 2 ... 3116942 3116943 3116944]
class
['e' 'p']
cap-diameter
[ 8.8 4.51 6.94 ... 38.11 55.63 54.07]
cap-shape
['f' 'x' 'p' 'b' 'o' 'c' 's' 'd' 'e' 'n' nan 'w' 'k' 'l' '19.29' '5 f' 't'
'g' 'z' 'a' '2.85' '7 x' 'r' 'u' '3.55' 'is s' 'y' '4.22' '3.6' '21.56'
'i' '6 x' '24.16' '8' 'm' 'ring-type' '10.13' 'is p' '7.43' 'h' '0.82'
'10.46' '2.77' '2.94' '12.62' '5.15' '19.04' '4.97' '49.21' 'b f' '9.13'
'1.66' '3.37' '7.21' '3.25' '11.12' '3 x' '4.3' '7.41' '6.21' '8.29'
'54.78' '20.25' '3.52' '3.04' '2.63' '3.91' '6.44' '8.3' '7.6' '17.44'
'4.33' '2.82' '6.53' '19.06']
cap-surface
['s' 'h' 'y' 'l' 't' 'e' 'g' nan 'd' 'i' 'w' 'k' '15.94' 'f' 'n' 'r' 'o'
'a' 'u' 'z' '2.7' 'does l' '5.07' 'p' 'b' 'm' 'cap-diameter' '1.43' 'x'
'7.14' 'c' 'is h' 'does t' '0.85' '6.57' '12.79' '6.45' '4.66' '23.18'
'3.06' '16.39' '4.21' 'veil-color' '11.78' '8.1' 'has-ring' 'does h'
'1.42' 'class' 'has h' 'does None' '10.83' 'season' '8.96' '14.04' '5.73'
'is None' '24.38' '2.81' '0.88' '2.11' '2.79' 'ring-type'
'does-bruise-or-bleed' '4.93' 'spore-print-color' 'spore-color' '2.92'
'2.51' '7.99' 'is y' '3.64' '3.33' '41.91' '12.2' '8.01' '9.22' '1.14'
'6.49' '10.34' '10.1' '1.08' 'is k' '0.87']
cap-color
['u' 'o' 'b' 'g' 'w' 'n' 'e' 'y' 'r' 'p' 'k' 'l' 'i' 'h' 'd' 's' 'a' 'f'
'2.05' 'season' 'c' 'x' '13' '7.72' 'm' 'z' '6.76' '7.15' 't' 'ring-type'
nan 'class' '12.89' '8.83' '24.75' '22.38' '1.51' '10.1' '17.94' '3.57'
'does n' '4.89' '6.2' '21.53' '6.41' '4.98' '3.95' 'does-bruise-or-bleed'
'6.59' '5.25' 'veil-color' '6.9' '5.41' '11.13' '3.11' '2.57' '17.93'
'2.7' '8.57' '11.92' '3.08' '2.82' '4.24' '17.19' '3.34' '7' '2.9' '6.36'
'5.91' '10.56' '26.89' '4. n' '20.62' 'stem-surface' '20.02' '20' '25.98'
'8.67' '9.02']
does-bruise-or-bleed
['f' 't' 'd' 'has-ring' 'w' 'o' 'b' 'x' 'p' nan 'g' 'y' 'r' 'a' 'l' 'i'
'c' 'n' 'z' 's' 'k' 'h' '3.43' 'e' '4.42' '2.9' 'u']
gill-attachment
['a' 'x' 's' 'd' 'e' nan 'f' 'p' 'l' 'm' 'b' '32.54' 'n' 'g' 'i' 'u'
'does-bruise-or-bleed' 't' 'o' 'c' 'w' '4.64' 'k' 'r' '4.77' 'h' 'p p'
'7.92' 'z' 'season' 'y' '8.79' 'does None' 'has f' 'ring-type' '16.33'
'10.85' '20.07' '2.82' '7.86' '3.91' 'does' '10.23' '6.74' '0.92' '3.45'
'1' 'is a' '3.71' '50.44' '11.62' 'has d' '1.32' '8.47' '6.11' '2.41'
'2.54' '6.32' '19.65' '15.49' '4.01' '8.37' 'does f' '28.7' '13.15'
'1.37' '28.15' '7.09' '9.88' '2.67' '18.21' '1.48' '5.93' '1.51' '16.27'
'11.26' '2.79' 'is f' '13.94']
gill-spacing
['c' nan 'd' 'f' 'x' 'b' 'a' '3.61' '2.69' 'k' '4.8' 'e' 'y' 'class' 's'
'9.01' 'p' '3.92' '5.22' '6.67' '4.04' 't' '0.73' 'i' '3.57' '24.38' 'w'
'h' 'cap-surface' 'l' '1' '12.27' '5.42' 'r' '1.6' 'n' 'g' '0' '3.81'
'4.09' '1.36' '3.24' '5.55' '5.7' '3.62' 'does f' '6.4' '1.88' '55.13']
gill-color
['w' 'n' 'g' 'k' 'y' 'f' 'p' 'o' 'b' 'u' 'e' 'r' 'd' 't' '3.45' 'z' '5'
'3.39' 'season' 'h' 'x' 's' '4' 'class' 'c' 'm' 'spacing' '0.92' nan
'18.12' 'l' 'does w' 'a' '7.59' '8.06' '6.19' 'has-ring' '4.49' '9.46'
'5.01' 'ring-type' '3.4' 'i' '17' '10.07' 'stem-root' '20.6'
'spore-print-color' '18.03' 'does-bruise-or-bleed' '8.83' 'habitat'
'10.21' '4.64' '6.4' 'is y' 'e y' '1.91' 'does n' '16.41' '6.41'
'veil-type' '20.44' '8.37']
stem-height
[ 4.51 4.79 6.85 ... 26.09 47.33 26.53]
stem-width
[15.39 6.48 9.93 ... 66.91 79.92 53.44]
stem-root
[nan 'b' 'c' 'r' 's' 'f' '5.59' '2.77' '20.01' 'y' 'o' 'k' 'd' 'n' 'w' 'u'
'p' 'x' 'i' '10.87' 'a' '3.63' 't' 'm' 'l' 'h' 'g' '16.88' '15.69' '1.48'
'3.23' 'e' '20.0' '18.06' 'z' 'spore-print-color' '3.49' '13.03' '7.15']
stem-surface
[nan 'y' 's' 't' 'g' 'h' 'k' 'i' 'f' 'l' 'd' 'x' '12.04' 'w' 'a' 'o' 'c'
'n' 'm' 'e' 'p' 'z' '6.58' '4.34' 'b' '3.89' 'r' '25.83' '1.59' '0.0'
'5.97' '5.81' 'u' 'season' '10.48' '3.68' '5.56' '4.41' '5.48' '5.51'
'class' 'has-ring' '13.1' '17.46' '5.35' '7.23' 'does None' '1.03'
'does s' '7.45' 'has h' 'does-bruise-or-bleed' '1.94' '49.46' '19.35'
'2.68' '4.74' 'spore-print-color' '10.93' '24.12' '13.94']
stem-color
['w' 'o' 'n' 'y' 'e' 'u' 'p' 'f' 'g' 'r' 'k' 'l' 'b' '3.13' 't' 'z' 'a'
'h' 'd' nan 's' '7.33' 'is n' 'i' 'c' 'x' 'e n' '1.75' 'm' '33.52'
'ring-type' '2.78' 'spore-print-color' '23.59' '8.32' 'is w' '26.4'
'4.75' '7.84' 'class' '2.75' '8.49' '4.49' '1.41' '17.45' '3.53' '12.92'
'3.98' '20.07' '7.7' '22.6' '6.31' '6.09' '3.56' '3.37' '4.62' '2.54'
'39.51' '18.06' '4.33']
veil-type
[nan 'u' 'd' 'a' 'h' '21.11' 'g' 'c' 'e' 'y' 'i' 'f' 'is None' 't' 'w' 'p'
'b' 's' 'k' 'r' 'l' 'n' '5.94']
veil-color
[nan 'n' 'w' 'k' 'y' 'e' 'u' 'p' 'd' 'g' 'r' 'h' 's' '8.25' 't' 'c' 'o'
'i' '2.49' 'f' 'a' 'b' 'l' 'z' '3.32']
has-ring
['f' 't' 'h' 'r' 'y' 'c' 'e' 'g' 'l' 's' nan 'p' 'x' 'k' 'z' 'f has-ring'
'd' 'o' 'n' 'm' 'i' '10.3' 'w' 'a']
ring-type
['f' 'z' 'e' nan 'p' 'l' 'g' 'r' 'm' 'y' 'h' 'o' 't' 'ring-type' 'a' 'd'
's' 'x' '4' 'b' '15' 'u' 'n' 'w' 'does f' '3.12' 'i' 'season' 'k' 'c'
'does-bruise-or-bleed' '11' '23.6' '1' '14' '2' 'spore-print-color'
'class' 'sp' '2.87' '8.25']
spore-print-color
[nan 'k' 'w' 'p' 'n' 'r' 'u' 'g' 't' 'f' 'd' 'l' 'y' 'a' 's' '2.49' 'e'
'o' 'c' 'b' '10 None' 'h' 'x' '9 None' 'i' 'm' 'veil-color' 'class'
'2.62' 'season' '9.55' '6.36' '4.58']
habitat
['d' 'l' 'g' 'h' 'p' 'm' 'u' 'w' 'y' nan 'n' 'a' 's' 'k' 'habitat' 'z'
'8.09' '17.1' 'b' 't' 'c' '9.28' 'ring-type' 'e' 'r'
'does-bruise-or-bleed' 'f' 'is w' 'o' '2.94' 'x' '4' 'is h' '5.56'
'class' 'i' '10.07' '7.31' '5.62' 'spore-print-color' 'cap-diameter'
'3.11' '16.46' '7.37' 'veil-type' '17.38' '1.66' '6.63' '18.35' '6.75'
'2.44' '3.68' '2.25']
season
['a' 'w' 'u' 's']
从上述数据信息可知,数据特征不多,但存在大量的缺失值,特征分为 分类型和数值型。
缺失值处理
保持所有列的空值小于70% ,其他特征进行删除
python
train.columns[train.isnull().mean()>.7]
Index(['stem-root', 'veil-type', 'veil-color', 'spore-print-color'], dtype='object')
python
test.columns[test.isnull().mean()>.7]
Index(['stem-root', 'veil-type', 'veil-color', 'spore-print-color'], dtype='object')
python
# 剔除缺失值过多的列,以70%为界
train = train[train.columns[train.isnull().mean()<0.7]]
test = test[test.columns[test.isnull().mean()<0.7]]
python
# 将目标转化为0 ,1 二分类
class_dict = {'e': 0, 'p':1}
train['class'] = train['class'].map(class_dict)
建模
python
X = train.drop('class', axis=1)
y = train['class']
# 分为 分类型和数值型
cat_cols = [col for col in X.columns if X[col].dtype=='object']
num_cols = [col for col in X.columns if X[col].dtype==['int64', 'float32']]
python
# 数值型用 均值 来填充
num_transformer = SimpleImputer(strategy='median')
# 分类型用 最多频次 来填充 同时 也顺序编码
cat_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='most_frequent')),
('enc', OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=np.nan))
])
preprocessor = ColumnTransformer(
transformers = [
('num', num_transformer, num_cols),
('cat', cat_transformer, cat_cols)
], remainder='passthrough')
用 Pipeline 【管线】ColumnTransformer【转换器】来处理数据,方便统一处理。
python
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 这里参数,是经过训练后的结果
model = XGBClassifier(n_jobs=-1,max_depth= 15, min_child_weight= 8.038088382806158, learning_rate= 0.1858697647052074, n_estimators= 197, colsample_bytree= 0.5226315394655169, random_state= 42)
model_pipeline = Pipeline(steps=[
('preprocessor', preprocessor),
('model', model)
])
model_pipeline.fit(X_train, y_train)
pred = model_pipeline.predict(X_test)
acc = matthews_corrcoef(y_test, pred)
print(acc)
0.9826202433642893
matthews_corrcoef
计算马修斯相关系数(MCC) Matthews相关系数在机器学习中用作衡量二进制和多类分类质量的指标。它考虑了真假正例和负例,通常被认为是平衡的度量,即使类别的大小差异很大,也可以使用该度量。MCC本质上是介于-1和+1之间的相关系数值。系数+1代表理想预测,0代表平均随机预测,-1代表逆预测。该统计也称为phi系数。马修斯相关系数(+1代表理想预测,0代表平均随机预测,-1代表反向预测)。
这个验证是题目要求。
python
model_pipeline.fit(X, y)
prediction = model_pipeline.predict(test)
通过管线处理测试数据
提交结果
python
submission['class'] = prediction
submission['class'] = submission['class'].map({0:'e', 1:'p'})
submission.to_csv('submission.csv', index=False)
提交测试平台
以图为证
结论
本文探讨的是二分类问题,主要学习了缺失值的处理方式,以及 Pipeline 【管线】ColumnTransformer【转换器】来处理数据 极大提升数据的效率,简化的操作步骤,本文并没有进行参数优化、模型融合等操作,如果小伙伴想提升成绩,也可以在这方面努力。
由于本篇没有过多篇幅进行EDA方面进行探讨,下一篇 《有毒蘑菇的二分类预测》(下) 将在这方面进行说明,另用深度学习的方法进行建模,用不同的视角进学习,以增强我们的能力。