【论文阅读】Associative Alignment for Few-shot Image Classification

用于小样本图像分类的关联对齐

引用:Afrasiyabi A, Lalonde J F, Gagné C. Associative alignment for few-shot image classification[C]//Computer Vision--ECCV 2020: 16th European Conference, Glasgow, UK, August 23--28, 2020, Proceedings, Part V 16. Springer International Publishing, 2020: 18-35.

论文地址:下载地址

论文代码:https://github.com/ArmanAfrasiyabi/associative-alignment-fs

Abstract

小样本图像分类旨在从每个"新类别"仅有的少量样本中训练模型。本文提出了一种关联对齐的思想,利用部分基础数据,通过将新类的训练实例与基础训练集中密切相关的样本对齐,来扩展新类的有效训练集规模。这种方法通过添加额外的"相关基础"样本到少量的新类样本中,从而允许更具建设性的微调。我们提出了两种关联对齐策略:1)一种度量学习损失,用于最小化相关基础样本与特征空间中新类样本中心点之间的距离;2)基于Wasserstein距离的条件对抗性对齐损失。在四个标准数据集和三个骨干网络上的实验表明,结合我们的基于中心点的对齐损失,在物体识别、细粒度分类和跨域适应的5-shot学习中,分别比现有技术取得了4.4%、1.2%和6.2%的绝对精度提升。

1 Introduction

尽管最近取得了进展,但在较少监督下对新概念进行泛化仍然是计算机视觉中的一项挑战。在图像分类的背景下,小样本学习旨在获得一个模型,使其在仅有极少训练样本时也能学习识别新类别的图像。

元学习 ^1^ ^2^ ^3^ ^4^ 是实现这一目标的可能方法,它通过从大量有标注的数据(即"基础"类别)中提取通用知识,训练出一个模型,之后能够仅凭少量样本学习对"新"概念进行分类。这是通过反复从大量基础图像池中抽取小的子集来实现的,实际上是模拟了小样本场景。标准的迁移学习也被探索为一种替代方法 ^5^ ^6^ ^7^。这种方法的思想是先在基础样本上对网络进行预训练,然后在新类样本上微调分类层。有趣的是,Chen 等人 ^5^ 证明,这种方法的表现与更复杂的元学习策略相当。然而,在对新类进行微调时,需要冻结网络的特征编码器部分,否则网络会过拟合到新类样本上。我们推测,这会限制性能,如果整个网络都能够适应新类别,可能会带来性能提升。

在本文中,我们提出了一种方法,能够在防止过拟合的同时,不限制网络在小样本图像分类中的学习能力。我们的方法以标准的迁移学习策略 ^5^ 作为起点,随后利用与少量新类样本最相似的基础类别(在特征空间中),有效地提供额外的训练样本。我们称这些相似的类别为"相关基础"类别。当然,相关基础类别与新类别代表了不同的概念,因此直接在它们上进行微调可能会让网络混淆(见图1-(a))。本文的关键思想是在特征空间中将新样本与相关基础样本对齐(图1-(b))。

为此,我们提出了两种可能的关联对齐解决方案:1)中心点对齐,受ProtoNet ^3^ 的启发,通过显式地缩小类内变化来获益,训练过程更稳定,但假设类分布能够被单模态很好地逼近。对抗性对齐,受WGAN ^8^ 的启发,不做这种假设,但由于评估网络的存在,训练复杂度更高。我们通过广泛的实验表明,我们的基于中心点的对齐过程在多个标准基准上的小样本分类中达到了当前最先进的性能。类似的结果也通过我们的对抗性对齐获得,表明了我们关联对齐方法的有效性。

我们提出了以下贡献。首先,我们提出了两种在特征空间中将新类对齐到相关基础类的方法,从而能够有效地训练整个网络以进行小样本图像分类。其次,我们引入了一个强大的基线模型,该模型将标准的迁移学习 ^5^ 与一个附加的角度边距损失 ^9^ 结合在一起,并在基础类别上进行预训练时通过早停来对网络进行正则化。我们发现,这个简单的基线实际上在最佳情况下将整体准确率提高了3%。第三,我们通过广泛的实验------在四个标准数据集上,并使用三个广为人知的骨干特征提取器------证明了我们提出的基于中心点的对齐在三种场景下显著超越了当前最先进的技术:通用物体识别(在mini-ImageNet、tieredImageNet和FC100上的5-shot学习中,整体准确率分别提升1.7%、4.4%和2.1%),细粒度分类(在CUB上的提升为1.2%),以及跨域适应(从mini-ImageNet到CUB的提升为6.2%)使用ResNet-18骨干网络。

主要的小样本学习方法可以大致分为元学习和标准迁移学习。此外,数据增强和正则化技术(通常在元学习中使用)也被用于小样本学习。我们简要回顾了每个类别中的相关工作。值得注意的是,多个不同的计算机视觉问题,如物体计数 ^10^、视频分类 ^11^、运动预测 ^12^ 和物体检测 ^13^,都被表述为小样本学习问题。在这里,我们主要关注图像分类领域的工作。

元学习 这类方法将小样本学习框定为情景训练 ^14^ ^1^ ^2^ ^15^ ^3^ ^16^ ^13^ ^17^。情景通过在训练基础类别(这些类别的样本量较大)时假装处于小样本情境下进行定义。初始化方法和度量方法是与本文相关的情景训练方案的两种变体。初始化方法 ^1^ ^18^ ^19^ 学习一个初始模型,该模型能够通过少量的梯度步适应少量的新样本。相比之下,我们的方法执行了更多的更新,但要求新样本与其相关基础样本之间保持对齐。度量方法 ^20^ ^21^ ^22^ ^23^ ^24^ ^25^ ^3^ ^26^ ^27^ ^4^ ^28^ ^29^ 学习一个度量,旨在减少类内差异,同时在基础类别上进行训练。例如,ProtoNet ^3^ 旨在学习一个特征空间,在该空间中,给定类的实例接近相应的原型(质心),从而实现基于距离的精确分类。我们的中心对齐策略借鉴了这种基于距离的准则,但将其用于在特征空间中匹配分布,而不是构建分类器。

标准迁移学习 这种方法的策略是先在基础类别上对网络进行预训练,然后在新样本上进行微调 ^5^ ^6^ ^7^。尽管其方法简单,Chen 等人 ^5^ 最近表明,当使用深层骨干网络作为特征提取器时,这种方法可以取得与元学习相似的泛化性能。然而,他们也表明,由于过拟合的倾向,在微调时必须冻结预训练特征提取器的权重。尽管我们提出的训练过程与基础类别中的标准微调类似,但我们的方法允许训练整个网络,从而增加了所学模型的容量,同时提高了分类准确性。

正则化技巧 Wang 等人 ^30^ 提出了用于正则化目的的回归网络,通过将微调模型的参数优化为接近预训练模型来实现正则化。最近,Lee 等人 ^31^ 利用线性分类器的隐式微分与hinge loss和L2正则化应用于基于CNN的特征学习器。Dvornik 等人 ^32^ 使用网络集成来减少分类器的方差。

数据增强 另一类技术依赖于在小样本情境下的额外数据进行训练,大多数情况下遵循元学习训练程序 ^33^ ^34^ ^35^ ^36^ ^37^ ^38^ ^39^ ^40^ ^41^ ^42^。为此,已经提出了几种方法,包括特征幻觉(FH) ^37^,该方法通过学习样本之间的映射并使用辅助生成器来在特征空间中生成额外的训练样本。随后,Wang 等人 ^40^ 提出了使用GAN来实现相同的目的,从而解决了FH框架泛化能力差的问题。不幸的是,这种方法被证明存在模式崩溃的问题 ^35^。与生成人工数据以进行增强不同,其他方法提出了利用额外的未标记数据 ^43^ ^44^ ^45^ ^46^。Liu 等人 ^47^ 提出了一种从少量标记数据向大量未标记数据传播标签的方法,类似于我们对相关基础样本的检测。我们同样依赖更多的数据进行训练,但与这些方法不同,我们的方法不需要任何新数据,也不需要生成数据。相反,我们利用已有的基础域数据,并通过微调将新类域对齐到相关的基础样本。

以前的工作也利用了基础训练数据,与我们最相关的工作是 ^33^ 和 ^48^。Chen 等人 ^33^ 提出了使用嵌入和变形子网络来利用额外的训练样本,而我们依赖于一个单一的特征提取网络,这更易于实现和训练。与随机基础样本采样 ^33^ 通过图像空间中的新样本变形插值不同,我们提出在特征空间中借用检测到的相关类别的内部分布结构。此外,我们的对齐策略引入了额外的准则,使学习者的注意力集中在新类上,防止新类成为异常点。针对物体检测,Lim 等人 ^48^ 提出了一种使用稀疏组Lasso框架搜索相似物体类别的模型。与 ^48^ 不同,我们在小样本图像分类的背景下提出并评估了两种关联对齐方法。

从对齐的角度来看,我们的工作与Jiang 等人 ^49^ 的工作相关,该工作是在零样本学习的背景下,提出在视觉-语义结构中通过匹配词典来找到匹配概念。相比之下,我们提出了关联的基础类-新类对齐方法,并提出了两种策略来强制统一相关概念。

3 Preliminaries

假设我们有一个大型的基础数据集 X b = { ( x b i , y b i ) } i = 1 N b X_b=\{(x_b^i,y_b^i)\}{i=1}^{N_b} Xb={(xbi,ybi)}i=1Nb,其中 x b i ∈ R d x_b^i\in\mathbb{R}^d xbi∈Rd 是第 i i i 个数据实例, y b i ∈ Y b y_b^i\in Y_b ybi∈Yb 是相应的类别标签。我们还给定了少量的新类别数据 X n = { ( x n i , y n i ) } i = 1 N n X_n=\{(x_n^i,y_n^i)\}{i=1}^{N_n} Xn={(xni,yni)}i=1Nn,其中标签 y n i ∈ Y n y_n^i\in Y_n yni∈Yn 来自一个与基础类别集不同的新类别集 Y n Y_n Yn。小样本分类的目标是仅通过每个新类别的少量样本(例如5个甚至1个)来训练分类器。在这项工作中,我们使用了 Chen 等人^5^ 提出的标准迁移学习策略,该策略分为以下两个阶段:

预训练阶段

学习模型是一个由特征提取器 f ( ⋅ ∣ θ ) f(\cdot|\theta) f(⋅∣θ)(由参数 θ \theta θ 表示)和线性分类器 c ( x ∣ W ) ≡ W ⊤ f ( x ∣ θ ) c(x|W)\equiv W^\top f(x|\theta) c(x∣W)≡W⊤f(x∣θ) 组成的神经网络,其中 W W W 是描述分类器的矩阵,最后通过如 softmax 之类的评分函数生成输出。该网络在基础类别集 X b X_b Xb 的样本上从头开始训练。

  • 微调阶段*
      为了使网络适应新类别,网络随后在来自 X n X_n Xn 的少量样本上进行微调。由于如果更新所有网络权重很可能会导致过拟合,因此在这一阶段,特征提取器的权重 θ \theta θ 被冻结,只有分类器的权重 W W W 会被更新。

4 Associative alignment

冻结特征提取器的权重 θ \theta θ 确实减少了过拟合,但也限制了模型的学习能力。在本文中,我们力求两全其美,并提出了一种在控制过拟合的同时保持模型原有学习能力的方法。我们借用了相关基础类别子集的内部分布结构, X r b ⊂ X b X_{rb} \subset X_b Xrb⊂Xb。为了处理新类别与相关基础类别之间的差异,我们提出在特征空间中将新类别对齐到相关基础类别。这种映射允许拥有更大的训练数据池,同时使这两个集合的实例更加一致。注意,与 [4] 相反,我们并不以任何方式修改相关基础实例:我们只是希望将新样本对齐到其相关类别实例的分布。

在本节中,我们首先描述如何确定相关基础类别。接着,我们提出本文的主要贡献:"中心点关联对齐"方法,该方法利用相关基础实例来提高新类别的分类性能。最后,我们提出了一种替代的关联对齐策略,它依赖于对抗框架。

我们开发了一个简单但有效的过程来选择与新类别相关的一组基础类别。我们的方法将 B B B 个基础类别与每个新类关联。在 X b X_b Xb 上训练 c ( f ( ⋅ ∣ θ ) ∣ W ) c(f(\cdot|\theta)|W) c(f(⋅∣θ)∣W) 后,我们首先在 X n X_n Xn 上微调 c ( ⋅ ∣ W ) c(\cdot|W) c(⋅∣W),同时保持 θ \theta θ 不变。然后,我们定义 M ∈ R K b × K n M\in\mathbb{R}^{K_b\times K_n} M∈RKb×Kn 为一个基础类别与新类别的相似性矩阵,其中 K b K_b Kb 和 K n K_n Kn 分别是 X b X_b Xb 和 X n X_n Xn 中类别的数量。矩阵 M M M 的元素 m i , j m_{i,j} mi,j 对应于与第 i i i 个基础类别相关的样本被分类为第 j j j 个新类别的比率:

m i , j = 1 ∣ X b i ∣ ∑ ( x b l , ⋅ ) ∈ X b i I [ j = arg ⁡ max ⁡ k = 1 K n ( c k ( f ( x b l ∣ θ ) ∣ W ) ) ] , (1) m_{i,j}=\frac{1}{|X_b^i|}\sum_{(x_b^l,\cdot)\in X_b^i}\mathbb{I}\left[j=\arg\max_{k=1}^{K_n}\left(c_k(f(x_b^l|\theta)|W)\right)\right], \tag{1} mi,j=∣Xbi∣1(xbl,⋅)∈Xbi∑I[j=argk=1maxKn(ck(f(xbl∣θ)∣W))],(1)

其中 c k ( f ( x ∣ θ ) ∣ W ) c_k(f(x|\theta)|W) ck(f(x∣θ)∣W) 是类别 k k k 的分类器输出 c ( ⋅ ∣ W ) c(\cdot|W) c(⋅∣W)。接下来,针对每个给定的新类别,得分最高的 B B B 个基础类别将被保留作为该类别的相关基础类别。图 2 展示了在 5-shot, 5-way 场景下使用此方法获得的示例结果。

4.2 Centroid associative alignment

假设 X n i X_n^i Xni 是属于第 i i i 个新类别 i ∈ Y n i \in Y_n i∈Yn 的实例集合,定义为 X n i = { ( x n j , y n j ) ∈ X n ∣ y n j = i } X_n^i = \{(x_n^j,y_n^j)\in X_n | y_n^j=i\} Xni={(xnj,ynj)∈Xn∣ynj=i},以及根据映射函数 g ( ⋅ ∣ M ) g(\cdot|M) g(⋅∣M) 属于相同新类别 i i i 的相关基础样本集合 X r b i X_{rb}^i Xrbi,定义为 X r b i = { ( x b j , y b j ) ∈ X r b ∣ g ( y j ∣ M ) = i } X_{rb}^i = \{(x_b^j,y_b^j)\in X_{rb} | g(y_j|M)=i\} Xrbi={(xbj,ybj)∈Xrb∣g(yj∣M)=i}。函数 g ( y j ∣ M ) : Y b → Y n g(y_j|M):Y_b\rightarrow Y_n g(yj∣M):Yb→Yn 根据相似性矩阵 M M M 将基础类别标签映射到新类别。我们希望找到一个对齐变换,用于匹配概率密度 p ( f ( x n i , k ∣ θ ) ) p(f(x_n^{i,k}|\theta)) p(f(xni,k∣θ)) 和 p ( f ( x r b i , l ∣ θ ) ) p(f(x_{rb}^{i,l}|\theta)) p(f(xrbi,l∣θ))。其中, x n i , k x_n^{i,k} xni,k 是新类别集中类别 i i i 的第 k k k 个元素, x r b i , l x_{rb}^{i,l} xrbi,l 是相关基础集中类别 i i i 的第 l l l 个元素。这种方法的额外好处是能够通过降低过拟合的水平,微调模型的所有参数 θ \theta θ 和 W W W。

我们提出了一种基于度量的质心分布对齐策略。该策略的核心思想是在对齐过程中强制类内紧致性。具体来说,我们在特征空间中显式地将来自第 i i i 个新类别的训练样本 X n i X_n^i Xni 推向其相关样本 X r b i X_{rb}^i Xrbi 的质心。 X r b i X_{rb}^i Xrbi 的质心 μ i \mu_i μi 通过以下公式计算:

μ i = 1 ∣ X r b i ∣ ∑ ( x j , ⋅ ) ∈ X r b i f ( x j ∣ θ ) , (2) \mu_i=\frac{1}{|X_{rb}^i|}\sum_{(x_j,\cdot)\in X_{rb}^i}f(x_j|\theta), \tag{2} μi=∣Xrbi∣1(xj,⋅)∈Xrbi∑f(xj∣θ),(2)

其中, N n N_n Nn 和 N r b N_{rb} Nrb 分别是 X n X_n Xn 和 X r b X_{rb} Xrb 中的样本数量。这使得质心对齐损失可以定义为:

L c a ( X n ) = − 1 N n N r b ∑ i = 1 K n ∑ ( x j , ⋅ ) ∈ X n i log ⁡ exp ⁡ [ − 1 2 ∥ f ( x j ∣ θ ) − μ i ∥ 2 ] ∑ k = 1 K n exp ⁡ [ − 1 2 ∥ f ( x j ∣ θ ) − μ k ∥ 2 ] . (3) L_{ca}(X_n)=-\frac{1}{N_n N_{rb}}\sum_{i=1}^{K_n}\sum_{(x_j,\cdot)\in X_n^i}\log\frac{\exp\left[-\frac{1}{2}\|f(x_j|\theta)-\mu_i\|^2\right]}{\sum_{k=1}^{K_n}\exp\left[-\frac{1}{2}\|f(x_j|\theta)-\mu_k\|^2\right]}. \tag{3} Lca(Xn)=−NnNrb1i=1∑Kn(xj,⋅)∈Xni∑log∑k=1Knexp[−21∥f(xj∣θ)−μk∥2]exp[−21∥f(xj∣θ)−μi∥2].(3)

我们的对齐策略与 ^3^ 相似,也是在元学习框架中使用公式 (3)。在我们的情况下,我们使用相同的公式来匹配分布。图 3 说明了我们提出的质心对齐,算法 1 展示了整体流程。首先,我们使用公式 (3) 更新特征提取网络 f ( ⋅ ∣ θ ) f(\cdot|\theta) f(⋅∣θ) 的参数。其次,整个网络通过分类损失 L c l f L_{clf} Lclf(定义见第5节)进行更新。

4.3 Adversarial associative alignment

作为另一种关联对齐策略,并受到WGAN \cite{arjovsky2017wasserstein} 的启发,我们尝试训练编码器 f ( ⋅ ∣ θ ) f(\cdot|\theta) f(⋅∣θ) 进行对抗性对齐,使用基于Wasserstein-1距离的条件判别网络 h ( ⋅ ∣ ϕ ) h(\cdot|\phi) h(⋅∣ϕ) 来衡量两个概率密度 p x p_x px 和 p y p_y py 之间的差异:

D ( p x , p y ) = sup ⁡ ∥ h ∥ L ≤ 1 E x ∼ p x [ h ( x ) ] − E x ∼ p y [ h ( x ) ] , (4) D(p_x,p_y)=\sup_{\|h\|L\leq 1}\mathbb{E}{x\sim p_x}[h(x)]-\mathbb{E}_{x\sim p_y}[h(x)], \tag{4} D(px,py)=∥h∥L≤1supEx∼px[h(x)]−Ex∼py[h(x)],(4)

其中, sup ⁡ \sup sup 是上确界, h h h 是一个1-Lipschitz函数。类似于Arjovsky等人 \cite{arjovsky2017wasserstein},我们使用参数化的判别网络 h ( ⋅ ∣ ϕ ) h(\cdot|\phi) h(⋅∣ϕ),其输入为 x n i x_n^i xni 或 x r b j x_{rb}^j xrbj 的特征嵌入与相应标签 y n i y_n^i yni 连接后的向量(标签 y n i y_n^i yni 编码为独热向量)。通过对判别器 h ( ⋅ ∣ ϕ ) h(\cdot|\phi) h(⋅∣ϕ) 进行条件处理,帮助其匹配新类别和对应的相关基础类别。

判别器 h ( ⋅ ∣ ϕ ) h(\cdot|\phi) h(⋅∣ϕ) 的训练损失为:

L h ( X n , X r b ) = 1 N r b ∑ ( x r b i , y r b i ) ∈ X r b h ( [ f ( x r b i ∣ θ ) , y r b i ] ∣ ϕ ) − 1 N n ∑ ( x n i , y n i ) ∈ X n h ( [ f ( x n i ∣ θ ) , y n i ] ∣ ϕ ) , (5) L_h(X_n,X_{rb})=\frac{1}{N_{rb}}\sum_{(x_{rb}^i,y_{rb}^i)\in X_{rb}}h([f(x_{rb}^i|\theta),y_{rb}^i]|\phi)-\frac{1}{N_n}\sum_{(x_n^i,y_n^i)\in X_n}h([f(x_n^i|\theta),y_n^i]|\phi), \tag{5} Lh(Xn,Xrb)=Nrb1(xrbi,yrbi)∈Xrb∑h([f(xrbi∣θ),yrbi]∣ϕ)−Nn1(xni,yni)∈Xn∑h([f(xni∣θ),yni]∣ϕ),(5)

其中, [ ⋅ ] [\cdot] [⋅] 是连接操作符。接着,编码器参数 θ \theta θ 通过以下损失函数更新:

L a a ( X n ) = 1 K n ∑ ( x n i , y n i ) ∈ X n h ( [ f ( x n i ∣ θ ) , y n i ] ∣ ϕ ) . (6) L_{aa}(X_n)=\frac{1}{K_n}\sum_{(x_n^i,y_n^i)\in X_n}h([f(x_n^i|\theta),y_n^i]|\phi). \tag{6} Laa(Xn)=Kn1(xni,yni)∈Xn∑h([f(xni∣θ),yni]∣ϕ).(6)

算法2总结了我们的对抗性对齐方法。首先,我们使用公式(5)更新判别器 h ( ⋅ ∣ ϕ ) h(\cdot|\phi) h(⋅∣ϕ) 的参数。类似于WGAN \cite{arjovsky2017wasserstein},我们执行 n c r i t i c n_{critic} ncritic 次迭代来优化 h h h,然后使用公式(6)更新 f ( ⋅ ∣ θ ) f(\cdot|\theta) f(⋅∣θ)。最后,通过分类损失 L c l f L_{clf} Lclf(定义见第5节)来更新整个网络。

5 Establishing a strong baseline

在第6节评估我们的对齐策略之前,我们首先通过遵循最近的文献建立一个强有力的基线用于比较。特别是,我们基于Chen等人 ^5^ 的工作,但在预训练阶段引入了不同的损失函数以及情景早停机制。

5.1 分类损失函数

Deng等人 已表明,加法角度边距(以下简称为"arcmax")在面部识别中优于其他度量学习算法。arcmax 具有度量学习的性质,因为它在归一化的超球体上强制施加测地距离边距惩罚。我们认为这对于小样本分类是有益的,因为它有助于保持类别簇紧凑且相互分离。

令 z z z 为特征空间中 x x x 的表示。如 \cite{deng2019arcface} 所述,我们将 logit 转换为 w j ⊤ z = ∥ w j ∥ ∥ z ∥ cos ⁡ ϕ j w_j^\top z = \|w_j\|\|z\|\cos\phi_j wj⊤z=∥wj∥∥z∥cosϕj,其中 ϕ j \phi_j ϕj 是 z z z 与 w j w_j wj 之间的夹角, w j w_j wj 是权重矩阵 W W W 的第 j j j 列。通过 l 2 l_2 l2 归一化,每个权重 ∥ w j ∥ = 1 \|w_j\|=1 ∥wj∥=1。arcmax 在超球体上的分布样本中添加一个角度边距 m m m:

L c l f = − 1 N ∑ i = 1 N log ⁡ exp ⁡ ( s cos ⁡ ( ϕ y i + m ) ) exp ⁡ ( s cos ⁡ ( ϕ y i + m ) ) + ∑ ∀ j ≠ y i exp ⁡ ( s cos ⁡ ϕ j ) , (7) L_{clf}=-\frac{1}{N}\sum_{i=1}^{N}\log\frac{\exp(s\cos(\phi_{y_i}+m))}{\exp(s\cos(\phi_{y_i}+m))+\sum_{\forall j\neq y_i}\exp(s\cos\phi_j)}, \tag{7} Lclf=−N1i=1∑Nlogexp(scos(ϕyi+m))+∑∀j=yiexp(scosϕj)exp(scos(ϕyi+m)),(7)

其中, s s s 是 z z z 分布在超球体上的半径, N N N 是样本数量, m m m 和 s s s 是超参数(详见第6.1节)。该边距的总体目标是强制类间差异和类内紧凑性。

5.2 情景早停机制

在预训练阶段使用固定数量的训练轮次是常见的做法(例如 ^5^, ^1^, ^3^, ^4^),但这可能会在微调阶段影响性能。通过观察验证误差,我们发现预训练阶段需要使用早停机制(验证误差图见补充材料)。因此,我们在预训练时使用了基于验证集的情景早停机制,具体操作是在最近几个训练轮次的窗口期内,当平均准确率开始下降时停止训练,并选择窗口期内表现最好的模型作为最终结果。

6 Experimental validation

在接下来的部分中,我们将对所提出的小样本学习的关联对齐策略进行实验评估和比较。首先,我们介绍所使用的数据集,并评估第5节中提出的强基线。

6.1 数据集与实现细节

数据集

我们在四个基准数据集上进行了实验:mini-ImageNet ^4^、tieredImageNet ^44^ 和 FC100 ^25^ 用于通用物体识别;CUB200-2011 (CUB) ^50^ 用于细粒度图像分类。mini-ImageNet 是 ImageNet ILSVRC-12 数据集 ^51^ 的一个子集,包含 100 个类别,每个类别 600 个样本。我们使用了 Ravi 和 Larochelle ^2^ 的相同数据集划分,其中基础类别、验证类别和新类别分别包含 64、16 和 20 个类别。tieredImageNet ^44^ 是一个更大的基准数据集,同样是 ImageNet ILSVRC-12 数据集 ^51^ 的子集,包含 351 个基础类别、97 个验证类别和 160 个新类别。FC100 数据集 ^25^ 源自 CIFAR-100 ^52^,包含 100 个类别,分为 20 个超类,目的是减少类别之间的重叠。基础类别、验证类别和新类别的划分分别为 60、20 和 20 个类别,分别属于 12、5 和 5 个超类。CUB 数据集 ^50^ 包含来自 200 个鸟类类别的 11,788 张图像。我们使用了 Hilliard 等人 ^53^ 的相同划分,基础类别、验证类别和新类别分别包含 100、50 和 50 个类别。

网络架构

我们使用三种特征学习器 f ( ⋅ ∣ θ ) f(\cdot|\theta) f(⋅∣θ) 的骨干网络进行实验:1)一个4层卷积网络 ("Conv4"),输入图像分辨率为 84 × 84,类似于 ^1^, ^2^, ^3^; 2)一个 ResNet-18 ^54^,输入大小为 224 × 224;3)一个 28 层宽残差网络 ("WRN-28-10") ^55^,输入大小为 80 × 80,包含三步降维。我们使用一个 1024 维的单隐藏层多层感知器(MLP)作为判别网络 h ( ⋅ ∣ ϕ ) h(\cdot|\phi) h(⋅∣ϕ)(参见第4.3节)。

实现细节

回顾第3节的内容,训练分为两个阶段:1)使用基础类别 X b X_b Xb 进行预训练;2)在新类别 X n X_n Xn 上进行微调。对于预训练,我们使用了第5.2节中的早停算法,窗口大小为50。标准的数据增强方法(即颜色抖动、随机裁剪和左右翻转,类似于 ^5^)被采用,Adam算法的学习率设置为 1 0 − 3 10^{-3} 10−3,批量大小为64,用于预训练和微调。arcmax 损失函数(公式 (7))的参数配置为 s = 20 s=20 s=20 和 m = 0.1 m=0.1 m=0.1,通过交叉验证设置。在微调阶段,情景通过从新类别 X n X_n Xn 中随机选择 N = 5 N=5 N=5 个类别来定义,每个类别随后采样 k k k 个样本(我们的实验中 k = 1 k=1 k=1 和 k = 5 k=5 k=5)。与 Chen 等人 ^5^ 类似,这一阶段没有使用标准的数据增强。在固定编码器的情况下,我们使用情景交叉验证来找到 s s s 和 m m m 的最佳值。具体来说,对于 Conv4 骨干网络, s s s 和 m m m 分别为 (5, 0.1),对于 WRN-28-10 和 ResNet-18 骨干网络, s s s 和 m m m 分别为 (5, 0.01)。Adam 的学习率在质心对齐和对抗对齐中分别设置为 1 0 − 3 10^{-3} 10−3 和 1 0 − 5 10^{-5} 10−5。类似于 ^8^,判别器 h ( ⋅ ∣ ϕ ) h(\cdot|\phi) h(⋅∣ϕ) 的训练迭代次数(算法2的内循环)为5次。我们将相关基础类别的数量固定为 B = 10 B=10 B=10(关于 B B B 的消融研究见补充材料)。因此,我们在 mini-ImageNet 数据集中使用了相对较多的类别(从64个可用类别中选取50个类别)。

6.2 mini-ImageNet 和 CUB 使用浅层 Conv4 骨干网络

我们首先使用 Conv4 骨干网络在 mini-ImageNet(关于更多分类任务的评估见补充材料)和 CUB 数据集上评估了第5节中提出的新基线以及我们的关联对齐策略,相关结果展示在表1中。我们注意到,在 1-shot 和 5-shot 场景下,arcmax 配合早停机制在 mini-ImageNet 和 CUB 数据集上的表现均优于 cosmax 和 softmax(无论是否使用早停机制)。我们遵循了 ^5^ 中给出的相同数据集划分配置、网络架构和实现细节进行测试。我们的质心关联对齐在所有实验中均超过了现有技术水平,在 mini-ImageNet 上,1-shot 和 5-shot 实验中比我们的基线分别提升了 1.24% 和 2.38%。对于 CUB 数据集,对抗对齐在质心对齐的基础上进一步提升了 0.6% 和 0.87%。

6.3 mini-ImageNet 和 tieredImageNet 使用深层骨干网络

现在我们使用两个深层骨干网络 ResNet-18 和 WRN-28-10 对我们提出的关联对齐策略在 mini-ImageNet 和 tieredImageNet 数据集上进行评估。表2比较了我们提出的对齐方法与几种其他方法的表现。

mini-ImageNet

我们的质心关联对齐策略在 ResNet-18 和 WRN-28-10 骨干网络上的 1-shot 和 5-shot 分类任务中均取得了最佳成绩,分别比 MetaOptNet ^31^ 和 Robust-dist++ ^32^ 取得了显著的绝对精度提升,分别为 2.72% 和 1.68%。唯一一种在某个任务中表现优于我们方法的情况是 MetaOptNet,在 1-shot 中超过了我们的方法 2.76%。对于 WRN-28-10 骨干网络,我们在 1-shot 中与 Transductive-ft ^14^ 取得了相似的结果,但在 5-shot 中超出了他们的方法 4.45%。值得注意的是,与 IDeMe-Net ^33^、SNAIL ^56^ 和 TADAM ^25^ 不同,我们的方法在不引入额外模块的情况下,在这些方法的基础上取得了显著改进。

tieredImageNet

表2还显示了我们的质心关联对齐策略在 tieredImageNet 数据集的 1-shot 和 5-shot 场景中优于比较方法。特别是,我们的质心对齐在使用 ResNet-18 时比 MetaOptNet ^31^ 分别提高了 3.3% 和 4.41%。同样,使用 WRN-28-10 时,我们的质心对齐策略比最佳的比较方法分别提高了 1.06% 和 1.11%。

6.4 FC100 and CUB with a ResNet-18 backbone

我们在表3中展示了使用 ResNet-18 骨干网络在 FC100 和 CUB 数据集上的额外结果。在 FC100 数据集中,我们的质心对齐在 1-shot 和 5-shot 中分别比 MTL ^57^ 提升了 0.73% 和 2.14%。在 CUB 数据集中,我们的关联对齐方法也显示了改进,其中质心对齐在 1-shot 中比 ProtoNet ^3^ 提升了 2.3%,在 5-shot 中提升了 1.2%。我们在 CUB 数据集中还超越了 Robust-20 ^32^(一个由20个网络组成的集成方法),分别提升了 4.03% 和 4.15%。

6.5 跨领域评估

我们还在跨领域图像分类任务中评估了我们的对齐策略。在这里,按照 ^5^,基础类别取自 mini-ImageNet,而新类别取自 CUB。正如表4所示,我们提出的质心对齐方法在 1-shot 和 5-shot 中相较基线分别提升了 1.3% 和 5.4%。对抗对齐在 1-shot 中比基线低 1.2%,但在 5-shot 中提升了 5.9%。总体而言,我们的质心对齐方法在 1-shot 和 5-shot 中相较于现有技术(即 cosmax ^5^)的绝对精度分别提升了 3.8% 和 6.0%。我们还在 mini-ImageNet 到 CUB 的跨领域任务中的 5-shot 中超越了 Robust-20 ^32^(一个由 20 个网络组成的集成方法),提升了 4.65%。有人可能会认为 mini-ImageNet 中的三个鸟类类别(即 house finch、robin 和 toucan)会对跨领域评估产生偏差。通过排除这些类别重新训练后,结果仍然表现出类似的性能,如表4所示。

7 讨论

本文提出了用于小样本图像分类的关联对齐思想,该方法通过使整个网络参与训练,同时避免过拟合,从而提高了泛化性能。为此,我们设计了一个过程,用于为每个新类检测相关的基础类别。然后,我们提出了一种基于质心的对齐策略,以保持类内对齐,同时进行分类任务的更新。我们还探索了一种对抗性对齐策略作为替代方案。实验表明,我们的方法,特别是基于质心的对齐策略,在几乎所有场景中均优于以往的工作。我们工作的当前局限性为未来研究提供了有趣的方向。首先,基于第4节的对齐方法可能包括基础类别中的无关样本,因此使用分类语义信息可以帮助过滤掉不良样本。分析表明,使用 ResNet-18 在 mini-ImageNet 的 5-way 1-shot 和 5-shot 任务中,约有12%的样本使用质心最近邻标准成为了分布外(OOD)样本。在每次迭代中丢弃这些 OOD 样本对分类结果没有显著影响。其次,某些基础类别的多模态性似乎是不可避免的,这可能会相比于我们质心对齐策略假设的单模态情况降低泛化性能。因此,研究使用混合家族模型可能会提高泛化性能。最后,我们的算法在一次迭代中仅计算一次相关基础类别,随后在整个情景训练过程中保持不变,未考虑到情景训练期间应用于潜在空间的变化。因此,在微调阶段,更复杂的动态采样机制可能会有所帮助。


  1. Finn, C., Abbeel, P., Levine, S.: Model-agnostic meta-learning for fast adaptation of deep networks. In: The International Conference on Machine Learning (2017) ↩︎ ↩︎ ↩︎ ↩︎ ↩︎

  2. Ravi, S., Larochelle, H.: Optimization as a model for few-shot learning (2016) ↩︎ ↩︎ ↩︎ ↩︎

  3. Snell, J., Swersky, K., Zemel, R.: Prototypical networks for few-shot learning. In: Advances in Neural Information Processing Systems (2017) ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎

  4. Vinyals, O., Blundell, C., Lillicrap, T., Wierstra, D., et al.: Matching networks for one shot learning. In: Advances in Neural Information Processing Systems (2016) ↩︎ ↩︎ ↩︎ ↩︎

  5. Chen, W.Y., Liu, Y.C., Kira, Z., Wang, Y.C.F., Huang, J.B.: A closer look at few-shot classification. arXiv preprint arXiv:1904.04232 (2019) ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎

  6. Gidaris, S., Komodakis, N.: Dynamic few-shot visual learning without forgetting. In: The Conference on Computer Vision and Pattern Recognition (2018) ↩︎ ↩︎

  7. Qi, H., Brown, M., Lowe, D.G.: Low-shot learning with imprinted weights. In: The Conference on Computer Vision and Pattern Recognition (2018) ↩︎ ↩︎

  8. Arjovsky, M., Chintala, S., Bottou, L.: Wasserstein gan. arXiv preprint arXiv:1701.07875 (2017) ↩︎ ↩︎

  9. Deng, J., Guo, J., Xue, N., Zafeiriou, S.: Arcface: Additive angular margin loss for deep face recognition. In: The Conference on Computer Vision and Pattern Recognition (2019) ↩︎

  10. Zhao, F., Zhao, J., Yan, S., Feng, J.: Dynamic conditional networks for few-shot learning. In: The European Conference on Computer Vision (2018) ↩︎

  11. Zhu, L., Yang, Y.: Compound memory networks for few-shot video classification. In: The European Conference on Computer Vision (2018) ↩︎

  12. Gui, L.Y., Wang, Y.X., Ramanan, D., Moura, J.M.F.: Few-shot human motion prediction via meta-learning. In: The European Conference on Computer Vision (2018) ↩︎

  13. Wang, Y.X., Ramanan, D., Hebert, M.: Meta-learning to detect rare objects. In: The International Conference on Computer Vision (2019) ↩︎ ↩︎

  14. Dhillon, G.S., Chaudhari, P., Ravichandran, A., Soatto, S.: A baseline for few-shot image classification. arXiv preprint arXiv:1909.02729 (2019) ↩︎ ↩︎

  15. Rusu, A.A., Rao, D., Sygnowski, J., Vinyals, O., Pascanu, R., Osindero, S., Hadsell, R.: Meta-learning with latent embedding optimization. arXiv preprint arXiv:1807.05960 (2018) ↩︎

  16. Vilalta, R., Drissi, Y.: A perspective view and survey of meta-learning. Artificial Intelligence Review (2002) ↩︎

  17. Yoon, S.W., Seo, J., Moon, J.: Tapnet: Neural network augmented with taskadaptive projection for few-shot learning. arXiv preprint arXiv:1905.06549 (2019) ↩︎

  18. Finn, C., Xu, K., Levine, S.: Probabilistic model-agnostic meta-learning. In: Advances in Neural Information Processing Systems (2018) ↩︎

  19. Kim, T., Yoon, J., Dia, O., Kim, S., Bengio, Y., Ahn, S.: Bayesian model-agnostic meta-learning. arXiv preprint arXiv:1806.03836 (2018) ↩︎

  20. Bertinetto, L., Henriques, J.F., Torr, P., Vedaldi, A.: Meta-learning with differentiable closed-form solvers. In: The International Conference on Learning Representations (2019) ↩︎

  21. Garcia, V., Bruna, J.: Few-shot learning with graph neural networks. arXiv preprint arXiv:1711.04043 (2017) ↩︎

  22. Kim, J., Oh, T.H., Lee, S., Pan, F., Kweon, I.S.: Variational prototyping-encoder: One-shot learning with prototypical images. In: The Conference on Computer Vision and Pattern Recognition (2019) ↩︎

  23. Li, W., Wang, L., Xu, J., Huo, J., Gao, Y., Luo, J.: Revisiting local descriptor based image-to-class measure for few-shot learning. In: The Conference on Computer Vision and Pattern Recognition (2019) ↩︎

  24. Lifchitz, Y., Avrithis, Y., Picard, S., Bursuc, A.: Dense classification and implanting for few-shot learning. In: The Conference on Computer Vision and Pattern Recognition (2019) ↩︎

  25. Oreshkin, B., López, P.R., Lacoste, A.: Tadam: Task dependent adaptive metric for improved few-shot learning. In: Advances in Neural Information Processing Systems (2018) ↩︎ ↩︎ ↩︎ ↩︎

  26. Sung, F., Yang, Y., Zhang, L., Xiang, T., Torr, P.H., Hospedales, T.M.: Learning to compare: Relation network for few-shot learning. In: The Conference on Computer Vision and Pattern Recognition (2018) ↩︎

  27. Tseng, H.Y., Lee, H.Y., Huang, J.B., Yang, M.H.: Cross-domain few-shot classification via learned feature-wise transformation. arXiv preprint arXiv:2001.08735 (2020) ↩︎

  28. Wertheimer, D., Hariharan, B.: Few-shot learning with localization in realistic settings. In: The Conference on Computer Vision and Pattern Recognition (2019) ↩︎

  29. Zhang, J., Zhao, C., Ni, B., Xu, M., Yang, X.: Variational few-shot learning. In: The International Conference on Computer Vision (2019) ↩︎

  30. Wang, Y.X., Hebert, M.: Learning to learn: Model regression networks for easy small sample learning. In: The European Conference on Computer Vision. Springer (2016) ↩︎

  31. Lee, K., Maji, S., Ravichandran, A., Soatto, S.: Meta-learning with differentiable convex optimization. In: The Conference on Computer Vision and Pattern Recognition (2019) ↩︎ ↩︎ ↩︎

  32. Dvornik, N., Schmid, C., Mairal, J.: Diversity with cooperation: Ensemble methods for few-shot classification. In: The International Conference on Computer Vision (2019) ↩︎ ↩︎ ↩︎ ↩︎

  33. Chen, Z., Fu, Y., Wang, Y.X., Ma, L., Liu, W., Hebert, M.: Image deformation meta-networks for one-shot learning. In: The Conference on Computer Vision and Pattern Recognition (2019) ↩︎ ↩︎ ↩︎ ↩︎ ↩︎

  34. Chu, W.H., Li, Y.J., Chang, J.C., Wang, Y.C.F.: Spot and learn: A maximumentropy patch sampler for few-shot image classification. In: The Conference on Computer Vision and Pattern Recognition (2019) ↩︎

  35. Gao, H., Shou, Z., Zareian, A., Zhang, H., Chang, S.F.: Low-shot learning via covariance-preserving adversarial augmentation networks. In: Advances in Neural Information Processing Systems (2018) ↩︎ ↩︎

  36. Gidaris, S., Komodakis, N.: Generating classification weights with gnn denoising autoencoders for few-shot learning. arXiv preprint arXiv:1905.01102 (2019) ↩︎

  37. Hariharan, B., Girshick, R.: Low-shot visual recognition by shrinking and hallucinating features. In: The International Conference on Computer Vision (2017) ↩︎ ↩︎

  38. Mehrotra, A., Dukkipati, A.: Generative adversarial residual pairwise networks for one shot learning. arXiv preprint arXiv:1703.08033 (2017) ↩︎

  39. Schwartz, E., Karlinsky, L., Shtok, J., Harary, S., Marder, M., Kumar, A., Feris, R., Giryes, R., Bronstein, A.: Delta-encoder: an effective sample synthesis method for few-shot object recognition. In: Advances in Neural Information Processing Systems (2018) ↩︎

  40. Wang, Y.X., Girshick, R., Hebert, M., Hariharan, B.: Low-shot learning from imaginary data. In: The Conference on Computer Vision and Pattern Recognition (2018) ↩︎ ↩︎

  41. Zhang, H., Zhang, J., Koniusz, P.: Few-shot learning via saliency-guided hallucination of samples. In: The Conference on Computer Vision and Pattern Recognition (2019) ↩︎

  42. Zhang, H., Cisse, M., Dauphin, Y.N., Lopez-Paz, D.: mixup: Beyond empirical risk minimization. arXiv preprint arXiv:1710.09412 (2017) ↩︎

  43. Gidaris, S., Bursuc, A., Komodakis, N., Pérez, P., Cord, M.: Boosting few-shot visual learning with self-supervision. In: The International Conference on Computer Vision (2019) ↩︎

  44. Ren, M., Triantafillou, E., Ravi, S., Snell, J., Swersky, K., Tenenbaum, J.B., Larochelle, H., Zemel, R.S.: Meta-learning for semi-supervised few-shot classification. arXiv preprint arXiv:1803.00676 (2018) ↩︎ ↩︎ ↩︎

  45. Li, X., Sun, Q., Liu, Y., Zhou, Q., Zheng, S., Chua, T.S., Schiele, B.: Learning to self-train for semi-supervised few-shot classification. In: Advances in Neural Information Processing Systems (2019) ↩︎

  46. Wang, Y.X., Hebert, M.: Learning from small sample sets by combining unsupervised meta-training with cnns. In: Advances in Neural Information Processing Systems (2016) ↩︎

  47. Liu, B., Wu, Z., Hu, H., Lin, S.: Deep metric transfer for label propagation with limited annotated data. In: The IEEE International Conference on Computer Vision (ICCV) Workshops (Oct 2019) ↩︎

  48. Lim, J.J., Salakhutdinov, R.R., Torralba, A.: Transfer learning by borrowing examples for multiclass object detection. In: Advances in Neural Information Processing Systems (2011) ↩︎ ↩︎ ↩︎

  49. Jiang, H., Wang, R., Shan, S., Chen, X.: Learning class prototypes via structure alignment for zero-shot recognition. In: The European Conference on Computer Vision (2018) ↩︎

  50. Wah, C., Branson, S., Welinder, P., Perona, P., Belongie, S.: The caltech-ucsd birds-200-2011 dataset (2011) ↩︎ ↩︎

  51. Russakovsky, O., Deng, J., Su, H., Krause, J., Satheesh, S., Ma, S., Huang, Z., Karpathy, A., Khosla, A., Bernstein, M., et al.: Imagenet large scale visual recognition challenge. The International Journal of Computer Vision (2015) ↩︎ ↩︎

  52. Krizhevsky, A., Nair, V., Hinton, G.: Cifar-10 and cifar-100 datasets. URl: https://www. cs. toronto. edu/kriz/cifar. html (2009) ↩︎

  53. Hilliard, N., Phillips, L., Howland, S., Yankov, A., Corley, C.D., Hodas, N.O.: Few-shot learning with metric-agnostic conditional embeddings. arXiv preprint arXiv:1802.04376 (2018) ↩︎

  54. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: The Conference on Computer Vision and Pattern Recognition (2016) ↩︎

  55. Sergey, Z., Nikos, K.: Wide residual networks. In: British Machine Vision Conference (2016) ↩︎

  56. Mishra, N., Rohaninejad, M., Chen, X., Abbeel, P.: A simple neural attentive meta-learner. arXiv preprint arXiv:1707.03141 (2017) ↩︎

  57. Sun, Q., Liu, Y., Chua, T.S., Schiele, B.: Meta-transfer learning for few-shot learning. In: The Conference on Computer Vision and Pattern Recognition (2019) ↩︎

相关推荐
四口鲸鱼爱吃盐10 分钟前
Pytorch | 从零构建MobileNet对CIFAR10进行分类
人工智能·pytorch·分类
苏言の狗11 分钟前
Pytorch中关于Tensor的操作
人工智能·pytorch·python·深度学习·机器学习
bastgia1 小时前
Tokenformer: 下一代Transformer架构
人工智能·机器学习·llm
菜狗woc1 小时前
opencv-python的简单练习
人工智能·python·opencv
15年网络推广青哥1 小时前
国际抖音TikTok矩阵运营的关键要素有哪些?
大数据·人工智能·矩阵
weixin_387545641 小时前
探索 AnythingLLM:借助开源 AI 打造私有化智能知识库
人工智能
engchina2 小时前
如何在 Python 中忽略烦人的警告?
开发语言·人工智能·python
paixiaoxin3 小时前
CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
人工智能·深度学习·机器学习·生成对抗网络·计算机视觉·ocr·.net
OpenCSG3 小时前
CSGHub开源版本v1.2.0更新
人工智能
weixin_515202493 小时前
第R3周:RNN-心脏病预测
人工智能·rnn·深度学习