机器学习练手(六):机器学习算法实践实战

本文为和鲸python 机器学习原理与实践·闯关训练营资料整理而来,加入了自己的理解(by GPT4o)

原活动链接

原作者:vgbhfive,多年风控引擎研发及金融模型开发经验,现任某公司风控研发工程师,对数据分析、金融模型开发、风控引擎研发具有丰富经验。

目录

在前面几关的学习中,总共学习了逻辑回归、KMeans、决策树、SVM 支持向量机和 XGBoost,我相信通过前面实际数据的模型实践和闯关题的解答,大家都已经学会了如何训练模型,那么我们就开始这个训练营的最后一关 堪培拉天气数据预测实战

下面我会给大家列出具体的实施步骤,需要大家补充完代码然后提交文件!大家可以参考之前的关卡完成实战项目。

开始之前我先介绍下实战的数据集:

  • train_weather-6.csv 训练数据集。
  • test_weather-6.csv 提交测试数据集。
  • submit_result-6.csv 提交结果数据集。

数据集中的特征含义表示如下:

特征列名称 特征含义
Date 日期
Location 观察的城市
MinTemp 当天最低温度(摄氏度)
MaxTemp 当天最高温度(摄氏度)温度都是 string
Rainfall 当天的降雨量(单位是毫米mm)
Evaporation 一个凹地上面水的蒸发量(单位是毫米mm),24小时内到早上9点
Sunshine 一天中出太阳的小时数
WindGustDir 最强劲的那股风的风向,24小时内到午夜
WindGustSpeed 最强劲的那股风的风速(km/h),24小时内到午夜
WindDir9am 上午9点的风向
WindDir3pm 下午3点的风向
WindSpeed9am 上午9点之前的十分钟里的平均风速,即 8:50~9:00的平均风速,单位是(km/hr)
WindSpeed3pm 下午3点之前的十分钟里的平均风速,即 14:50~15:00的平均风速,单位是(km/hr)
Humidity9am 上午9点的湿度
Humidity3pm 下午3点的湿度
Pressure9am 上午9点的大气压强(hpa)
Pressure3pm 下午3点的大气压强
Cloud9am 上午9点天空中云的密度,取值是[0, 8],以1位一个单位,0的话表示天空中几乎没云,8的话表示天空中几乎被云覆盖了
Cloud3pm 下午3点天空中云的密度
Temp9am 上午9点的温度(单位是摄氏度)
Temp3pm 下午3点的温度(单位是摄氏度)
RainTomorrow 明天是否下雨标签

引入依赖

python 复制代码
# 引入依赖

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import LabelEncoder
from sklearn.impute import SimpleImputer
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score

加载数据

python 复制代码
# 加载数据

train = pd.read_csv('./data/train_weather-6.csv', index_col='Unnamed: 0')
test = pd.read_csv('./data/test_weather-6.csv', index_col='Unnamed: 0')
submit = pd.read_csv('./data/submit_result-6.csv', index_col='Unnamed: 0')
train.head(), test.head()
(         Date     Location  MinTemp  MaxTemp  Rainfall  Evaporation  Sunshine  \
 0  2009-10-14    NorahHead     15.1     23.9       0.0          NaN       NaN   
 1  2011-09-26      Walpole      9.7     14.2       7.6          NaN       NaN   
 2  2010-04-20  Williamtown     13.2     25.4       0.0          3.2       8.8   
 3  2011-07-12       Hobart      7.6     14.8       0.0          4.0       7.0   
 4  2015-04-13  Williamtown     12.9     22.2       0.0          4.0       7.9   
 
   WindGustDir  WindGustSpeed WindDir9am  ... WindSpeed3pm  Humidity9am  \
 0         SSW           67.0         NW  ...         22.0         38.0   
 1         WSW           50.0        WNW  ...         28.0         91.0   
 2         ENE           30.0          W  ...         17.0         79.0   
 3         WNW           94.0        WNW  ...         35.0         52.0   
 4           S           37.0         SW  ...         20.0         69.0   
 
    Humidity3pm  Pressure9am  Pressure3pm  Cloud9am  Cloud3pm  Temp9am  \
 0         68.0       1001.9       1002.4       NaN       NaN     19.8   
 1         56.0       1008.2       1007.7       NaN       NaN     11.1   
 2         63.0       1025.2       1021.5       6.0       5.0     21.2   
 3         45.0       1004.6       1001.4       NaN       NaN     11.1   
 4         52.0       1023.0       1021.2       6.0       2.0     18.8   
 
    Temp3pm  RainTomorrow  
 0     14.3            No  
 1     13.4           Yes  
 2     24.0            No  
 3     12.9            No  
 4     20.6            No  
 
 [5 rows x 22 columns],
          Date      Location  MinTemp  MaxTemp  Rainfall  Evaporation  \
 0  2016-06-09      Ballarat      7.1     13.0       8.8          NaN   
 1  2009-10-24       Walpole     13.2     18.3       0.0          NaN   
 2  2015-09-21  PerthAirport      9.2     22.7       0.0          5.0   
 3  2011-12-06         Cobar     15.3     26.1       0.0         10.4   
 4  2014-03-15          Sale     11.9     31.8       0.0          5.0   
 
    Sunshine WindGustDir  WindGustSpeed WindDir9am  ... WindSpeed9am  \
 0       NaN           N           41.0          N  ...         24.0   
 1       NaN           E           48.0        ESE  ...         24.0   
 2      11.1         ENE           52.0        ENE  ...         26.0   
 3       NaN           E           44.0          E  ...         24.0   
 4       4.1          NW           72.0          E  ...          6.0   
 
    WindSpeed3pm  Humidity9am  Humidity3pm  Pressure9am  Pressure3pm  Cloud9am  \
 0          22.0        100.0         98.0       1001.7       1005.4       8.0   
 1          20.0         73.0         73.0       1027.6       1023.8       NaN   
 2          20.0         45.0         25.0       1030.1       1025.9       0.0   
 3          19.0         48.0         40.0       1013.2       1009.8       7.0   
 4          19.0         89.0         25.0       1006.7       1001.0       7.0   
 
    Cloud3pm  Temp9am  Temp3pm  
 0       8.0      8.6     11.5  
 1       NaN     14.2     17.0  
 2       0.0     15.1     22.5  
 3       7.0     17.5     24.3  
 4       6.0     16.2     27.4  
 
 [5 rows x 21 columns])
python 复制代码
# 提交结果数据集
submit.reset_index(drop = False,inplace = True)
submit.columns = ['','RainTomorrow']
submit.head()

| | | RainTomorrow |
| 0 | 0 | NaN |
| 1 | 1 | NaN |
| 2 | 2 | NaN |
| 3 | 3 | NaN |

4 4 NaN
python 复制代码
submit.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 28439 entries, 0 to 28438
Data columns (total 2 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0                 28439 non-null  int64  
 1   RainTomorrow  0 non-null      float64
dtypes: float64(1), int64(1)
memory usage: 444.5 KB

数据基础性分析

python 复制代码
# 数据基础性分析

# 查看训练集和结果集数据空值、类型
train.info(), test.info()
<class 'pandas.core.frame.DataFrame'>
Index: 113754 entries, 0 to 113753
Data columns (total 22 columns):
 #   Column         Non-Null Count   Dtype  
---  ------         --------------   -----  
 0   Date           113754 non-null  object 
 1   Location       113754 non-null  object 
 2   MinTemp        113229 non-null  float64
 3   MaxTemp        113486 non-null  float64
 4   Rainfall       112572 non-null  float64
 5   Evaporation    64963 non-null   float64
 6   Sunshine       59409 non-null   float64
 7   WindGustDir    106296 non-null  object 
 8   WindGustSpeed  106348 non-null  float64
 9   WindDir9am     105718 non-null  object 
 10  WindDir3pm     110731 non-null  object 
 11  WindSpeed9am   112671 non-null  float64
 12  WindSpeed3pm   111645 non-null  float64
 13  Humidity9am    112334 non-null  float64
 14  Humidity3pm    110841 non-null  float64
 15  Pressure9am    102497 non-null  float64
 16  Pressure3pm    102529 non-null  float64
 17  Cloud9am       70713 non-null   float64
 18  Cloud3pm       67987 non-null   float64
 19  Temp9am        113018 non-null  float64
 20  Temp3pm        111548 non-null  float64
 21  RainTomorrow   113754 non-null  object 
dtypes: float64(16), object(6)
memory usage: 20.0+ MB
<class 'pandas.core.frame.DataFrame'>
Index: 28439 entries, 0 to 28438
Data columns (total 21 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   Date           28439 non-null  object 
 1   Location       28439 non-null  object 
 2   MinTemp        28327 non-null  float64
 3   MaxTemp        28385 non-null  float64
 4   Rainfall       28215 non-null  float64
 5   Evaporation    16387 non-null  float64
 6   Sunshine       14968 non-null  float64
 7   WindGustDir    26567 non-null  object 
 8   WindGustSpeed  26575 non-null  float64
 9   WindDir9am     26462 non-null  object 
 10  WindDir3pm     27684 non-null  object 
 11  WindSpeed9am   28174 non-null  float64
 12  WindSpeed3pm   27918 non-null  float64
 13  Humidity9am    28085 non-null  float64
 14  Humidity3pm    27742 non-null  float64
 15  Pressure9am    25682 non-null  float64
 16  Pressure3pm    25683 non-null  float64
 17  Cloud9am       17823 non-null  float64
 18  Cloud3pm       17112 non-null  float64
 19  Temp9am        28271 non-null  float64
 20  Temp3pm        27919 non-null  float64
dtypes: float64(16), object(5)
memory usage: 4.8+ MB





(None, None)
python 复制代码
# 训练集和结果集基础分析查看(是否有非法数据)

train.describe().T, test.describe().T
(                  count         mean        std    min     25%     50%  \
 MinTemp        113229.0    12.176037   6.398791   -8.5     7.6    12.0   
 MaxTemp        113486.0    23.222816   7.118185   -4.8    17.9    22.6   
 Rainfall       112572.0     2.347680   8.466572    0.0     0.0     0.0   
 Evaporation     64963.0     5.470719   4.229935    0.0     2.6     4.8   
 Sunshine        59409.0     7.622586   3.778445    0.0     4.9     8.4   
 WindGustSpeed  106348.0    39.957395  13.574900    6.0    31.0    39.0   
 WindSpeed9am   112671.0    13.994169   8.884425    0.0     7.0    13.0   
 WindSpeed3pm   111645.0    18.626325   8.790884    0.0    13.0    19.0   
 Humidity9am    112334.0    68.824764  19.063076    0.0    57.0    70.0   
 Humidity3pm    110841.0    51.466659  20.799362    0.0    37.0    52.0   
 Pressure9am    102497.0  1017.651395   7.111363  980.5  1012.9  1017.6   
 Pressure3pm    102529.0  1015.258031   7.040286  978.2  1010.4  1015.2   
 Cloud9am        70713.0     4.433188   2.886888    0.0     1.0     5.0   
 Cloud3pm        67987.0     4.500478   2.722538    0.0     2.0     5.0   
 Temp9am        113018.0    16.983173   6.491592   -7.2    12.3    16.7   
 Temp3pm        111548.0    21.681986   6.939722   -5.4    16.6    21.1   
 
                   75%     max  
 MinTemp          16.8    33.9  
 MaxTemp          28.2    48.1  
 Rainfall          0.8   371.0  
 Evaporation       7.4   145.0  
 Sunshine         10.6    14.5  
 WindGustSpeed    48.0   135.0  
 WindSpeed9am     19.0   130.0  
 WindSpeed3pm     24.0    87.0  
 Humidity9am      83.0   100.0  
 Humidity3pm      66.0   100.0  
 Pressure9am    1022.4  1041.0  
 Pressure3pm    1020.0  1038.9  
 Cloud9am          7.0     8.0  
 Cloud3pm          7.0     9.0  
 Temp9am          21.6    40.2  
 Temp3pm          26.4    46.7  ,
                  count         mean        std    min     25%     50%     75%  \
 MinTemp        28327.0    12.227822   6.421153   -8.0     7.7    12.0    17.0   
 MaxTemp        28385.0    23.242649   7.115455   -3.8    18.0    22.6    28.3   
 Rainfall       28215.0     2.359128   8.459732    0.0     0.0     0.0     0.6   
 Evaporation    16387.0     5.466278   4.020351    0.0     2.6     4.8     7.4   
 Sunshine       14968.0     7.633852   3.793839    0.0     4.8     8.5    10.7   
 WindGustSpeed  26575.0    40.091929  13.644014    7.0    31.0    39.0    48.0   
 WindSpeed9am   28174.0    14.033258   8.928980    0.0     7.0    13.0    19.0   
 WindSpeed3pm   27918.0    18.682570   8.853018    0.0    13.0    19.0    24.0   
 Humidity9am    28085.0    68.919993  19.004234    1.0    57.0    70.0    83.0   
 Humidity3pm    27742.0    51.546320  20.791669    0.0    37.0    52.0    66.0   
 Pressure9am    25682.0  1017.663192   7.082063  984.6  1013.0  1017.6  1022.4   
 Pressure3pm    25683.0  1015.258891   7.022385  977.1  1010.5  1015.2  1020.0   
 Cloud9am       17823.0     4.453066   2.887546    0.0     1.0     5.0     7.0   
 Cloud3pm       17112.0     4.513850   2.713102    0.0     2.0     5.0     7.0   
 Temp9am        28271.0    17.004842   6.497904   -5.9    12.3    16.7    21.6   
 Temp3pm        27919.0    21.708206   6.929170   -5.1    16.6    21.2    26.5   
 
                   max  
 MinTemp          31.4  
 MaxTemp          47.0  
 Rainfall        278.4  
 Evaporation      60.8  
 Sunshine         14.1  
 WindGustSpeed   135.0  
 WindSpeed9am     87.0  
 WindSpeed3pm     83.0  
 Humidity9am     100.0  
 Humidity3pm     100.0  
 Pressure9am    1040.9  
 Pressure3pm    1039.6  
 Cloud9am          9.0  
 Cloud3pm          8.0  
 Temp9am          38.9  
 Temp3pm          46.1  )

异常数据处理

python 复制代码
# 非法值处理,训练集和测试集中均存在非法值,删除对应数据的索引 index(测试集中若出现非法值,提交答案集中的响应index 也需要删除)
# 指定要检测非法值的列
columns_to_check = ['Cloud9am', 'Cloud3pm']

# 查找指定列中含有非法值的行的索引
train_illegal_indices = train[(train[columns_to_check] > 8.0).any(axis=1)].index
test_illegal_indices = test[(test[columns_to_check] > 8.0).any(axis=1)].index
# 输出含有非法值的行索引
print("含有非法值的行索引:", train_illegal_indices)
print("含有非法值的行索引:", test_illegal_indices)
# 删除这些行
train_cleaned = train.drop(train_illegal_indices)
test_cleaned = test.drop(test_illegal_indices)
submit_cleaned = submit.drop(test_illegal_indices)
含有非法值的行索引: Index([88723], dtype='int64')
含有非法值的行索引: Index([20890, 22448], dtype='int64')
python 复制代码
test_cleaned.info(),submit_cleaned.info()
<class 'pandas.core.frame.DataFrame'>
Index: 28437 entries, 0 to 28438
Data columns (total 21 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   Date           28437 non-null  object 
 1   Location       28437 non-null  object 
 2   MinTemp        28325 non-null  float64
 3   MaxTemp        28383 non-null  float64
 4   Rainfall       28213 non-null  float64
 5   Evaporation    16386 non-null  float64
 6   Sunshine       14966 non-null  float64
 7   WindGustDir    26566 non-null  object 
 8   WindGustSpeed  26574 non-null  float64
 9   WindDir9am     26460 non-null  object 
 10  WindDir3pm     27682 non-null  object 
 11  WindSpeed9am   28172 non-null  float64
 12  WindSpeed3pm   27916 non-null  float64
 13  Humidity9am    28083 non-null  float64
 14  Humidity3pm    27740 non-null  float64
 15  Pressure9am    25680 non-null  float64
 16  Pressure3pm    25681 non-null  float64
 17  Cloud9am       17821 non-null  float64
 18  Cloud3pm       17111 non-null  float64
 19  Temp9am        28269 non-null  float64
 20  Temp3pm        27917 non-null  float64
dtypes: float64(16), object(5)
memory usage: 4.8+ MB
<class 'pandas.core.frame.DataFrame'>
Index: 28437 entries, 0 to 28438
Data columns (total 2 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0                 28437 non-null  int64  
 1   RainTomorrow  0 non-null      float64
dtypes: float64(1), int64(1)
memory usage: 666.5 KB





(None, None)
python 复制代码
(train['Cloud9am'] > 8.0).value_counts(),(train['Cloud3pm'] > 8.0).value_counts()
(Cloud9am
 False    113754
 Name: count, dtype: int64,
 Cloud3pm
 False    113753
 True          1
 Name: count, dtype: int64)
python 复制代码
(test['Cloud9am'] > 8.0).value_counts(),(test['Cloud3pm'] > 8.0).value_counts()
(Cloud9am
 False    28437
 True         2
 Name: count, dtype: int64,
 Cloud3pm
 False    28439
 Name: count, dtype: int64)
python 复制代码
# 更改索引为日期
train_cleaned.set_index('Date', inplace=True)
test_cleaned.set_index('Date', inplace=True)
python 复制代码
# 离散型数据处理缺失值
# 离散型数据则是指只能取到有限个数或者是可数个数的数据,通常以整数表示。

# 这里需要单独对训练集和测试集分别处理,因为测试集中没有 RainTomorrow 字段

cate_columns = ['RainTomorrow', 'WindDir3pm', 'WindDir9am', 'WindGustDir']

si = SimpleImputer(missing_values=np.nan,strategy="most_frequent") # 使用众数填充缺失值
train_cleaned[cate_columns] = si.fit_transform(train_cleaned[cate_columns])
python 复制代码
# 测试集
cate_columns = ['WindDir3pm', 'WindDir9am', 'WindGustDir']

si = SimpleImputer(missing_values=np.nan,strategy="most_frequent") # 使用众数填充缺失值
test_cleaned[cate_columns] = si.fit_transform(test_cleaned[cate_columns])
python 复制代码
# 连续型数据处理缺失值
# 连续型数据是指可以取到某个区间内的任意值的数据,通常以实数表示。

# cate_columns = ['RainTomorrow', 'WindDir3pm', 'WindDir9am', 'WindGustDir', 'Location']

# columns = train.columns.to_list()
# for col in cate_columns:
#     columns.remove(col)

# impmean = SimpleImputer(missing_values=np.nan,strategy = "mean")
# train[cate_columns] = impmean.fit_transform(train[cate_columns])
python 复制代码
# 测试集

# cate_columns = ['WindDir3pm', 'WindDir9am', 'WindGustDir', 'Location']

# columns = test.columns.to_list()
# for col in cate_columns:
#     columns.remove(col)
    
# impmean = SimpleImputer(missing_values=np.nan,strategy = "mean")
# test[cate_columns] = impmean.fit_transform(test[cate_columns])
python 复制代码
# 离散型数据标准化

cate_columns = ['RainTomorrow', 'WindDir3pm', 'WindDir9am', 'WindGustDir', 'Location']
lb = LabelEncoder()
for col in cate_columns:
    train_cleaned[col] = lb.fit_transform(train_cleaned[col])
python 复制代码
# 测试集

cate_columns = ['WindDir3pm', 'WindDir9am', 'WindGustDir', 'Location']
lb = LabelEncoder()
for col in cate_columns:
    test_cleaned[col] = lb.fit_transform(test_cleaned[col])

训练模型

python 复制代码
# 特征相关性

# 选择数值型列
numeric_cols = train_cleaned.select_dtypes(include=[float, int]).columns

# 计算相关性矩阵
correlation_matrix = train_cleaned[numeric_cols].corr()
# 提取与 'Price' 列相关的相关性值
price_correlation = correlation_matrix['RainTomorrow']

# 打印结果
print(price_correlation)
# print(cars.corr()['Price'])
sns.heatmap(correlation_matrix)
Location        -0.002738
MinTemp          0.083424
MaxTemp         -0.161241
Rainfall         0.240423
Evaporation     -0.118702
Sunshine        -0.450789
WindGustDir      0.054245
WindGustSpeed    0.232409
WindDir9am       0.037897
WindDir3pm       0.032786
WindSpeed9am     0.091004
WindSpeed3pm     0.087242
Humidity9am      0.257926
Humidity3pm      0.448985
Pressure9am     -0.248096
Pressure3pm     -0.227403
Cloud9am         0.319130
Cloud3pm         0.384321
Temp9am         -0.026820
Temp3pm         -0.195017
RainTomorrow     1.000000
Name: RainTomorrow, dtype: float64





<Axes: >


python 复制代码
# 切分训练集和测试集

x = train_cleaned.drop('RainTomorrow', axis=1)
y = train_cleaned['RainTomorrow']

# train_test_split()
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
x_train.head(), x_test.head(), y_train.head(), y_test.head()
(            Location  MinTemp  MaxTemp  Rainfall  Evaporation  Sunshine  \
 Date                                                                      
 2009-08-29        38     14.2     27.0       0.2          9.2       5.2   
 2009-09-01        17      3.3     14.9       0.0          NaN       NaN   
 2016-01-13        31     15.3     31.0       0.6          8.8      13.3   
 2009-09-30         6      5.6     24.1       0.0          2.6       NaN   
 2009-08-05        19      6.1     16.1       0.4          1.4       7.2   
 
             WindGustDir  WindGustSpeed  WindDir9am  WindDir3pm  WindSpeed9am  \
 Date                                                                           
 2009-08-29            7           74.0           3           4          26.0   
 2009-09-01            7           39.0           6          14          17.0   
 2016-01-13            2           43.0           2           9          17.0   
 2009-09-30            5           56.0           5           3          15.0   
 2009-08-05            3           50.0           3           6          28.0   
 
             WindSpeed3pm  Humidity9am  Humidity3pm  Pressure9am  Pressure3pm  \
 Date                                                                           
 2009-08-29          30.0         43.0         50.0       1008.0        998.1   
 2009-09-01          22.0         61.0         44.0       1014.5       1015.3   
 2016-01-13          17.0         50.0         22.0       1023.4       1019.7   
 2009-09-30          24.0         53.0         19.0       1014.6       1009.5   
 2009-08-05          19.0         80.0         51.0       1019.9       1017.2   
 
             Cloud9am  Cloud3pm  Temp9am  Temp3pm  
 Date                                              
 2009-08-29       7.0       5.0     17.9     22.6  
 2009-09-01       NaN       NaN      9.4     13.8  
 2016-01-13       6.0       1.0     20.7     30.0  
 2009-09-30       5.0       2.0     12.7     23.6  
 2009-08-05       7.0       7.0      8.9     14.9  ,
             Location  MinTemp  MaxTemp  Rainfall  Evaporation  Sunshine  \
 Date                                                                      
 2012-07-14        33      8.7     11.1       2.0          0.4       0.9   
 2013-07-28        21      6.2     21.0       0.0          2.0      10.4   
 2016-08-22         9      1.1     13.1       0.0          NaN       NaN   
 2012-09-15        29      5.5     25.6       0.0          NaN      11.0   
 2011-02-14        42     16.6     27.6       0.0          6.6      11.4   
 
             WindGustDir  WindGustSpeed  WindDir9am  WindDir3pm  WindSpeed9am  \
 Date                                                                           
 2012-07-14           12           48.0           3          12          17.0   
 2013-07-28            3           41.0           0           4          17.0   
 2016-08-22            3           28.0           8          15           7.0   
 2012-09-15            2           33.0          11           0           7.0   
 2011-02-14            0           39.0           0           1          28.0   
 
             WindSpeed3pm  Humidity9am  Humidity3pm  Pressure9am  Pressure3pm  \
 Date                                                                           
 2012-07-14          30.0         96.0         96.0       1003.7       1007.8   
 2013-07-28          22.0         71.0         36.0       1029.8       1025.8   
 2016-08-22           2.0         90.0         52.0       1017.6       1013.1   
 2012-09-15          17.0         54.0         22.0       1019.3       1016.9   
 2011-02-14          13.0         52.0         44.0       1021.8       1020.0   
 
             Cloud9am  Cloud3pm  Temp9am  Temp3pm  
 Date                                              
 2012-07-14       7.0       8.0      9.1      9.2  
 2013-07-28       1.0       3.0     13.0     20.7  
 2016-08-22       7.0       8.0      7.0     12.7  
 2012-09-15       NaN       NaN     16.0     25.4  
 2011-02-14       1.0       1.0     19.7     27.0  ,
 Date
 2009-08-29    0
 2009-09-01    0
 2016-01-13    0
 2009-09-30    0
 2009-08-05    0
 Name: RainTomorrow, dtype: int32,
 Date
 2012-07-14    1
 2013-07-28    0
 2016-08-22    1
 2012-09-15    0
 2011-02-14    0
 Name: RainTomorrow, dtype: int32)
python 复制代码
# 构建模型
model = XGBClassifier(max_depth=3, learning_rate=0.5, n_estimators=50, gamma=0.5, min_child_weight=5, random_state=42)
model.fit(x_train, y_train)
复制代码
XGBClassifier(base_score=None, booster=None, callbacks=None,
`          colsample_bylevel=None, colsample_bynode=None,
          colsample_bytree=None, device=None, early_stopping_rounds=None,
          enable_categorical=False, eval_metric=None, feature_types=None,
          gamma=0.5, grow_policy=None, importance_type=None,
          interaction_constraints=None, learning_rate=0.5, max_bin=None,
          max_cat_threshold=None, max_cat_to_onehot=None,
          max_delta_step=None, max_depth=3, max_leaves=None,
          min_child_weight=5, missing=nan, monotone_constraints=None,
          multi_strategy=None, n_estimators=50, n_jobs=None,
          num_parallel_tree=None, random_state=42, ...)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-1" type="checkbox" checked><label for="sk-estimator-id-1" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">&nbsp;XGBClassifier<span class="sk-estimator-doc-link fitted">i<span>Fitted</span></span></label><div class="sk-toggleable__content fitted"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,
          colsample_bylevel=None, colsample_bynode=None,
          colsample_bytree=None, device=None, early_stopping_rounds=None,
          enable_categorical=False, eval_metric=None, feature_types=None,
          gamma=0.5, grow_policy=None, importance_type=None,
          interaction_constraints=None, learning_rate=0.5, max_bin=None,
          max_cat_threshold=None, max_cat_to_onehot=None,
          max_delta_step=None, max_depth=3, max_leaves=None,
          min_child_weight=5, missing=nan, monotone_constraints=None,
          multi_strategy=None, n_estimators=50, n_jobs=None,
          num_parallel_tree=None, random_state=42, ...)</pre></div> </div></div></div></div>
`
python 复制代码
# 预测测试集并计算指标

y_pred = model.predict(x_test)

acc = accuracy_score(y_test, y_pred)
acc
0.8565337787350007

模型调参

python 复制代码
# 模型调参

param_grid = {
    "max_depth": [3, 5, 7, 10],
    "learning_rate": [0.01, 0.1, 0.5],
    "n_estimators": [50, 100, 200],
    "gamma": [0, 0.1, 0.2, 0.5],
    "min_child_weight": [1, 3, 5]
}

# GridSearchCV
grid_search = GridSearchCV(estimator=model, param_grid=param_grid, cv=5, scoring='accuracy', verbose=2)
grid_search.fit(x_train, y_train)
    Fitting 5 folds for each of 432 candidates, totalling 2160 fits
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=50; total time=   0.5s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=50; total time=   0.4s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=50; total time=   0.5s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=50; total time=   0.4s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=50; total time=   0.4s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=100; total time=   0.7s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=100; total time=   0.8s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=100; total time=   0.7s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=100; total time=   0.7s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=100; total time=   0.8s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=200; total time=   1.3s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=200; total time=   1.5s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=200; total time=   1.3s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=200; total time=   1.5s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=1, n_estimators=200; total time=   1.3s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=50; total time=   0.4s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=50; total time=   0.5s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=50; total time=   0.5s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=50; total time=   0.5s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=50; total time=   0.6s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=100; total time=   0.8s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=100; total time=   0.7s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=100; total time=   0.8s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=100; total time=   0.8s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=100; total time=   0.8s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=200; total time=   1.3s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=200; total time=   1.4s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=200; total time=   1.5s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=200; total time=   1.4s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=3, n_estimators=200; total time=   1.3s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=50; total time=   0.6s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=50; total time=   0.5s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=50; total time=   0.8s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=50; total time=   0.5s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=50; total time=   0.5s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=100; total time=   0.7s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=100; total time=   0.9s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=100; total time=   0.7s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=100; total time=   0.8s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=100; total time=   0.8s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=200; total time=   1.4s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=200; total time=   1.3s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=200; total time=   1.3s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=200; total time=   1.4s
    [CV] END gamma=0, learning_rate=0.01, max_depth=3, min_child_weight=5, n_estimators=200; total time=   1.3s
    [CV] END gamma=0, learning_rate=0.01, max_depth=5, min_child_weight=1, n_estimators=50; total time=   0.6s
    [CV] END gamma=0, learning_rate=0.01, max_depth=5, min_child_weight=1, n_estimators=50; total time=   0.6s
    [CV] END gamma=0, learning_rate=0.01, max_depth=5, min_child_weight=1, n_estimators=50; total time=   0.7s
    ...
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=3, min_child_weight=5, n_estimators=200; total time=   1.5s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=3, min_child_weight=5, n_estimators=200; total time=   1.4s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=50; total time=   0.6s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=50; total time=   0.7s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=50; total time=   0.6s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=50; total time=   0.8s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=50; total time=   0.7s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=100; total time=   0.9s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=100; total time=   0.9s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=100; total time=   0.9s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=100; total time=   0.9s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=100; total time=   0.9s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=200; total time=   1.6s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=200; total time=   1.9s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=200; total time=   1.5s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=200; total time=   1.6s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=1, n_estimators=200; total time=   1.6s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=3, n_estimators=50; total time=   0.6s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=3, n_estimators=50; total time=   0.7s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=3, n_estimators=50; total time=   0.7s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=3, n_estimators=50; total time=   0.6s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=3, n_estimators=50; total time=   0.6s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=3, n_estimators=100; total time=   1.0s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=3, n_estimators=100; total time=   1.0s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=3, n_estimators=100; total time=   1.1s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=3, n_estimators=100; total time=   1.0s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=3, n_estimators=100; total time=   1.0s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weight=3, n_estimators=200; total time=   1.9s
    [CV] END gamma=0.5, learning_rate=0.1, max_depth=5, min_child_weigh...

D:\Anacanda3\envs\pytorch_cuda12_0_py310\lib\site-packages\numpy\ma\core.py:2820: RuntimeWarning: invalid value encountered in cast
  _data = np.array(data, dtype=dtype, copy=copy,
复制代码
GridSearchCV(cv=5,
`         estimator=XGBClassifier(base_score=None, booster=None,
                                 callbacks=None, colsample_bylevel=None,
                                 colsample_bynode=None,
                                 colsample_bytree=None, device=None,
                                 early_stopping_rounds=None,
                                 enable_categorical=False, eval_metric=None,
                                 feature_types=None, gamma=0.5,
                                 grow_policy=None, importance_type=None,
                                 interaction_constraints=None,
                                 learning_rate=0.5, ma...
                                 max_delta_step=None, max_depth=3,
                                 max_leaves=None, min_child_weight=5,
                                 missing=nan, monotone_constraints=None,
                                 multi_strategy=None, n_estimators=50,
                                 n_jobs=None, num_parallel_tree=None,
                                 random_state=42, ...),
         param_grid={&#x27;gamma&#x27;: [0, 0.1, 0.2, 0.5],
                     &#x27;learning_rate&#x27;: [0.01, 0.1, 0.5],
                     &#x27;max_depth&#x27;: [3, 5, 7, 10],
                     &#x27;min_child_weight&#x27;: [1, 3, 5],
                     &#x27;n_estimators&#x27;: [50, 100, 200]},
         scoring=&#x27;accuracy&#x27;, verbose=2)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-2" type="checkbox" ><label for="sk-estimator-id-2" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">&nbsp;&nbsp;GridSearchCV<a class="sk-estimator-doc-link fitted" rel="noreferrer" target="_blank" href="https://scikit-learn.org/1.5/modules/generated/sklearn.model_selection.GridSearchCV.html">?<span>Documentation for GridSearchCV</span></a><span class="sk-estimator-doc-link fitted">i<span>Fitted</span></span></label><div class="sk-toggleable__content fitted"><pre>GridSearchCV(cv=5,
         estimator=XGBClassifier(base_score=None, booster=None,
                                 callbacks=None, colsample_bylevel=None,
                                 colsample_bynode=None,
                                 colsample_bytree=None, device=None,
                                 early_stopping_rounds=None,
                                 enable_categorical=False, eval_metric=None,
                                 feature_types=None, gamma=0.5,
                                 grow_policy=None, importance_type=None,
                                 interaction_constraints=None,
                                 learning_rate=0.5, ma...
                                 max_delta_step=None, max_depth=3,
                                 max_leaves=None, min_child_weight=5,
                                 missing=nan, monotone_constraints=None,
                                 multi_strategy=None, n_estimators=50,
                                 n_jobs=None, num_parallel_tree=None,
                                 random_state=42, ...),
         param_grid={&#x27;gamma&#x27;: [0, 0.1, 0.2, 0.5],
                     &#x27;learning_rate&#x27;: [0.01, 0.1, 0.5],
                     &#x27;max_depth&#x27;: [3, 5, 7, 10],
                     &#x27;min_child_weight&#x27;: [1, 3, 5],
                     &#x27;n_estimators&#x27;: [50, 100, 200]},
         scoring=&#x27;accuracy&#x27;, verbose=2)</pre></div> </div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-3" type="checkbox" ><label for="sk-estimator-id-3" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">best_estimator_: XGBClassifier</label><div class="sk-toggleable__content fitted"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,
          colsample_bylevel=None, colsample_bynode=None,
          colsample_bytree=None, device=None, early_stopping_rounds=None,
          enable_categorical=False, eval_metric=None, feature_types=None,
          gamma=0.2, grow_policy=None, importance_type=None,
          interaction_constraints=None, learning_rate=0.1, max_bin=None,
          max_cat_threshold=None, max_cat_to_onehot=None,
          max_delta_step=None, max_depth=7, max_leaves=None,
          min_child_weight=3, missing=nan, monotone_constraints=None,
          multi_strategy=None, n_estimators=200, n_jobs=None,
          num_parallel_tree=None, random_state=42, ...)</pre></div> </div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-4" type="checkbox" ><label for="sk-estimator-id-4" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">XGBClassifier</label><div class="sk-toggleable__content fitted"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,
          colsample_bylevel=None, colsample_bynode=None,
          colsample_bytree=None, device=None, early_stopping_rounds=None,
          enable_categorical=False, eval_metric=None, feature_types=None,
          gamma=0.2, grow_policy=None, importance_type=None,
          interaction_constraints=None, learning_rate=0.1, max_bin=None,
          max_cat_threshold=None, max_cat_to_onehot=None,
          max_delta_step=None, max_depth=7, max_leaves=None,
          min_child_weight=3, missing=nan, monotone_constraints=None,
          multi_strategy=None, n_estimators=200, n_jobs=None,
          num_parallel_tree=None, random_state=42, ...)</pre></div> </div></div></div></div></div></div></div></div></div>
`
python 复制代码
# 最佳参数和最高准确率
grid_search.best_params_, grid_search.best_score_
({'gamma': 0.2,
  'learning_rate': 0.1,
  'max_depth': 7,
  'min_child_weight': 3,
  'n_estimators': 200},
 0.8611788496103643)
python 复制代码
# 最佳模型
best_model = XGBClassifier(max_depth=7, learning_rate=0.1, n_estimators=200, gamma=0.2, min_child_weight=3, random_state=42)
best_model.fit(x_train, y_train)
复制代码
XGBClassifier(base_score=None, booster=None, callbacks=None,
`          colsample_bylevel=None, colsample_bynode=None,
          colsample_bytree=None, device=None, early_stopping_rounds=None,
          enable_categorical=False, eval_metric=None, feature_types=None,
          gamma=0.2, grow_policy=None, importance_type=None,
          interaction_constraints=None, learning_rate=0.1, max_bin=None,
          max_cat_threshold=None, max_cat_to_onehot=None,
          max_delta_step=None, max_depth=7, max_leaves=None,
          min_child_weight=3, missing=nan, monotone_constraints=None,
          multi_strategy=None, n_estimators=200, n_jobs=None,
          num_parallel_tree=None, random_state=42, ...)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-6" type="checkbox" checked><label for="sk-estimator-id-6" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">&nbsp;XGBClassifier<span class="sk-estimator-doc-link fitted">i<span>Fitted</span></span></label><div class="sk-toggleable__content fitted"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,
          colsample_bylevel=None, colsample_bynode=None,
          colsample_bytree=None, device=None, early_stopping_rounds=None,
          enable_categorical=False, eval_metric=None, feature_types=None,
          gamma=0.2, grow_policy=None, importance_type=None,
          interaction_constraints=None, learning_rate=0.1, max_bin=None,
          max_cat_threshold=None, max_cat_to_onehot=None,
          max_delta_step=None, max_depth=7, max_leaves=None,
          min_child_weight=3, missing=nan, monotone_constraints=None,
          multi_strategy=None, n_estimators=200, n_jobs=None,
          num_parallel_tree=None, random_state=42, ...)</pre></div> </div></div></div></div>
`
python 复制代码
test_cleaned.info()
<class 'pandas.core.frame.DataFrame'>
Index: 28437 entries, 2016-06-09 to 2013-07-03
Data columns (total 20 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   Location       28437 non-null  int32  
 1   MinTemp        28325 non-null  float64
 2   MaxTemp        28383 non-null  float64
 3   Rainfall       28213 non-null  float64
 4   Evaporation    16386 non-null  float64
 5   Sunshine       14966 non-null  float64
 6   WindGustDir    28437 non-null  int32  
 7   WindGustSpeed  26574 non-null  float64
 8   WindDir9am     28437 non-null  int32  
 9   WindDir3pm     28437 non-null  int32  
 10  WindSpeed9am   28172 non-null  float64
 11  WindSpeed3pm   27916 non-null  float64
 12  Humidity9am    28083 non-null  float64
 13  Humidity3pm    27740 non-null  float64
 14  Pressure9am    25680 non-null  float64
 15  Pressure3pm    25681 non-null  float64
 16  Cloud9am       17821 non-null  float64
 17  Cloud3pm       17111 non-null  float64
 18  Temp9am        28269 non-null  float64
 19  Temp3pm        27917 non-null  float64
dtypes: float64(16), int32(4)
memory usage: 5.1+ MB
python 复制代码
y_pred2 = best_model.predict(test_cleaned)
y_pred2
array([1, 0, 0, ..., 0, 0, 0])
python 复制代码
submit_cleaned['RainTomorrow'] = y_pred2
submit_cleaned.head()

| | | RainTomorrow |
| 0 | 0 | 1 |
| 1 | 1 | 0 |
| 2 | 2 | 0 |
| 3 | 3 | 0 |

4 4 1
python 复制代码
len(y_pred2),len(submit_cleaned)
(28437, 28437)

提交结果

python 复制代码
submit_cleaned.to_csv('./data/第六关结果提交.csv',index = False)
python 复制代码
si = SimpleImputer(missing_values=np.nan, strategy="most_frequent")  # 使用众数填充缺失值  

这段代码的目的是处理数据集中某些类别列的缺失值,通过填充最频繁出现的值(众数)来替换缺失值。以下是对这段代码的详细解析:

重要代码解析

python 复制代码
# 导入必要的库  
import numpy as np  
from sklearn.impute import SimpleImputer  

**创建 `SimpleImputer` 实例**:  
   ```python  
   si = SimpleImputer(missing_values=np.nan, strategy="most_frequent")  # 使用众数填充缺失值  
  • SimpleImputer 是一个用于填充缺失值的类。
  • missing_values=np.nan 指定了要处理的缺失值类型为 NaN
  • strategy="most_frequent" 指定了填充策略为使用最频繁出现的值(众数)。这对于类别变量(categorical variables)特别有用,因为众数是该类别中最常见的值。
  • si 是创建的 SimpleImputer 实例,将用于填充指定列中的缺失值。

使用示例

假设我们有一个 DataFrame df,包含上述类别列,并且某些列包含缺失值。我们可以使用 SimpleImputer 来填充这些缺失值。

python 复制代码
import pandas as pd  

# 创建示例 DataFrame  
data = {  
    'RainTomorrow': ['Yes', 'No', np.nan, 'Yes', 'No'],  
    'WindDir3pm': ['N', 'S', 'E', np.nan, 'W'],  
    'WindDir9am': [np.nan, 'N', 'E', 'W', 'N'],  
    'WindGustDir': ['W', np.nan, 'N', 'E', 'S']  
}  

df = pd.DataFrame(data)  

# 输出填充前的 DataFrame  
print("填充前的 DataFrame:")  
print(df)  

# 使用 SimpleImputer 填充缺失值  
si = SimpleImputer(missing_values=np.nan, strategy="most_frequent")  
df[cate_columns] = si.fit_transform(df[cate_columns])  

# 输出填充后的 DataFrame  
print("\n填充后的 DataFrame:")  
print(df)  

结果

填充前的 DataFrame:  
  RainTomorrow WindDir3pm WindDir9am WindGustDir  
0          Yes          N        NaN           W  
1           No          S          N         NaN  
2          NaN          E          E           N  
3          Yes        NaN          W           E  
4           No          W          N           S  

填充后的 DataFrame:  
  RainTomorrow WindDir3pm WindDir9am WindGustDir  
0          Yes          N          N           W  
1           No          S          N           N  
2           No          E          E           N  
3          Yes          N          W           E  
4           No          W          N           S  

在这个示例中,SimpleImputer 使用每列中最常见的值填充缺失值。这样处理后,所有缺失值都被替换为该列的众数。

python 复制代码
相关推荐
冠位观测者几秒前
【Leetcode 每日一题】2545. 根据第 K 场考试的分数排序
数据结构·算法·leetcode
GocNeverGiveUp7 分钟前
机器学习2-NumPy
人工智能·机器学习·numpy
古希腊掌管学习的神1 小时前
[搜广推]王树森推荐系统笔记——曝光过滤 & Bloom Filter
算法·推荐算法
qystca1 小时前
洛谷 P1706 全排列问题 C语言
算法
浊酒南街1 小时前
决策树(理论知识1)
算法·决策树·机器学习
就爱学编程1 小时前
重生之我在异世界学编程之C语言小项目:通讯录
c语言·开发语言·数据结构·算法
B站计算机毕业设计超人1 小时前
计算机毕业设计PySpark+Hadoop中国城市交通分析与预测 Python交通预测 Python交通可视化 客流量预测 交通大数据 机器学习 深度学习
大数据·人工智能·爬虫·python·机器学习·课程设计·数据可视化
学术头条1 小时前
清华、智谱团队:探索 RLHF 的 scaling laws
人工智能·深度学习·算法·机器学习·语言模型·计算语言学
18号房客1 小时前
一个简单的机器学习实战例程,使用Scikit-Learn库来完成一个常见的分类任务——**鸢尾花数据集(Iris Dataset)**的分类
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理·sklearn
feifeikon1 小时前
机器学习DAY3 : 线性回归与最小二乘法与sklearn实现 (线性回归完)
人工智能·机器学习·线性回归