运动想象 (MI) 迁移学习系列 (15) : 基于Wasserstein距离的改进域适应网络

运动想象迁移学习系列:基于Wasserstein距离的改进域适应网络

  • [0. 引言](#0. 引言)
  • [1. 主要贡献](#1. 主要贡献)
  • [2. 与以往方法的对比](#2. 与以往方法的对比)
  • [3. 提出的方法](#3. 提出的方法)
    • [3.1 特征提取器](#3.1 特征提取器)
    • [3.2 域判别器](#3.2 域判别器)
    • [3.3 分类器](#3.3 分类器)
    • [3.4 算法流程](#3.4 算法流程)
  • [4. 总结](#4. 总结)
  • 欢迎来稿

论文地址:https://ieeexplore.ieee.org/abstract/document/10035017

论文题目:Improved Domain Adaptation Network Based on Wasserstein Distance for Motor Imagery EEG Classification

论文代码:无

0. 引言

受生成对抗网络(GAN) 的启发,本研究旨在提出一种基于Wasserstein距离的改进域适应网络,该网络利用来自多个受试者(源域)的现有标记数据来提高单个受试者(目标域)的MI分类性能。具体来说,我们提出的框架由三个组件组成,包括特征提取器、域鉴别器和分类器。特征提取器采用注意力机制和方差层来提高对从不同MI类别中提取的特征的区分。接下来,领域鉴别器采用Wasserstein矩阵来测量源域与目标域之间的距离,并通过对抗学习策略对齐源域和目标域的数据分布。最后,分类器使用从源域获取的知识来预测目标域中的标签。

总得来说:将所有数据通过域适应的方法来提高某一数据的分类精度。本篇内容建议看着算法流程来进行解读。。。

1. 主要贡献

  1. 提出了一种基于Wasserstein距离矩阵的改进领域适应网络,将改进的特征提取器对抗性领域适应模型相结合。
  2. 在对抗域适应中,使用了 Wasserstein 距离矩阵,减少了跨受试者的差异,从而可以应用来自源受试者的标记数据来扩大训练数据量。

2. 与以往方法的对比

与以往的领域适应网络不同,本文提出的框架利用特征提取器中的注意力机制和方差层来增加对运动意象特征的判别。然后,基于Wasserstein距离矩阵而不是对抗性损失函数,该框架进行领域适应以减少跨受试者差异。图1显示了以往领域适应工作与本研究框架的比较。在我们的框架中,不仅不同运动意象任务的特征可以更具辨别性,而且不同参与者的MI-EEG数据的分布可以更好地对齐。通过这种方式,可以利用来自多个参与者的MI-EEG数据来帮助对单个参与者的数据进行分类解决数据短缺问题并改善分类结果。

3. 提出的方法

如下图所示,所提出的模型中有三个主要模块,包括特征提取器分类器域判别器。在模型的训练过程中,首先将源脑电信号和目标脑电信号发送到特征提取器,利用子带滤波器和卷积层结合CBAM提取空间信息。然后使用方差层来提取时间信息。随后,可以获得源和目标特征,分别定义源域 D s D_s Ds 和目标域 D t D_t Dt 。通过以对抗方式最小化两个域之间的 Wasserstein 距离,可以减少两个域之间的数据分布差异,从而对齐两个域的数据分布,并同时学习域不变特征表征。因此,可以利用多个受试者(源域)的标记数据来帮助提高单个受试者(目标域)的分类性能。

3.1 特征提取器

特征提取器主要包含两个成分:CBAM和方差层。其中,CBAM层由通道注意力模块和空间通道注意力模块组成,可以有效提取信号中的相关特征。方差层通过计算方差来或区域时间序列的特征v,可以表示为:

x V ( k ) = 1 w ∑ t = w ∗ k ( k + 1 ) ∗ w − 1 ( x ( t ) − μ ( k ) ) 2 \begin{equation*} x_{V} (k)=\frac {1}{w}\sum \limits _{t=w\ast k}^{(k+1)\ast w-1} {(x(t)-\mu (k))^{2}} \tag{4}\end{equation*} xV(k)=w1t=w∗k∑(k+1)∗w−1(x(t)−μ(k))2(4)

其中, μ ( k ) \mu (k) μ(k) 是 x ( t ) x(t) x(t) 在第 k 个窗口内的时间平均值。

3.2 域判别器

在我们的对抗性训练中,特征提取器从源域和目标域学习域不变特征表示,以使域鉴别器难以区分特征来自哪个域,而域鉴别器测量源域和目标域数据分布之间的 Wasserstein 距离,试图找出数据所属的域。最后,学习到的特征表示可以欺骗域鉴别器,这意味着两个域之间的Wasserstein距离最小化,换句话说,两个域之间的差异减小了。因此,两个域的边际数据分布是一致的。

对于域判别器来说,源域和目标域之间的Wasserstein距离可以通过最大化与参数 θ d θ_d θd 有关的域判别器损失 L w d L_{wd} Lwd 来评估:

L w d ( x s , x t ) = 1 n s ∑ x s ∈ D s f w ( f g ( x s ) ) − 1 n t ∑ x t ∈ D t f w ( f g ( x t ) ) \begin{align*} {\mathcal{ L}}{wd} (x^{s},x^{t})=\frac {1}{n^{s}}\sum \limits {x^{s}\in D^{s}} {f{w} (f{g} (x^{s}))} -\frac {1}{n^{t}}\sum \limits {x^{t}\in D^{t}} {f{w} (f_{g} (x^{t}))} \tag{5}\end{align*} Lwd(xs,xt)=ns1xs∈Ds∑fw(fg(xs))−nt1xt∈Dt∑fw(fg(xt))(5)

3.3 分类器

分类器旨在预测从特征提取器中学习的表示的标签。在这项工作中,目标特征的标签没有用于训练分类器。相反,分类器仅使用来自源域的标记 MI-EEG 数据进行训练。然后,将训练好的分类器直接应用于目标域数据预测。 交叉熵损失被记为:
L c l s = − E x ∼ D ∑ k = 1 c l s I ( y = = k ) log ( M ( x ) ) \begin{equation*} {\mathcal{ L}}{cls} =-\mathbb {E}{x\sim D} \sum \limits {k=1}^{cls} {\mathbb {I}{(y==k)}} \text {log}({\mathcal{ M}}(x)) \tag{8}\end{equation*} Lcls=−Ex∼Dk=1∑clsI(y==k)log(M(x))(8)

其中, I \mathbb {I} I 是指标函数,如果 y 等于 k,其结果为 1,如果不等于 k,其结果为 0; M \mathcal{M} M 是建议的模型。

3.4 算法流程

4. 总结

到此,使用 基于Wasserstein距离的改进域适应网络 已经介绍完毕了!!! 如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。

如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。

欢迎来稿

欢迎投稿合作,投稿请遵循科学严谨、内容清晰明了的原则!!!! 有意者可以后台私信!!

相关推荐
LZXCyrus7 分钟前
【杂记】vLLM如何指定GPU单卡/多卡离线推理
人工智能·经验分享·python·深度学习·语言模型·llm·vllm
Enougme10 分钟前
Appium常用的使用方法(一)
python·appium
懷淰メ16 分钟前
PyQt飞机大战游戏(附下载地址)
开发语言·python·qt·游戏·pyqt·游戏开发·pyqt5
我感觉。25 分钟前
【机器学习chp4】特征工程
人工智能·机器学习·主成分分析·特征工程
hummhumm30 分钟前
第 22 章 - Go语言 测试与基准测试
java·大数据·开发语言·前端·python·golang·log4j
YRr YRr33 分钟前
深度学习神经网络中的优化器的使用
人工智能·深度学习·神经网络
DieYoung_Alive34 分钟前
一篇文章了解机器学习(下)
人工智能·机器学习
夏沫的梦35 分钟前
生成式AI对产业的影响与冲击
人工智能·aigc
hummhumm1 小时前
第 28 章 - Go语言 Web 开发入门
java·开发语言·前端·python·sql·golang·前端框架
goomind1 小时前
YOLOv8实战木材缺陷识别
人工智能·yolo·目标检测·缺陷检测·pyqt5·木材缺陷识别