Unsupervised Domain Adaption (UDA)及domain shift介绍

UDA

UDA想解决的问题是目标域上数据标签的缺乏,具体而言,存在着源域和目标域,源域上存在大量的标注样本对 <math xmlns="http://www.w3.org/1998/Math/MathML"> D s = { ( X i , y i ) } D_s=\{(X_i,y_i)\} </math>Ds={(Xi,yi)},我们可以在上面以有监督的方式训练各种模型,但此外我们想要将模型迁移到一个不存在标注的目标域 <math xmlns="http://www.w3.org/1998/Math/MathML"> D t = { X i } D_t=\{X_i\} </math>Dt={Xi}上,由于不存在便签,因此我们无法对训练好的模型进行简单的finetune,而是是要通过已知源域、目标域的特征空间来进行模型的迁移,这就是UDA。

简单来看,UDA整体就类似于模型的训练过程,即在含标签的训练集数据上进行模型的训练,然后迁移到标签未知的测试集(或者实际应用场景)中进行评估,这里存在着一个显著的区别--模型训练时我们假定了测试集和训练集独立同分布,在同一个域中。而UDA的场景则更为复杂,其源域和目标域可能是并不完全对齐的,因此在UDA中一个很简单的思路就是寻找某种变换,将源域和目标域映射到同一空间,然后在该空间内使用源域数据进行模型训练,这样就能使得训练好的模型可以完美的迁移到目标域了,如下图所示: 那么在这种思路下问题就变成了如何对齐源域和目标域,有几种比较典型的思路:

  1. 最小化域间差异 ,找一个变换,使得变换前后源域目标域的数据分布在某种度量下最小,这里的问题就是如何定义数据分布,特征和标签构成的分布有联合分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( X , y ) P(X,y) </math>P(X,y),条件分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( X ∣ y ) , P ( y ∣ X ) P(X|y), P(y|X) </math>P(X∣y),P(y∣X)以及边缘分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( X ) , P ( y ) P(X), P(y) </math>P(X),P(y),那么该如何适配这些分布呢,尤其是在目标域标签未知的情况下;
  2. 域不变特征: 这一思路也较为直观,就是直接找到两个域所共享的特征,无论是在源域还是在目标域都可以用这一特征来进行判别。
  3. 对抗网络 :利用对抗网络的思想,包含特征提取网络和类判别网络、域判别网络,使得网络无法确认提取到的特征是来自于源域还是目标域^1^,但都能给出良好的判别效果,其实算是寻找域不变特征

self-training

除了这一类UDA本土的方法外,还有一些人把Semi-Supervised-Learning中的Self-training方法用到了UDA领域。Self-training的常用任务场景和UDA类似,但感觉上并没有考虑域分布不同的问题,更多是在同一个域,只是某些label未知。Self-training旨在利用有标签的数据对模型进行训练,然后用训练好的模型对无标记数据进行预测得到伪标签(可以是通过阈值阶截断后的hard target或者是soft target),然后把伪标注样本也纳入训练,以实现半监督训练的效果^2^。

self-training本身显著的问题在于它生成的伪标签非常的noisy,而且通常会只采纳高置信度伪标注样本纳入后续训练^3^,使得低置信度样本得不到充分的训练;并且这种bias也使得模型会给许多样本同一类别^4^。  而将self-training 方法迁移到UDA领域面临的主要问题就是源域目标域的分布不同,这和self-training的任务场景有着显著差异,需要进行适当的修改。

domain shift

既然要考虑到源域和目标域数据分布不同,就得细致的去了解这种分布不同有哪些表现形式,应该如何解决。domain shift(域偏移)又可以细分为label shift和covariate shift两种,前者指数据的条件分布相同而边缘分布不同,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> P s ( X ∣ y ) ≠ P t ( X ∣ y ) , P s ( X ) = P t ( X ) P_s(X|y)\neq P_t(X|y),\ P_s(X)=P_t(X) </math>Ps(X∣y)=Pt(X∣y), Ps(X)=Pt(X)后者指数据的边缘分布不同而条件分布相同 <math xmlns="http://www.w3.org/1998/Math/MathML"> P s ( y ∣ X ) = P t ( y ∣ X ) , P s ( X ) ≠ P t ( X ) P_s(y|X)=P_t(y|X),\ P_s(X)\neq P_t(X) </math>Ps(y∣X)=Pt(y∣X), Ps(X)=Pt(X),这两种偏移在深度学习的域迁移问题中广泛存在^5^,以下进行详细介绍。

Covariate shift

covariate shift描述的是两个域体条件分布一致,但是边缘分布不一致。举个不恰当的例子,A地和B地由于气候的不同,某天两地下雨还是天气的概率不同,但是一旦确定了当天天气,那么该地居民出门是否带伞的概率相同。

covariate shift的出现使得我们在源域上构建的模型并不能够良好的迁移到目标域上,除非源域上设立的模型空间恰好包含了目标域的模型空间。同样以天气为例,A地某人看到蜻蜓低飞决定带伞出门,但B地没有蜻蜓,也就无法借由蜻蜓这一特征来判断是否要带伞出门,而倘若B地有蚂蚁,我们建立起蜻蜓和蚂蚁在不同天气下行为的联系,就可根据这一特效进行B地预测。更为严谨的数学推导见^6^ 。典型图如下: 源域和目标域的条件分布可以用同一函数表示(True func),但由于边缘分布存在显著差异,在使用线性假设空间的情况下,我们在源域上能够学到的最优假设同样为线性(Learned func),因此在目标域上表现不佳。

在已知covariate shift存在的情况下,我们需要对源域的预测模型进行调整以适应目标域的数据分布,具体表现为在源域训练的模型优化时要添加基于分布的系数:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> arg ⁡ min ⁡ θ 1 m s ∑ i = 1 m s P t ( X i ) P s ( X i ) H ( F θ ( X i ) , y i ) \underset{\theta}{\arg\min} \ \ \ \frac{1}{m_s}\sum_{i=1}^{m_s}\frac{P_t(X_i)}{P_s(X_i)}H(F_{\theta}(X_i),y_i) </math>θargmin ms1i=1∑msPs(Xi)Pt(Xi)H(Fθ(Xi),yi)

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> H H </math>H为交叉熵函数, <math xmlns="http://www.w3.org/1998/Math/MathML"> F θ ( X i ) F_{\theta}(X_i) </math>Fθ(Xi)为模型在源域的预测输出。可以看出问题的关键就在于如何求取 <math xmlns="http://www.w3.org/1998/Math/MathML"> P t ( X i ) P s ( X i ) \frac{P_t(X_i)}{P_s(X_i)} </math>Ps(Xi)Pt(Xi),显然我们是无法直接得到源域和目标域数据分布的,为了计算该值,研究者也进行了大量的研究,典型方法如额外训练LR分类器来判断该样本属于目标域还是源域,并将输出值之比作为 <math xmlns="http://www.w3.org/1998/Math/MathML"> P t ( X i ) P s ( X i ) \frac{P_t(X_i)}{P_s(X_i)} </math>Ps(Xi)Pt(Xi).

Label shift

待更新

参考文献

Footnotes

  1. 基于对抗的迁移学习方法: DANN域对抗网络

  2. self training 文章梳理

  3. FixMatch

  4. NeurIPS 2021 | 助力半监督学习:课程伪标签方法FlexMatch和统一开源库TorchSSL

  5. Generalized Label Shift

  6. 基于样例的迁移学习------Covariate Shift------原始文章解读

相关推荐
千天夜18 分钟前
激活函数解析:神经网络背后的“驱动力”
人工智能·深度学习·神经网络
大数据面试宝典19 分钟前
用AI来写SQL:让ChatGPT成为你的数据库助手
数据库·人工智能·chatgpt
封步宇AIGC24 分钟前
量化交易系统开发-实时行情自动化交易-3.4.1.2.A股交易数据
人工智能·python·机器学习·数据挖掘
m0_5236742126 分钟前
技术前沿:从强化学习到Prompt Engineering,业务流程管理的创新之路
人工智能·深度学习·目标检测·机器学习·语言模型·自然语言处理·数据挖掘
HappyAcmen36 分钟前
IDEA部署AI代写插件
java·人工智能·intellij-idea
噜噜噜噜鲁先森1 小时前
看懂本文,入门神经网络Neural Network
人工智能
InheritGuo2 小时前
It’s All About Your Sketch: Democratising Sketch Control in Diffusion Models
人工智能·计算机视觉·sketch
weixin_307779132 小时前
证明存在常数c, C > 0,使得在一系列特定条件下,某个特定投资时刻出现的概率与天数的对数成反比
人工智能·算法·机器学习
封步宇AIGC2 小时前
量化交易系统开发-实时行情自动化交易-3.4.1.6.A股宏观经济数据
人工智能·python·机器学习·数据挖掘
Jack黄从零学c++2 小时前
opencv(c++)图像的灰度转换
c++·人工智能·opencv