本文为为🔗365天深度学习训练营内部文章
原作者:K同学啊
一 前期准备
1.数据导入
python
import pandas as pd
from keras.optimizers import Adam
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from keras.models import Sequential
from keras.layers import Dense,SimpleRNN
import warnings
warnings.filterwarnings('ignore')
df = pd.read_csv('heart.csv')
2.检查数据
查看是否有空值
python
print(df.shape)
print(df.isnull().sum())
(303, 14)
age 0
sex 0
cp 0
trestbps 0
chol 0
fbs 0
restecg 0
thalach 0
exang 0
oldpeak 0
slope 0
ca 0
thal 0
target 0
dtype: int64
二 数据预处理
1.拆分训练集
python
X = df.iloc[:,:-1]
y = df.iloc[:,-1]
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.1,random_state=14)
2.数据标准化
python
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.fit_transform(X_test)
X_train = X_train.reshape(X_train.shape[0],X_train.shape[1],1)
X_test = X_test.reshape(X_test.shape[0],X_test.shape[1],1)
array([[[ 1.44626869],
[ 0.54006172],
[ 0.62321699],
[ 1.37686599],
[ 0.83801861],
[-0.48989795],
[ 0.92069654],
[-1.38834656],
[ 1.34839972],
[ 1.83944021],
[-0.74161985],
[ 0.18805174],
[ 1.09773445]],
[[-0.11901962],
[ 0.54006172],
[ 1.4632051 ],
[-0.7179976 ],
[-1.01585167],
[-0.48989795],
[-0.86315301],
[ 0.77440436],
[-0.74161985],
[ 0.85288923],
[-0.74161985],
[-0.78354893],
[ 1.09773445]],
三 构建RNN模型
python
model = Sequential()
model.add(SimpleRNN(200,input_shape=(X_train.shape[1],1),activation='relu'))
model.add(Dense(100,activation='relu'))
model.add(Dense(1,activation='sigmoid'))
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn (SimpleRNN) (None, 200) 40400
dense (Dense) (None, 100) 20100
dense_1 (Dense) (None, 1) 101
=================================================================
Total params: 60,601
Trainable params: 60,601
Non-trainable params: 0
_________________________________________________________________
四 编译模型
python
optimizer = Adam(learning_rate=1e-4)
# 定义损失函数为二元交叉熵(binary_crossentropy),适用于二分类任务。使用先前定义的优化器,并设置监控指标为准确率
model.compile(loss='binary_crossentropy',optimizer=optimizer,metrics='accuracy')
五 训练模型
python
epochs = 100
model.fit(x=X_train,y=y_train,validation_data=(X_test,y_test),verbose=1,
epochs=epochs,batch_size=128)
acc = model.history.history['accuracy']
val_acc = model.history.history['val_accuracy']
loss = model.history.history['loss']
val_loss = model.history.history['val_loss']
Epoch 1/100
3/3 [==============================] - 1s 130ms/step - loss: 0.6872 - accuracy: 0.5551 - val_loss: 0.6884 - val_accuracy: 0.5806
Epoch 2/100
3/3 [==============================] - 0s 19ms/step - loss: 0.6763 - accuracy: 0.6250 - val_loss: 0.6848 - val_accuracy: 0.6129
Epoch 3/100
3/3 [==============================] - 0s 19ms/step - loss: 0.6660 - accuracy: 0.6912 - val_loss: 0.6814 - val_accuracy: 0.6452
Epoch 4/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6562 - accuracy: 0.7426 - val_loss: 0.6781 - val_accuracy: 0.6452
Epoch 5/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6467 - accuracy: 0.7647 - val_loss: 0.6751 - val_accuracy: 0.6129
Epoch 6/100
3/3 [==============================] - 0s 19ms/step - loss: 0.6375 - accuracy: 0.7941 - val_loss: 0.6722 - val_accuracy: 0.6452
Epoch 7/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6285 - accuracy: 0.8051 - val_loss: 0.6694 - val_accuracy: 0.6129
Epoch 8/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6193 - accuracy: 0.8015 - val_loss: 0.6666 - val_accuracy: 0.6129
Epoch 9/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6094 - accuracy: 0.8125 - val_loss: 0.6635 - val_accuracy: 0.5806
Epoch 10/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6002 - accuracy: 0.8162 - val_loss: 0.6602 - val_accuracy: 0.6129
Epoch 11/100
3/3 [==============================] - 0s 25ms/step - loss: 0.5903 - accuracy: 0.8125 - val_loss: 0.6565 - val_accuracy: 0.5806
Epoch 12/100
3/3 [==============================] - 0s 18ms/step - loss: 0.5795 - accuracy: 0.8125 - val_loss: 0.6526 - val_accuracy: 0.5806
Epoch 13/100
3/3 [==============================] - 0s 18ms/step - loss: 0.5686 - accuracy: 0.8125 - val_loss: 0.6484 - val_accuracy: 0.6129
Epoch 14/100
3/3 [==============================] - 0s 20ms/step - loss: 0.5571 - accuracy: 0.8125 - val_loss: 0.6436 - val_accuracy: 0.6452
Epoch 15/100
3/3 [==============================] - 0s 20ms/step - loss: 0.5451 - accuracy: 0.8125 - val_loss: 0.6377 - val_accuracy: 0.6452
Epoch 16/100
3/3 [==============================] - 0s 17ms/step - loss: 0.5322 - accuracy: 0.8125 - val_loss: 0.6315 - val_accuracy: 0.6452
Epoch 17/100
3/3 [==============================] - 0s 24ms/step - loss: 0.5190 - accuracy: 0.8199 - val_loss: 0.6251 - val_accuracy: 0.6452
Epoch 18/100
3/3 [==============================] - 0s 17ms/step - loss: 0.5053 - accuracy: 0.8199 - val_loss: 0.6190 - val_accuracy: 0.6774
Epoch 19/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4910 - accuracy: 0.8162 - val_loss: 0.6132 - val_accuracy: 0.6774
Epoch 20/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4765 - accuracy: 0.8199 - val_loss: 0.6076 - val_accuracy: 0.6774
Epoch 21/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4616 - accuracy: 0.8235 - val_loss: 0.6007 - val_accuracy: 0.6774
Epoch 22/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4470 - accuracy: 0.8125 - val_loss: 0.5943 - val_accuracy: 0.6774
Epoch 23/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4345 - accuracy: 0.8162 - val_loss: 0.5906 - val_accuracy: 0.6774
Epoch 24/100
3/3 [==============================] - 0s 15ms/step - loss: 0.4219 - accuracy: 0.8162 - val_loss: 0.5901 - val_accuracy: 0.7419
Epoch 25/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4116 - accuracy: 0.8162 - val_loss: 0.5921 - val_accuracy: 0.7742
Epoch 26/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4056 - accuracy: 0.8272 - val_loss: 0.5990 - val_accuracy: 0.7419
Epoch 27/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3983 - accuracy: 0.8309 - val_loss: 0.5970 - val_accuracy: 0.7097
Epoch 28/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3920 - accuracy: 0.8309 - val_loss: 0.5914 - val_accuracy: 0.7097
Epoch 29/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3860 - accuracy: 0.8235 - val_loss: 0.5863 - val_accuracy: 0.7097
Epoch 30/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3802 - accuracy: 0.8235 - val_loss: 0.5724 - val_accuracy: 0.7097
Epoch 31/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3757 - accuracy: 0.8346 - val_loss: 0.5572 - val_accuracy: 0.7419
Epoch 32/100
3/3 [==============================] - 0s 20ms/step - loss: 0.3766 - accuracy: 0.8272 - val_loss: 0.5545 - val_accuracy: 0.7419
Epoch 33/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3706 - accuracy: 0.8272 - val_loss: 0.5608 - val_accuracy: 0.7419
Epoch 34/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3639 - accuracy: 0.8382 - val_loss: 0.5899 - val_accuracy: 0.7419
Epoch 35/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3694 - accuracy: 0.8272 - val_loss: 0.6097 - val_accuracy: 0.7742
Epoch 36/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3682 - accuracy: 0.8346 - val_loss: 0.5859 - val_accuracy: 0.7419
Epoch 37/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3567 - accuracy: 0.8309 - val_loss: 0.5680 - val_accuracy: 0.7419
Epoch 38/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3497 - accuracy: 0.8419 - val_loss: 0.5528 - val_accuracy: 0.7419
Epoch 39/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3484 - accuracy: 0.8603 - val_loss: 0.5417 - val_accuracy: 0.7742
Epoch 40/100
3/3 [==============================] - 0s 22ms/step - loss: 0.3487 - accuracy: 0.8603 - val_loss: 0.5386 - val_accuracy: 0.6774
Epoch 41/100
3/3 [==============================] - 0s 22ms/step - loss: 0.3473 - accuracy: 0.8640 - val_loss: 0.5383 - val_accuracy: 0.7097
Epoch 42/100
3/3 [==============================] - 0s 19ms/step - loss: 0.3422 - accuracy: 0.8676 - val_loss: 0.5425 - val_accuracy: 0.7742
Epoch 43/100
3/3 [==============================] - 0s 19ms/step - loss: 0.3353 - accuracy: 0.8713 - val_loss: 0.5467 - val_accuracy: 0.7419
Epoch 44/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3318 - accuracy: 0.8787 - val_loss: 0.5565 - val_accuracy: 0.7419
Epoch 45/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3289 - accuracy: 0.8750 - val_loss: 0.5572 - val_accuracy: 0.7419
Epoch 46/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3263 - accuracy: 0.8750 - val_loss: 0.5548 - val_accuracy: 0.7419
Epoch 47/100
3/3 [==============================] - 0s 19ms/step - loss: 0.3227 - accuracy: 0.8787 - val_loss: 0.5520 - val_accuracy: 0.7419
Epoch 48/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3191 - accuracy: 0.8824 - val_loss: 0.5564 - val_accuracy: 0.7419
Epoch 49/100
3/3 [==============================] - 0s 19ms/step - loss: 0.3172 - accuracy: 0.8713 - val_loss: 0.5539 - val_accuracy: 0.7419
Epoch 50/100
3/3 [==============================] - 0s 20ms/step - loss: 0.3149 - accuracy: 0.8824 - val_loss: 0.5381 - val_accuracy: 0.7419
Epoch 51/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3110 - accuracy: 0.8824 - val_loss: 0.5427 - val_accuracy: 0.7419
Epoch 52/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3084 - accuracy: 0.8787 - val_loss: 0.5510 - val_accuracy: 0.7419
Epoch 53/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3069 - accuracy: 0.8750 - val_loss: 0.5571 - val_accuracy: 0.7419
Epoch 54/100
3/3 [==============================] - 0s 19ms/step - loss: 0.3052 - accuracy: 0.8860 - val_loss: 0.5468 - val_accuracy: 0.7419
Epoch 55/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3024 - accuracy: 0.8787 - val_loss: 0.5347 - val_accuracy: 0.7419
Epoch 56/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3010 - accuracy: 0.8787 - val_loss: 0.5417 - val_accuracy: 0.7419
Epoch 57/100
3/3 [==============================] - 0s 21ms/step - loss: 0.3013 - accuracy: 0.8860 - val_loss: 0.5496 - val_accuracy: 0.7419
Epoch 58/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2975 - accuracy: 0.8824 - val_loss: 0.5355 - val_accuracy: 0.7419
Epoch 59/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2954 - accuracy: 0.8787 - val_loss: 0.5198 - val_accuracy: 0.7419
Epoch 60/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2970 - accuracy: 0.8787 - val_loss: 0.5148 - val_accuracy: 0.7419
Epoch 61/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2991 - accuracy: 0.8824 - val_loss: 0.5187 - val_accuracy: 0.7419
Epoch 62/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2958 - accuracy: 0.8787 - val_loss: 0.5376 - val_accuracy: 0.7419
Epoch 63/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2891 - accuracy: 0.8860 - val_loss: 0.5659 - val_accuracy: 0.7419
Epoch 64/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2923 - accuracy: 0.8824 - val_loss: 0.5777 - val_accuracy: 0.7419
Epoch 65/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2892 - accuracy: 0.8824 - val_loss: 0.5560 - val_accuracy: 0.7419
Epoch 66/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2848 - accuracy: 0.8934 - val_loss: 0.5405 - val_accuracy: 0.7419
Epoch 67/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2828 - accuracy: 0.8897 - val_loss: 0.5334 - val_accuracy: 0.7419
Epoch 68/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2810 - accuracy: 0.8934 - val_loss: 0.5332 - val_accuracy: 0.7419
Epoch 69/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2792 - accuracy: 0.8934 - val_loss: 0.5307 - val_accuracy: 0.7419
Epoch 70/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2780 - accuracy: 0.8934 - val_loss: 0.5370 - val_accuracy: 0.7419
Epoch 71/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2763 - accuracy: 0.8934 - val_loss: 0.5459 - val_accuracy: 0.7419
Epoch 72/100
3/3 [==============================] - 0s 21ms/step - loss: 0.2762 - accuracy: 0.8971 - val_loss: 0.5583 - val_accuracy: 0.7419
Epoch 73/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2759 - accuracy: 0.8971 - val_loss: 0.5676 - val_accuracy: 0.7419
Epoch 74/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2764 - accuracy: 0.8934 - val_loss: 0.5715 - val_accuracy: 0.7419
Epoch 75/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2747 - accuracy: 0.8934 - val_loss: 0.5540 - val_accuracy: 0.7419
Epoch 76/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2701 - accuracy: 0.8971 - val_loss: 0.5387 - val_accuracy: 0.7419
Epoch 77/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2689 - accuracy: 0.9044 - val_loss: 0.5308 - val_accuracy: 0.7419
Epoch 78/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2701 - accuracy: 0.9081 - val_loss: 0.5241 - val_accuracy: 0.7097
Epoch 79/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2716 - accuracy: 0.9007 - val_loss: 0.5241 - val_accuracy: 0.7097
Epoch 80/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2690 - accuracy: 0.9007 - val_loss: 0.5332 - val_accuracy: 0.7097
Epoch 81/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2650 - accuracy: 0.9154 - val_loss: 0.5418 - val_accuracy: 0.7419
Epoch 82/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2631 - accuracy: 0.9118 - val_loss: 0.5434 - val_accuracy: 0.7419
Epoch 83/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2620 - accuracy: 0.9154 - val_loss: 0.5406 - val_accuracy: 0.7419
Epoch 84/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2603 - accuracy: 0.9154 - val_loss: 0.5395 - val_accuracy: 0.7419
Epoch 85/100
3/3 [==============================] - 0s 26ms/step - loss: 0.2588 - accuracy: 0.9154 - val_loss: 0.5497 - val_accuracy: 0.7419
Epoch 86/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2562 - accuracy: 0.9081 - val_loss: 0.5687 - val_accuracy: 0.7419
Epoch 87/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2609 - accuracy: 0.8971 - val_loss: 0.5754 - val_accuracy: 0.7419
Epoch 88/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2569 - accuracy: 0.8971 - val_loss: 0.5555 - val_accuracy: 0.7419
Epoch 89/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2532 - accuracy: 0.9081 - val_loss: 0.5399 - val_accuracy: 0.7419
Epoch 90/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2545 - accuracy: 0.9191 - val_loss: 0.5361 - val_accuracy: 0.7419
Epoch 91/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2578 - accuracy: 0.9118 - val_loss: 0.5375 - val_accuracy: 0.7419
Epoch 92/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2572 - accuracy: 0.9118 - val_loss: 0.5507 - val_accuracy: 0.7419
Epoch 93/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2516 - accuracy: 0.9118 - val_loss: 0.5715 - val_accuracy: 0.7419
Epoch 94/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2487 - accuracy: 0.9118 - val_loss: 0.5705 - val_accuracy: 0.7419
Epoch 95/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2464 - accuracy: 0.9118 - val_loss: 0.5551 - val_accuracy: 0.7419
Epoch 96/100
3/3 [==============================] - 0s 20ms/step - loss: 0.2454 - accuracy: 0.9191 - val_loss: 0.5480 - val_accuracy: 0.7419
Epoch 97/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2438 - accuracy: 0.9154 - val_loss: 0.5543 - val_accuracy: 0.7419
Epoch 98/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2447 - accuracy: 0.9118 - val_loss: 0.5534 - val_accuracy: 0.7419
Epoch 99/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2446 - accuracy: 0.9118 - val_loss: 0.5425 - val_accuracy: 0.7419
Epoch 100/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2434 - accuracy: 0.9118 - val_loss: 0.5213 - val_accuracy: 0.7742
六 结果可视化
python
epochs_range = range(100)
plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(epochs_range,acc,label='training accuracy')
plt.plot(epochs_range,val_acc,label='validation accuracy')
plt.legend(loc='lower right')
plt.title('training and validation accuracy')
plt.subplot(1,2,2)
plt.plot(epochs_range,loss,label='training loss')
plt.plot(epochs_range,val_loss,label='validation loss')
plt.legend(loc='upper right')
plt.title('training and validation loss')
plt.show()
总结:
1. 模型输入要求 RNN 输入格式:许多深度学习模型,尤其是 RNN 和 LSTM,需要输入数据的形状为三维:(样本数, 时间步数, 特征数)。这使得模型能够处理序列数据并学习时间依赖关系。 2. 数据原始形状 在标准化后,X_train 和 X_test 的形状是 (样本数, 特征数)。例如,如果 X_train 有 100 个样本和 10 个特征,则其形状为 (100, 10)。 3. 重塑的目的 重塑为三维:通过 X_train.reshape(X_train.shape[0], X_train.shape[1], 1),你将数据的形状改变为 (样本数, 特征数, 1)。这里的 1 表示特征数,在单变量情况下,只包含一个特征。 例如,假设 X_train 原本的形状是 (100, 10),重塑后将变为 (100, 10, 1),表示有 100 个样本,每个样本有 10 个时间步(特征)。 4. 适应模型结构 通过这种重塑,数据可以被 RNN 模型正确地处理,从而捕捉到特征随时间变化的模式。