【论文笔记】Dual-Balancing for Multi-Task Learning

Abstract

多任务学习(Multi-task learning, MTL)中,任务平衡问题仍然是重要的挑战,损失、梯度尺度的不同,会导致性能的折中。

本文提出Dual-Balancing for Multi-Task Learning (DB-MTL),从损失和梯度两个角度缓解任务均衡问题。

DB-MTL通过对每个任务的损失进行对数变换,保证损失-尺度平衡,通过将所有任务梯度归一化到与最大梯度范数相同的幅度来保证梯度-幅度平衡。

1 Introduction

很多方法被提出,用来动态调整任务权重,可以粗略划分为损失平衡方法、梯度平衡方法。

本文重点关注同时平衡损失级别的损失尺度和梯度级别的梯度幅度,以减轻任务平衡问题。

不同任务的损失尺度和梯度幅度不同,大的一方可以左右模型更新的方向,导致部分任务的表现下降。

本文提出的DB-MTL简单且有效第平衡损失尺度和梯度规模。

① 对每个任务的损失施加对数变换 (logarithm transformation),保证所有任务的损失都有相同的尺度,是非参数的变换。对数变换有利于现有的梯度平衡方法,如图1所示。

② 将所有任务的梯度标准化成与最大梯度范数相同的幅度,这是免训练的,与GradNorm相比,所有梯度的大小都相同。归一化梯度大小对性能起着重要的作用,将其设置为任务中最大梯度范数效果最好,如图4所示。

总结贡献:

  • 提出DB-MTL方法,缓解任务平衡问题的双重平衡方法,包含损失尺度和梯度幅度平衡方法。
  • 大量实验证明DB-MTL在多个基准数据集上实现最先进的性能。
  • 实验结果表明,损失尺度平衡方法有利于现有的梯度平衡方法。

给定 T T T个任务和每个任务 t t t的训练集 D t \mathcal{D}t Dt,MTL的目标是在 { D t } t = 1 T \{\mathcal{D}t\}{t=1}^T {Dt}t=1T训练一个模型。MTL模型的参数包括两个部分:任务共享参数 θ \theta θ和任务独享参数 { ψ t } t = 1 T \{\psi_t\}{t=1}^T {ψt}t=1T。 θ \theta θ占据MTL模型的大部分参数,这对性能至关重要。

令 ℓ t ( D t ; θ , ψ t ) \ell_t(\mathcal{D}_t;\theta,\psi_t) ℓt(Dt;θ,ψt)为 ( θ , ψ t ) (\theta,\psi_t) (θ,ψt)下任务 t t t在数据集 D t \mathcal{D}t Dt的平均损失。目标函数可表示为:
∑ t = 1 T γ t ℓ t ( D t ; θ , ψ t ) (1) \sum
{t=1}^T \gamma_t\ell_t(\mathcal{D}_t;\theta,\psi_t)\tag{1} t=1∑Tγtℓt(Dt;θ,ψt)(1)

其中 γ t \gamma_t γt是任务 t t t的任务权重。

Equal weighting (EW)是一种简单的MTL方法,设置所有任务 γ t = 1 \gamma_t=1 γt=1。但是EW会导致任务平衡问题,即某些任务执行不理想。因此后续提出了很多MTL方法,通过在训练过程中动态调整任务权重 { γ t } t = 1 T \{\gamma_t\}_{t=1}^T {γt}t=1T,来提高EW的性能。这可以归类为损失平衡(loss balancing)、梯度平衡(gradient balancing)、混合平衡(hybrid balancing)。

2.1 Loss Balancing Methods

这一类方法,通过不同的衡量方法动态地计算任务权重 { γ t } t = 1 T \{\gamma_t\}_{t=1}^T {γt}t=1T,如同方差不确定性 (homoscedastic uncertainty),学习速率 (learning speed),验证集性能 (validation performance)。不同于上述方法,IMTL-L期望所有任务的权重损失 { γ t ℓ t ( D t ; θ , ψ t ) } t = 1 T \{\gamma_t\ell_t(\mathcal{D}t;\theta,\psi_t)\}{t=1}^T {γtℓt(Dt;θ,ψt)}t=1T是常值,对每一个损失实施变换 e s t ℓ t ( D t ; θ , ψ t ) − s t e^{s_t}\ell_t(\mathcal{D}_t;\theta,\psi_t)-s_t estℓt(Dt;θ,ψt)−st,其中 s t s_t st是第 t t t个任务上的可学习的参数,在每次迭代中通过梯度下降近似求解。

2.2 Gradient Balancing Methods

从梯度的角度,任何共享参数 θ \theta θ的更新取决于所有任务的梯度 { ∇ θ ℓ t ( D t ; θ , ψ t ) } t = 1 T \{\nabla_\theta\ell_t(\mathcal{D}t;\theta,\psi_t)\}{t=1}^T {∇θℓt(Dt;θ,ψt)}t=1T。梯度平衡方法旨在以不同的方式聚合所有任务梯度。例如MGDA将MTL表述为多目标优化问题,避免某些任务的梯度主导更新方向,目标是找到一个更新方向 d d d,使得所有任务的梯度在该方向上的投影尽可能小;CAGrad通过将聚合梯度约束到平均梯度周围来优化MGDA;MoCo通过引入类动量梯度估计和正则化项来缓解MGDA中的偏差;GradNorm通过学习任务权重来衡量任务梯度,使其具有接近的数量级;如果两个任务的梯度冲突,PCGrad会将一个任务的法平面投影到另一个任务的法平面;无论两个任务是否发生梯度冲突,GradVac都会将梯度对齐;GradDrop随机掩盖一些符号不一致的梯度值,IMTL-G学习任务权重,以确保聚合梯度在每个任务梯度上具有相等的投影;Nash-MTL将聚合梯度制定为纳什均衡。

2.3 Hybrid Balancing Method

Towards impartial multi-task learning, Liu et al.中,发现损失平衡和梯度平衡具有互补性,提出了IMTL-L和IMTL-G相结合的混合平衡法IMTL。

3 Proposed Method

3.1 Scale-Balancing Loss Transformation

不同类型的损失函数会带来不同的损失尺度,导致任务平衡问题。

假设损失规模的先验已知,可以选择 { s t ⋆ } t = 1 T \{s_t^\star\}{t=1}^T {st⋆}t=1T,使得 { s t ⋆ ℓ t ( D ; θ , ψ t ) } t = 1 T \{s_t^\star\ell_t(\mathcal{D};\theta,\psi_t)\}{t=1}^T {st⋆ℓt(D;θ,ψt)}t=1T具有相同的尺度,并最小化 ∑ t = 1 T s t ⋆ ℓ t ( D ; θ , ψ t ) \sum_{t=1}^T s_t^\star\ell_t(\mathcal{D};\theta,\psi_t) ∑t=1Tst⋆ℓt(D;θ,ψt)。之前的工作在学习任务权重 { γ t } t = 1 T \{\gamma_t\}{t=1}^T {γt}t=1T时直接学习 { s t ⋆ } t = 1 T \{s_t^\star\}{t=1}^T {st⋆}t=1T,但由于训练过程无法获得最优的 { s t ⋆ } t = 1 T \{s_t^\star\}_{t=1}^T {st⋆}t=1T,这种方法会导致结果不是最优。

对数变换(Logarithmic transformation)可以实现所有损失达到相同的尺度,而不需要 { s t ⋆ } t = 1 T \{s_t^\star\}_{t=1}^T {st⋆}t=1T。

由于 ∇ θ , ψ t log ⁡ ℓ t ( D ; θ , ψ t ) = ∇ θ , ψ t ℓ t ( D ; θ , ψ t ) ℓ t ( D ; θ , ψ t ) \nabla_{\theta,\psi_t}\log\ell_t(\mathcal{D};\theta,\psi_t)=\frac{\nabla_{\theta,\psi_t}\ell_t(\mathcal{D};\theta,\psi_t)}{\ell_t(\mathcal{D};\theta,\psi_t)} ∇θ,ψtlogℓt(D;θ,ψt)=ℓt(D;θ,ψt)∇θ,ψtℓt(D;θ,ψt)(普通的对log求导),这等价于对调整了尺度的任务损失 ℓ t ( D ; θ , ψ t ) stop-gradient ( ℓ t ( D ; θ , ψ t ) ) \frac{\ell_t(\mathcal{D};\theta,\psi_t)}{\text{stop-gradient}(\ell_t(\mathcal{D};\theta,\psi_t))} stop-gradient(ℓt(D;θ,ψt))ℓt(D;θ,ψt)取梯度,这一项对于任意的任务都有相同的尺度。

Discussion

尽管对数变换可以轻易实现尺度平衡,在MTL中使用得很少。本文对其进行深入研究,并将其集成到现有的梯度平衡方法(PCGrad、GradVac、IMTL-G、CAGrad、Nash-MTL、Aligned-MTL),大幅提升了他们的性能,如图1所示,证明了在MTL中平衡损失规模的有效性。

图1:现有梯度平衡方法+损失尺度平衡方法在NYUv2上的表现。

IMTL-L用一个转换后的损失来处理损失尺度问题: e s t ℓ t ( D t ; θ , ψ t ) − s t e^{s_t}\ell_t(\mathcal{D}_t;\theta,\psi_t)-s_t estℓt(Dt;θ,ψt)−st,其中 s t s_t st是第 t t t个任务上的可学习的参数,在每次迭代中通过梯度下降近似求解。这不能保证每次迭代中所有的损失尺度是相同的,而对数变换却可以。

3.2 Magnitude-Balancing Gradient Normalization

除了任务损失中的尺度问题,任务梯度也存在尺度问题。通过均匀平均所有可能主导最终梯度的任务的更新方向,导致次优的性能。

一个简单的方法是将任务梯度归一化到相同的幅度。

对于任务的梯度,计算一个batch的梯度 ∇ θ log ⁡ ℓ t ( D t ; θ , ψ t ) \nabla_\theta \log\ell_t(\mathcal{D}t;\theta,\psi_t) ∇θlogℓt(Dt;θ,ψt)的计算开销很大,通常使用小批量随机梯度下降方法。在第 k k k次迭代中,从 D t \mathcal{D}t Dt中采样一个小批量 B t , k \mathcal{B}{t,k} Bt,k(Algorithm 1中的第5步),计算这个小批量的梯度 g t , k = ∇ θ k log ⁡ ℓ t ( B t , k ; θ k , ψ t , k ) g{t,k}=\nabla_{\theta_k} \log\ell_t(\mathcal{B}{t,k};\theta_k,\psi{t,k}) gt,k=∇θklogℓt(Bt,k;θk,ψt,k)(Algorithm 1中的第6步)。在动态估计 E B t , k ∼ D t ∇ θ k log ⁡ ℓ t ( B t , k ; θ k , ψ t , k ) \mathbb{E}{\mathcal{B}{t,k}\sim\mathcal{D}t}\nabla{\theta_k} \log\ell_t(\mathcal{B}{t,k};\theta_k,\psi{t,k}) EBt,k∼Dt∇θklogℓt(Bt,k;θk,ψt,k)中使用了指数移动平均(Exponential moving average, EMA):
g ^ t , k = β g ^ t , k − 1 + ( 1 − β ) g t , k \hat{g}{t,k}=\beta\hat{g}{t,k-1}+(1-\beta)g_{t,k} g^t,k=βg^t,k−1+(1−β)gt,k

其中 β ∈ ( 0 , 1 ) \beta\in(0,1) β∈(0,1)控制遗忘率。

获得任务梯度 { g ^ t , k } t = 1 T \{\hat{g}{t,k}\}{t=1}^T {g^t,k}t=1T后,进行标准化,使得具有相同的 ℓ 2 \ell_2 ℓ2范数,计算聚合梯度为:
g ~ k = α k ∑ t = 1 T g ^ t , k ∣ ∣ g ^ t , k ∣ ∣ 2 (2) \tilde{g}k=\alpha_k\sum{t=1}^T\frac{\hat{g}{t,k}}{||\hat{g}{t,k}||_2}\tag{2} g~k=αkt=1∑T∣∣g^t,k∣∣2g^t,k(2)

其中 α k \alpha_k αk是控制更新尺度的尺度因子。标准化后,所有任务对更新的方向提供相同的贡献。

α k \alpha_k αk的选择对于缓解任务均衡的问题至关重要。

当某些任务梯度范数较大,其他任务梯度范数较小时,意味着模型 θ k \theta_k θk接近于前者尚未收敛而后者已经收敛的点。这一点在MTL中是不令人满意的,会导致任务平衡问题,因为期望的是所有任务都能实现收敛。因此, α k \alpha_k αk需要足够大来躲避这种不令人满意的点。

当所有任务的梯度范数都很小,表明模型 θ k \theta_k θk对于所有任务都接近令人满意的点了, α k \alpha_k αk应当足够小,使得模型 θ k \theta_k θk可以捕捉到这个最好的点。

因此可选择 α k = max ⁡ 1 ≤ t ≤ T ∣ ∣ g ^ t , k ∣ ∣ 2 \alpha_k=\max_{1\leq t\leq T} ||\hat{g}_{t,k}||_2 αk=max1≤t≤T∣∣g^t,k∣∣2。

图4展示了在NYUv2数据集上,使用不同的策略调整 α k \alpha_k αk的表现区别,实验设置在4.1节。由此可见,最大范数策略表现得更好。

图4:选择 α k \alpha_k αk的不同策略在NYUv2数据集的表现 Δ p \Delta_p Δp。

对损失和梯度进行缩放后,任务共享参数由 θ k + 1 = θ k − η g ~ k \theta_{k+1}=\theta_k-\eta\tilde{g}_k θk+1=θk−ηg~k(Algorithm 1中第10步)来更新。

对于任务独享参数 { ψ t } t = 1 T \{\psi_t\}{t=1}^T {ψt}t=1T,不同任务之间是独立更新的,因此不必进行梯度缩放。因此任务独享参数由 ψ t , k + 1 = ψ t , k − η ∇ ψ t , k log ⁡ ℓ t ( B t , k ; θ k , ψ t , k ) \psi{t,k+1}=\psi_{t,k}-\eta\nabla_{\psi_{t,k}}\log\ell_t(\mathcal{B}{t,k};\theta_k,\psi{t,k}) ψt,k+1=ψt,k−η∇ψt,klogℓt(Bt,k;θk,ψt,k)(Algorithm 1中第11~13步)更新。

相关推荐
春末的南方城市32 分钟前
FLUX的ID保持项目也来了! 字节开源PuLID-FLUX-v0.9.0,开启一致性风格写真新纪元!
人工智能·计算机视觉·stable diffusion·aigc·图像生成
AI完全体1 小时前
【AI知识点】偏差-方差权衡(Bias-Variance Tradeoff)
人工智能·深度学习·神经网络·机器学习·过拟合·模型复杂度·偏差-方差
sp_fyf_20241 小时前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-02
人工智能·神经网络·算法·计算机视觉·语言模型·自然语言处理·数据挖掘
卷心菜小温2 小时前
【BUG】P-tuningv2微调ChatGLM2-6B时所踩的坑
python·深度学习·语言模型·nlp·bug
陈苏同学2 小时前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm
FL16238631293 小时前
[深度学习][python]yolov11+bytetrack+pyqt5实现目标追踪
深度学习·qt·yolo
羊小猪~~3 小时前
深度学习项目----用LSTM模型预测股价(包含LSTM网络简介,代码数据均可下载)
pytorch·python·rnn·深度学习·机器学习·数据分析·lstm
龙的爹23333 小时前
论文 | Model-tuning Via Prompts Makes NLP Models Adversarially Robust
人工智能·gpt·深度学习·语言模型·自然语言处理·prompt
工业机器视觉设计和实现3 小时前
cnn突破四(生成卷积核与固定核对比)
人工智能·深度学习·cnn
醒了就刷牙3 小时前
58 深层循环神经网络_by《李沐:动手学深度学习v2》pytorch版
pytorch·rnn·深度学习