信号处理--基于EEG脑电信号的眼睛状态的分析

本实验为生物信息学专题设计小项目。项目目的是通过提供的14导联EEG 脑电信号,实现对于人体睁眼和闭眼两个状态的数据分类分析。每个脑电信号的时长大约为117秒。

目录

加载相关的库函数

读取脑电信号数据并查看数据的属性

绘制脑电多通道连接矩阵

绘制两类数据的相对占比

数据集划分和预处理

模型定义及可视化

模型训练及训练可视化

模型评价


加载相关的库函数

python 复制代码
import tensorflow.compat.v1 as tf
from sklearn.metrics import confusion_matrix
import numpy as np
from scipy.io import loadmat
import os
from pywt import wavedec
from functools import reduce
from scipy import signal
from scipy.stats import entropy
from scipy.fft import fft, ifft
import pandas as pd
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from tensorflow import keras as K
import matplotlib.pyplot as plt
import scipy
from sklearn import metrics
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold,cross_validate
from tensorflow.keras.layers import Dense, Activation, Flatten, concatenate, Input, Dropout, LSTM, Bidirectional,BatchNormalization,PReLU,ReLU,Reshape
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.metrics import classification_report
from tensorflow.keras.models import Sequential, Model, load_model
import matplotlib.pyplot as plt;
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.decomposition import PCA
from tensorflow import keras
from sklearn.model_selection import cross_val_score
from tensorflow.keras.layers import Conv1D,Conv2D,Add
from tensorflow.keras.layers import MaxPool1D, MaxPooling2D
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

读取脑电信号数据并查看数据的属性

python 复制代码
df = pd.read_csv("../input/eye-state-classification-eeg-dataset/EEG_Eye_State_Classification.csv")

df.info()

绘制脑电多通道连接矩阵

python 复制代码
plt.figure(figsize = (15,15))
cor_matrix = df.corr()
sns.heatmap(cor_matrix,annot=True)

绘制两类数据的相对占比

python 复制代码
# Plotting target distribution 
plt.figure(figsize=(6,6))
df['eyeDetection'].value_counts().plot.pie(explode=[0.1,0.1], autopct='%1.1f%%', shadow=True, textprops={'fontsize':16}).set_title("Target distribution")

数据集划分和预处理

python 复制代码
data = df.copy()
y= data.pop('eyeDetection')
x= data


x_new = StandardScaler().fit_transform(x)

x_new = pd.DataFrame(x_new) 
x_new.columns = x.columns


x_train,x_test,y_train,y_test = train_test_split(x_new,y,test_size=0.15)

x_train = np.array(x_train).reshape(-1,14,1)
x_test = np.array(x_test).reshape(-1,14,1)

模型定义及可视化

python 复制代码
inputs = tf.keras.Input(shape=(14,1))

Dense1 = Dense(64, activation = 'relu',kernel_regularizer=keras.regularizers.l2())(inputs)

#Dense2 = Dense(128, activation = 'relu',kernel_regularizer=keras.regularizers.l2())(Dense1)
#Dense3 = Dense(256, activation = 'relu',kernel_regularizer=keras.regularizers.l2())(Dense2)

lstm_1=  Bidirectional(LSTM(256, return_sequences = True))(Dense1)
drop = Dropout(0.3)(lstm_1)
lstm_3=  Bidirectional(LSTM(128, return_sequences = True))(drop)
drop2 = Dropout(0.3)(lstm_3)

flat = Flatten()(drop2)

#Dense_1 = Dense(256, activation = 'relu')(flat)

Dense_2 = Dense(128, activation = 'relu')(flat)
outputs = Dense(1, activation='sigmoid')(Dense_2)

model = tf.keras.Model(inputs, outputs)

model.summary()

tf.keras.utils.plot_model(model)



def train_model(model,x_train, y_train,x_test,y_test, save_to, epoch = 2):

        opt_adam = keras.optimizers.Adam(learning_rate=0.001)

        es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)
        mc = ModelCheckpoint(save_to + '_best_model.h5', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
        lr_schedule = tf.keras.callbacks.LearningRateScheduler(lambda epoch: 0.001 * np.exp(-epoch / 10.))
        
        model.compile(optimizer=opt_adam,
                  loss=['binary_crossentropy'],
                  metrics=['accuracy'])
        
        history = model.fit(x_train,y_train,
                        batch_size=20,
                        epochs=epoch,
                        validation_data=(x_test,y_test),
                        callbacks=[es,mc,lr_schedule])
        
        saved_model = load_model(save_to + '_best_model.h5')
        
        return model,history

模型训练及训练可视化

python 复制代码
model,history = train_model(model, x_train, y_train,x_test, y_test, save_to= './', epoch = 100)


plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

模型评价

python 复制代码
y_pred =model.predict(x_test)
y_pred = np.array(y_pred >= 0.5, dtype = np.int)
confusion_matrix(y_test, y_pred)



print(classification_report(y_test, y_pred))
相关推荐
AI人工智能+4 分钟前
从海量文档到精准数据:文档抽取技术驱动金融财税决策新范式
人工智能·nlp·ocr·文档抽取
脑极体5 分钟前
金融智能体,站在商业模式的旷野
人工智能·金融
一个处女座的程序猿7 分钟前
NLP之Embedding:Youtu-Embedding的简介、安装和使用方法、案例应用之详细攻略
人工智能·自然语言处理·embedding
青梅主码-杰哥9 分钟前
GFF(全球金融科技节)2025 BCG报告深度解读:印度,正站在全球 AI 枢纽的风口
人工智能·金融
大模型真好玩15 分钟前
OCR技术简史: 从深度学习到大模型,最强OCR大模型花落谁家
人工智能·python·deepseek
艾莉丝努力练剑21 分钟前
【C++:继承】C++面向对象继承全面解析:派生类构造、多继承、菱形虚拟继承与设计模式实践
linux·开发语言·c++·人工智能·stl·1024程序员节
少年码客22 分钟前
英文 PDF 文档翻译成中文的优质应用
人工智能·1024程序员节
刘小小_算法工程师23 分钟前
「PPG信号处理——(2)脉搏波信号刺激前后RMSSD心率变异性研究」2025年10月23日
信号处理
cxr82832 分钟前
涌现的架构:集体智能框架构建解析
人工智能·语言模型·架构·1024程序员节·ai智能体·ai赋能
星空的资源小屋1 小时前
Tuesday JS,一款可视化小说编辑器
运维·网络·人工智能·编辑器·电脑·excel