信号处理--使用EEGNet进行BCI脑电信号的分类

目录

理论

工具

方法实现

代码获取


理论

EEGNet作为一个比较成熟的框架,在BCI众多任务中,表现出不俗的性能。EEGNet 的主要特点包括:1)框架相对比较简单紧凑 2)适合许多的BCI脑电分析任务 3)使用两种卷积 Depth-wise convolution 和 separable convolution 实现普适特征的提取。

工具

Pytorch

P300 visual-evoked potentials数据集

error-related negativity responses (ERN) 数据集

movement-related cortical potentials (MRCP) 数据集

sensory motor rhythms (SMR) 数据集

方法实现

EEGNet模型定义

python 复制代码
class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()
        self.T = 120
        
        # Layer 1
        self.conv1 = nn.Conv2d(1, 16, (1, 64), padding = 0)
        self.batchnorm1 = nn.BatchNorm2d(16, False)
        
        # Layer 2
        self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))
        self.conv2 = nn.Conv2d(1, 4, (2, 32))
        self.batchnorm2 = nn.BatchNorm2d(4, False)
        self.pooling2 = nn.MaxPool2d(2, 4)
        
        # Layer 3
        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))
        self.conv3 = nn.Conv2d(4, 4, (8, 4))
        self.batchnorm3 = nn.BatchNorm2d(4, False)
        self.pooling3 = nn.MaxPool2d((2, 4))
        
        # FC Layer
        # NOTE: This dimension will depend on the number of timestamps per sample in your data.
        # I have 120 timepoints. 
        self.fc1 = nn.Linear(4*2*7, 1)
        

    def forward(self, x):
        # Layer 1
        x = F.elu(self.conv1(x))
        x = self.batchnorm1(x)
        x = F.dropout(x, 0.25)
        x = x.permute(0, 3, 1, 2)
        
        # Layer 2
        x = self.padding1(x)
        x = F.elu(self.conv2(x))
        x = self.batchnorm2(x)
        x = F.dropout(x, 0.25)
        x = self.pooling2(x)
        
        # Layer 3
        x = self.padding2(x)
        x = F.elu(self.conv3(x))
        x = self.batchnorm3(x)
        x = F.dropout(x, 0.25)
        x = self.pooling3(x)
        
        # FC Layer
        x = x.view(-1, 4*2*7)
        x = F.sigmoid(self.fc1(x))
        return x


net = EEGNet().cuda(0)
print net.forward(Variable(torch.Tensor(np.random.rand(1, 1, 120, 64)).cuda(0)))
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters())

评估模型分类的相关指标

python 复制代码
def evaluate(model, X, Y, params = ["acc"]):
    results = []
    batch_size = 100
    
    predicted = []
    
    for i in range(len(X)/batch_size):
        s = i*batch_size
        e = i*batch_size+batch_size
        
        inputs = Variable(torch.from_numpy(X[s:e]).cuda(0))
        pred = model(inputs)
        
        predicted.append(pred.data.cpu().numpy())
        
        
    inputs = Variable(torch.from_numpy(X).cuda(0))
    predicted = model(inputs)
    
    predicted = predicted.data.cpu().numpy()
    
    for param in params:
        if param == 'acc':
            results.append(accuracy_score(Y, np.round(predicted)))
        if param == "auc":
            results.append(roc_auc_score(Y, predicted))
        if param == "recall":
            results.append(recall_score(Y, np.round(predicted)))
        if param == "precision":
            results.append(precision_score(Y, np.round(predicted)))
        if param == "fmeasure":
            precision = precision_score(Y, np.round(predicted))
            recall = recall_score(Y, np.round(predicted))
            results.append(2*precision*recall/ (precision+recall))
    return results

模型的训练和测试

python 复制代码
batch_size = 32

for epoch in range(10):  # loop over the dataset multiple times
    print "\nEpoch ", epoch
    
    running_loss = 0.0
    for i in range(len(X_train)/batch_size-1):
        s = i*batch_size
        e = i*batch_size+batch_size
        
        inputs = torch.from_numpy(X_train[s:e])
        labels = torch.FloatTensor(np.array([y_train[s:e]]).T*1.0)
        
        # wrap them in Variable
        inputs, labels = Variable(inputs.cuda(0)), Variable(labels.cuda(0))

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        
        optimizer.step()
        
        running_loss += loss.data[0]
    
    # Validation accuracy
    params = ["acc", "auc", "fmeasure"]
    print params
    print "Training Loss ", running_loss
    print "Train - ", evaluate(net, X_train, y_train, params)
    print "Validation - ", evaluate(net, X_val, y_val, params)
    print "Test - ", evaluate(net, X_test, y_test, params)

模型提取部分特征的可视化

代码获取

信号处理-使用EEGNet进行BCI脑电信号的分类https://download.csdn.net/download/YINTENAXIONGNAIER/89025247

相关推荐
jiang_bluetooth1 天前
蓝牙典型射频架构剖析
蓝牙·信号处理·射频·pa·lna
通信小呆呆2 天前
神经网络在通信与雷达领域:从信号处理到智能决策
人工智能·神经网络·信号处理
新新学长搞科研2 天前
【安徽大学主办】第五届半导体与电子技术国际研讨会(ISSET 2026)
大数据·数据库·人工智能·自动化·信号处理·半导体·电子
高翔·权衡之境3 天前
主题7:缓存与队列——速度不匹配的通用解
开发语言·人工智能·物联网·缓存·信息与通信·信号处理
Irissgwe3 天前
九、Linux信号机制(一)
linux·信号处理·信号·进程信号·信号的产生
通信小呆呆6 天前
基于 ADMM-MFOCUSS 的捷变频雷达扩展目标稀疏重构原理
算法·重构·信息与通信·信号处理·雷达
2601_958352906 天前
拆解 EN-46:一块 15mA 的 DSP 芯片如何实现 50dB 降噪
人工智能·语音识别·信号处理·嵌入式开发·音频降噪·双麦波束成形·硬件拆解
好好学仿真6 天前
【故障诊断】DSCNN-HA-TL:融合Swin窗口注意力和全局注意力机制的变工况轴承故障诊断(迁移学习/小样本)
机器学习·信号处理·迁移学习·swintransformer·轴承故障诊断·深度可分离卷积·gam注意力
山河君6 天前
从 ACF 到 YIN:基频检测算法原理与实现
人工智能·算法·音视频·语音识别·信号处理
156002548407 天前
基于VU9P+Zynq的双FMC基带信号处理板(支持国产替换)
信号处理