- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
RNN心脏病识别
python
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
tf.config.set_visible_devices([gpu0],"GPU")
gpus
[]
导入数据
python
import pandas as pd
import numpy as np
df = pd.read_csv(r"C:\Users\11054\Desktop\kLearning\R123\heart.csv")
df
| | age | sex | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | target |
| 0 | 63 | 1 | 3 | 145 | 233 | 1 | 0 | 150 | 0 | 2.3 | 0 | 0 | 1 | 1 |
| 1 | 37 | 1 | 2 | 130 | 250 | 0 | 1 | 187 | 0 | 3.5 | 0 | 0 | 2 | 1 |
| 2 | 41 | 0 | 1 | 130 | 204 | 0 | 0 | 172 | 0 | 1.4 | 2 | 0 | 2 | 1 |
| 3 | 56 | 1 | 1 | 120 | 236 | 0 | 1 | 178 | 0 | 0.8 | 2 | 0 | 2 | 1 |
| 4 | 57 | 0 | 0 | 120 | 354 | 0 | 1 | 163 | 1 | 0.6 | 2 | 0 | 2 | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 298 | 57 | 0 | 0 | 140 | 241 | 0 | 1 | 123 | 1 | 0.2 | 1 | 0 | 3 | 0 |
| 299 | 45 | 1 | 3 | 110 | 264 | 0 | 1 | 132 | 0 | 1.2 | 1 | 0 | 3 | 0 |
| 300 | 68 | 1 | 0 | 144 | 193 | 1 | 1 | 141 | 0 | 3.4 | 1 | 2 | 3 | 0 |
| 301 | 57 | 1 | 0 | 130 | 131 | 0 | 1 | 115 | 1 | 1.2 | 1 | 1 | 3 | 0 |
302 | 57 | 0 | 1 | 130 | 236 | 0 | 0 | 174 | 0 | 0.0 | 1 | 1 | 2 | 0 |
---|
303 rows × 14 columns
python
# 检查是否有空值
df.isnull().sum()
age 0
sex 0
cp 0
trestbps 0
chol 0
fbs 0
restecg 0
thalach 0
exang 0
oldpeak 0
slope 0
ca 0
thal 0
target 0
dtype: int64
数据预处理
python
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
X = df.iloc[:,:-1]
y = df.iloc[:,-1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1, random_state = 1)
python
X_train.shape, y_train.shape
((272, 13), (272,))
标准化
python
# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)
python
import tensorflow
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNN
model = Sequential()
model.add(SimpleRNN(200, input_shape= (13,1), activation='relu'))
model.add(Dense(100, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()
C:\Users\11054\.conda\envs\tf39\lib\site-packages\keras\src\layers\rnn\rnn.py:204: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
super().__init__(**kwargs)
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ simple_rnn_2 (SimpleRNN) │ (None, 200) │ 40,400 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense_4 (Dense) │ (None, 100) │ 20,100 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense_5 (Dense) │ (None, 1) │ 101 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 60,601 (236.72 KB)
Trainable params: 60,601 (236.72 KB)
Non-trainable params: 0 (0.00 B)
模型训练
python
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(loss='binary_crossentropy',
optimizer=opt,metrics=['accuracy'])
epochs = 100
history = model.fit(X_train, y_train,
epochs=epochs,
batch_size=128,
validation_data=(X_test, y_test),
verbose=1)
Epoch 1/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 93ms/step - accuracy: 0.5453 - loss: 0.6920 - val_accuracy: 0.4516 - val_loss: 0.6892
Epoch 2/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5724 - loss: 0.6806 - val_accuracy: 0.5484 - val_loss: 0.6757
Epoch 3/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5508 - loss: 0.6761 - val_accuracy: 0.5806 - val_loss: 0.6621
Epoch 4/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.5803 - loss: 0.6686 - val_accuracy: 0.7097 - val_loss: 0.6493
Epoch 5/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.6259 - loss: 0.6600 - val_accuracy: 0.7419 - val_loss: 0.6377
Epoch 6/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.6521 - loss: 0.6524 - val_accuracy: 0.7742 - val_loss: 0.6270
Epoch 7/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.7017 - loss: 0.6459 - val_accuracy: 0.7742 - val_loss: 0.6166
Epoch 8/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.7033 - loss: 0.6409 - val_accuracy: 0.7742 - val_loss: 0.6064
Epoch 9/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.7209 - loss: 0.6337 - val_accuracy: 0.7742 - val_loss: 0.5962
Epoch 10/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.7111 - loss: 0.6282 - val_accuracy: 0.7742 - val_loss: 0.5851
Epoch 11/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.7254 - loss: 0.6212 - val_accuracy: 0.8065 - val_loss: 0.5732
Epoch 12/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.7557 - loss: 0.6125 - val_accuracy: 0.8065 - val_loss: 0.5611
Epoch 13/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.7681 - loss: 0.6058 - val_accuracy: 0.8387 - val_loss: 0.5486
Epoch 14/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.7688 - loss: 0.6011 - val_accuracy: 0.8387 - val_loss: 0.5352
Epoch 15/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.7824 - loss: 0.5899 - val_accuracy: 0.8710 - val_loss: 0.5210
Epoch 16/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.7812 - loss: 0.5804 - val_accuracy: 0.8710 - val_loss: 0.5058
Epoch 17/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.7879 - loss: 0.5700 - val_accuracy: 0.8710 - val_loss: 0.4898
Epoch 18/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.7869 - loss: 0.5580 - val_accuracy: 0.8710 - val_loss: 0.4728
Epoch 19/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.7840 - loss: 0.5494 - val_accuracy: 0.8710 - val_loss: 0.4558
Epoch 20/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.7908 - loss: 0.5339 - val_accuracy: 0.8710 - val_loss: 0.4379
Epoch 21/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.8005 - loss: 0.5173 - val_accuracy: 0.8710 - val_loss: 0.4192
Epoch 22/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8060 - loss: 0.5147 - val_accuracy: 0.8710 - val_loss: 0.4008
Epoch 23/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.8012 - loss: 0.4986 - val_accuracy: 0.8710 - val_loss: 0.3824
Epoch 24/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.7907 - loss: 0.4896 - val_accuracy: 0.8710 - val_loss: 0.3638
Epoch 25/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.8089 - loss: 0.4668 - val_accuracy: 0.8710 - val_loss: 0.3451
Epoch 26/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.8003 - loss: 0.4653 - val_accuracy: 0.8710 - val_loss: 0.3275
Epoch 27/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.7934 - loss: 0.4620 - val_accuracy: 0.8710 - val_loss: 0.3126
Epoch 28/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8051 - loss: 0.4373 - val_accuracy: 0.8710 - val_loss: 0.2993
Epoch 29/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.7791 - loss: 0.4399 - val_accuracy: 0.8710 - val_loss: 0.2880
Epoch 30/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8080 - loss: 0.4224 - val_accuracy: 0.8710 - val_loss: 0.2817
Epoch 31/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8108 - loss: 0.4114 - val_accuracy: 0.9032 - val_loss: 0.2772
Epoch 32/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.8146 - loss: 0.4190 - val_accuracy: 0.8710 - val_loss: 0.2708
Epoch 33/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8192 - loss: 0.4095 - val_accuracy: 0.8710 - val_loss: 0.2687
Epoch 34/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.8154 - loss: 0.3996 - val_accuracy: 0.8387 - val_loss: 0.2698
Epoch 35/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8116 - loss: 0.4052 - val_accuracy: 0.8387 - val_loss: 0.2676
Epoch 36/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.8107 - loss: 0.4004 - val_accuracy: 0.8710 - val_loss: 0.2619
Epoch 37/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8268 - loss: 0.3980 - val_accuracy: 0.9032 - val_loss: 0.2591
Epoch 38/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.8208 - loss: 0.3948 - val_accuracy: 0.9032 - val_loss: 0.2586
Epoch 39/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8343 - loss: 0.3842 - val_accuracy: 0.9032 - val_loss: 0.2586
Epoch 40/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8429 - loss: 0.3873 - val_accuracy: 0.9032 - val_loss: 0.2612
Epoch 41/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8323 - loss: 0.3854 - val_accuracy: 0.9032 - val_loss: 0.2559
Epoch 42/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.8246 - loss: 0.3750 - val_accuracy: 0.9032 - val_loss: 0.2536
Epoch 43/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.8323 - loss: 0.3777 - val_accuracy: 0.8710 - val_loss: 0.2566
Epoch 44/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8179 - loss: 0.3889 - val_accuracy: 0.9032 - val_loss: 0.2586
Epoch 45/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8295 - loss: 0.3767 - val_accuracy: 0.9032 - val_loss: 0.2631
Epoch 46/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8224 - loss: 0.3760 - val_accuracy: 0.9032 - val_loss: 0.2713
Epoch 47/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8459 - loss: 0.3678 - val_accuracy: 0.9032 - val_loss: 0.2754
Epoch 48/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.8372 - loss: 0.3651 - val_accuracy: 0.9032 - val_loss: 0.2713
Epoch 49/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8371 - loss: 0.3692 - val_accuracy: 0.9032 - val_loss: 0.2633
Epoch 50/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.8535 - loss: 0.3444 - val_accuracy: 0.9032 - val_loss: 0.2626
Epoch 51/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8389 - loss: 0.3682 - val_accuracy: 0.9032 - val_loss: 0.2644
Epoch 52/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8370 - loss: 0.3691 - val_accuracy: 0.9032 - val_loss: 0.2661
Epoch 53/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8482 - loss: 0.3693 - val_accuracy: 0.9032 - val_loss: 0.2685
Epoch 54/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8521 - loss: 0.3515 - val_accuracy: 0.9032 - val_loss: 0.2743
Epoch 55/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.8504 - loss: 0.3456 - val_accuracy: 0.9032 - val_loss: 0.2798
Epoch 56/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8553 - loss: 0.3351 - val_accuracy: 0.9032 - val_loss: 0.2785
Epoch 57/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.8657 - loss: 0.3317 - val_accuracy: 0.9032 - val_loss: 0.2773
Epoch 58/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.8562 - loss: 0.3383 - val_accuracy: 0.9032 - val_loss: 0.2802
Epoch 59/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8628 - loss: 0.3463 - val_accuracy: 0.8387 - val_loss: 0.2841
Epoch 60/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.8569 - loss: 0.3483 - val_accuracy: 0.8387 - val_loss: 0.2835
Epoch 61/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8675 - loss: 0.3362 - val_accuracy: 0.9032 - val_loss: 0.2820
Epoch 62/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8473 - loss: 0.3377 - val_accuracy: 0.9032 - val_loss: 0.2854
Epoch 63/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8675 - loss: 0.3280 - val_accuracy: 0.9032 - val_loss: 0.2920
Epoch 64/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8626 - loss: 0.3382 - val_accuracy: 0.9032 - val_loss: 0.2953
Epoch 65/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8618 - loss: 0.3182 - val_accuracy: 0.9032 - val_loss: 0.2951
Epoch 66/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.8702 - loss: 0.3271 - val_accuracy: 0.9032 - val_loss: 0.2951
Epoch 67/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8673 - loss: 0.3294 - val_accuracy: 0.9032 - val_loss: 0.2977
Epoch 68/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8635 - loss: 0.3200 - val_accuracy: 0.9032 - val_loss: 0.3028
Epoch 69/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8723 - loss: 0.3158 - val_accuracy: 0.9032 - val_loss: 0.3083
Epoch 70/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8771 - loss: 0.3153 - val_accuracy: 0.9032 - val_loss: 0.3101
Epoch 71/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8683 - loss: 0.3253 - val_accuracy: 0.9032 - val_loss: 0.3085
Epoch 72/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8730 - loss: 0.3205 - val_accuracy: 0.9032 - val_loss: 0.3067
Epoch 73/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.8720 - loss: 0.3280 - val_accuracy: 0.8710 - val_loss: 0.3030
Epoch 74/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8690 - loss: 0.3192 - val_accuracy: 0.8710 - val_loss: 0.3043
Epoch 75/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.8768 - loss: 0.3114 - val_accuracy: 0.9032 - val_loss: 0.3111
Epoch 76/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.8761 - loss: 0.3024 - val_accuracy: 0.9032 - val_loss: 0.3173
Epoch 77/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8617 - loss: 0.3163 - val_accuracy: 0.9032 - val_loss: 0.3179
Epoch 78/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8914 - loss: 0.2921 - val_accuracy: 0.8710 - val_loss: 0.3168
Epoch 79/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.8751 - loss: 0.2908 - val_accuracy: 0.8710 - val_loss: 0.3156
Epoch 80/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8855 - loss: 0.3034 - val_accuracy: 0.8710 - val_loss: 0.3149
Epoch 81/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8739 - loss: 0.3021 - val_accuracy: 0.8710 - val_loss: 0.3168
Epoch 82/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8682 - loss: 0.3094 - val_accuracy: 0.8710 - val_loss: 0.3213
Epoch 83/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8816 - loss: 0.2977 - val_accuracy: 0.8710 - val_loss: 0.3280
Epoch 84/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8942 - loss: 0.2852 - val_accuracy: 0.8710 - val_loss: 0.3302
Epoch 85/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8960 - loss: 0.2807 - val_accuracy: 0.8710 - val_loss: 0.3282
Epoch 86/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.8997 - loss: 0.2902 - val_accuracy: 0.8710 - val_loss: 0.3313
Epoch 87/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8919 - loss: 0.2955 - val_accuracy: 0.8710 - val_loss: 0.3350
Epoch 88/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8928 - loss: 0.2953 - val_accuracy: 0.8710 - val_loss: 0.3369
Epoch 89/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.8919 - loss: 0.2931 - val_accuracy: 0.8710 - val_loss: 0.3398
Epoch 90/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.8958 - loss: 0.2817 - val_accuracy: 0.8710 - val_loss: 0.3417
Epoch 91/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8968 - loss: 0.2933 - val_accuracy: 0.9032 - val_loss: 0.3443
Epoch 92/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8922 - loss: 0.2758 - val_accuracy: 0.9032 - val_loss: 0.3459
Epoch 93/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.8872 - loss: 0.2865 - val_accuracy: 0.8710 - val_loss: 0.3443
Epoch 94/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.9008 - loss: 0.2677 - val_accuracy: 0.8710 - val_loss: 0.3408
Epoch 95/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8957 - loss: 0.2822 - val_accuracy: 0.8710 - val_loss: 0.3352
Epoch 96/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.8940 - loss: 0.2733 - val_accuracy: 0.8710 - val_loss: 0.3328
Epoch 97/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.9007 - loss: 0.2686 - val_accuracy: 0.8710 - val_loss: 0.3295
Epoch 98/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.8996 - loss: 0.2729 - val_accuracy: 0.8710 - val_loss: 0.3316
Epoch 99/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.8874 - loss: 0.2709 - val_accuracy: 0.8710 - val_loss: 0.3394
Epoch 100/100
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.8807 - loss: 0.2768 - val_accuracy: 0.8710 - val_loss: 0.3405
模型评估
python
import matplotlib.pyplot as plt
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
python
scores = model.evaluate(X_test, y_test, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
compile_metrics: 87.10%
个人总结
- RNN 的核心思想是通过引入循环结构来捕捉序列数据中的时间依赖性。
- RNN 的每个隐藏层单元不仅接收当前时间步的输入,还接收前一时间步隐藏层的状态作为输入。
s t s_t st 是 t t t 时刻的隐藏状态。它是网络的"记忆"。 s t s_t st的计算依赖于前一个时刻的状态和当前时刻的输入:
s t = f ( U x t + W s t − 1 ) s_t = f(Ux_t + W s_{t-1}) st=f(Uxt+Wst−1)
函数 f f f 通常是诸如 t a n h tanh tanh或者 R e L U ReLU ReLU的非线性函数。
s − 1 s_{-1} s−1 是用来计算第一个隐藏状态,通常可以初始化成 0 0 0。