R3打卡——tensorflow实现RNN心脏病预测

🍨 本文为 🔗365天深度学习训练营中的学习记录博客

1.检查GPU

复制代码
import tensorflow as tf
import pandas     as pd
import numpy      as np

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")
print(gpus)

2.查看数据

复制代码
import pandas as pd
import numpy as np

df = pd.read_csv("data/heart.csv")
df

# 检查是否有空值
df.isnull().sum()

3.划分数据集

复制代码
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

X = df.iloc[:,:-1]
y = df.iloc[:,-1]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1, random_state = 1)

X_train.shape, y_train.shape

# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
sc      = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test  = sc.transform(X_test)

X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_test  = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)

​​​

​​​​​

4.创建模型与编译训练

复制代码
import tensorflow
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNN

model = Sequential()
model.add(SimpleRNN(200, input_shape= (13,1), activation='relu'))
model.add(Dense(100, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()

​​​​5.编译及训练模型

复制代码
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)

model.compile(loss='binary_crossentropy',
              optimizer=opt,
              metrics=["accuracy"])

epochs = 100

history = model.fit(X_train, y_train, 
                    epochs=epochs, 
                    batch_size=128, 
                    validation_data=(X_test, y_test),
                    verbose=1)

​​​​​

6.结果可视化

复制代码
import matplotlib.pyplot as plt
from datetime import datetime
current_time = datetime.now() # 获取当前时间

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

​​​​

​​​​​7.模型评估

复制代码
scores = model.evaluate(X_test, y_test, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

​​​​​​​​​​​

总结:

1.RNN

1. RNN是什么?

RNN是一种专门处理序列数据 的神经网络。它的特点是能够利用​"记忆"​​(隐藏状态)来捕捉序列中的时序信息。比如:

  • 自然语言(句子中的单词顺序)
  • 时间序列数据(股票价格、传感器数据)
  • 语音信号(声音的先后顺序)

传统神经网络(如CNN)假设输入是独立的,而RNN通过循环结构让当前时刻的输出依赖之前的输入,从而建模时间依赖关系。


2. 核心思想:循环与记忆

  • 隐藏状态(Hidden State)​:RNN的"记忆单元",保存了历史信息。
  • 循环机制 :每一步的隐藏状态会传递到下一步,与当前输入共同决定输出。
    公式简化版
    ht=f(ht−1,xt)
    其中,ht是当前时刻的隐藏状态,xt是当前输入,f是激活函数(如tanh)
2. 实验概述

本实验基于心脏病诊断数据构建二分类模型,目标是通过患者生理特征预测心脏病存在与否。采用标准化预处理后的数据输入SimpleRNN网络,结合全连接层进行特征学习与分类。实验通过100轮训练,以Adam优化器(学习率1e-4)和二元交叉熵损失函数优化模型,最终通过准确率评估性能,并可视化训练过程曲线。


3. 核心结果分析
  • 模型有效性

    模型在训练集与测试集上均展现出稳定的学习趋势。训练后期,训练集与验证集的准确率曲线趋近 ,表明模型未出现显著过拟合;损失曲线同步下降,反映优化过程收敛稳定。测试集准确率(需运行后获取具体数值)可初步验证模型对心脏病预测任务的适用性。

  • 结构适配性

    尽管SimpleRNN常用于时间序列数据,但本实验将非时序的静态特征重塑为序列格式输入RNN,可能引入冗余计算。若数据无时序关联性,可尝试全连接网络(DNN)​卷积网络(CNN)​简化结构,提升效率。

相关推荐
IT古董2 小时前
【第五章:计算机视觉-项目实战之图像分割实战】1.图像分割理论-(2)图像分割衍生:语义分割、实例分割、弱监督语义分割
人工智能·计算机视觉
大明者省4 小时前
《青花》歌曲,使用3D表现出意境
人工智能
一朵小红花HH4 小时前
SimpleBEV:改进的激光雷达-摄像头融合架构用于三维目标检测
论文阅读·人工智能·深度学习·目标检测·机器学习·计算机视觉·3d
Daitu_Adam4 小时前
R语言——ggmap包可视化地图
人工智能·数据分析·r语言·数据可视化
weixin_377634844 小时前
【阿里DeepResearch】写作组件WebWeaver详解
人工智能
AndrewHZ4 小时前
【AI算力系统设计分析】1000PetaOps 算力云计算系统设计方案(大模型训练推理专项版)
人工智能·深度学习·llm·云计算·模型部署·大模型推理·算力平台
GilgameshJSS4 小时前
STM32H743-ARM例程3-SYSTICK定时闪烁LED
arm开发·stm32·单片机·嵌入式硬件·学习
AI_gurubar5 小时前
[NeurIPS‘25] AI infra / ML sys 论文(解析)合集
人工智能
胡耀超5 小时前
PaddleLabel百度飞桨Al Studio图像标注平台安装和使用指南(包冲突 using the ‘flask‘ extra、眼底医疗分割数据集演示)
人工智能·百度·开源·paddlepaddle·图像识别·图像标注·paddlelabel
聆思科技AI芯片5 小时前
【AI入门课程】2、AI 的载体 —— 智能硬件
人工智能·单片机·智能硬件