Sd-CDA (自退化对比域适应框架):解决工业故障诊断中数据不平衡问题

现代工业故障诊断任务常常面临分布差异和双不平衡的双重挑战。现有的域适应方法很少关注普遍存在的双不平衡问题,导致域适应性能差或甚至产生负面迁移。在这项工作中,提出了一种自降级对比域适应(SdCDA)诊断框架,用于处理双不平衡数据下的域差异。该框架首先通过基于模型剪枝的不平衡感知对比学习来预训练特征提取器,以自监督的方式高效学习特征表示。然后,它基于监督对比域对抗学习(SupCon-DA)将样本推向域边界之外,并确保特征提取器生成的特征足够区分。此外,提出了剪枝对比域对抗学习(PSupCon-DA),以在对抗训练期间自动重新加权关注少数群体,从而提高对双不平衡数据的性能。

1 方法

Sd-CDA 框架包含两个主要部分:不平衡感知对比表示学习(Ia-CLR)和边界感知对抗域适应(Ba-ADA)。

1.1 Ia-CLR (Imbalance-aware contrastive representation learning)

Sd-CDA 框架中的预训练步骤,旨在解决工业故障诊断中数据双不平衡问题对特征学习的影响。与原始的 SimCLR 相比,Ia-CLR 可以自动检测少数类并在损失函数中增加它们的比例,以避免训练过程中对多数类的偏见。与现有的监督设置中的不平衡学习方法相比,Ia-CLR 不需要为目标样本的少数类寻找最优权重或伪标签来进行重新加权或重新平衡。++++更重要的是,Ia-CLR 非常适合 UDA 设置,因为它不需要目标样本的标签。++++

目标: 通过对比学习,学习对数据不平衡鲁棒的表征。

方法:

  • 使用 SimCLR 作为基础,并进行改进,引入模型剪枝。
  • 将原始数据输入到两个特征提取器:G 和 Gp(G 的剪枝版本)。
  • 将两个特征提取器输出的特征输入到同一个投影头 P。
  • 使用 PNT-Xent 损失函数,最小化 G 和 Gp 输出的正样本对之间的差异,并最大化负样本对之间的差异。
  • 由于剪枝模型 Gp 更容易忘记少数类样本,因此 PNT-Xent 损失函数会自动给少数类样本更高的权重,从而更好地学习少数类样本的特征。

1.2 Ba-ADA (Boundary-aware adversarial domain adaptation)

Sd-CDA 框架中的域适应步骤,旨在解决域适应过程中特征提取器容易生成边界特征的问题,并确保所有样本的特征具有区分度。包含以下四部分:

  • 特征提取器 G: 从源域和目标域中提取特征表示。
  • 域判别器 D: 判断特征来自源域还是目标域。
  • 剪枝域判别器 Dp: 域判别器 D 的剪枝版本,用于检测边界样本。
  • 标签分类器 C: 对特征进行分类。

目标: 通过对抗学习和对比学习,学习域不变且具有判别性的特征。

方法:

  • 使用 DANN 作为基础,并进行改进,引入监督对比学习和模型剪枝。
  • 训练特征提取器 G、域判别器 D 和标签分类器 C。
  • 使用域对抗训练,使 G 学习域不变特征。
  • 冻结 G 和 C,并对 D 进行剪枝,得到 Dp(D 的剪枝版本)。
  • 使用 SupCon-DA 损失函数,最小化 D 和 Dp 输出的正样本对之间的差异,并最大化负样本对之间的差异。
  • 由于剪枝模型 Dp 更容易忘记边界样本,因此 SupCon-DA 损失函数会自动给边界样本更高的权重,从而更好地学习边界样本的特征,并使 G 生成的特征远离域边界。

1.3 Sd-CDA 的优势

  • 处理数据不平衡: Ia-CLR 通过剪枝模型自动给少数类样本更高的权重,Ba-ADA 通过剪枝模型自动给边界样本更高的权重,从而有效地处理数据不平衡问题。
  • 学习域不变特征: DANN 的域对抗训练使 G 学习域不变特征。
  • 学习具有判别性的特征: SupCon-DA 的监督对比学习使 G 生成的特征远离域边界,从而学习具有判别性的特征。

2 实验

2.1 机械滚动轴承数据集

2.1.1 数据集

Case Western Reserve University 提供的机械滚动轴承振动信号数据集。

2.1.2 实验设置

++++数据集划分为源域和目标域,并构造了四种不同的数据不平衡情况:++++ ++++B2B(Balanced to Balanced)、B2I (Balanced to Imbalanced)++++ ++++、I2B、I2I。++++

使用 CNN 作为基础特征提取器,FNN 作为域判别器、投影头和标签分类器。

对比方法:CNN、JDA、DANN、ConDA、Imba-DA。

2.1.3 结果

  • Sd-CDA 在所有四种数据不平衡情况下都取得了比其他方法更好的性能,尤其是在 I2B 和 I2I 情况下。
  • Sd-CDA 在难分类的类别(例如,LBF 故障)上取得了比其他方法更好的准确率。

2.2 工业三相流数据集

2.2.1 数据集

Cranfield University 的工业三相流设施数据集。

2.2.2 实验设置

数据集划分为源域和目标域,并构造了四种不同的数据不平衡情况:B2B(Balanced to Balanced)、B2I (Balanced to Imbalanced)、I2B、I2I。

使用 FNN 作为基础特征提取器,FNN 作为域判别器、投影头和标签分类器。

对比方法:FNN、JDA、DANN、ConDA、Imba-DA。

2.2.3 结果

  • Sd-CDA 在所有四种数据不平衡情况下都取得了比其他方法更好的性能,尤其是在 B2I 和 I2I 情况下。
  • Sd-CDA 在难分类的类别上取得了比其他方法更好的准确率。

2.3 深入分析

  • 不同 αd 和 λbd 的影响: 实验表明,αd 和 λbd 的选择对 Sd-CDA 的性能有显著影响。αd 越大,Lp_bd 越大,对判别器 D 的推动力越强,但过大的 αd 会导致训练过程发散。λbd 越大,Lp_bd 对 D 的更新步长越大,但过大的 λbd 会导致训练过程不稳定。

注:

  • 消融实验: 消融实验表明,Ia-CLR 和 Ba-ADA 都是 Sd-CDA 的重要组成部分,它们都对提高诊断准确率做出了贡献。++++在 B2I 和 I2B 情况下,Ia-CLR 的贡献更大;在 I2I 情况下,Ba-ADA 的贡献更大。++++

3 相关知识

3.1 基本参数αd 及λbd

αd 和 λbd 是 Sd-CDA 框架中两个重要的超参数,分别控制着模型训练过程中的剪枝程度和特征区分度学习程度。

3.1.1 αd (Pruning Proportion for Domain Discriminator)

  • αd 表示在边界感知对抗域适应 (Ba-ADA) 阶段,对域判别器 D 进行剪枝时保留的参数比例。
  • 通过剪枝,模型会丢失一些参数,从而"忘记"一些学习到的特征。由于剪枝程度越大,模型丢失的信息越多,因此 αd 控制了剪枝的程度,进而影响了模型对边界样本的识别能力。

3.1.2 λbd (Loss Weight for Pruned Supervised Contrastive Domain Adversarial Loss)

  • 在 Ba-ADA 阶段,用于控制 Pruned Supervised Contrastive Domain Adversarial Loss (PSupCon-DA) 在总损失函数中的权重。
  • PSupCon-DA 损失函数用于惩罚模型对来自不同域的样本生成相同特征的情况,从而引导模型学习到更具区分度的特征。λbd 控制了 PSupCon-DA 损失函数在总损失函数中的权重,进而影响了模型对特征区分度的学习程度。

3.2 对比方法CNN、JDA、DANN、ConDA、Imba-DA

  • CNN: 使用卷积神经网络 (CNN) 作为特征提取器,对源域和目标域进行分类,但未考虑域差异和双不平衡问题。
  • JDA: 通过联合分布适配 (JDA) 对齐两个域的联合分布,但可能需要预定义的度量函数。
  • DANN: 使用域对抗神经网络 (DANN) 隐式地测量域间差异,但容易生成边界特征。
  • ConDA: 使用对比学习 (ConDA) 预训练特征提取器,并通过对抗学习进一步对齐两个域,但未考虑双不平衡问题。
  • Imba-DA: 使用成本敏感学习解决双不平衡问题,但可能需要大量的标签数据。
相关推荐
陈苏同学1 分钟前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm
吾名招财19 分钟前
yolov5-7.0模型DNN加载函数及参数详解(重要)
c++·人工智能·yolo·dnn
FL162386312930 分钟前
[深度学习][python]yolov11+bytetrack+pyqt5实现目标追踪
深度学习·qt·yolo
羊小猪~~36 分钟前
深度学习项目----用LSTM模型预测股价(包含LSTM网络简介,代码数据均可下载)
pytorch·python·rnn·深度学习·机器学习·数据分析·lstm
我是哈哈hh39 分钟前
专题十_穷举vs暴搜vs深搜vs回溯vs剪枝_二叉树的深度优先搜索_算法专题详细总结
服务器·数据结构·c++·算法·机器学习·深度优先·剪枝
鼠鼠龙年发大财1 小时前
【鼠鼠学AI代码合集#7】概率
人工智能
Tisfy1 小时前
LeetCode 2187.完成旅途的最少时间:二分查找
算法·leetcode·二分查找·题解·二分
龙的爹23331 小时前
论文 | Model-tuning Via Prompts Makes NLP Models Adversarially Robust
人工智能·gpt·深度学习·语言模型·自然语言处理·prompt
工业机器视觉设计和实现1 小时前
cnn突破四(生成卷积核与固定核对比)
人工智能·深度学习·cnn
醒了就刷牙1 小时前
58 深层循环神经网络_by《李沐:动手学深度学习v2》pytorch版
pytorch·rnn·深度学习