RNN心脏病预测

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

一 前期准备

1.数据导入

python 复制代码
import pandas as pd
from keras.optimizers import Adam
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from keras.models import Sequential
from keras.layers import Dense,SimpleRNN
import warnings
warnings.filterwarnings('ignore')

df = pd.read_csv('heart.csv')

2.检查数据

查看是否有空值

python 复制代码
print(df.shape)
print(df.isnull().sum())
复制代码
(303, 14)
age         0
sex         0
cp          0
trestbps    0
chol        0
fbs         0
restecg     0
thalach     0
exang       0
oldpeak     0
slope       0
ca          0
thal        0
target      0
dtype: int64

二 数据预处理

1.拆分训练集

python 复制代码
X = df.iloc[:,:-1]
y = df.iloc[:,-1]

X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.1,random_state=14)

2.数据标准化

python 复制代码
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.fit_transform(X_test)
X_train = X_train.reshape(X_train.shape[0],X_train.shape[1],1)
X_test = X_test.reshape(X_test.shape[0],X_test.shape[1],1)
复制代码
array([[[ 1.44626869],
        [ 0.54006172],
        [ 0.62321699],
        [ 1.37686599],
        [ 0.83801861],
        [-0.48989795],
        [ 0.92069654],
        [-1.38834656],
        [ 1.34839972],
        [ 1.83944021],
        [-0.74161985],
        [ 0.18805174],
        [ 1.09773445]],

       [[-0.11901962],
        [ 0.54006172],
        [ 1.4632051 ],
        [-0.7179976 ],
        [-1.01585167],
        [-0.48989795],
        [-0.86315301],
        [ 0.77440436],
        [-0.74161985],
        [ 0.85288923],
        [-0.74161985],
        [-0.78354893],
        [ 1.09773445]],

三 构建RNN模型

python 复制代码
model = Sequential()
model.add(SimpleRNN(200,input_shape=(X_train.shape[1],1),activation='relu'))
model.add(Dense(100,activation='relu'))
model.add(Dense(1,activation='sigmoid'))
model.summary()
复制代码
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 simple_rnn (SimpleRNN)      (None, 200)               40400     
                                                                 
 dense (Dense)               (None, 100)               20100     
                                                                 
 dense_1 (Dense)             (None, 1)                 101       
                                                                 
=================================================================
Total params: 60,601
Trainable params: 60,601
Non-trainable params: 0
_________________________________________________________________

四 编译模型

python 复制代码
optimizer = Adam(learning_rate=1e-4)
# 定义损失函数为二元交叉熵(binary_crossentropy),适用于二分类任务。使用先前定义的优化器,并设置监控指标为准确率
model.compile(loss='binary_crossentropy',optimizer=optimizer,metrics='accuracy')

五 训练模型

python 复制代码
epochs = 100
model.fit(x=X_train,y=y_train,validation_data=(X_test,y_test),verbose=1,
         epochs=epochs,batch_size=128)

acc = model.history.history['accuracy']
val_acc = model.history.history['val_accuracy']
loss = model.history.history['loss']
val_loss = model.history.history['val_loss']
复制代码
Epoch 1/100
3/3 [==============================] - 1s 130ms/step - loss: 0.6872 - accuracy: 0.5551 - val_loss: 0.6884 - val_accuracy: 0.5806
Epoch 2/100
3/3 [==============================] - 0s 19ms/step - loss: 0.6763 - accuracy: 0.6250 - val_loss: 0.6848 - val_accuracy: 0.6129
Epoch 3/100
3/3 [==============================] - 0s 19ms/step - loss: 0.6660 - accuracy: 0.6912 - val_loss: 0.6814 - val_accuracy: 0.6452
Epoch 4/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6562 - accuracy: 0.7426 - val_loss: 0.6781 - val_accuracy: 0.6452
Epoch 5/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6467 - accuracy: 0.7647 - val_loss: 0.6751 - val_accuracy: 0.6129
Epoch 6/100
3/3 [==============================] - 0s 19ms/step - loss: 0.6375 - accuracy: 0.7941 - val_loss: 0.6722 - val_accuracy: 0.6452
Epoch 7/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6285 - accuracy: 0.8051 - val_loss: 0.6694 - val_accuracy: 0.6129
Epoch 8/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6193 - accuracy: 0.8015 - val_loss: 0.6666 - val_accuracy: 0.6129
Epoch 9/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6094 - accuracy: 0.8125 - val_loss: 0.6635 - val_accuracy: 0.5806
Epoch 10/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6002 - accuracy: 0.8162 - val_loss: 0.6602 - val_accuracy: 0.6129
Epoch 11/100
3/3 [==============================] - 0s 25ms/step - loss: 0.5903 - accuracy: 0.8125 - val_loss: 0.6565 - val_accuracy: 0.5806
Epoch 12/100
3/3 [==============================] - 0s 18ms/step - loss: 0.5795 - accuracy: 0.8125 - val_loss: 0.6526 - val_accuracy: 0.5806
Epoch 13/100
3/3 [==============================] - 0s 18ms/step - loss: 0.5686 - accuracy: 0.8125 - val_loss: 0.6484 - val_accuracy: 0.6129
Epoch 14/100
3/3 [==============================] - 0s 20ms/step - loss: 0.5571 - accuracy: 0.8125 - val_loss: 0.6436 - val_accuracy: 0.6452
Epoch 15/100
3/3 [==============================] - 0s 20ms/step - loss: 0.5451 - accuracy: 0.8125 - val_loss: 0.6377 - val_accuracy: 0.6452
Epoch 16/100
3/3 [==============================] - 0s 17ms/step - loss: 0.5322 - accuracy: 0.8125 - val_loss: 0.6315 - val_accuracy: 0.6452
Epoch 17/100
3/3 [==============================] - 0s 24ms/step - loss: 0.5190 - accuracy: 0.8199 - val_loss: 0.6251 - val_accuracy: 0.6452
Epoch 18/100
3/3 [==============================] - 0s 17ms/step - loss: 0.5053 - accuracy: 0.8199 - val_loss: 0.6190 - val_accuracy: 0.6774
Epoch 19/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4910 - accuracy: 0.8162 - val_loss: 0.6132 - val_accuracy: 0.6774
Epoch 20/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4765 - accuracy: 0.8199 - val_loss: 0.6076 - val_accuracy: 0.6774
Epoch 21/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4616 - accuracy: 0.8235 - val_loss: 0.6007 - val_accuracy: 0.6774
Epoch 22/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4470 - accuracy: 0.8125 - val_loss: 0.5943 - val_accuracy: 0.6774
Epoch 23/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4345 - accuracy: 0.8162 - val_loss: 0.5906 - val_accuracy: 0.6774
Epoch 24/100
3/3 [==============================] - 0s 15ms/step - loss: 0.4219 - accuracy: 0.8162 - val_loss: 0.5901 - val_accuracy: 0.7419
Epoch 25/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4116 - accuracy: 0.8162 - val_loss: 0.5921 - val_accuracy: 0.7742
Epoch 26/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4056 - accuracy: 0.8272 - val_loss: 0.5990 - val_accuracy: 0.7419
Epoch 27/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3983 - accuracy: 0.8309 - val_loss: 0.5970 - val_accuracy: 0.7097
Epoch 28/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3920 - accuracy: 0.8309 - val_loss: 0.5914 - val_accuracy: 0.7097
Epoch 29/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3860 - accuracy: 0.8235 - val_loss: 0.5863 - val_accuracy: 0.7097
Epoch 30/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3802 - accuracy: 0.8235 - val_loss: 0.5724 - val_accuracy: 0.7097
Epoch 31/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3757 - accuracy: 0.8346 - val_loss: 0.5572 - val_accuracy: 0.7419
Epoch 32/100
3/3 [==============================] - 0s 20ms/step - loss: 0.3766 - accuracy: 0.8272 - val_loss: 0.5545 - val_accuracy: 0.7419
Epoch 33/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3706 - accuracy: 0.8272 - val_loss: 0.5608 - val_accuracy: 0.7419
Epoch 34/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3639 - accuracy: 0.8382 - val_loss: 0.5899 - val_accuracy: 0.7419
Epoch 35/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3694 - accuracy: 0.8272 - val_loss: 0.6097 - val_accuracy: 0.7742
Epoch 36/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3682 - accuracy: 0.8346 - val_loss: 0.5859 - val_accuracy: 0.7419
Epoch 37/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3567 - accuracy: 0.8309 - val_loss: 0.5680 - val_accuracy: 0.7419
Epoch 38/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3497 - accuracy: 0.8419 - val_loss: 0.5528 - val_accuracy: 0.7419
Epoch 39/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3484 - accuracy: 0.8603 - val_loss: 0.5417 - val_accuracy: 0.7742
Epoch 40/100
3/3 [==============================] - 0s 22ms/step - loss: 0.3487 - accuracy: 0.8603 - val_loss: 0.5386 - val_accuracy: 0.6774
Epoch 41/100
3/3 [==============================] - 0s 22ms/step - loss: 0.3473 - accuracy: 0.8640 - val_loss: 0.5383 - val_accuracy: 0.7097
Epoch 42/100
3/3 [==============================] - 0s 19ms/step - loss: 0.3422 - accuracy: 0.8676 - val_loss: 0.5425 - val_accuracy: 0.7742
Epoch 43/100
3/3 [==============================] - 0s 19ms/step - loss: 0.3353 - accuracy: 0.8713 - val_loss: 0.5467 - val_accuracy: 0.7419
Epoch 44/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3318 - accuracy: 0.8787 - val_loss: 0.5565 - val_accuracy: 0.7419
Epoch 45/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3289 - accuracy: 0.8750 - val_loss: 0.5572 - val_accuracy: 0.7419
Epoch 46/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3263 - accuracy: 0.8750 - val_loss: 0.5548 - val_accuracy: 0.7419
Epoch 47/100
3/3 [==============================] - 0s 19ms/step - loss: 0.3227 - accuracy: 0.8787 - val_loss: 0.5520 - val_accuracy: 0.7419
Epoch 48/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3191 - accuracy: 0.8824 - val_loss: 0.5564 - val_accuracy: 0.7419
Epoch 49/100
3/3 [==============================] - 0s 19ms/step - loss: 0.3172 - accuracy: 0.8713 - val_loss: 0.5539 - val_accuracy: 0.7419
Epoch 50/100
3/3 [==============================] - 0s 20ms/step - loss: 0.3149 - accuracy: 0.8824 - val_loss: 0.5381 - val_accuracy: 0.7419
Epoch 51/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3110 - accuracy: 0.8824 - val_loss: 0.5427 - val_accuracy: 0.7419
Epoch 52/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3084 - accuracy: 0.8787 - val_loss: 0.5510 - val_accuracy: 0.7419
Epoch 53/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3069 - accuracy: 0.8750 - val_loss: 0.5571 - val_accuracy: 0.7419
Epoch 54/100
3/3 [==============================] - 0s 19ms/step - loss: 0.3052 - accuracy: 0.8860 - val_loss: 0.5468 - val_accuracy: 0.7419
Epoch 55/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3024 - accuracy: 0.8787 - val_loss: 0.5347 - val_accuracy: 0.7419
Epoch 56/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3010 - accuracy: 0.8787 - val_loss: 0.5417 - val_accuracy: 0.7419
Epoch 57/100
3/3 [==============================] - 0s 21ms/step - loss: 0.3013 - accuracy: 0.8860 - val_loss: 0.5496 - val_accuracy: 0.7419
Epoch 58/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2975 - accuracy: 0.8824 - val_loss: 0.5355 - val_accuracy: 0.7419
Epoch 59/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2954 - accuracy: 0.8787 - val_loss: 0.5198 - val_accuracy: 0.7419
Epoch 60/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2970 - accuracy: 0.8787 - val_loss: 0.5148 - val_accuracy: 0.7419
Epoch 61/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2991 - accuracy: 0.8824 - val_loss: 0.5187 - val_accuracy: 0.7419
Epoch 62/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2958 - accuracy: 0.8787 - val_loss: 0.5376 - val_accuracy: 0.7419
Epoch 63/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2891 - accuracy: 0.8860 - val_loss: 0.5659 - val_accuracy: 0.7419
Epoch 64/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2923 - accuracy: 0.8824 - val_loss: 0.5777 - val_accuracy: 0.7419
Epoch 65/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2892 - accuracy: 0.8824 - val_loss: 0.5560 - val_accuracy: 0.7419
Epoch 66/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2848 - accuracy: 0.8934 - val_loss: 0.5405 - val_accuracy: 0.7419
Epoch 67/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2828 - accuracy: 0.8897 - val_loss: 0.5334 - val_accuracy: 0.7419
Epoch 68/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2810 - accuracy: 0.8934 - val_loss: 0.5332 - val_accuracy: 0.7419
Epoch 69/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2792 - accuracy: 0.8934 - val_loss: 0.5307 - val_accuracy: 0.7419
Epoch 70/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2780 - accuracy: 0.8934 - val_loss: 0.5370 - val_accuracy: 0.7419
Epoch 71/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2763 - accuracy: 0.8934 - val_loss: 0.5459 - val_accuracy: 0.7419
Epoch 72/100
3/3 [==============================] - 0s 21ms/step - loss: 0.2762 - accuracy: 0.8971 - val_loss: 0.5583 - val_accuracy: 0.7419
Epoch 73/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2759 - accuracy: 0.8971 - val_loss: 0.5676 - val_accuracy: 0.7419
Epoch 74/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2764 - accuracy: 0.8934 - val_loss: 0.5715 - val_accuracy: 0.7419
Epoch 75/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2747 - accuracy: 0.8934 - val_loss: 0.5540 - val_accuracy: 0.7419
Epoch 76/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2701 - accuracy: 0.8971 - val_loss: 0.5387 - val_accuracy: 0.7419
Epoch 77/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2689 - accuracy: 0.9044 - val_loss: 0.5308 - val_accuracy: 0.7419
Epoch 78/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2701 - accuracy: 0.9081 - val_loss: 0.5241 - val_accuracy: 0.7097
Epoch 79/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2716 - accuracy: 0.9007 - val_loss: 0.5241 - val_accuracy: 0.7097
Epoch 80/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2690 - accuracy: 0.9007 - val_loss: 0.5332 - val_accuracy: 0.7097
Epoch 81/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2650 - accuracy: 0.9154 - val_loss: 0.5418 - val_accuracy: 0.7419
Epoch 82/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2631 - accuracy: 0.9118 - val_loss: 0.5434 - val_accuracy: 0.7419
Epoch 83/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2620 - accuracy: 0.9154 - val_loss: 0.5406 - val_accuracy: 0.7419
Epoch 84/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2603 - accuracy: 0.9154 - val_loss: 0.5395 - val_accuracy: 0.7419
Epoch 85/100
3/3 [==============================] - 0s 26ms/step - loss: 0.2588 - accuracy: 0.9154 - val_loss: 0.5497 - val_accuracy: 0.7419
Epoch 86/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2562 - accuracy: 0.9081 - val_loss: 0.5687 - val_accuracy: 0.7419
Epoch 87/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2609 - accuracy: 0.8971 - val_loss: 0.5754 - val_accuracy: 0.7419
Epoch 88/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2569 - accuracy: 0.8971 - val_loss: 0.5555 - val_accuracy: 0.7419
Epoch 89/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2532 - accuracy: 0.9081 - val_loss: 0.5399 - val_accuracy: 0.7419
Epoch 90/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2545 - accuracy: 0.9191 - val_loss: 0.5361 - val_accuracy: 0.7419
Epoch 91/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2578 - accuracy: 0.9118 - val_loss: 0.5375 - val_accuracy: 0.7419
Epoch 92/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2572 - accuracy: 0.9118 - val_loss: 0.5507 - val_accuracy: 0.7419
Epoch 93/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2516 - accuracy: 0.9118 - val_loss: 0.5715 - val_accuracy: 0.7419
Epoch 94/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2487 - accuracy: 0.9118 - val_loss: 0.5705 - val_accuracy: 0.7419
Epoch 95/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2464 - accuracy: 0.9118 - val_loss: 0.5551 - val_accuracy: 0.7419
Epoch 96/100
3/3 [==============================] - 0s 20ms/step - loss: 0.2454 - accuracy: 0.9191 - val_loss: 0.5480 - val_accuracy: 0.7419
Epoch 97/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2438 - accuracy: 0.9154 - val_loss: 0.5543 - val_accuracy: 0.7419
Epoch 98/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2447 - accuracy: 0.9118 - val_loss: 0.5534 - val_accuracy: 0.7419
Epoch 99/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2446 - accuracy: 0.9118 - val_loss: 0.5425 - val_accuracy: 0.7419
Epoch 100/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2434 - accuracy: 0.9118 - val_loss: 0.5213 - val_accuracy: 0.7742

六 结果可视化

python 复制代码
epochs_range = range(100)
plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(epochs_range,acc,label='training accuracy')
plt.plot(epochs_range,val_acc,label='validation accuracy')
plt.legend(loc='lower right')
plt.title('training and validation accuracy')

plt.subplot(1,2,2)
plt.plot(epochs_range,loss,label='training loss')
plt.plot(epochs_range,val_loss,label='validation loss')
plt.legend(loc='upper right')
plt.title('training and validation loss')
plt.show()

总结:

复制代码
1. 模型输入要求
RNN 输入格式:许多深度学习模型,尤其是 RNN 和 LSTM,需要输入数据的形状为三维:(样本数, 时间步数, 特征数)。这使得模型能够处理序列数据并学习时间依赖关系。
2. 数据原始形状
在标准化后,X_train 和 X_test 的形状是 (样本数, 特征数)。例如,如果 X_train 有 100 个样本和 10 个特征,则其形状为 (100, 10)。
3. 重塑的目的
重塑为三维:通过 X_train.reshape(X_train.shape[0], X_train.shape[1], 1),你将数据的形状改变为 (样本数, 特征数, 1)。这里的 1 表示特征数,在单变量情况下,只包含一个特征。
例如,假设 X_train 原本的形状是 (100, 10),重塑后将变为 (100, 10, 1),表示有 100 个样本,每个样本有 10 个时间步(特征)。
4. 适应模型结构
通过这种重塑,数据可以被 RNN 模型正确地处理,从而捕捉到特征随时间变化的模式。
相关推荐
唯道行4 分钟前
计算机图形学·9 几何学
人工智能·线性代数·计算机视觉·矩阵·几何学·计算机图形学
Antonio9157 分钟前
【图像处理】tiff格式介绍
图像处理·人工智能
AndrewHZ10 分钟前
【图像处理基石】什么是alpha matting?
图像处理·人工智能·计算机视觉·matting·发丝分割·trimap·人像模式
慕云紫英23 分钟前
人工智能在全球多领域的应用潜力及当前技术面临的挑战
人工智能·aigc
“向阳的蛋”27 分钟前
生老病死(一)
人工智能·ai
流烟默30 分钟前
机器学习中模型的鲁棒性是什么
人工智能·机器学习·鲁棒性
Baihai_IDP1 小时前
并行智能体是否将重塑软件开发模式?
人工智能·程序员·ai编程
飞哥数智坊1 小时前
当你还在用 AI 写周报,别人的 AI 已经在炒币炒股了
人工智能
Juchecar1 小时前
翻译:软件开发的演进:从机器码到 AI 编排
人工智能
字节数据平台1 小时前
火山引擎发布Data Agent新能力,推动用户洞察进入“智能3.0时代”
大数据·人工智能