标题:
Bootstrap Your Own Latent
A New Approach to Self-Supervised Learning
作者:Jean-Bastien Grill,Florian Strub,Florent Altché,Corentin Tallec,Pierre H. Richemond,Elena Buchatskaya,Carl Doersch,Bernardo Avila Pires,Zhaohan Daniel Guo,Mohammad Gheshlaghi Azar,Bilal Piot,Koray Kavukcuoglu,Rémi Munos,Michal Valko
单位:DeepMind, Imperial College
摘要
我们提出了一种新的自监督图像表示学习方法------Bootstrap Your Own Latent (BYOL) 。BYOL依赖于两个神经网络,分别称为在线网络和目标网络,它们通过相互作用进行学习。从一张图像的增强视图中,我们训练在线网络来预测同一图像在不同增强视图下目标网络的表示。同时,目标网络使用在线网络的慢速移动平均值进行更新。虽然现有最先进的方法依赖于负样本对,BYOL在不使用负样本对的情况下,达到了新的性能水平。BYOL在ImageNet上使用ResNet-50架构进行线性评估时,获得了74.3%的top-1分类准确率,使用更大的ResNet网络获得了79.6%的准确率。我们证明了BYOL在迁移学习和半监督学习基准测试中表现得与当前最先进方法相当或更好。我们的实现和预训练模型已在GitHub上提供。(点击跳转至BYOL_GitHub)
1 介绍
图像表示学习是计算机视觉中的一个关键挑战,因为它能够有效地训练下游任务。许多不同的训练方法被提出用于学习这样的表示,通常依赖于视觉的预训练任务。在这些方法中,当前最先进的对比方法是通过减少相同图像不同增强视图的表示之间的距离(称为正样本对),并增大来自不同图像的增强视图的表示之间的距离(称为负样本对)来进行训练。这些方法需要对负样本对进行仔细处理,例如依赖于大的批次大小、存储库或定制的挖掘策略来检索负样本。此外,它们的性能在很大程度上取决于图像增强的选择。
在本文中,我们提出了一种新的算法------Bootstrap Your Own Latent (BYOL) ,用于图像表示的自监督学习。BYOL在不使用负样本对的情况下,达到了比现有最先进对比方法更高的性能。它通过迭代地引导网络的输出,作为增强表示的目标。此外,BYOL在图像增强选择上比对比方法更具鲁棒性;我们怀疑,这种改进的鲁棒性主要得益于它不依赖于负样本对。虽然以前基于自举的方法使用了伪标签、聚类索引或少量标签,我们提出直接自举表示。特别地,BYOL使用两个神经网络,分别称为在线网络和目标网络,它们相互作用并从对方学习。通过图像的增强视图,BYOL训练在线网络来预测目标网络对同一图像的另一个增强视图的表示。虽然这一目标允许塌陷的解决方案,例如对所有图像输出相同的向量,但我们通过实验表明BYOL并未收敛到这种解决方案。我们假设,在线网络增加预测器,以及使用在线网络参数的移动平均值作为目标网络,鼓励编码越来越多的信息,避免了塌陷的解决方案。
我们在ImageNet和其他视觉基准上使用ResNet架构评估了BYOL所学习的表示。在ImageNet上的线性评估协议中,即在冻结的表示上训练线性分类器,BYOL使用标准的ResNet-50达到了74.3%的top-1准确率,使用较大的ResNet达到了79.6%的top-1准确率。在ImageNet的半监督和迁移学习设置中,我们的结果与当前的最先进方法相当或更好。
我们的贡献包括:
- 我们提出了BYOL,一种自监督表示学习方法,它在ImageNet的线性评估协议下,能够在不使用负样本对的情况下实现最先进的结果。
- 我们表明,BYOL所学习的表示在半监督和迁移学习基准上表现得比现有技术更好。
- 我们展示了与对比方法相比,BYOL对批次大小和图像增强的选择更加具有鲁棒性,尤其是当仅使用随机裁剪作为图像增强时,BYOL的性能下降幅度远小于SimCLR(一种强大的对比学习基准)。
2. 相关工作
大多数用于表示学习的无监督方法可以分为生成式或判别式方法。生成式方法通过构建数据和潜在嵌入的分布,并使用所学习的嵌入作为图像表示。这些方法大多依赖于图像的自动编码或对抗学习,联合建模数据和表示。生成式方法通常直接在像素空间中运行,这种方法计算开销大,并且图像生成所需的高细节水平可能对表示学习来说并不是必要的。
在判别式方法中,对比方法目前在自监督学习中取得了最先进的性能。这些方法通过使相同图像不同视图的表示更加接近,并将来自不同图像的视图的表示分开,避免了在像素空间中的代价高昂的生成步骤。对比方法通常需要将每个样本与许多其他样本进行比较才能取得良好效果,因此引发了是否必须使用负样本对来防止塌陷的问题。
DeepCluster部分回答了这个问题。它使用之前版本的表示生成目标以训练新的表示,它通过聚类数据点,并使用每个样本的聚类索引作为新表示的分类目标。虽然避免了使用负样本对,但这需要一个代价高昂的聚类过程,并需要特定的预防措施以避免塌陷到简单的解决方案。
某些自监督方法不依赖于对比,而是通过使用辅助的手工设计的预测任务来学习其表示。特别是,相对的patch预测、灰度图像的着色、图像修复、图像拼图、图像超分辨率和几何变换已经被证明是有用的。然而,即便有合适的架构,这些方法的性能依然被对比方法超越。
我们的方法与Bootstrapped Latents预测有一些相似之处,这是一种用于强化学习的自监督表示学习技术。PBL同时训练代理的历史表示和未来观察的编码。观察编码用于训练代理的表示,代理的表示作为目标训练观察编码。与PBL不同,BYOL使用表示的慢速移动平均作为目标,不需要第二个网络。
3. 方法
我们首先动机阐述我们的方法,接下来在第3.1节中详细说明其细节。许多成功的自监督学习方法构建在 [63] 中引入的跨视图预测框架之上。通常,这些方法通过预测同一图像的不同视图(例如,不同的随机裁剪)来学习表示。许多此类方法直接将预测问题转换到表示空间中:一个图像的增强视图的表示应该能够预测同一图像的另一个增强视图的表示。然而,直接在表示空间中进行预测可能会导致表示塌陷:例如,一个在所有视图中恒定的表示总是能够完全预测自身。对比方法通过将预测问题重新表述为区分问题来规避这一问题:从一个增强视图的表示中,他们学会区分同一图像的另一个增强视图的表示与不同图像的增强视图的表示。在大多数情况下,这防止了训练出现塌陷表示。然而,这种区分方法通常需要将每个增强视图的表示与许多负样本进行比较,以找到那些足够接近的,使区分任务具有挑战性。在这项工作中,我们的任务是找出这些负样本是否对于防止塌陷是不可或缺的,同时保持高性能。
为了防止塌陷,一个直接的解决方案是使用一个随机初始化的固定网络来生成我们预测的目标。尽管这种方法可以避免塌陷,但经验上它并不能产生非常好的表示。尽管如此,值得注意的是,使用这种方法得到的表示已经比最初的固定表示要好得多。在我们的消融研究中(见第5节),我们应用了这种通过预测一个固定的随机初始化网络来进行训练的程序,并在ImageNet上的线性评估协议中实现了18.8%的top-1准确率(表5a),而随机初始化的网络本身仅能实现1.4%的准确率。这一实验发现是BYOL的核心动机:从一个给定的表示(称为目标)开始,我们可以通过预测目标表示来训练一个新的、潜在增强的表示(称为在线表示)。通过迭代这一程序,使用后续的在线网络作为新的目标网络进行进一步的训练,我们可以期望构建出质量逐步提升的表示序列。在实践中,BYOL通过迭代地优化其表示来推广这一自举过程,但使用在线网络的慢速移动平均值作为目标网络,而不是使用固定的检查点。
3.1 BYOL描述
BYOL的目标是学习一个表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> y θ y_{\theta} </math>yθ,该表示随后可以用于下游任务。正如前文所述,BYOL使用两个神经网络进行学习:在线网络和目标网络。在线网络由一组权重 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 定义,分为三个阶段:编码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ f_{\theta} </math>fθ,投影器 <math xmlns="http://www.w3.org/1998/Math/MathML"> g θ g_{\theta} </math>gθ,以及预测器 <math xmlns="http://www.w3.org/1998/Math/MathML"> q θ q_{\theta} </math>qθ,如图2所示。目标网络与在线网络具有相同的架构,但使用一组不同的权重 <math xmlns="http://www.w3.org/1998/Math/MathML"> ξ \xi </math>ξ。目标网络为在线网络提供回归目标,其参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> ξ \xi </math>ξ 是在线网络参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 的指数移动平均。更具体地,给定目标衰减率 <math xmlns="http://www.w3.org/1998/Math/MathML"> τ ∈ [ 0 , 1 ] \tau \in [0, 1] </math>τ∈[0,1],在每次训练步骤后我们进行以下更新:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ξ ← τ ξ + ( 1 − τ ) θ . \xi \leftarrow \tau \xi + (1 - \tau)\theta. </math>ξ←τξ+(1−τ)θ.
给定一组图像 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D,图像 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∼ D x \sim D </math>x∼D 从 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D 中均匀采样,并从两个图像增强分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> T ′ T' </math>T′ 分别生成增强视图 <math xmlns="http://www.w3.org/1998/Math/MathML"> v = t ( x ) v = t(x) </math>v=t(x) 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> v ′ = t ′ ( x ) v' = t'(x) </math>v′=t′(x),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> t ∼ T t \sim T </math>t∼T, <math xmlns="http://www.w3.org/1998/Math/MathML"> t ′ ∼ T ′ t' \sim T' </math>t′∼T′。从第一个增强视图 <math xmlns="http://www.w3.org/1998/Math/MathML"> v v </math>v,在线网络输出表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> y θ = f θ ( v ) y_{\theta} = f_{\theta}(v) </math>yθ=fθ(v) 和投影 <math xmlns="http://www.w3.org/1998/Math/MathML"> z θ = g θ ( y ) z_{\theta} = g_{\theta}(y) </math>zθ=gθ(y)。目标网络从第二个增强视图 <math xmlns="http://www.w3.org/1998/Math/MathML"> v ′ v' </math>v′ 输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ξ ′ = f ξ ( v ′ ) y'{\xi} = f{\xi}(v') </math>yξ′=fξ(v′) 和目标投影 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ξ ′ = g ξ ( y ′ ) z'{\xi} = g{\xi}(y') </math>zξ′=gξ(y′)。然后,我们对 <math xmlns="http://www.w3.org/1998/Math/MathML"> z θ z_{\theta} </math>zθ 进行预测,并输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> q θ ( z θ ) q_{\theta}(z_{\theta}) </math>qθ(zθ)。我们对 <math xmlns="http://www.w3.org/1998/Math/MathML"> q θ ( z θ ) q_{\theta}(z_{\theta}) </math>qθ(zθ) 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ξ ′ z'{\xi} </math>zξ′ 进行 <math xmlns="http://www.w3.org/1998/Math/MathML"> l 2 l_2 </math>l2 归一化,即:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q θ ( z θ ) = q θ ( z θ ) ∥ q θ ( z θ ) ∥ 2 , z ξ ′ = z ξ ′ ∥ z ξ ′ ∥ 2 . q{\theta}(z_{\theta}) = \frac{q_{\theta}(z_{\theta})}{\|q_{\theta}(z_{\theta})\|2}, \quad z'{\xi} = \frac{z'{\xi}}{\|z'{\xi}\|_2}. </math>qθ(zθ)=∥qθ(zθ)∥2qθ(zθ),zξ′=∥zξ′∥2zξ′.
注意,该预测器只应用于在线分支,使得架构在在线和目标流程之间不对称。最后,我们定义归一化预测与目标投影之间的均方误差:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L θ , ξ = ∥ q θ ( z θ ) − z ξ ′ ∥ 2 2 = 2 − 2 ⋅ ⟨ q θ ( z θ ) , z ξ ′ ⟩ ∥ q θ ( z θ ) ∥ 2 ⋅ ∥ z ξ ′ ∥ 2 . L_{\theta, \xi} = \| q_{\theta}(z_{\theta}) - z'{\xi} \|^2_2 = 2 - 2 \cdot \frac{ \langle q{\theta}(z_{\theta}), z'{\xi} \rangle}{ \|q{\theta}(z_{\theta})\|2 \cdot \|z'{\xi}\|_2 }. </math>Lθ,ξ=∥qθ(zθ)−zξ′∥22=2−2⋅∥qθ(zθ)∥2⋅∥zξ′∥2⟨qθ(zθ),zξ′⟩.
我们通过将 <math xmlns="http://www.w3.org/1998/Math/MathML"> v ′ v' </math>v′ 输入在线网络并将 <math xmlns="http://www.w3.org/1998/Math/MathML"> v v </math>v 输入目标网络来对损失函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> L θ , ξ L_{\theta, \xi} </math>Lθ,ξ 进行对称化,从而计算出 <math xmlns="http://www.w3.org/1998/Math/MathML"> L ~ θ , ξ \tilde{L}{\theta, \xi} </math>L~θ,ξ。在每个训练步骤中,我们执行一次随机优化步骤,以最小化损失 <math xmlns="http://www.w3.org/1998/Math/MathML"> L BYOL = L θ , ξ + L ~ θ , ξ L{\text{BYOL}} = L_{\theta, \xi} + \tilde{L}{\theta, \xi} </math>LBYOL=Lθ,ξ+L~θ,ξ,仅对 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 进行优化,但不对 <math xmlns="http://www.w3.org/1998/Math/MathML"> ξ \xi </math>ξ 进行优化,正如图2中的"停止梯度"所示。BYOL的动态总结为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ ← optimizer ( θ , ∇ θ L BYOL , η ) , \theta \leftarrow \text{optimizer} \left( \theta, \nabla{\theta}L_{\text{BYOL}}, \eta \right), </math>θ←optimizer(θ,∇θLBYOL,η),
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ξ ← τ ξ + ( 1 − τ ) θ , \xi \leftarrow \tau \xi + (1 - \tau) \theta, </math>ξ←τξ+(1−τ)θ,
其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> optimizer \text{optimizer} </math>optimizer 是一个优化器, <math xmlns="http://www.w3.org/1998/Math/MathML"> η \eta </math>η 是学习率。训练结束时,我们只保留编码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ f_{\theta} </math>fθ;如同在文献中提到的那样,当与其他方法进行比较时,我们仅考虑最终表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ f_{\theta} </math>fθ 中的推理时权重数。
当与其他方法进行比较时,如同在 [9] 中提到的那样,我们仅考虑推理时使用的最终表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ f_{\theta} </math>fθ 的权重数量。架构、超参数和训练细节在附录 A 中进行了说明,完整的训练过程在附录 B 中总结,基于 JAX [64] 和 Haiku [65] 的 Python 伪代码则在附录 J 中提供。
3.2 关于 BYOL 行为的直觉
由于 BYOL 在最小化 <math xmlns="http://www.w3.org/1998/Math/MathML"> L θ , ξ L_{\theta,\xi} </math>Lθ,ξ 时没有使用显式的术语来防止塌陷(例如,负样本 [10]),因此看起来 BYOL 应该收敛到 <math xmlns="http://www.w3.org/1998/Math/MathML"> L θ , ξ L_{\theta,\xi} </math>Lθ,ξ 的某个最小值(例如,一个塌陷的常量表示)。然而,BYOL 的目标参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> ξ \xi </math>ξ 的更新方向并不是 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ ξ L θ , ξ \nabla_\xi L_{\theta,\xi} </math>∇ξLθ,ξ。更一般地,我们假设不存在一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> L θ , ξ L_{\theta,\xi} </math>Lθ,ξ,使得 BYOL 的动态在 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ξ \xi </math>ξ 上是联合的梯度下降。这类似于 GANs [66],在 GAN 中并不存在一个在判别器和生成器参数上联合最小化的损失。因此,没有理由认为 BYOL 的参数会收敛到 <math xmlns="http://www.w3.org/1998/Math/MathML"> L θ , ξ L_{\theta,\xi} </math>Lθ,ξ 的最小值。
尽管 BYOL 的动态仍然允许不理想的平衡点,但在我们的实验中我们没有观察到收敛到这种平衡点。此外,当假设 BYOL 的预测器是最优的,即
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q θ = q ∗ 其中 q ∗ = arg min q E [ ∥ q ( z θ ) − z ξ ′ ∥ 2 2 ] , q ∗ ( z θ ) = E [ z ξ ′ ∣ z θ ] , q_{\theta} = q^{*} \quad \text{其中} \quad q^{*} = \arg\min_q \mathbb{E}[\|q(z_{\theta}) - z'{\xi}\|^2_2], \quad q^{*}(z{\theta}) = \mathbb{E}[z'{\xi} | z{\theta}], </math>qθ=q∗其中q∗=argqminE[∥q(zθ)−zξ′∥22],q∗(zθ)=E[zξ′∣zθ],
我们假设这些不理想的平衡点是不稳定的。确实,在这种最优预测器的情况下,BYOL 对 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 的更新期望是遵循期望条件方差的梯度(详见附录 I),我们将 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ξ , i ′ z'{\xi,i} </math>zξ,i′ 记作 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ξ ′ z'{\xi} </math>zξ′ 的第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个特征:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∇ θ E [ ∥ q ∗ ( z θ ) − z ξ ′ ∥ 2 2 ] = ∇ θ E [ ∥ E [ z ξ ′ ∣ z θ ] − z ξ ′ ∥ 2 2 ] = ∇ θ E [ ∑ i Var ( z ξ , i ′ ∣ z θ ) ] , \nabla_{\theta} \mathbb{E}[\|q^{*}(z_{\theta}) - z'{\xi}\|^2_2] = \nabla{\theta} \mathbb{E}[\|\mathbb{E}[z'{\xi} | z{\theta}] - z'{\xi}\|^2_2] = \nabla{\theta} \mathbb{E}\left[\sum_i \text{Var}(z'{\xi,i} | z{\theta})\right], </math>∇θE[∥q∗(zθ)−zξ′∥22]=∇θE[∥E[zξ′∣zθ]−zξ′∥22]=∇θE[i∑Var(zξ,i′∣zθ)],
注意,对于任何随机变量 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X、 <math xmlns="http://www.w3.org/1998/Math/MathML"> Y Y </math>Y 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> Z Z </math>Z,有 <math xmlns="http://www.w3.org/1998/Math/MathML"> Var ( X ∣ Y , Z ) ≤ Var ( X ∣ Y ) \text{Var}(X | Y, Z) \leq \text{Var}(X | Y) </math>Var(X∣Y,Z)≤Var(X∣Y)。令 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 为目标投影, <math xmlns="http://www.w3.org/1998/Math/MathML"> Y Y </math>Y 为当前的在线投影, <math xmlns="http://www.w3.org/1998/Math/MathML"> Z Z </math>Z 为训练动态中由于随机性引入的额外变异:简单地丢弃来自在线投影的信息不能减少条件方差。
特别地,BYOL 避免了 <math xmlns="http://www.w3.org/1998/Math/MathML"> z θ z_{\theta} </math>zθ 中的常量特征,因为对于任何常量 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c 和随机变量 <math xmlns="http://www.w3.org/1998/Math/MathML"> z θ z_{\theta} </math>zθ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ξ ′ z'{\xi} </math>zξ′, <math xmlns="http://www.w3.org/1998/Math/MathML"> Var ( z ξ ′ ∣ z θ ) ≤ Var ( z ξ ′ ∣ c ) \text{Var}(z'{\xi} | z_{\theta}) \leq \text{Var}(z'{\xi} | c) </math>Var(zξ′∣zθ)≤Var(zξ′∣c);因此我们假设这些塌陷的常量平衡点是不稳定的。有趣的是,如果我们最小化 <math xmlns="http://www.w3.org/1998/Math/MathML"> E [ ∑ i Var ( z ξ , i ′ ∣ z θ ) ] \mathbb{E}[\sum_i \text{Var}(z'{\xi,i} | z_{\theta})] </math>E[∑iVar(zξ,i′∣zθ)],那么我们将得到一个塌陷的 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ξ ′ z'{\xi} </math>zξ′,因为方差在常量 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ξ ′ z'{\xi} </math>zξ′ 时被最小化。然而,BYOL 通过使 <math xmlns="http://www.w3.org/1998/Math/MathML"> ξ \xi </math>ξ 更接近 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 来避免这一问题,从而将在线投影中捕获的变异引入到目标投影中。
此外,注意到在线参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 的硬复制到目标参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> ξ \xi </math>ξ 将足以传播新的变异源。然而,目标网络中的突然变化可能会打破最优预测器的假设,在这种情况下,BYOL 的损失可能不会接近条件方差。我们假设 BYOL 的移动平均目标网络的主要作用是确保在训练过程中预测器的近似最优性;第 5 节和附录 J 提供了对这一解释的一些经验支持。
4. 实验评估
我们在ImageNet ILSVRC-2012数据集的训练集上进行了自监督预训练后,评估了BYOL的表示性能。我们首先在ImageNet上进行线性评估和半监督评估。接着,我们测量了其在其他数据集和任务上的迁移能力,包括分类、分割、目标检测和深度估计任务。为了比较,我们还报告了使用ImageNet子集标签进行监督训练的表示性能,称为Supervised-IN。在附录F中,我们通过在Places365-Standard数据集上预训练表示来评估BYOL的通用性,然后复现这一评估协议。
ImageNet上的线性评估
我们首先通过在冻结表示上训练线性分类器来评估BYOL的表示,按照[48, 74, 41, 10, 8]中描述的程序和附录D.1中的说明进行操作;我们在测试集上报告top-1和top-5的准确率(百分比),见表1。使用标准的ResNet-50(×1),BYOL获得了74.3%的top-1准确率(91.6%的top-5准确率),这比之前的自监督技术[12]提高了1.3%(top-1)和0.5%(top-5)。这缩小了与监督基准的差距[8],其为76.5%,但仍显著低于更强的监督基准[75],其为78.9%。使用更深、更宽的架构时,BYOL持续超越先前的自监督技术(附录D.2),并取得了79.6%的top-1准确率,这一成绩高于之前的自监督技术。使用ResNet-50(4×),BYOL达到了78.6%,与[8]中相同架构的最佳监督基准(78.9%)相当。
ImageNet上的半监督训练
接下来,我们评估BYOL的表示在使用少量ImageNet训练集标签的分类任务中的表现,此时使用标签信息。我们遵循[74, 76, 8, 32]中详细描述的半监督协议,使用与[8]相同的1%和10% ImageNet标签训练数据的固定分割。我们在测试集上分别报告top-1和top-5准确率,见表2。BYOL在一系列架构中始终优于之前的方法。此外,如附录D.1中详细描述的,BYOL在使用ImageNet的100%标签进行微调时,达到77.7%的top-1准确率。
转移到其他分类任务
为了评估我们表示的泛化能力,我们在其他分类数据集上评估表示,衡量其是否是通用的并且能够在其他图像领域中有效,或者其仅在ImageNet中表现出色。我们在[8, 74]中使用的相同分类任务集上进行线性评估和微调,严格遵循他们的评估协议,详见附录E。我们在每个基准任务中使用标准的评估指标,并在验证集上进行超参数选择后,报告在保留的测试集上的结果。我们在表3中同时报告了线性评估和微调的结果。BYOL在所有基准上均优于SimCLR,并且在12个基准中的7个上超越了Supervised-IN基准,在剩下的5个基准上表现略微逊色。BYOL的表示能够迁移到小尺寸图像(例如CIFAR[78])、风景图像(例如SUN397[79])或纹理图像(例如DTD[81])中。
转移到其他视觉任务
我们在其他与计算机视觉从业者相关的任务上评估我们的表示,包括语义分割、目标检测和深度估计。通过这些评估,我们测试了BYOL的表示是否能够在超越分类任务的情境下泛化。
我们首先在VOC2012语义分割任务上评估BYOL,详见附录E.4。该任务的目标是对图像中的每个像素进行分类[7]。我们在表4a中报告了结果。BYOL在平均交并比(mIoU)指标上超越了Supervised-IN基准(+1.9 mIoU)和SimCLR(+1.1 mIoU)。
类似地,我们复现[9]中的目标检测设置,在Faster R-CNN架构[82]下进行评估,详见附录E.5。我们在trainval2007上进行微调,并使用标准的AP50指标在test2007上报告结果;BYOL在AP50上显著优于Supervised-IN基准(+3.1 AP50)和SimCLR(+2.3 AP50)。
最后,我们在NYU v2数据集上评估深度估计,目标是根据单张RGB图像估计场景的深度图。深度预测衡量了网络表示几何信息的能力,以及这些信息是否能够在像素精度上进行本地化[40]。该设置基于[83],详见附录E.6。我们在654张常用测试子集上进行评估,并使用几个常见的指标报告结果,见表4b:相对(rel)误差、均方根(rms)误差,以及像素百分比(pct),其中误差 <math xmlns="http://www.w3.org/1998/Math/MathML"> max ( d gt d p , d p d gt ) \max(\frac{d_{\text{gt}}}{d_{\text{p}}}, \frac{d_{\text{p}}}{d_{\text{gt}}}) </math>max(dpdgt,dgtdp) 小于1.25的阈值。BYOL在每个指标上都比其他方法更好或相当。例如,具有挑战性的pct.<1.25指标相比监督基准和SimCLR分别提高了+3.5和+1.3个百分点。
5. 通过消融实验构建直觉
我们对BYOL进行了消融实验,以帮助理解其行为和性能。为了确保可复现性,我们为每组参数配置运行三次实验,并报告平均性能。当最佳和最差运行之间的差异超过0.25时,我们还报告该差异。虽然之前的工作在100个epoch上进行消融实验[8, 12],但我们注意到在100个epoch时的相对改进在更长时间的训练中并不总是成立。因此,我们在64个TPU v3核心上运行了300个epoch的消融实验,这与我们1000个epoch的基线训练结果一致。在本节的所有实验中,我们将初始学习率设置为0.3,批次大小为4096,权重衰减设置为 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 0 − 6 10^{-6} </math>10−6,与SimCLR[8]一致,基本目标衰减率 <math xmlns="http://www.w3.org/1998/Math/MathML"> τ base \tau_{\text{base}} </math>τbase设置为0.99。在本节中,我们报告了ImageNet线性评估协议下的top-1准确率,详见附录D.1。
批次大小
在对比方法中,负样本从小批次中抽取的那些方法在批次大小减小时性能会下降。BYOL不使用负样本,因此我们预计它对较小批次大小更为鲁棒。为了经验验证这一假设,我们使用不同的批次大小(从128到4096)训练BYOL和SimCLR。为了避免重新调整其他超参数,我们在批次大小减少N倍时,在更新在线网络之前先对连续N步的梯度进行平均。在线网络每N步更新一次,在线网络更新后再更新目标网络;我们在运行中并行累积这些N步。如图3a所示,SimCLR的性能随着批次大小的减少迅速下降,可能是由于负样本数量的减少。相比之下,BYOL的性能在批次大小从256到4096的范围内保持稳定,只有在更小的值下才会下降,这是因为编码器中的批归一化层。
图像增强
对比方法对图像增强的选择非常敏感。例如,当从SimCLR的图像增强集中去除颜色失真时,它的表现会明显下降。作为解释,SimCLR表明,同一图像的裁剪通常共享其颜色直方图。同时,不同图像的颜色直方图变化较大。因此,当对比任务仅依赖随机裁剪作为图像增强时,可以通过仅关注颜色直方图来解决该任务。因此,表示没有被鼓励保留超出颜色直方图之外的信息。为了防止这种情况,SimCLR将颜色失真添加到了其图像增强集中。而BYOL则被鼓励将目标表示中捕获的任何信息保留在其在线网络中,以改进其预测。因此,即使同一图像的增强视图共享相同的颜色直方图,BYOL仍然被鼓励在其表示中保留额外的特征。因此,我们认为BYOL在图像增强选择上比对比方法更为鲁棒。
图3b中的结果支持了这一假设:当从图像增强集中去除颜色失真时,BYOL的性能比SimCLR的性能下降幅度更小(BYOL下降9.1个百分点,SimCLR下降22.2个百分点)。当图像增强减少到仅随机裁剪时,BYOL仍然表现出良好的性能(59.4%,即从72.5%下降13.1个百分点),而SimCLR失去了三分之一以上的性能(40.3%,即从67.9%下降27.6个百分点)。我们在附录G.3中报告了更多的消融实验。
自举
BYOL使用目标网络的投影表示,其权重是在线网络权重的指数移动平均值,作为预测的目标。目标网络的权重表示了在线网络权重的一个延迟且更加稳定的版本。当目标衰减率为1时,目标网络从不更新,并保持其初始化的固定值。当目标衰减率为0时,目标网络在每一步都被即时更新为在线网络。在更新目标过于频繁和更新过于缓慢之间存在权衡,如表5a所示。即时更新目标网络( <math xmlns="http://www.w3.org/1998/Math/MathML"> τ = 0 \tau = 0 </math>τ=0)会使训练不稳定,导致非常差的性能,而从不更新目标网络( <math xmlns="http://www.w3.org/1998/Math/MathML"> τ = 1 \tau = 1 </math>τ=1)则使得训练稳定,但阻止了迭代改进,导致最终得到的表示质量较低。在300个epoch的实验中,所有衰减率在0.9和0.999之间的值都能实现超过68.4%的top-1准确率。
与对比方法的消融
在本小节中,我们通过相同的形式化方法重新表述SimCLR和BYOL,以便更好地理解BYOL相对于SimCLR的改进来自于哪里。让我们考虑以下扩展了InfoNCE目标的公式 [10, 84](见附录G.4):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> InfoNCE α , β θ = 2 B ∑ i = 1 B S θ ( v i , v i ′ ) − β ⋅ 2 α B ∑ i = 1 B log ( ∑ j ≠ i exp ( S θ ( v i , v j ) α ) + ∑ j exp ( S θ ( v i , v j ′ ) α ) ) , \text{InfoNCE}{\alpha, \beta}^{\theta} = \frac{2}{B} \sum{i=1}^{B} S_{\theta}(v_i, v'i) - \beta \cdot \frac{2\alpha}{B} \sum{i=1}^{B} \log \left( \sum_{j \neq i} \exp \left( \frac{S_{\theta}(v_i, v_j)}{\alpha} \right) + \sum_{j} \exp \left( \frac{S_{\theta}(v_i, v'_j)}{\alpha} \right) \right), </math>InfoNCEα,βθ=B2i=1∑BSθ(vi,vi′)−β⋅B2αi=1∑Blog j=i∑exp(αSθ(vi,vj))+j∑exp(αSθ(vi,vj′)) ,
其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> α > 0 \alpha > 0 </math>α>0 是一个固定温度, <math xmlns="http://www.w3.org/1998/Math/MathML"> β ∈ [ 0 , 1 ] \beta \in [0, 1] </math>β∈[0,1] 是一个权重系数, <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B 是批次大小, <math xmlns="http://www.w3.org/1998/Math/MathML"> v v </math>v 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> v ′ v' </math>v′ 是增强视图的批次,对于每个批次索引 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i, <math xmlns="http://www.w3.org/1998/Math/MathML"> v i v_i </math>vi 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> v i ′ v'i </math>vi′ 是来自同一图像的增强视图;实数值函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> S θ S{\theta} </math>Sθ 量化了增强视图之间的成对相似性。对于任何增强视图 <math xmlns="http://www.w3.org/1998/Math/MathML"> u u </math>u,我们定义 <math xmlns="http://www.w3.org/1998/Math/MathML"> z θ ( u ) = f θ ( g θ ( u ) ) z_{\theta}(u) = f_{\theta}(g_{\theta}(u)) </math>zθ(u)=fθ(gθ(u)) 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ξ ( u ) = f ξ ( g ξ ( u ) ) z_{\xi}(u) = f_{\xi}(g_{\xi}(u)) </math>zξ(u)=fξ(gξ(u))。给定 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϕ \phi </math>ϕ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ψ \psi </math>ψ,我们考虑归一化的点积:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> S θ ( u 1 , u 2 ) = ⟨ ϕ ( u 1 ) , ψ ( u 2 ) ⟩ ∥ ϕ ( u 1 ) ∥ 2 ⋅ ∥ ψ ( u 2 ) ∥ 2 . S_{\theta}(u_1, u_2) = \frac{\langle \phi(u_1), \psi(u_2) \rangle}{\|\phi(u_1)\|_2 \cdot \|\psi(u_2)\|_2}. </math>Sθ(u1,u2)=∥ϕ(u1)∥2⋅∥ψ(u2)∥2⟨ϕ(u1),ψ(u2)⟩.
我们使用 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϕ ( u 1 ) = z θ ( u 1 ) \phi(u_1) = z_{\theta}(u_1) </math>ϕ(u1)=zθ(u1)(无预测器)和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ψ ( u 2 ) = z θ ( u 2 ) \psi(u_2) = z_{\theta}(u_2) </math>ψ(u2)=zθ(u2)(无目标网络)且 <math xmlns="http://www.w3.org/1998/Math/MathML"> β = 1 \beta = 1 </math>β=1 时恢复了SimCLR损失。当使用一个预测器和一个目标网络时,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϕ ( u 1 ) = p θ ( z θ ( u 1 ) ) \phi(u_1) = p_{\theta}(z_{\theta}(u_1)) </math>ϕ(u1)=pθ(zθ(u1)) 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ψ ( u 2 ) = z ξ ( u 2 ) \psi(u_2) = z_{\xi}(u_2) </math>ψ(u2)=zξ(u2),并且 <math xmlns="http://www.w3.org/1998/Math/MathML"> β = 0 \beta = 0 </math>β=0 时,我们恢复了BYOL的损失。为了评估目标网络、预测器和系数 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β 的影响,我们对它们进行了消融实验。结果见表5b,附录G.4中提供了更多细节。唯一在没有负样本(即 <math xmlns="http://www.w3.org/1998/Math/MathML"> β = 0 \beta = 0 </math>β=0)的情况下仍然表现良好的变体是BYOL,它同时使用了一个自举目标网络和一个预测器。将负样本添加到BYOL的损失中而不重新调整温度参数会损害其性能。在附录G.4中,我们展示了在适当调整温度的情况下,可以重新添加负样本,并且仍然与BYOL的性能相匹配。
简单地将目标网络添加到SimCLR已经提升了性能(+1.6点)。这为MoCo [9] 中目标网络的使用提供了新的理解,在MoCo中,目标网络用于提供更多的负样本。这里,我们展示了即使在使用相同数量的负样本时,仅通过稳定的效果,使用目标网络也是有益的。最后,我们观察到在SimCLR的 <math xmlns="http://www.w3.org/1998/Math/MathML"> S θ S_{\theta} </math>Sθ架构中添加预测器仅对性能有轻微的影响。
与"Mean Teacher"的关系
另一种半监督方法,Mean Teacher (MT) [20],将一个监督损失与一个额外的一致性损失相结合。在[20]中,这个一致性损失是学生网络的logits与其时间平均版本的教师网络logits之间的 <math xmlns="http://www.w3.org/1998/Math/MathML"> L 2 L_2 </math>L2距离。去除BYOL中的预测器,结果是没有分类损失的MT的无监督版本,使用图像增强代替了原来的架构噪声(例如,dropout)。这一BYOL变体会出现塌陷(见表5b的第7行),这表明附加的预测器对于防止无监督场景中的塌陷是至关重要的。
接近最优预测器的重要性
表5b已经表明了结合预测器和目标网络的重要性:在去除两者时,表示会塌陷。我们进一步发现,通过使预测器接近最优,可以在不出现塌陷的情况下去除目标网络,无论是通过 (i) 使用一个最优线性预测器(通过线性回归当前批次的数据得到),然后将误差反向传播到网络中(top-1准确率为52.5%),还是通过 (ii) 提高预测器的学习率(top-1准确率为66.5%)。相比之下,如果同时提高投影器和预测器的学习率(不使用目标网络),则会产生较差的结果(约25%的top-1准确率)。有关更多细节,请参见附录J。这似乎表明,确保预测器始终接近最优对于防止塌陷至关重要,而这可能是BYOL的目标网络的一个作用。
6. 结论
我们提出了BYOL,一种新的自监督图像表示学习算法。BYOL通过预测其输出的先前版本来学习其表示,而不使用负样本。我们展示了BYOL在多个基准测试中的性能优越性。特别是在ImageNet上的线性评估协议下,使用ResNet-50 (1×),BYOL取得了新的最优成绩,并缩小了自监督方法与监督学习基准之间的差距。使用ResNet-200 (2×),BYOL取得了79.6%的top-1准确率,较先前的最优成绩(76.8%)提高,同时使用了30%更少的参数。
然而,BYOL仍然依赖于现有的适用于视觉应用的增强集。要将BYOL推广到其他模态(例如,音频、视频、文本等),需要为每个模态找到相应的合适增强方法。设计这些增强可能需要大量的精力和专业知识。因此,自动化这些增强的搜索将是使BYOL推广到其他模态的一个重要步骤。
更广泛的影响
本文所展示的研究属于无监督学习领域。此项工作可能会激发新的算法、理论研究和实验研究。本文提出的算法可以应用于多种视觉应用,其具体应用可能具有积极或消极的影响,这就是所谓的双重用途问题。此外,由于视觉数据集可能存在偏见,BYOL学习的表示可能容易复制这些偏见。
致谢
作者感谢以下人员在撰写本文过程中所提供的帮助,按字母顺序排列:Aaron van den Oord,Andrew Brock,Jason Ramapuram,Jeffrey De Fauw,Karen Simonyan,Katrina McKinney,Nathalie Beauguerlange,Olivier Henaff,Oriol Vinyals,Pauline Luc,Razvan Pascanu,Sander Dieleman,以及DeepMind团队。特别感谢Jason Ramapuram和Jeffrey De Fauw,他们提供了本文中使用的JAX SimCLR复现。