【信号处理】基于变分自编码器(VAE)的脑电信号增强典型方法实现(tensorflow)

关于

在脑电信号分析处理任务中,数据不均衡是一个常见的问题。针对数据不均衡,传统方法有过采样和欠采样方法来应对,但是效果有限。本项目通过变分自编码器对脑电信号进行生成增强,提高增强样本的多样性,从而提高最终的后端分析性能。

EEG数据增强方法参考:https://dlib.phenikaa-uni.edu.vn/bitstream/PNK/8319/1/Data%20Augmentation%20techniques%20in%20time%20series%20domain%20a%20survey%20and%20taxonomy-2023.pdf

工具

数据集下载地址: BCI Competition IV

方法实现

加载必要的库函数和数据

python 复制代码
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob

from sklearn.model_selection import train_test_split

from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, BatchNormalization, LeakyReLU, Dense, Lambda, Reshape, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.losses import mse
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.callbacks import EarlyStopping

from tensorflow.keras import backend as K



direc = r'bci_iv_2a_data/A01/train/0/'     #data directory

train_dataset = []
train_label = []

test_dataset = []
test_label = []

files = os.listdir(direc)
for j, name in enumerate(files):
    filename = glob.glob(direc + '/'+ name)
    df = pd.read_csv(filename[0], index_col=None, header=None)
    df = df.drop(0, axis=1)     #dropping column of channel names
    df = df.iloc[:,0:1000]      #taking 1000 timesteps
    train_dataset.append(np.array(df))
            


train_dataset = np.array(train_dataset)
train_data = np.expand_dims(train_dataset,axis=-1)

VAE模型>编码器定义

python 复制代码
# VAE model
input_shape=(X_train.shape[1:])
batch_size = 32
kernel_size = 5
filters = 16
latent_dim = 2
epochs = 1000

# reparameterization
def sampling(args): 
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon




# encoder
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs

filters = filters* 2
x = Conv2D(filters=filters,kernel_size=(1, 50),strides=(1,25),)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)


filters = filters* 2
x = Conv2D(filters=filters,kernel_size=(22, 1),)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)

shape = K.int_shape(x)

x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)
z_log_var = z_log_var + 1e-8 

# reparameterization
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var]) 

encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()

VAE模型>解码器定义

python 复制代码
# decoder 
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

x = Conv2DTranspose(filters=filters,kernel_size=(22, 1),activation='relu',)(x)
x = BatchNormalization()(x)

filters = filters// 2
x = Conv2DTranspose(filters=filters,kernel_size=(1, 50),activation='relu',strides=(1,25))(x)
x = BatchNormalization()(x)

filters = filters// 2
outputs = Conv2DTranspose(filters=1,kernel_size=kernel_size,padding='same',name='decoder_output')(x)

decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()
python 复制代码
# VAE model (merging encoder and decoder)
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae')
vae.summary()

定义损失函数

python 复制代码
# defining Custom loss function 
reconstruction_loss = mse(K.flatten(inputs), K.flatten(outputs))

reconstruction_loss *= input_shape[0] * input_shape[1]
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)

#optimizer
optimizer = Adam(learning_rate=0.001, beta_1=0.5, beta_2=0.999)

# compiling vae
vae.compile(optimizer=optimizer, loss=None)
vae.summary()

模型配置和训练

python 复制代码
# early stopping callback
callbacks = EarlyStopping(monitor = 'val_loss',
                          mode='min',
                          patience =50,
                          verbose = 1,
                          restore_best_weights = True)


# fit vae model
history = vae.fit(X_train,X_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(X_test, X_test),callbacks=callbacks)

训练流程可视化

python 复制代码
# loss curves
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('loss curves')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()

中间隐空间特征2D可视化

python 复制代码
# 2D plot of the classes in latent space
z_m, _, _ = encoder.predict(X_test,batch_size=batch_size)
plt.figure(figsize=(12, 10))
plt.scatter(z_m[:, 0], z_m[:, 1], c=X_test[:,0,0,0])
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.show()

数据合成

python 复制代码
# predicting on validation data
pred=vae.predict(X_test)

代码获取

附文章底部;

相关项目开发,问题咨询,欢迎交流沟通。

相关推荐
Deepoch22 分钟前
Deepoc 大模型在无人机行业应用效果的方法
人工智能·科技·ai·语言模型·无人机
Deepoch25 分钟前
Deepoc 大模型:无人机行业的智能变革引擎
人工智能·科技·算法·ai·动态规划·无人机
kngines1 小时前
【字节跳动】数据挖掘面试题0003:有一个文件,每一行是一个数字,如何用 MapReduce 进行排序和求每个用户每个页面停留时间
人工智能·数据挖掘·mapreduce·面试题
Binary_ey1 小时前
AR衍射光波导设计遇瓶颈,OAS 光学软件来破局
人工智能·软件需求·光学软件·光波导
昵称是6硬币1 小时前
YOLOv11: AN OVERVIEW OF THE KEY ARCHITECTURAL ENHANCEMENTS目标检测论文精读(逐段解析)
图像处理·人工智能·深度学习·yolo·目标检测·计算机视觉
平和男人杨争争1 小时前
机器学习2——贝叶斯理论下
人工智能·机器学习
静心问道1 小时前
XLSR-Wav2Vec2:用于语音识别的无监督跨语言表示学习
人工智能·学习·语音识别
算家计算1 小时前
5 秒预览物理世界,2 行代码启动生成——ComfyUI-Cosmos-Predict2 本地部署教程,重塑机器人训练范式!
人工智能·开源
摆烂工程师2 小时前
国内如何安装和使用 Claude Code 教程 - Windows 用户篇
人工智能·ai编程·claude
云天徽上9 天前
【目标检测】图像处理基础:像素、分辨率与图像格式解析
图像处理·人工智能·目标检测·计算机视觉·数据可视化