【信号处理】基于变分自编码器(VAE)的脑电信号增强典型方法实现(tensorflow)

关于

在脑电信号分析处理任务中,数据不均衡是一个常见的问题。针对数据不均衡,传统方法有过采样和欠采样方法来应对,但是效果有限。本项目通过变分自编码器对脑电信号进行生成增强,提高增强样本的多样性,从而提高最终的后端分析性能。

EEG数据增强方法参考:https://dlib.phenikaa-uni.edu.vn/bitstream/PNK/8319/1/Data%20Augmentation%20techniques%20in%20time%20series%20domain%20a%20survey%20and%20taxonomy-2023.pdf

工具

数据集下载地址: BCI Competition IV

方法实现

加载必要的库函数和数据

python 复制代码
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob

from sklearn.model_selection import train_test_split

from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, BatchNormalization, LeakyReLU, Dense, Lambda, Reshape, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.losses import mse
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.callbacks import EarlyStopping

from tensorflow.keras import backend as K



direc = r'bci_iv_2a_data/A01/train/0/'     #data directory

train_dataset = []
train_label = []

test_dataset = []
test_label = []

files = os.listdir(direc)
for j, name in enumerate(files):
    filename = glob.glob(direc + '/'+ name)
    df = pd.read_csv(filename[0], index_col=None, header=None)
    df = df.drop(0, axis=1)     #dropping column of channel names
    df = df.iloc[:,0:1000]      #taking 1000 timesteps
    train_dataset.append(np.array(df))
            


train_dataset = np.array(train_dataset)
train_data = np.expand_dims(train_dataset,axis=-1)

VAE模型>编码器定义

python 复制代码
# VAE model
input_shape=(X_train.shape[1:])
batch_size = 32
kernel_size = 5
filters = 16
latent_dim = 2
epochs = 1000

# reparameterization
def sampling(args): 
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon




# encoder
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs

filters = filters* 2
x = Conv2D(filters=filters,kernel_size=(1, 50),strides=(1,25),)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)


filters = filters* 2
x = Conv2D(filters=filters,kernel_size=(22, 1),)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)

shape = K.int_shape(x)

x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)
z_log_var = z_log_var + 1e-8 

# reparameterization
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var]) 

encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()

VAE模型>解码器定义

python 复制代码
# decoder 
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

x = Conv2DTranspose(filters=filters,kernel_size=(22, 1),activation='relu',)(x)
x = BatchNormalization()(x)

filters = filters// 2
x = Conv2DTranspose(filters=filters,kernel_size=(1, 50),activation='relu',strides=(1,25))(x)
x = BatchNormalization()(x)

filters = filters// 2
outputs = Conv2DTranspose(filters=1,kernel_size=kernel_size,padding='same',name='decoder_output')(x)

decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()
python 复制代码
# VAE model (merging encoder and decoder)
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae')
vae.summary()

定义损失函数

python 复制代码
# defining Custom loss function 
reconstruction_loss = mse(K.flatten(inputs), K.flatten(outputs))

reconstruction_loss *= input_shape[0] * input_shape[1]
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)

#optimizer
optimizer = Adam(learning_rate=0.001, beta_1=0.5, beta_2=0.999)

# compiling vae
vae.compile(optimizer=optimizer, loss=None)
vae.summary()

模型配置和训练

python 复制代码
# early stopping callback
callbacks = EarlyStopping(monitor = 'val_loss',
                          mode='min',
                          patience =50,
                          verbose = 1,
                          restore_best_weights = True)


# fit vae model
history = vae.fit(X_train,X_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(X_test, X_test),callbacks=callbacks)

训练流程可视化

python 复制代码
# loss curves
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('loss curves')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()

中间隐空间特征2D可视化

python 复制代码
# 2D plot of the classes in latent space
z_m, _, _ = encoder.predict(X_test,batch_size=batch_size)
plt.figure(figsize=(12, 10))
plt.scatter(z_m[:, 0], z_m[:, 1], c=X_test[:,0,0,0])
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.show()

数据合成

python 复制代码
# predicting on validation data
pred=vae.predict(X_test)

代码获取

附文章底部;

相关项目开发,问题咨询,欢迎交流沟通。

相关推荐
柠檬071112 小时前
vector<cv::point2f>如何快速转成opencv mat
人工智能·opencv·计算机视觉
Pyeako12 小时前
Opencv计算机视觉
人工智能·python·深度学习·opencv·计算机视觉
aopstudio12 小时前
ASR概念和术语学习指南(2):传统 ASR 系统的工作流程
人工智能·语音识别·asr
雅欣鱼子酱12 小时前
ECP5702 PD诱骗协议芯片,单芯片取电5V~20V输出给后端充电模板!
网络·人工智能·芯片·电子元器件
司南OpenCompass12 小时前
司南“六位一体”评测体系的一年演进
人工智能·大模型·多模态模型·大模型评测·司南评测·ai评测
大模型任我行12 小时前
电信:Agent记忆管理决策理论框架DAM
人工智能·语言模型·自然语言处理·论文笔记
学习3人组12 小时前
目标检测训练常见问题排查清单
人工智能·目标检测·计算机视觉
Coder_Boy_12 小时前
基于SpringAI的智能AIOps项目:微服务与DDD多模块融合设计概述
java·运维·人工智能·微服务·faiss
Apache IoTDB12 小时前
TsFile 开源文件格式:AI 时代工业时序数据集新选择,让数据资产“活”起来
人工智能·开源
com_4sapi12 小时前
星链引擎4SAPICOM:全球API服务平台优选,助力企业高效连接智能生态
大数据·人工智能·云计算