运动想象 (MI) 迁移学习系列 (15) : 基于Wasserstein距离的改进域适应网络

运动想象迁移学习系列:基于Wasserstein距离的改进域适应网络

  • [0. 引言](#0. 引言)
  • [1. 主要贡献](#1. 主要贡献)
  • [2. 与以往方法的对比](#2. 与以往方法的对比)
  • [3. 提出的方法](#3. 提出的方法)
    • [3.1 特征提取器](#3.1 特征提取器)
    • [3.2 域判别器](#3.2 域判别器)
    • [3.3 分类器](#3.3 分类器)
    • [3.4 算法流程](#3.4 算法流程)
  • [4. 总结](#4. 总结)
  • 欢迎来稿

论文地址:https://ieeexplore.ieee.org/abstract/document/10035017

论文题目:Improved Domain Adaptation Network Based on Wasserstein Distance for Motor Imagery EEG Classification

论文代码:无

0. 引言

受生成对抗网络(GAN) 的启发,本研究旨在提出一种基于Wasserstein距离的改进域适应网络,该网络利用来自多个受试者(源域)的现有标记数据来提高单个受试者(目标域)的MI分类性能。具体来说,我们提出的框架由三个组件组成,包括特征提取器、域鉴别器和分类器。特征提取器采用注意力机制和方差层来提高对从不同MI类别中提取的特征的区分。接下来,领域鉴别器采用Wasserstein矩阵来测量源域与目标域之间的距离,并通过对抗学习策略对齐源域和目标域的数据分布。最后,分类器使用从源域获取的知识来预测目标域中的标签。

总得来说:将所有数据通过域适应的方法来提高某一数据的分类精度。本篇内容建议看着算法流程来进行解读。。。

1. 主要贡献

  1. 提出了一种基于Wasserstein距离矩阵的改进领域适应网络,将改进的特征提取器对抗性领域适应模型相结合。
  2. 在对抗域适应中,使用了 Wasserstein 距离矩阵,减少了跨受试者的差异,从而可以应用来自源受试者的标记数据来扩大训练数据量。

2. 与以往方法的对比

与以往的领域适应网络不同,本文提出的框架利用特征提取器中的注意力机制和方差层来增加对运动意象特征的判别。然后,基于Wasserstein距离矩阵而不是对抗性损失函数,该框架进行领域适应以减少跨受试者差异。图1显示了以往领域适应工作与本研究框架的比较。在我们的框架中,不仅不同运动意象任务的特征可以更具辨别性,而且不同参与者的MI-EEG数据的分布可以更好地对齐。通过这种方式,可以利用来自多个参与者的MI-EEG数据来帮助对单个参与者的数据进行分类解决数据短缺问题并改善分类结果。

3. 提出的方法

如下图所示,所提出的模型中有三个主要模块,包括特征提取器分类器域判别器。在模型的训练过程中,首先将源脑电信号和目标脑电信号发送到特征提取器,利用子带滤波器和卷积层结合CBAM提取空间信息。然后使用方差层来提取时间信息。随后,可以获得源和目标特征,分别定义源域 D s D_s Ds 和目标域 D t D_t Dt 。通过以对抗方式最小化两个域之间的 Wasserstein 距离,可以减少两个域之间的数据分布差异,从而对齐两个域的数据分布,并同时学习域不变特征表征。因此,可以利用多个受试者(源域)的标记数据来帮助提高单个受试者(目标域)的分类性能。

3.1 特征提取器

特征提取器主要包含两个成分:CBAM和方差层。其中,CBAM层由通道注意力模块和空间通道注意力模块组成,可以有效提取信号中的相关特征。方差层通过计算方差来或区域时间序列的特征v,可以表示为:

x V ( k ) = 1 w ∑ t = w ∗ k ( k + 1 ) ∗ w − 1 ( x ( t ) − μ ( k ) ) 2 \begin{equation*} x_{V} (k)=\frac {1}{w}\sum \limits _{t=w\ast k}^{(k+1)\ast w-1} {(x(t)-\mu (k))^{2}} \tag{4}\end{equation*} xV(k)=w1t=w∗k∑(k+1)∗w−1(x(t)−μ(k))2(4)

其中, μ ( k ) \mu (k) μ(k) 是 x ( t ) x(t) x(t) 在第 k 个窗口内的时间平均值。

3.2 域判别器

在我们的对抗性训练中,特征提取器从源域和目标域学习域不变特征表示,以使域鉴别器难以区分特征来自哪个域,而域鉴别器测量源域和目标域数据分布之间的 Wasserstein 距离,试图找出数据所属的域。最后,学习到的特征表示可以欺骗域鉴别器,这意味着两个域之间的Wasserstein距离最小化,换句话说,两个域之间的差异减小了。因此,两个域的边际数据分布是一致的。

对于域判别器来说,源域和目标域之间的Wasserstein距离可以通过最大化与参数 θ d θ_d θd 有关的域判别器损失 L w d L_{wd} Lwd 来评估:

L w d ( x s , x t ) = 1 n s ∑ x s ∈ D s f w ( f g ( x s ) ) − 1 n t ∑ x t ∈ D t f w ( f g ( x t ) ) \begin{align*} {\mathcal{ L}}{wd} (x^{s},x^{t})=\frac {1}{n^{s}}\sum \limits {x^{s}\in D^{s}} {f{w} (f{g} (x^{s}))} -\frac {1}{n^{t}}\sum \limits {x^{t}\in D^{t}} {f{w} (f_{g} (x^{t}))} \tag{5}\end{align*} Lwd(xs,xt)=ns1xs∈Ds∑fw(fg(xs))−nt1xt∈Dt∑fw(fg(xt))(5)

3.3 分类器

分类器旨在预测从特征提取器中学习的表示的标签。在这项工作中,目标特征的标签没有用于训练分类器。相反,分类器仅使用来自源域的标记 MI-EEG 数据进行训练。然后,将训练好的分类器直接应用于目标域数据预测。 交叉熵损失被记为:
L c l s = − E x ∼ D ∑ k = 1 c l s I ( y = = k ) log ( M ( x ) ) \begin{equation*} {\mathcal{ L}}{cls} =-\mathbb {E}{x\sim D} \sum \limits {k=1}^{cls} {\mathbb {I}{(y==k)}} \text {log}({\mathcal{ M}}(x)) \tag{8}\end{equation*} Lcls=−Ex∼Dk=1∑clsI(y==k)log(M(x))(8)

其中, I \mathbb {I} I 是指标函数,如果 y 等于 k,其结果为 1,如果不等于 k,其结果为 0; M \mathcal{M} M 是建议的模型。

3.4 算法流程

4. 总结

到此,使用 基于Wasserstein距离的改进域适应网络 已经介绍完毕了!!! 如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。

如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。

欢迎来稿

欢迎投稿合作,投稿请遵循科学严谨、内容清晰明了的原则!!!! 有意者可以后台私信!!

相关推荐
SpikeKing10 分钟前
LLM - 使用 LLaMA-Factory 微调大模型 环境配置与训练推理 教程 (1)
人工智能·llm·大语言模型·llama·环境配置·llamafactory·训练框架
小码的头发丝、17 分钟前
Django中ListView 和 DetailView类的区别
数据库·python·django
黄焖鸡能干四碗39 分钟前
信息化运维方案,实施方案,开发方案,信息中心安全运维资料(软件资料word)
大数据·人工智能·软件需求·设计规范·规格说明书
40 分钟前
开源竞争-数据驱动成长-11/05-大专生的思考
人工智能·笔记·学习·算法·机器学习
ctrey_1 小时前
2024-11-4 学习人工智能的Day21 openCV(3)
人工智能·opencv·学习
攻城狮_Dream1 小时前
“探索未来医疗:生成式人工智能在医疗领域的革命性应用“
人工智能·设计·医疗·毕业
忘梓.1 小时前
划界与分类的艺术:支持向量机(SVM)的深度解析
机器学习·支持向量机·分类
Chef_Chen1 小时前
从0开始机器学习--Day17--神经网络反向传播作业
python·神经网络·机器学习
千澜空1 小时前
celery在django项目中实现并发任务和定时任务
python·django·celery·定时任务·异步任务
学习前端的小z1 小时前
【AIGC】如何通过ChatGPT轻松制作个性化GPTs应用
人工智能·chatgpt·aigc