【信号处理】基于变分自编码器(VAE)的脑电信号增强典型方法实现(tensorflow)

关于

在脑电信号分析处理任务中,数据不均衡是一个常见的问题。针对数据不均衡,传统方法有过采样和欠采样方法来应对,但是效果有限。本项目通过变分自编码器对脑电信号进行生成增强,提高增强样本的多样性,从而提高最终的后端分析性能。

EEG数据增强方法参考:https://dlib.phenikaa-uni.edu.vn/bitstream/PNK/8319/1/Data%20Augmentation%20techniques%20in%20time%20series%20domain%20a%20survey%20and%20taxonomy-2023.pdf

工具

数据集下载地址: BCI Competition IV

方法实现

加载必要的库函数和数据

python 复制代码
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob

from sklearn.model_selection import train_test_split

from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, BatchNormalization, LeakyReLU, Dense, Lambda, Reshape, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.losses import mse
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.callbacks import EarlyStopping

from tensorflow.keras import backend as K



direc = r'bci_iv_2a_data/A01/train/0/'     #data directory

train_dataset = []
train_label = []

test_dataset = []
test_label = []

files = os.listdir(direc)
for j, name in enumerate(files):
    filename = glob.glob(direc + '/'+ name)
    df = pd.read_csv(filename[0], index_col=None, header=None)
    df = df.drop(0, axis=1)     #dropping column of channel names
    df = df.iloc[:,0:1000]      #taking 1000 timesteps
    train_dataset.append(np.array(df))
            


train_dataset = np.array(train_dataset)
train_data = np.expand_dims(train_dataset,axis=-1)

VAE模型>编码器定义

python 复制代码
# VAE model
input_shape=(X_train.shape[1:])
batch_size = 32
kernel_size = 5
filters = 16
latent_dim = 2
epochs = 1000

# reparameterization
def sampling(args): 
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon




# encoder
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs

filters = filters* 2
x = Conv2D(filters=filters,kernel_size=(1, 50),strides=(1,25),)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)


filters = filters* 2
x = Conv2D(filters=filters,kernel_size=(22, 1),)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)

shape = K.int_shape(x)

x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)
z_log_var = z_log_var + 1e-8 

# reparameterization
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var]) 

encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()

VAE模型>解码器定义

python 复制代码
# decoder 
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

x = Conv2DTranspose(filters=filters,kernel_size=(22, 1),activation='relu',)(x)
x = BatchNormalization()(x)

filters = filters// 2
x = Conv2DTranspose(filters=filters,kernel_size=(1, 50),activation='relu',strides=(1,25))(x)
x = BatchNormalization()(x)

filters = filters// 2
outputs = Conv2DTranspose(filters=1,kernel_size=kernel_size,padding='same',name='decoder_output')(x)

decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()
python 复制代码
# VAE model (merging encoder and decoder)
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae')
vae.summary()

定义损失函数

python 复制代码
# defining Custom loss function 
reconstruction_loss = mse(K.flatten(inputs), K.flatten(outputs))

reconstruction_loss *= input_shape[0] * input_shape[1]
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)

#optimizer
optimizer = Adam(learning_rate=0.001, beta_1=0.5, beta_2=0.999)

# compiling vae
vae.compile(optimizer=optimizer, loss=None)
vae.summary()

模型配置和训练

python 复制代码
# early stopping callback
callbacks = EarlyStopping(monitor = 'val_loss',
                          mode='min',
                          patience =50,
                          verbose = 1,
                          restore_best_weights = True)


# fit vae model
history = vae.fit(X_train,X_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(X_test, X_test),callbacks=callbacks)

训练流程可视化

python 复制代码
# loss curves
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('loss curves')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()

中间隐空间特征2D可视化

python 复制代码
# 2D plot of the classes in latent space
z_m, _, _ = encoder.predict(X_test,batch_size=batch_size)
plt.figure(figsize=(12, 10))
plt.scatter(z_m[:, 0], z_m[:, 1], c=X_test[:,0,0,0])
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.show()

数据合成

python 复制代码
# predicting on validation data
pred=vae.predict(X_test)

代码获取

附文章底部;

相关项目开发,问题咨询,欢迎交流沟通。

相关推荐
Tutankaaa1 天前
从被动接受到主动挑战:知识竞赛如何重塑学习价值
人工智能·经验分享·笔记·学习
Jmayday1 天前
机器学习基本理论
人工智能·机器学习
ZhengEnCi1 天前
01b-上下文向量与信息瓶颈
人工智能
王_teacher1 天前
机器学习 矩阵求导 完整公式+严谨推导
人工智能·线性代数·考研·机器学习·矩阵·线性回归
码以致用1 天前
DeerFlow Memory架构
人工智能·ai·架构·agent
ting94520001 天前
从零构建大模型实战:数据处理与 GPT-2 完整实现
人工智能
学点程序1 天前
Manifest:帮个人 AI Agent 降低模型成本的开源路由器
人工智能·开源
可观测性用观测云1 天前
观测云 x AI Agent:运维智能化的范式跃迁实践
人工智能
数数科技的数据干货1 天前
ThinkingAI携手华为云,共建企业级AI Agent平台Agentic Engine
人工智能·ai·华为云·agent
人工智能AI技术1 天前
春招急救:7天面试突击方案
人工智能