【信号处理】基于变分自编码器(VAE)的脑电信号增强典型方法实现(tensorflow)

关于

在脑电信号分析处理任务中,数据不均衡是一个常见的问题。针对数据不均衡,传统方法有过采样和欠采样方法来应对,但是效果有限。本项目通过变分自编码器对脑电信号进行生成增强,提高增强样本的多样性,从而提高最终的后端分析性能。

EEG数据增强方法参考:https://dlib.phenikaa-uni.edu.vn/bitstream/PNK/8319/1/Data%20Augmentation%20techniques%20in%20time%20series%20domain%20a%20survey%20and%20taxonomy-2023.pdf

工具

数据集下载地址: BCI Competition IV

方法实现

加载必要的库函数和数据

python 复制代码
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob

from sklearn.model_selection import train_test_split

from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, BatchNormalization, LeakyReLU, Dense, Lambda, Reshape, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.losses import mse
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.callbacks import EarlyStopping

from tensorflow.keras import backend as K



direc = r'bci_iv_2a_data/A01/train/0/'     #data directory

train_dataset = []
train_label = []

test_dataset = []
test_label = []

files = os.listdir(direc)
for j, name in enumerate(files):
    filename = glob.glob(direc + '/'+ name)
    df = pd.read_csv(filename[0], index_col=None, header=None)
    df = df.drop(0, axis=1)     #dropping column of channel names
    df = df.iloc[:,0:1000]      #taking 1000 timesteps
    train_dataset.append(np.array(df))
            


train_dataset = np.array(train_dataset)
train_data = np.expand_dims(train_dataset,axis=-1)

VAE模型>编码器定义

python 复制代码
# VAE model
input_shape=(X_train.shape[1:])
batch_size = 32
kernel_size = 5
filters = 16
latent_dim = 2
epochs = 1000

# reparameterization
def sampling(args): 
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon




# encoder
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs

filters = filters* 2
x = Conv2D(filters=filters,kernel_size=(1, 50),strides=(1,25),)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)


filters = filters* 2
x = Conv2D(filters=filters,kernel_size=(22, 1),)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)

shape = K.int_shape(x)

x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)
z_log_var = z_log_var + 1e-8 

# reparameterization
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var]) 

encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()

VAE模型>解码器定义

python 复制代码
# decoder 
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

x = Conv2DTranspose(filters=filters,kernel_size=(22, 1),activation='relu',)(x)
x = BatchNormalization()(x)

filters = filters// 2
x = Conv2DTranspose(filters=filters,kernel_size=(1, 50),activation='relu',strides=(1,25))(x)
x = BatchNormalization()(x)

filters = filters// 2
outputs = Conv2DTranspose(filters=1,kernel_size=kernel_size,padding='same',name='decoder_output')(x)

decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()
python 复制代码
# VAE model (merging encoder and decoder)
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae')
vae.summary()

定义损失函数

python 复制代码
# defining Custom loss function 
reconstruction_loss = mse(K.flatten(inputs), K.flatten(outputs))

reconstruction_loss *= input_shape[0] * input_shape[1]
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)

#optimizer
optimizer = Adam(learning_rate=0.001, beta_1=0.5, beta_2=0.999)

# compiling vae
vae.compile(optimizer=optimizer, loss=None)
vae.summary()

模型配置和训练

python 复制代码
# early stopping callback
callbacks = EarlyStopping(monitor = 'val_loss',
                          mode='min',
                          patience =50,
                          verbose = 1,
                          restore_best_weights = True)


# fit vae model
history = vae.fit(X_train,X_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(X_test, X_test),callbacks=callbacks)

训练流程可视化

python 复制代码
# loss curves
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('loss curves')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()

中间隐空间特征2D可视化

python 复制代码
# 2D plot of the classes in latent space
z_m, _, _ = encoder.predict(X_test,batch_size=batch_size)
plt.figure(figsize=(12, 10))
plt.scatter(z_m[:, 0], z_m[:, 1], c=X_test[:,0,0,0])
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.show()

数据合成

python 复制代码
# predicting on validation data
pred=vae.predict(X_test)

代码获取

附文章底部;

相关项目开发,问题咨询,欢迎交流沟通。

相关推荐
mit6.82429 分钟前
[AI React Web] 包与依赖管理 | `axios`库 | `framer-motion`库
前端·人工智能·react.js
小阿鑫44 分钟前
不要太信任Cursor,这位网友被删库了。。。
人工智能·aigc·cursor·部署mcp
说私域1 小时前
基于定制开发开源 AI 智能名片 S2B2C 商城小程序的热点与人工下发策略研究
人工智能·小程序
GoGeekBaird2 小时前
GoHumanLoopHub开源上线,开启Agent人际协作新方式
人工智能·后端·github
Jinkxs2 小时前
测试工程师的AI转型指南:从工具使用到测试策略重构
人工智能·重构
别惹CC3 小时前
Spring AI 进阶之路01:三步将 AI 整合进 Spring Boot
人工智能·spring boot·spring
stbomei5 小时前
当 AI 开始 “理解” 情感:情感计算技术正在改写人机交互规则
人工智能·人机交互
Moshow郑锴10 小时前
人工智能中的(特征选择)数据过滤方法和包裹方法
人工智能
TY-202510 小时前
【CV 目标检测】Fast RCNN模型①——与R-CNN区别
人工智能·目标检测·目标跟踪·cnn
CareyWYR11 小时前
苹果芯片Mac使用Docker部署MinerU api服务
人工智能