运动想象 (MI) 迁移学习系列 (15) : 基于Wasserstein距离的改进域适应网络

运动想象迁移学习系列:基于Wasserstein距离的改进域适应网络

  • [0. 引言](#0. 引言)
  • [1. 主要贡献](#1. 主要贡献)
  • [2. 与以往方法的对比](#2. 与以往方法的对比)
  • [3. 提出的方法](#3. 提出的方法)
    • [3.1 特征提取器](#3.1 特征提取器)
    • [3.2 域判别器](#3.2 域判别器)
    • [3.3 分类器](#3.3 分类器)
    • [3.4 算法流程](#3.4 算法流程)
  • [4. 总结](#4. 总结)
  • 欢迎来稿

论文地址:https://ieeexplore.ieee.org/abstract/document/10035017

论文题目:Improved Domain Adaptation Network Based on Wasserstein Distance for Motor Imagery EEG Classification

论文代码:无

0. 引言

受生成对抗网络(GAN) 的启发,本研究旨在提出一种基于Wasserstein距离的改进域适应网络,该网络利用来自多个受试者(源域)的现有标记数据来提高单个受试者(目标域)的MI分类性能。具体来说,我们提出的框架由三个组件组成,包括特征提取器、域鉴别器和分类器。特征提取器采用注意力机制和方差层来提高对从不同MI类别中提取的特征的区分。接下来,领域鉴别器采用Wasserstein矩阵来测量源域与目标域之间的距离,并通过对抗学习策略对齐源域和目标域的数据分布。最后,分类器使用从源域获取的知识来预测目标域中的标签。

总得来说:将所有数据通过域适应的方法来提高某一数据的分类精度。本篇内容建议看着算法流程来进行解读。。。

1. 主要贡献

  1. 提出了一种基于Wasserstein距离矩阵的改进领域适应网络,将改进的特征提取器对抗性领域适应模型相结合。
  2. 在对抗域适应中,使用了 Wasserstein 距离矩阵,减少了跨受试者的差异,从而可以应用来自源受试者的标记数据来扩大训练数据量。

2. 与以往方法的对比

与以往的领域适应网络不同,本文提出的框架利用特征提取器中的注意力机制和方差层来增加对运动意象特征的判别。然后,基于Wasserstein距离矩阵而不是对抗性损失函数,该框架进行领域适应以减少跨受试者差异。图1显示了以往领域适应工作与本研究框架的比较。在我们的框架中,不仅不同运动意象任务的特征可以更具辨别性,而且不同参与者的MI-EEG数据的分布可以更好地对齐。通过这种方式,可以利用来自多个参与者的MI-EEG数据来帮助对单个参与者的数据进行分类解决数据短缺问题并改善分类结果。

3. 提出的方法

如下图所示,所提出的模型中有三个主要模块,包括特征提取器分类器域判别器。在模型的训练过程中,首先将源脑电信号和目标脑电信号发送到特征提取器,利用子带滤波器和卷积层结合CBAM提取空间信息。然后使用方差层来提取时间信息。随后,可以获得源和目标特征,分别定义源域 D s D_s Ds 和目标域 D t D_t Dt 。通过以对抗方式最小化两个域之间的 Wasserstein 距离,可以减少两个域之间的数据分布差异,从而对齐两个域的数据分布,并同时学习域不变特征表征。因此,可以利用多个受试者(源域)的标记数据来帮助提高单个受试者(目标域)的分类性能。

3.1 特征提取器

特征提取器主要包含两个成分:CBAM和方差层。其中,CBAM层由通道注意力模块和空间通道注意力模块组成,可以有效提取信号中的相关特征。方差层通过计算方差来或区域时间序列的特征v,可以表示为:

x V ( k ) = 1 w ∑ t = w ∗ k ( k + 1 ) ∗ w − 1 ( x ( t ) − μ ( k ) ) 2 \begin{equation*} x_{V} (k)=\frac {1}{w}\sum \limits _{t=w\ast k}^{(k+1)\ast w-1} {(x(t)-\mu (k))^{2}} \tag{4}\end{equation*} xV(k)=w1t=w∗k∑(k+1)∗w−1(x(t)−μ(k))2(4)

其中, μ ( k ) \mu (k) μ(k) 是 x ( t ) x(t) x(t) 在第 k 个窗口内的时间平均值。

3.2 域判别器

在我们的对抗性训练中,特征提取器从源域和目标域学习域不变特征表示,以使域鉴别器难以区分特征来自哪个域,而域鉴别器测量源域和目标域数据分布之间的 Wasserstein 距离,试图找出数据所属的域。最后,学习到的特征表示可以欺骗域鉴别器,这意味着两个域之间的Wasserstein距离最小化,换句话说,两个域之间的差异减小了。因此,两个域的边际数据分布是一致的。

对于域判别器来说,源域和目标域之间的Wasserstein距离可以通过最大化与参数 θ d θ_d θd 有关的域判别器损失 L w d L_{wd} Lwd 来评估:

L w d ( x s , x t ) = 1 n s ∑ x s ∈ D s f w ( f g ( x s ) ) − 1 n t ∑ x t ∈ D t f w ( f g ( x t ) ) \begin{align*} {\mathcal{ L}}{wd} (x^{s},x^{t})=\frac {1}{n^{s}}\sum \limits {x^{s}\in D^{s}} {f{w} (f{g} (x^{s}))} -\frac {1}{n^{t}}\sum \limits {x^{t}\in D^{t}} {f{w} (f_{g} (x^{t}))} \tag{5}\end{align*} Lwd(xs,xt)=ns1xs∈Ds∑fw(fg(xs))−nt1xt∈Dt∑fw(fg(xt))(5)

3.3 分类器

分类器旨在预测从特征提取器中学习的表示的标签。在这项工作中,目标特征的标签没有用于训练分类器。相反,分类器仅使用来自源域的标记 MI-EEG 数据进行训练。然后,将训练好的分类器直接应用于目标域数据预测。 交叉熵损失被记为:
L c l s = − E x ∼ D ∑ k = 1 c l s I ( y = = k ) log ( M ( x ) ) \begin{equation*} {\mathcal{ L}}{cls} =-\mathbb {E}{x\sim D} \sum \limits {k=1}^{cls} {\mathbb {I}{(y==k)}} \text {log}({\mathcal{ M}}(x)) \tag{8}\end{equation*} Lcls=−Ex∼Dk=1∑clsI(y==k)log(M(x))(8)

其中, I \mathbb {I} I 是指标函数,如果 y 等于 k,其结果为 1,如果不等于 k,其结果为 0; M \mathcal{M} M 是建议的模型。

3.4 算法流程

4. 总结

到此,使用 基于Wasserstein距离的改进域适应网络 已经介绍完毕了!!! 如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。

如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。

欢迎来稿

欢迎投稿合作,投稿请遵循科学严谨、内容清晰明了的原则!!!! 有意者可以后台私信!!

相关推荐
阿俊仔(摸鱼版)1 分钟前
Python 常用运维模块之OS模块篇
运维·开发语言·python·云服务器
lly_csdn12334 分钟前
【Image Captioning】DynRefer
python·深度学习·ai·图像分类·多模态·字幕生成·属性识别
速融云1 小时前
汽车制造行业案例 | 发动机在制造品管理全解析(附解决方案模板)
大数据·人工智能·自动化·汽车·制造
西猫雷婶1 小时前
python学opencv|读取图像(四十一 )使用cv2.add()函数实现各个像素点BGR叠加
开发语言·python·opencv
金融OG1 小时前
99.11 金融难点通俗解释:净资产收益率(ROE)VS投资资本回报率(ROIC)VS总资产收益率(ROA)
大数据·python·算法·机器学习·金融
AI明说1 小时前
什么是稀疏 MoE?Doubao-1.5-pro 如何以少胜多?
人工智能·大模型·moe·豆包
XianxinMao1 小时前
重构开源LLM分类:从二分到三分的转变
人工智能·语言模型·开源
Elastic 中国社区官方博客2 小时前
使用 Elasticsearch 导航检索增强生成图表
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
小唐C++2 小时前
C++小病毒-1.0勒索
开发语言·c++·vscode·python·算法·c#·编辑器
云天徽上2 小时前
【数据可视化】全国星巴克门店可视化
人工智能·机器学习·信息可视化·数据挖掘·数据分析