EDA数据分析结合深度学习---基于EDA数据分析和MLP模型的天气预测(tensorflow实现)

前言

  • 这一篇文章数据量很大,EDA是我一直很想学的一个数据分析思想,这个思想主要思路是:数据缺失值、异常值、数据特征分析;
  • 正好这一篇可以正式入门,注意:对于各种类型图的特点还需要掌握 ,数据分析很需要根据不同的数据类型选取用不同图的展示,从而进行特征分析;
  • 在模型上,本文采用MLP模型进行预测,后期会加上机器学习、时间序列(ARIMA、SARIMA )模型,敬请期待;
  • 欢迎收藏 + 关注,本人将会持续更新。

文章目录

1、导入数据

1、导入库

python 复制代码
import pandas as pd 
import numpy as np  
import matplotlib.pyplot as plt  
import seaborn as sns 
import warnings
warnings.filterwarnings('ignore')

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler  
import tensorflow as tf 
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Dropout

2、导入数据

python 复制代码
data_df = pd.read_csv('./weatherAUS.csv')
data = data_df.copy()
data.head()

| | Date | Location | MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustDir | WindGustSpeed | WindDir9am | ... | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | RainToday | RainTomorrow |
| 0 | 2008-12-01 | Albury | 13.4 | 22.9 | 0.6 | NaN | NaN | W | 44.0 | W | ... | 71.0 | 22.0 | 1007.7 | 1007.1 | 8.0 | NaN | 16.9 | 21.8 | No | No |
| 1 | 2008-12-02 | Albury | 7.4 | 25.1 | 0.0 | NaN | NaN | WNW | 44.0 | NNW | ... | 44.0 | 25.0 | 1010.6 | 1007.8 | NaN | NaN | 17.2 | 24.3 | No | No |
| 2 | 2008-12-03 | Albury | 12.9 | 25.7 | 0.0 | NaN | NaN | WSW | 46.0 | W | ... | 38.0 | 30.0 | 1007.6 | 1008.7 | NaN | 2.0 | 21.0 | 23.2 | No | No |
| 3 | 2008-12-04 | Albury | 9.2 | 28.0 | 0.0 | NaN | NaN | NE | 24.0 | SE | ... | 45.0 | 16.0 | 1017.6 | 1012.8 | NaN | NaN | 18.1 | 26.5 | No | No |

4 2008-12-05 Albury 17.5 32.3 1.0 NaN NaN W 41.0 ENE ... 82.0 33.0 1010.8 1006.0 7.0 8.0 17.8 29.7 No No

5 rows × 23 columns

python 复制代码
data.columns
Index(['Date', 'Location', 'MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation',
       'Sunshine', 'WindGustDir', 'WindGustSpeed', 'WindDir9am', 'WindDir3pm',
       'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm',
       'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', 'Temp9am',
       'Temp3pm', 'RainToday', 'RainTomorrow'],
      dtype='object')
  1. Date 日期:记录天气数据的具体日期。
  2. Location 地点:数据收集的地理位置。
  3. MinTemp 最低温度:当天记录到的最低气温(单位:摄氏度)。
  4. MaxTemp 最高温度:当天记录到的最高气温(单位:摄氏度)。
  5. Rainfall 降雨量:当天的总降水量(单位:毫米)。
  6. Evaporation 蒸发量:当天从开放水面蒸发的水量(单位:毫米)。
  7. Sunshine 日照时长:当天直射太阳的累计时间(单位:小时)。
  8. WindGustDir 阵风方向:当天记录的最大阵风的方向(如北、东北等)。
  9. WindGustSpeed 阵风速度:当天记录的最大阵风的速度(单位:公里/小时)。
  10. WindDir9am 上午9点风向:上午9点时的风向。
  11. WindDir3pm 下午3点风向:下午3点时的风向。
  12. WindSpeed9am 上午9点风速:上午9点时的平均风速(单位:公里/小时)。
  13. WindSpeed3pm 下午3点风速:下午3点时的平均风速(单位:公里/小时)。
  14. Humidity9am 上午9点湿度:上午9点时的相对湿度(百分比)。
  15. Humidity3pm 下午3点湿度:下午3点时的相对湿度(百分比)。
  16. Pressure9am 上午9点气压:上午9点时的大气压力(单位:百帕)。
  17. Pressure3pm 下午3点气压:下午3点时的大气压力(单位:百帕)。
  18. Cloud9am 上午9点云量:上午9点时天空中云覆盖的比例(通常用Okta表示,0为晴天,8为完全多云)。
  19. Cloud3pm 下午3点云量:下午3点时天空中云覆盖的比例。
  20. Temp9am 上午9点温度:上午9点时的气温(单位:摄氏度)。
  21. Temp3pm 下午3点温度:下午3点时的气温(单位:摄氏度)。
  22. RainToday 今日是否下雨:当天是否有降雨,通常是二元变量(是/否或1/0)。
  23. RainTomorrow 明日是否下雨:预测第二天是否有降雨,也是二元变量。

3、数据初步处理

python 复制代码
# 查看数据量
data['Date'].count()
145460
python 复制代码
# 数据信息展示
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 145460 entries, 0 to 145459
Data columns (total 23 columns):
 #   Column         Non-Null Count   Dtype  
---  ------         --------------   -----  
 0   Date           145460 non-null  object 
 1   Location       145460 non-null  object 
 2   MinTemp        143975 non-null  float64
 3   MaxTemp        144199 non-null  float64
 4   Rainfall       142199 non-null  float64
 5   Evaporation    82670 non-null   float64
 6   Sunshine       75625 non-null   float64
 7   WindGustDir    135134 non-null  object 
 8   WindGustSpeed  135197 non-null  float64
 9   WindDir9am     134894 non-null  object 
 10  WindDir3pm     141232 non-null  object 
 11  WindSpeed9am   143693 non-null  float64
 12  WindSpeed3pm   142398 non-null  float64
 13  Humidity9am    142806 non-null  float64
 14  Humidity3pm    140953 non-null  float64
 15  Pressure9am    130395 non-null  float64
 16  Pressure3pm    130432 non-null  float64
 17  Cloud9am       89572 non-null   float64
 18  Cloud3pm       86102 non-null   float64
 19  Temp9am        143693 non-null  float64
 20  Temp3pm        141851 non-null  float64
 21  RainToday      142199 non-null  object 
 22  RainTomorrow   142193 non-null  object 
dtypes: float64(16), object(7)
memory usage: 25.5+ MB
python 复制代码
# 日期数据转换
data['Date'] = pd.to_datetime(data['Date'])

# 将日期转化为year、month、day,这样可以将日期类型转化为 数值类型,也可以探索 与年份、月份的关系
data['Year'] = data['Date'].dt.year
data['Month'] = data['Date'].dt.month
data['Day'] = data['Date'].dt.day

# 删除时间
data.drop('Date', axis=1, inplace=True)
python 复制代码
data.columns
Index(['Location', 'MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine',
       'WindGustDir', 'WindGustSpeed', 'WindDir9am', 'WindDir3pm',
       'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm',
       'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', 'Temp9am',
       'Temp3pm', 'RainToday', 'RainTomorrow', 'Year', 'Month', 'Day'],
      dtype='object')

2、数据EDA分析

1、相关性分析

python 复制代码
# 选择数据列
columns_data = data[['MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine', 'WindGustSpeed', 'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm',
       'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', 'Temp9am', 'Temp3pm', 'RainToday', 'RainTomorrow', 'Year', 'Month', 'Day']]

# 将最后两列 RainToday, RainTomorrow 是数值, 这里NO-> 0, Yew-> 1
columns_data['RainToday'] = columns_data['RainToday'].map({'No': 0, 'Yes': 1})
columns_data['RainTomorrow'] = columns_data['RainTomorrow'].map({'No': 0, 'Yes': 1})

plt.figure(figsize=(15, 12))
ax = sns.heatmap(columns_data.corr(), annot=True, fmt='.2f')
plt.show()

初步观察,今天和明天是否会下雨与:风速、日照时长关系较大

2、是否会下雨

分别统计今天、明天下雨天数, 用统计图表示

python 复制代码
# 设置颜色
sns.set(style="whitegrid", palette='Set2')

# 创建画板
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# 图标样式
title_font = {'fontsize': 14, 'fontweight': 'bold', 'color': 'darkblue'}

# RainTomorrow
sns.countplot(x='RainTomorrow', data=data, ax=axes[0], edgecolor='black')
axes[0].set_title("Rain Tomorrow", fontdict=title_font)
axes[0].set_xlabel('Will it Rain Tomorrow?', fontsize=12)
axes[0].set_ylabel('Count', fontsize=12)
axes[0].tick_params(axis='x', labelsize=11)
axes[0].tick_params(axis='y', labelsize=11)

sns.countplot(x='RainToday', data=data, ax=axes[1], edgecolor='black')
axes[1].set_title("Rain Today", fontdict=title_font)
axes[1].set_xlabel('Will it Rain Today?', fontsize=12)
axes[1].set_ylabel('Count', fontsize=12)
axes[1].tick_params(axis='x', labelsize=11)
axes[1].tick_params(axis='y', labelsize=11)

plt.show()


这两张图来看,结合相关性分析,可以猜测:

  • 如果今天下雨,明天也可能下雨
  • 人工今天不下雨,明天也可能不下雨

3、地理位置是否与下雨有关

这个也可以用统计图进行分析

python 复制代码
data['Location'].value_counts()
Location
Canberra            3436
Sydney              3344
Darwin              3193
Melbourne           3193
Brisbane            3193
Adelaide            3193
Perth               3193
Hobart              3193
Albany              3040
MountGambier        3040
Ballarat            3040
Townsville          3040
GoldCoast           3040
Cairns              3040
Launceston          3040
AliceSprings        3040
Bendigo             3040
Albury              3040
MountGinini         3040
Wollongong          3040
Newcastle           3039
Tuggeranong         3039
Penrith             3039
Woomera             3009
Nuriootpa           3009
Cobar               3009
CoffsHarbour        3009
Moree               3009
Sale                3009
PerthAirport        3009
PearceRAAF          3009
Witchcliffe         3009
BadgerysCreek       3009
Mildura             3009
NorfolkIsland       3009
MelbourneAirport    3009
Richmond            3009
SydneyAirport       3009
WaggaWagga          3009
Williamtown         3009
Dartmoor            3009
Watsonia            3009
Portland            3009
Walpole             3006
NorahHead           3004
SalmonGums          3001
Katherine           1578
Nhil                1578
Uluru               1578
Name: count, dtype: int64

地点数有点多这里采用百分比条形图来展示

python 复制代码
pd.crosstab(data['Location'], data['RainToday'])

| RainToday | No | Yes |
| Location | | |
| Adelaide | 2402 | 689 |
| Albany | 2114 | 902 |
| Albury | 2394 | 617 |
| AliceSprings | 2788 | 244 |
| BadgerysCreek | 2345 | 583 |
| Ballarat | 2247 | 781 |
| Bendigo | 2472 | 562 |
| Brisbane | 2452 | 709 |
| Cairns | 2038 | 950 |
| Canberra | 2789 | 629 |
| Cobar | 2602 | 386 |
| CoffsHarbour | 2084 | 869 |
| Dartmoor | 2021 | 921 |
| Darwin | 2341 | 852 |
| GoldCoast | 2205 | 775 |
| Hobart | 2426 | 762 |
| Katherine | 1295 | 265 |
| Launceston | 2328 | 700 |
| Melbourne | 1799 | 636 |
| MelbourneAirport | 2356 | 653 |
| Mildura | 2680 | 327 |
| Moree | 2460 | 394 |
| MountGambier | 2110 | 921 |
| MountGinini | 2088 | 819 |
| Newcastle | 2224 | 731 |
| Nhil | 1327 | 242 |
| NorahHead | 2121 | 808 |
| NorfolkIsland | 2045 | 919 |
| Nuriootpa | 2411 | 592 |
| PearceRAAF | 2257 | 505 |
| Penrith | 2369 | 595 |
| Perth | 2548 | 645 |
| PerthAirport | 2442 | 567 |
| Portland | 1902 | 1094 |
| Richmond | 2391 | 560 |
| Sale | 2357 | 643 |
| SalmonGums | 2483 | 472 |
| Sydney | 2471 | 866 |
| SydneyAirport | 2231 | 774 |
| Townsville | 2513 | 520 |
| Tuggeranong | 2430 | 568 |
| Uluru | 1406 | 116 |
| WaggaWagga | 2440 | 536 |
| Walpole | 1870 | 949 |
| Watsonia | 2261 | 738 |
| Williamtown | 1853 | 700 |
| Witchcliffe | 2073 | 879 |
| Wollongong | 2269 | 713 |

Woomera 2789 202
python 复制代码
# 获取数据,统计yes、no数据量,生成连列表
x = pd.crosstab(data['Location'], data['RainToday'])
# 获取百分比, transpose转置,values转化为numpy
y = x / x.transpose().sum().values.reshape((-1, 1)) * 100

# 按照百分比不同排序
y = y.sort_values(by='Yes', ascending=True)

color=['#cc6699', '#006699', '#006666', '#862d86', '#ff9966']
y.Yes.plot(kind="barh",figsize=(15, 20),color=color)


这个结合地图,会更好分析,但是也说明了位置对下雨有影响

4、温度和压力对下雨的影响

这个用散点图最合适

python 复制代码
data.columns
Index(['Location', 'MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine',
       'WindGustDir', 'WindGustSpeed', 'WindDir9am', 'WindDir3pm',
       'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm',
       'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', 'Temp9am',
       'Temp3pm', 'RainToday', 'RainTomorrow', 'Year', 'Month', 'Day'],
      dtype='object')
python 复制代码
# 压力
plt.figure(figsize=(8, 6))
sns.scatterplot(data=data, x='Pressure9am', y='Pressure3pm', hue='RainTomorrow')
plt.show()


python 复制代码
# 温度
plt.figure(figsize=(8, 6))
sns.scatterplot(data=data, x='Humidity9am', y='Humidity3pm', hue='RainTomorrow')
plt.show()


由上图可知,温度和压力对下雨有影响,尤其是下午3点的温度

5、气温对下雨的影响

python 复制代码
plt.figure(figsize=(8, 6))
sns.scatterplot(data=data, x='MaxTemp', y='MinTemp', hue='RainTomorrow')
plt.show()


这里可知,气温对下雨有影响,尤其当最大温度和最低温度接近的时候,影响较大

3、数据处理

1、数据处理

python 复制代码
data.isnull().sum()
Location             0
MinTemp           1485
MaxTemp           1261
Rainfall          3261
Evaporation      62790
Sunshine         69835
WindGustDir      10326
WindGustSpeed    10263
WindDir9am       10566
WindDir3pm        4228
WindSpeed9am      1767
WindSpeed3pm      3062
Humidity9am       2654
Humidity3pm       4507
Pressure9am      15065
Pressure3pm      15028
Cloud9am         55888
Cloud3pm         59358
Temp9am           1767
Temp3pm           3609
RainToday         3261
RainTomorrow      3267
Year                 0
Month                0
Day                  0
dtype: int64
python 复制代码
# 统计占比
data.isnull().sum() / data.shape[0] * 100
Location          0.000000
MinTemp           1.020899
MaxTemp           0.866905
Rainfall          2.241853
Evaporation      43.166506
Sunshine         48.009762
WindGustDir       7.098859
WindGustSpeed     7.055548
WindDir9am        7.263853
WindDir3pm        2.906641
WindSpeed9am      1.214767
WindSpeed3pm      2.105046
Humidity9am       1.824557
Humidity3pm       3.098446
Pressure9am      10.356799
Pressure3pm      10.331363
Cloud9am         38.421559
Cloud3pm         40.807095
Temp9am           1.214767
Temp3pm           2.481094
RainToday         2.241853
RainTomorrow      2.245978
Year              0.000000
Month             0.000000
Day               0.000000
dtype: float64
python 复制代码
data.columns
Index(['Location', 'MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine',
       'WindGustDir', 'WindGustSpeed', 'WindDir9am', 'WindDir3pm',
       'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm',
       'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', 'Temp9am',
       'Temp3pm', 'RainToday', 'RainTomorrow', 'Year', 'Month', 'Day'],
      dtype='object')
python 复制代码
# 对于数值型数据,用** 随机选取数填充 **
lis = ['Evaporation', 'Sunshine', 'Cloud9am', 'Cloud3pm']
for col in lis:
    fill_list = data[col].dropna()  # 删除缺失值
    data[col] = data[col].fillna(pd.Series(np.random.choice(fill_list, size=len(data.index))))
python 复制代码
# 查看对象型数据
s = (data.dtypes == 'object')
object_list = list(s[s].index)
object_list
['Location',
 'WindGustDir',
 'WindDir9am',
 'WindDir3pm',
 'RainToday',
 'RainTomorrow']
python 复制代码
# 填充:频次最高
for col in object_list:
    fill_list = data[col].dropna()  # 删除缺失值
    data[col].fillna(data[col].mode()[0], inplace=True)
python 复制代码
# 其他的用中位数
t = (data.dtypes == 'float64')
num_cols = list(t[t].index)
for i in num_cols:
    data[i].fillna(data[i].median(), inplace=True)
python 复制代码
# 查看缺失值
data.isnull().sum()
Location         0
MinTemp          0
MaxTemp          0
Rainfall         0
Evaporation      0
Sunshine         0
WindGustDir      0
WindGustSpeed    0
WindDir9am       0
WindDir3pm       0
WindSpeed9am     0
WindSpeed3pm     0
Humidity9am      0
Humidity3pm      0
Pressure9am      0
Pressure3pm      0
Cloud9am         0
Cloud3pm         0
Temp9am          0
Temp3pm          0
RainToday        0
RainTomorrow     0
Year             0
Month            0
Day              0
dtype: int64
python 复制代码
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 145460 entries, 0 to 145459
Data columns (total 25 columns):
 #   Column         Non-Null Count   Dtype  
---  ------         --------------   -----  
 0   Location       145460 non-null  object 
 1   MinTemp        145460 non-null  float64
 2   MaxTemp        145460 non-null  float64
 3   Rainfall       145460 non-null  float64
 4   Evaporation    145460 non-null  float64
 5   Sunshine       145460 non-null  float64
 6   WindGustDir    145460 non-null  object 
 7   WindGustSpeed  145460 non-null  float64
 8   WindDir9am     145460 non-null  object 
 9   WindDir3pm     145460 non-null  object 
 10  WindSpeed9am   145460 non-null  float64
 11  WindSpeed3pm   145460 non-null  float64
 12  Humidity9am    145460 non-null  float64
 13  Humidity3pm    145460 non-null  float64
 14  Pressure9am    145460 non-null  float64
 15  Pressure3pm    145460 non-null  float64
 16  Cloud9am       145460 non-null  float64
 17  Cloud3pm       145460 non-null  float64
 18  Temp9am        145460 non-null  float64
 19  Temp3pm        145460 non-null  float64
 20  RainToday      145460 non-null  object 
 21  RainTomorrow   145460 non-null  object 
 22  Year           145460 non-null  int32  
 23  Month          145460 non-null  int32  
 24  Day            145460 non-null  int32  
dtypes: float64(16), int32(3), object(6)
memory usage: 26.1+ MB

2、数据划分与数据标准化

python 复制代码
# 对object进行标签编码
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
for i in object_list:
    data[i] = label_encoder.fit_transform(data[i])
python 复制代码
# 划分数据
X = data.drop(['RainTomorrow', 'Day'], axis=1).values
y = data['RainTomorrow'].values 

# 划分 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 数据标准化
scaler = MinMaxScaler()
scaler.fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.fit_transform(X_test)

4、MLP模型构建

python 复制代码
model = Sequential()
model.add(Dense(units=24, activation='tanh'))
model.add(Dense(units=18, activation='tanh'))
model.add(Dense(units=23, activation='tanh'))
model.add(Dropout(0.5))
model.add(Dense(units=12, activation='tanh'))
model.add(Dropout(0.2))
model.add(Dense(units=1, activation='sigmoid'))

5、模型训练

1、超参数设置

python 复制代码
from tensorflow.keras.optimizers import Adam 
# 超参数设置
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

model.compile(loss='binary_crossentropy',
              optimizer=optimizer,
              metrics='accuracy')

2、模型训练

python 复制代码
epochs = 10

history = model.fit(x=X_train,
          y=y_train,
          validation_data=(X_test, y_test), 
          verbose=1,
          epochs=epochs,
          batch_size=32)
Epoch 1/10
3637/3637 [==============================] - 7s 2ms/step - loss: 0.4461 - accuracy: 0.8077 - val_loss: 0.3875 - val_accuracy: 0.8313
Epoch 2/10
3637/3637 [==============================] - 7s 2ms/step - loss: 0.3915 - accuracy: 0.8336 - val_loss: 0.3776 - val_accuracy: 0.8377
Epoch 3/10
3637/3637 [==============================] - 10s 3ms/step - loss: 0.3842 - accuracy: 0.8375 - val_loss: 0.3753 - val_accuracy: 0.8395
Epoch 4/10
3637/3637 [==============================] - 9s 2ms/step - loss: 0.3823 - accuracy: 0.8379 - val_loss: 0.3750 - val_accuracy: 0.8405
Epoch 5/10
3637/3637 [==============================] - 9s 2ms/step - loss: 0.3804 - accuracy: 0.8383 - val_loss: 0.3730 - val_accuracy: 0.8413
Epoch 6/10
3637/3637 [==============================] - 9s 2ms/step - loss: 0.3792 - accuracy: 0.8388 - val_loss: 0.3802 - val_accuracy: 0.8379
Epoch 7/10
3637/3637 [==============================] - 6s 2ms/step - loss: 0.3784 - accuracy: 0.8384 - val_loss: 0.3792 - val_accuracy: 0.8381
Epoch 8/10
3637/3637 [==============================] - 7s 2ms/step - loss: 0.3776 - accuracy: 0.8392 - val_loss: 0.3848 - val_accuracy: 0.8353
Epoch 9/10
3637/3637 [==============================] - 11s 3ms/step - loss: 0.3772 - accuracy: 0.8391 - val_loss: 0.3759 - val_accuracy: 0.8410
Epoch 10/10
3637/3637 [==============================] - 8s 2ms/step - loss: 0.3760 - accuracy: 0.8398 - val_loss: 0.3811 - val_accuracy: 0.8374

6、结果展示

python 复制代码
# 获取训练集和验证集损失率和准确率
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()


相关推荐
IT古董1 小时前
【机器学习】主动学习-增加标签的操作方法-流式选择性采样(Stream-based selective sampling)
人工智能·学习·机器学习
KeyPan1 小时前
【机器学习:十九、反向传播】
人工智能·深度学习·机器学习
m0_743106463 小时前
【论文笔记】多个大规模数据集上的SOTA绝对位姿回归方法:Reloc3r
论文阅读·深度学习·计算机视觉·3d·几何学
埃菲尔铁塔_CV算法3 小时前
双线性插值算法:原理、实现、优化及在图像处理和多领域中的广泛应用与发展趋势(二)
c++·人工智能·算法·机器学习·计算机视觉
hnmpf4 小时前
flask_sqlalchemy relationship 子表排序
后端·python·flask
疯狂学习GIS4 小时前
互联网大中小厂实习面经:滴滴、美团、货拉拉、蔚来、信通院等
c++·python
Nobita Chen4 小时前
Python实现windows自动关机
开发语言·windows·python
码路刺客4 小时前
一学就废|Python基础碎片,OS模块
开发语言·python
z千鑫4 小时前
【Python】Python之Selenium基础教程+实战demo:提升你的测试+测试数据构造的效率!
开发语言·python·selenium
HyperAI超神经5 小时前
微软与腾讯技术交锋,TRELLIS引领3D生成领域多格式支持新方向
人工智能·深度学习·机器学习·计算机视觉·3d·大模型·数据集