目录
[1. Introduction](#1. Introduction)
[2. Related Work](#2. Related Work)
[3. Method](#3. Method)
[3.1. DG from the Causal View](#3.1. DG from the Causal View)
[3.2.2 Causal Factorization Module](#3.2.2 Causal Factorization Module)
[3.2.3 Adversarial Mask Module](#3.2.3 Adversarial Mask Module)
[4. Experiment](#4. Experiment)
[4.1. Datasets](#4.1. Datasets)
[4.2. Implementation Details](#4.2. Implementation Details)
[4.3. Experimental Results](#4.3. Experimental Results)
论文连接:https://arxiv.org/pdf/2203.14237
摘要
领域泛化(DG)本质上是一个分布外问题,旨在将从多个源领域学到的知识泛化到未见过的目标领域。
GAP:
主流方法是利用统计模型来模拟数据和标签之间的依赖关系,从而学习独立于领域的表征。尽管如此,统计模型只是对现实的简单描述,因为它们只需要对依赖性而非内在因果机制进行建模。当依赖性随目标分布发生变化时,统计模型可能无法泛化。
方法:
为此,作者引入了一个通用的结构因果模型来形式化 DG 问题。
具体来说,作者假设每个输入都是由因果因素(其与标签的关系在不同领域是不变的)和非因果因素(与类别无关)混合构建而成,只有前者才会导致分类判断。目标是从输入中提取因果因素,然后重建不变的因果机制。然而,由于所需的因果/非因果因素是无法观测到的,因此这一理论想法与 DG 的实际情况相去甚远。
理想的因果因素应满足三个基本属性:与非因果因素分离、共同独立、因果关系足以进行分类。在此基础上,作者提出了一种因果关系启发表征学习(CIRL)算法,该算法强制要求表征满足上述属性,然后利用它们来模拟因果因素,从而提高泛化能力。
在几个广泛使用的数据集上的大量实验结果验证了方法的有效性。
1. Introduction
背景介绍:
近年来,随着现实世界中任务的复杂性不断增加,分布外(OOD)问题对基于深度神经网络的i.i.d. 假设 [29, 30, 36]。将源域上训练好的模型直接应用到分布不同的未知目标域上,通常会出现灾难性的性能下降 [17, 35, 37, 66]。为了解决领域转移问题,领域泛化(Domain Generalization,DG)引起了越来越多的关注,它旨在将从多个源领域提取的知识泛化到一个未知的目标领域[2, 25, 28, 41]。
为了提高泛化能力,人们提出了许多 DG 方法,大致可分为不变表示学习 [12, 28, 31, 40]、领域增强 [62,69,73,78] 和元学习 [2,9,26] 等。虽然这些研究取得了可喜的成果,但也存在一个内在问题。这些努力只是试图弥补 OOD 数据带来的问题,并对数据和标签之间的统计依赖性进行建模,而没有解释其背后的因果机制。最近有观点认为 [51],这种做法可能不够充分,要在 i.i.d. 设置之外很好地进行泛化,需要学习的不仅仅是变量之间的统计依赖性,而是一个潜在的因果模型 [4、46、50、51、58、63]。例如,在图像识别任务中,所有的长颈鹿都很可能在草地上,这就显示了很高的统计依赖性,当背景在目标方向上发生变化时,这很容易误导模型做出错误的预测。毕竟,长颈鹿的头部、颈部等特征而非背景才是长颈鹿的特征。
文章引入结构因果模型(SCM)[57] 来解决 DG 问题,旨在挖掘数据与标签之间的内在因果机制,实现更好的泛化能力。具体来说,假定与类别相关的信息为数据中的信息作为因果因素,其与标签的关系独立于领域,例如数字识别中的"形状"。而与类别无关的信息则被假定为非因果因素,通常是与领域相关的信息,如数字识别中的 "笔迹风格"。如图 1 所示,每个原始数据 X 都是由因果因素 S 和非因果因素 U 混合构建而成,只有前者会对类别标签 Y 产生因果影响。
(图 1. DG 的 SCM。实线箭头表示父节点导致子节点;而虚线箭头表示存在统计依赖性。)
模型的目标是从原始输入 X 中提取因果因素 S,然后重建不变的因果机制,这可以借助因果干预 P(Y ∣do(U ), S)来实现。操作符 do(⋅) [13] 表示对变量的干预。遗憾的是,无法直接将原始输入因子化为 X = f (S, U ),因为因果/非因果因子是无法观察到的,也无法表述,这使得因果推理尤其具有挑战性[60, 64]。
为了将理论想法付诸实践,根据 [51, 54, 58] 的研究,因果因子 S 应满足三个属性: 1) 与非因果因素 U 分离;2) S 的因果化应该是共同独立的;3) 对于分类任务 X -→ Y 来说,因果足够,即包含所有因果信息。
如图 2 (a)所示,与 U 的混合会导致 S 包含潜在的非因果信息,而共同依赖因子化会使 S 成为冗余,进一步导致遗漏某些潜在的因果信息。相比之下,图 2 (b) 中的因果因子 S 是符合所有要求的理想因果因子。受此启发,作者提出了一种因果关系启发表征学习(CIRL)算法,强制要求学习到的表征具有上述特性,然后利用表征的每个二维扩展来模仿因果因子的因子化,因为因果因子的因子化具有更强的泛化能力。
简而言之,对于每个输入,首先利用一个因果监督模块,通过生成具有扰动领域相关信息的新数据,将因果因素 S 从非因果因素 U 中分离出来。与原始数据相比,生成的数据具有不同的非因果因素 U,但具有相同的因果因素 S,因此表征会强制保持不变。此外,提出了一个因式分解模块,使表示的每个维度表征联合独立,可以利用它们来近似地推测因果关系。此外,为了在因果关系上实现充分的分类,设计了一个对抗性掩码模块,通过掩码器和表征生成器之间的对抗性学习,迭代检测包含相对较少因果关系信息的维度,并迫使它们包含更多新的因果关系信息。工作贡献如下:
指出了仅对统计依赖性建模的不足,并将基于因果关系的观点引入 DG,以挖掘内在的因果机制。
强调了理想的因果因素应具备的三个属性,并提出了一种 CIRL 算法来学习能够模仿因果因素的因果表征,这种因果表征具有更好的泛化能力。
在几个广泛使用的数据集上进行的大量实验和分析结果证明了方法的有效性和优越性。
2. Related Work
因果机制[19, 47, 50]关注的是,统计依赖性("看到别人吃药就说明他生病了")无法可靠地预测反事实输入("停止吃药并不会让他健康")的结果。一般来说,可以将其视为推理链的组成部分[24],为以下内容提供预测与观察到的分布相去甚远的情况。从这个意义上说,挖掘因果机制意味着获取稳健的知识,这些知识超越了观察到的数据分布的支持范围[59]。在过去几年中,因果关系与泛化之间的联系越来越受到关注[39, 51]。人们提出了许多基于因果关系的方法,以获得不变的因果机制[16, 65, 71]或恢复因果特征[6, 13, 33, 55],从而提高 OOD 的泛化能力。值得注意的是,这些方法一般都依赖于对因果图或结构方程的限制性假设。
最近,MatchDG [38]将因果关系引入了 DG 文献,通过对比学习强制跨域输入具有相同的表示(如果它们来自同一对象)。 CIRL 在学习因果表征方面与 MatchDG 有关。但是,CIRL 的不同之处在于,它是在一个理论性更强的形式基础上,明确利用维度表征来模仿因果因素,并且只依赖于一个更一般的因果结构模型,而没有限制性假设。从本质上讲,CIRL 可以看作是带有干预的因果因式分解,这与对象条件的 MatchDG 显然不同。
3. Method
将从因果关系的角度来考虑 DG,并使用图 1 所示的一般结构因果模型。证明,如果因果因素是给定的,那么构建内在因果机制(形式化为条件分布)是可行的。然而,正如文献[1]所讨论的,由于因果因素是不可观测的,因此很难准确地恢复因果因素。因此,建议根据因果因素的属性来学习因果表征,作为一种模仿,同时继承其优越的泛化能力。
3.1. DG from the Causal View
DG 的主流是对观测输入与相应标签之间的统计依赖性建模,即 P (X,Y),并假定这种依赖性在不同领域间存在差异。为了获得不变的依赖性,他们一般会强制要求分布在边际上或条件上是域不变的,即尽量减小 P (X) 或 P (X ∣ Y) 的跨域差距。然而,由于统计依赖性无法解释输入和标签之间的内在因果机制,它往往会随领域而变化。因此,学习到的源域之间的不变依赖性在未见过的目标域上仍可能失效。与此同时,因果机制通常会在不同领域之间保持稳定 [51]。正如 Reichenbach [54] 在原则 1 中所说的那样,首先阐明了因果关系和统计依赖性之间的联系。
原则 1([54])。共因原理:如果两个可观测值X和Y在统计上是相关的,那么存在一个变量S,它对两者都有因果影响,并解释了所有的依赖性,即当以S为条件时,它们是独立的。基于原则 1,形式化了以下结构因果模型(SCM)来描述 DG 问题:
V1、V2 是共同独立的未解释噪音变量。至于 f、h、g,它们可以看作是未知的结构函数。
因此,对于任意分布 P(X,Y )∈ P,如果给定因果因子 S,则存在一般条件分布 P(Y ∣ S),即不变因果机制。
根据上述讨论,如果能获取因果因素,那么通过优化 h 很容易得到在 i.i.d. 假设之外具有良好普适性的因果机制:
其中,l(⋅, ⋅) 是交叉熵损失。
遗憾的是,并没有先验地获得因果因素 S,而是获得了原始图像 X,而这些图像通常是非结构化的。由于这些因果因素无法观测且定义不清,因此要想彻底重建这些因果因素和机制,在某种程度上是不切实际的。此外,更重要的是,正如[34,59,60]所讨论的,可以提取哪些因素及其粒度取决于可用的分布变化、监督信号和显式干预。不过,显而易见的是,因果因素仍需符合某些要求。之前的研究 [51, 58] 指出,如原则 2 所述,因果因素应是共同独立的。
**原则 2 ( [51, 58])。独立因果机制 (ICM) 原则:**每个变量的条件分布与其原因(即其机制)无关,也不影响其他机制。
由于式 (1) 中的 S 代表所有因果关系因素 {s1, s2, . ,sN },该原则告诉我们:1)改变(或干预)一个机制 P( ∣ )不会改变任何其他机制 P( ∣ ),i ≠ j [58](表示 si 在因果图中的父节点,由于 S 已是根节点,可将其视为 所包含的因果信息)2)知道其他一些机制 P ( ∣ ) 并不能给我们提供机制 P ( ∣ ) 的信息[21]。因此,我们可以将因果因素的联合分布条件因子化为如下,涉及因果分解:
因此,基于共同原因原则(原则 1)中因果变量的定义和 ICM 原则(原则 2)中因果机制的性质,因果因素 S 应满足三个基本属性:
(1) 因果因素 S 应与非因果因素 U 分开,即 S ⫫ U。 因此,对 U 进行干预不会改变 S;
(2) s1, ..., sN 应该是共同独立的,其中任何一个都不包含其他因素的信息。
(3) 因果因子 S 对分类任务 X -→ Y 来说应该是因果充分的,即包含能解释所有统计依赖关系的信息。
因此,建议不直接重建因果因素,而是学习因果表征,迫使它们具有与因果因素相同的属性。具体细节将在第 3.2 节中解释。
3.2.因果关系启发的表示学习
在本节中,将说明受上述因果关系启发而提出的表征学习算法,该算法由三个模块组成:因果干预模块、因果因式分解模块和对抗掩码模块。整个框架如图 3 所示。
(图 3. CIRL 的框架。首先通过因果干预模块生成增强图像,并对非因果因素进行干预。原始图像和增强图像的表征都会被发送到因果化模块,该模块会施加一个因果化损失,以强制表征与非因果因素分离并共同独立。最后,对抗掩码模块在生成器和掩码器之间进行对抗,使表征在因果关系上足以进行分类。)
首先要通过因果干预将因果因素 S 从非因果因素 U 的混合物中分离出来。具体来说,虽然公式(2)中的因果因素提取器 g(⋅) 的明确形式在一般情况下是未知的,但有先验知识,即因果因素 S 应该对 U 的干预保持不变,即 P (S ∣ do(U )) 。而在 DG 文献中,确实知道一些与领域相关的信息无法确定输入的类别,这些信息可被视为非因果因素,并被一些技术所捕捉 [73, 76, 78]。
例如,傅立叶变换有一个众所周知的特性:傅立叶频谱的相位分量保留了原始信号的高级语义,而振幅分量则包含了低级统计信息[45, 52]。因此,在对 U 进行干预时,会干扰幅度信息,而保持相位信息不变,就像文献 [73] 所做的那样。形式上,给定原始输入图像 ,其傅里叶变换可表述为 :
和分别表示幅度和相位分量。傅立叶变换 F(⋅) 及其逆 (⋅)可通过 FFT 算法 [44] 有效计算。然后,通过线性插值原始图像 的振幅频谱和从任意源域随机采样的图像 的振幅频谱,对振幅形成进行扰动:
其中,λ ∼ U (0, η) 和 η 控制扰动的强度。然后,将扰动振幅谱与原始相位分量相结合,通过反傅里叶变换生成增强图像 :
将 CNN 模型实现的表征生成器表示为 ,表征表示为 ,其中 N 是维数。为了模拟对 U 的干预保持不变的因果因素,对 ˆg 进行了优化,以强制表征在维度上对上述干预保持不变:
其中, 和 分别表示 和 的第 i 列的 Z -score归一化。为批量大小, 和 分别为 i∈ {1, . , B}.利用 COR 函数来衡量干预前后表征的相关性。这样,就可以通过使表征 R 独立于 U 来实现模拟因果因素 S 的第一步。
3.2.2 Causal Factorization Module
正如在第 3.1 节中提出的,因果因子 s1、s2 ... ... , sN 应该是共同独立的,即其中任何一个因素都不包含其他因素的信息。因此,打算让任何两个维度的表征相互独立:
请注意,为了节省计算成本,省略了 或 内的约束。统一方程的优化目标等式 (7) 和等式 (8)、建立相关矩阵C:
其中 < ⋅ > 表示内积运算。因此,和的相同维度可以被视为需要最大化相关性的正对,而不同维度可以被视为需要最小化相关性的负对。基于此,设计了分解损失 ,其公式如下:
备注 1. 式(10)中的目标可以使相关矩阵 C 的二乘元素近似为 1,这意味着在 "相关矩阵 "之前和之后的表示都可以近似为 1。对非因果因素的干预是不变的。这表明可以有效地将因果因素从非因果因素的混合物中分离出来。
此外,它还使 C 的非对角元素接近于 0,即强制各表征的对角元素共同独立。因此,通过最小化 ,可以将嘈杂和依赖的表征转化为干净和独立的表征,从而满足理想因果因子的前两个属性。
3.2.3 Adversarial Mask Module
为了成功完成分类任务 X → Y,表示应该是因果充分的,包含所有支持信息。最直接的方法是在多个源域中利用监督标签 y:
其中 是分类器。然而,这种直接的方法不能保证学到的表示的每个维度都是重要的,即包含足够的用于分类的潜在因果信息。具体来说,可能存在较差的维度,它们携带相对较少的因果信息,然后对分类做出很小的贡献。因此,作者建议检测这些维度,迫使它们做出更大贡献。由于因式分解模块还要求各维度是共同独立的,因此检测到的次要维度会包含更多其他维度没有包含的新因果信息,从而使整个表征更具因果性。
因此,为了检测劣势维度,设计了一个对抗性掩模模块 。构建一个基于神经网络的掩码器,用 表示来学习每个维度的贡献,其中κ∈(0,1)比值最大的维度被视为优势维度,其余的被视为劣势维度:
采用常用的可导 Gumbel-Softmax 技巧 [20] 对 κN 值接近 1 的掩模进行采样(个人理解是:使用 Gumbel-Softmax 生成稀疏化的掩码,根据掩码值的排序,选取贡献最大的前 κ⋅N个维度)。通过将学习到的表示乘以获得的掩码 m 和 1−m,可以分别获得表示的优势和劣势。然后,将它们输入两个不同的分类器和。等式 (11) 可以重写如下:
通过最小化 和最大化 来优化掩码器,同时通过最小化两个监督损失来优化生成器 和分类器和。
备注 2. 所提出的对抗性掩码模块可以精确地检测劣势维度,因为 1) 对于优化的 以基于现有掩码维度最小化 ,学习 m 来选择最大化 的维度可以找到贡献较小的劣势维度,2)优劣维度集相互补充,如果一个维度不被视为优维度,那么它将被视为劣维度,因此优维度的选择将有助于劣维度的选择。此外,与只优化公式 (11) 相比,对抗掩码模块与因果系数化模块相结合,可以帮助生成更具因果合理性的表征,因为通过优化 以最小化 和 ,劣势维度被迫携带更多因果信息,并与现有优势维度相互独立。
最后,学习到的表征将通过迭代地将较差的表征"替换"为新的较好的表征来充分接近因果关系。需要明确的是,提出的 CIRL 的总体优化目标总结如下:
其中 τ 是权衡参数。请注意,在推理过程中使用了整个表示 r 和分类器 。
备注 3. 请注意,特征维数的影响可以忽略不计。通过三个模块的协同优化,整个表征中包含的因果信息总量会不断增加,直到学习到的表征能够解释输入和标签之间的所有统计依赖关系,而与特征维度无关。补充材料中的实验分析验证了论证。
4. Experiment
4.1. Datasets
Digits-DG [75] 包含四个数字域,包括 MNIST [22]、MNIST-M [10]、SVHN [43] 和 SYN [10],这些数字域在字体风格、背景和笔画颜色方面存在巨大差异。按照文献[75],为每个领域的每个类别随机选择 600 张图像,然后将 80% 的数据用于训练,20% 的数据用于验证。
PACS[25]是专门为DG提出的,它包含来自四个领域(艺术绘画、卡通、照片和素描)的9, 991张图像,风格差异较大。每个领域有 7 个类别:狗、大象、长颈鹿、吉他、房子、马和人。为了公平比较,使用[25]提供的原始训练-验证分割。
Office-Home [68] 是办公室和家庭环境中的对象识别数据集,收集了 65 个类别的 15, 500 张图像。这 65 个类别由四个领域(艺术、剪贴画、产品和现实世界)共享,这提供观点和图像风格。按照[73],每个域被分为 90% 用于训练,10% 用于验证。
4.2. Implementation Details
按照常用的 "留出一个域 "方案[25],指定一个域作为评估的未见目标域,然后用其余域进行训练。对于 Digits-DG,所有图像的大小都调整为 32 × 32,使用迷你批量 SGD 优化器从头开始训练网络,批量大小为 128,动量为 0.9,权重衰减为 5e-4,持续 50 个 epochs。学习率每 20 个epoch衰减 0.1。在 PACS 和 Office-Home 中,所有图像的大小均调整为 224 × 224。使用迷你批量 SGD 从头开始训练网络,批量大小为 16,动量为 0.9,权重衰减为 5e-4,共 50 个历时,学习率在总历时的 80% 时衰减 0.1。至于超参数 κ 和 τ,它们的值是根据源验证集上的结果选择的,因为目标域在训练过程中是不可见的。具体来说,我们为 Digits-DG 和 PACS 设置了 κ = 60%,而为 Office-Home 设置了 κ = 80%。所有结果均基于三次重复运行的平均准确率。
4.3. Experimental Results
表 1 列出了 Digits-DG 的结果,其中 CIRL 的平均准确率超过了所有比较过的基线方法。请注意,CIRL 分别以 8.0% 和 7.9% 的较大优势超过了基于领域不变表示的 CCSA [40] 和 MMD-AAE [28],这表明挖掘数据与标签之间的内在因果机制而不是超统计依赖性的重要性。此外,还将 CIRL 与 FACT [73]进行了比较,因为因果干预模块采用了相同的增强技术。值得一提的是,FACT 是 DG 领域最先进的方法,1.0% 的性能提升具有挑战性。而 CIRL 比 FACT 提高了 1.0%,这进一步验证了方法的有效性。