R3打卡——tensorflow实现RNN心脏病预测

🍨 本文为 🔗365天深度学习训练营中的学习记录博客

1.检查GPU

复制代码
import tensorflow as tf
import pandas     as pd
import numpy      as np

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")
print(gpus)

2.查看数据

复制代码
import pandas as pd
import numpy as np

df = pd.read_csv("data/heart.csv")
df

# 检查是否有空值
df.isnull().sum()

3.划分数据集

复制代码
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

X = df.iloc[:,:-1]
y = df.iloc[:,-1]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1, random_state = 1)

X_train.shape, y_train.shape

# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
sc      = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test  = sc.transform(X_test)

X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_test  = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)

​​​

​​​​​

4.创建模型与编译训练

复制代码
import tensorflow
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNN

model = Sequential()
model.add(SimpleRNN(200, input_shape= (13,1), activation='relu'))
model.add(Dense(100, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()

​​​​5.编译及训练模型

复制代码
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)

model.compile(loss='binary_crossentropy',
              optimizer=opt,
              metrics=["accuracy"])

epochs = 100

history = model.fit(X_train, y_train, 
                    epochs=epochs, 
                    batch_size=128, 
                    validation_data=(X_test, y_test),
                    verbose=1)

​​​​​

6.结果可视化

复制代码
import matplotlib.pyplot as plt
from datetime import datetime
current_time = datetime.now() # 获取当前时间

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

​​​​

​​​​​7.模型评估

复制代码
scores = model.evaluate(X_test, y_test, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

​​​​​​​​​​​

总结:

1.RNN

1. RNN是什么?

RNN是一种专门处理序列数据 的神经网络。它的特点是能够利用​"记忆"​​(隐藏状态)来捕捉序列中的时序信息。比如:

  • 自然语言(句子中的单词顺序)
  • 时间序列数据(股票价格、传感器数据)
  • 语音信号(声音的先后顺序)

传统神经网络(如CNN)假设输入是独立的,而RNN通过循环结构让当前时刻的输出依赖之前的输入,从而建模时间依赖关系。


2. 核心思想:循环与记忆

  • 隐藏状态(Hidden State)​:RNN的"记忆单元",保存了历史信息。
  • 循环机制 :每一步的隐藏状态会传递到下一步,与当前输入共同决定输出。
    公式简化版
    ht=f(ht−1,xt)
    其中,ht是当前时刻的隐藏状态,xt是当前输入,f是激活函数(如tanh)
2. 实验概述

本实验基于心脏病诊断数据构建二分类模型,目标是通过患者生理特征预测心脏病存在与否。采用标准化预处理后的数据输入SimpleRNN网络,结合全连接层进行特征学习与分类。实验通过100轮训练,以Adam优化器(学习率1e-4)和二元交叉熵损失函数优化模型,最终通过准确率评估性能,并可视化训练过程曲线。


3. 核心结果分析
  • 模型有效性

    模型在训练集与测试集上均展现出稳定的学习趋势。训练后期,训练集与验证集的准确率曲线趋近 ,表明模型未出现显著过拟合;损失曲线同步下降,反映优化过程收敛稳定。测试集准确率(需运行后获取具体数值)可初步验证模型对心脏病预测任务的适用性。

  • 结构适配性

    尽管SimpleRNN常用于时间序列数据,但本实验将非时序的静态特征重塑为序列格式输入RNN,可能引入冗余计算。若数据无时序关联性,可尝试全连接网络(DNN)​卷积网络(CNN)​简化结构,提升效率。

相关推荐
Cat_Rocky14 小时前
docker简单学习
学习·docker·容器
2501_9479082014 小时前
F5携手亚马逊云科技与微软参与NSS Labs AI研究报告,定义AI运行时安全测试基准
人工智能·科技·microsoft
Jagger_14 小时前
我终于想明白了,为什么我不会赚钱。
人工智能
xixixi7777714 小时前
跨境AI服务:多语种大模型+卫星通信+量子加密+数据脱敏+安全审计,合规·高效·安全三重保障
人工智能·安全·大模型·通信·卫星通信·审计·量子安全
中金快讯14 小时前
光大同创(301387)外骨骼机器人订单落地,轻量化方案获军方认证。
人工智能
qingwufeiyang_53014 小时前
Mybatis-plus学习笔记1
笔记·学习·mybatis
无垠的广袤14 小时前
【“星睿O6”AI PC开发套件评测】基于 OpenClaw 的物体识别
linux·人工智能·opencv·摄像头·openclaw
bingd0114 小时前
慕课网、CSDN、菜鸟教程…2026 国内编程学习平台实测对比
java·开发语言·人工智能·python·学习
qq_4112624214 小时前
设备的选型与其优势
人工智能·物联网·ai
乐迪信息14 小时前
乐迪信息:智慧港口AI防爆摄像机实现船舶违规靠岸自动抓拍
大数据·人工智能·算法·安全·目标跟踪