Liu, Bo, et al. Conflict-averse gradient descent for multi-task learning. NeurIPS 2021.
在多任务学习(Multi-task Learning, MTL)中,一个模型同时学习多个任务,期待通过知识共享提升整体效率与性能。然而,现实情况下,不同任务的梯度方向可能南辕北辙,直接按平均梯度更新模型,反而会导致某些任务性能大幅下降。这篇题为《Conflict-Averse Gradient Descent for Multi-task Learning》的论文,提出了一个简洁而有效的解决方案 ------ 冲突避免梯度下降(CAGrad)。
文章目录
-
- 一、问题背景:多任务学习与梯度冲突
-
- [1.1 多任务学习的目标](#1.1 多任务学习的目标)
- [1.2 梯度冲突:为何平均损失优化会失败?](#1.2 梯度冲突:为何平均损失优化会失败?)
- 二、现有方法及其局限
- 三、CAGrad:冲突避免梯度下降
-
- [3.1 核心直觉](#3.1 核心直觉)
- [3.2 形式化定义](#3.2 形式化定义)
- [3.3 如何求解?对偶变换降维](#3.3 如何求解?对偶变换降维)
- [3.4 超参数 c 的角色:平衡的艺术](#3.4 超参数 c 的角色:平衡的艺术)
- [四、理论:CAGrad 保证收敛到平均损失最小点](#四、理论:CAGrad 保证收敛到平均损失最小点)
-
- [4.1 收敛定理](#4.1 收敛定理)
- [4.2 定理含义](#4.2 定理含义)
- 五、实验验证:全方位领先
-
- [5.1 监督学习:视觉多任务数据集](#5.1 监督学习:视觉多任务数据集)
- [5.2 强化学习:Meta-World 机器人操作](#5.2 强化学习:Meta-World 机器人操作)
- [5.3 半监督学习:CIFAR-10 辅助任务](#5.3 半监督学习:CIFAR-10 辅助任务)
- 六、总结与展望
一、问题背景:多任务学习与梯度冲突
1.1 多任务学习的目标
我们假设有 K 个任务(例如 K=3,分别是图像分类、目标检测、语义分割),它们共享一个模型参数 θ。每个任务有自己的损失函数 L i ( θ ) L_i(\theta) Li(θ)。常见的优化目标是最小化所有任务的平均损失:
L 0 ( θ ) = 1 K ∑ i = 1 K L i ( θ ) L_0(\theta) = \frac{1}{K} \sum_{i=1}^{K} L_i(\theta) L0(θ)=K1i=1∑KLi(θ)
1.2 梯度冲突:为何平均损失优化会失败?
设任务 i 的梯度为 g i = ∇ L i ( θ ) g_i = \nabla L_i(\theta) gi=∇Li(θ),平均梯度为 g 0 = 1 K ∑ i g i g_0 = \frac{1}{K}\sum_i g_i g0=K1∑igi。直接按照平均梯度更新参数: θ ← θ − α g 0 \theta \leftarrow \theta - \alpha g_0 θ←θ−αg0,看似合理,但存在两大问题:
- 梯度尺度差异:某个任务的梯度可能特别大,主导更新方向,淹没其他任务。
- 梯度方向冲突 :可能存在任务 i,使得 ⟨ g i , g 0 ⟩ < 0 \langle g_i, g_0 \rangle < 0 ⟨gi,g0⟩<0,即该任务的梯度与平均梯度方向相反。此时,按平均梯度更新反而会恶化该任务的性能。
举例说明:假设任务 A(分类)需要模型关注整体轮廓,梯度方向指向右侧;任务 B(细节分割)需要模型关注局部纹理,梯度方向指向左侧。平均梯度可能指向正前方,对两个任务都没有明显好处,甚至可能让两者都变差。
二、现有方法及其局限
梯度操作的主流方法:
- MGDA(Multiple Gradient Descent Algorithm) :将多任务学习视为多目标优化,寻找帕累托最优解。其更新方向是使得任何任务损失都不增加的最小范数梯度方向。问题:只能保证收敛到帕累托集上的某个点,无法控制具体是哪个点,可能远离平均损失最优点。
- PCGrad(Projecting Conflicting Gradients) :将每个任务的梯度投影到其他梯度的正交方向上,消除冲突分量后再平均。问题:同样只能收敛到任意帕累托点,且计算效率低(需循环处理所有任务对)。
- 梯度重加权(如 GradNorm) :根据任务难度或梯度幅值动态调整损失权重。问题:缺乏理论收敛保证,性能不稳定。
核心局限:现有方法大多 缺乏收敛到平均损失最小点的理论保证,且可能过度偏向某些任务,忽视整体优化目标。
三、CAGrad:冲突避免梯度下降
3.1 核心直觉
CAGrad 希望在平均梯度 g 0 g_0 g0 的附近 寻找一个新的更新方向 d d d,这个方向不仅要降低平均损失,还要确保每个任务的损失都能有所下降(至少不要恶化最差的任务)。换句话说,CAGrad 扮演一个"公平的教练",在确保团队整体进步的同时,特别关照进步最慢的队员。
3.2 形式化定义
每一步,CAGrad 求解如下优化问题:
max d ∈ R m min i ∈ [ K ] ⟨ g i , d ⟩ s.t. ∥ d − g 0 ∥ ≤ c ∥ g 0 ∥ \max_{d \in \mathbb{R}^m} \min_{i \in [K]} \langle g_i, d \rangle \quad \text{s.t.} \quad \|d - g_0\| \leq c \|g_0\| d∈Rmmaxi∈[K]min⟨gi,d⟩s.t.∥d−g0∥≤c∥g0∥
- min i ⟨ g i , d ⟩ \min_i \langle g_i, d \rangle mini⟨gi,d⟩ 表示所有任务中,损失下降最慢的那个任务的"下降量"。我们希望最大化这个最慢的下降量,从而保证所有任务都能受益。
- 约束条件 ∥ d − g 0 ∥ ≤ c ∥ g 0 ∥ \|d - g_0\| \leq c \|g_0\| ∥d−g0∥≤c∥g0∥ 要求更新方向 d d d 不能偏离平均梯度 g 0 g_0 g0 太远,其中 c ∈ [ 0 , 1 ) c \in [0,1) c∈[0,1) 是一个超参数,控制着偏离的程度。这个约束是理论收敛的关键:它确保我们不会完全脱离优化平均损失的主航道。
3.3 如何求解?对偶变换降维
直接优化高维向量 d d d(参数数量可能数百万)计算量大。论文通过拉格朗日对偶 ,将原问题转化为一个仅需优化 K 维权重向量 w w w 的问题:
-
令 W \mathcal{W} W 为概率单纯形(所有权重非负且和为 1)。可以证明:
min i ⟨ g i , d ⟩ = min w ∈ W ⟨ ∑ i w i g i , d ⟩ \min_i \langle g_i, d \rangle = \min_{w \in \mathcal{W}} \langle \sum_{i} w_i g_i, d \rangle imin⟨gi,d⟩=w∈Wmin⟨i∑wigi,d⟩因为线性函数在单纯形上的最小值必然在某个顶点(即某个任务梯度)处取得。
-
记 g w = ∑ i w i g i g_w = \sum_i w_i g_i gw=∑iwigi,原问题的对偶形式为:
min w ∈ W ( ⟨ g w , g 0 ⟩ + c ∥ g 0 ∥ ∥ g w ∥ ) \min_{w \in \mathcal{W}} \left( \langle g_w, g_0 \rangle + c \|g_0\| \|g_w\| \right) w∈Wmin(⟨gw,g0⟩+c∥g0∥∥gw∥)这是一个仅关于 w w w 的凸优化问题,维度为任务数 K(通常很小,如2~50),可用标准凸优化库快速求解。
-
求解得到最优权重 w ∗ w^* w∗ 后,原更新方向 d ∗ d^* d∗ 为:
d ∗ = g 0 + c ∥ g 0 ∥ ∥ g w ∗ ∥ g w ∗ d^* = g_0 + \frac{c \|g_0\|}{\|g_{w^*}\|} g_{w^*} d∗=g0+∥gw∗∥c∥g0∥gw∗直观上看, d ∗ d^* d∗ 是在平均梯度 g 0 g_0 g0 的基础上,沿着加权梯度 g w ∗ g_{w^*} gw∗ 的方向进行了一定程度的修正。
3.4 超参数 c 的角色:平衡的艺术
- c = 0: d ∗ = g 0 d^* = g_0 d∗=g0,CAGrad 退化为普通梯度下降(GD)。
- 0 < c < 1:在"优化平均损失"与"照顾最差任务"之间取得平衡。理论保证成立(见下一节)。
- c → ∞ \infty ∞(实践中取较大值如 10): d ∗ d^* d∗ 趋近于 g w g_w gw 方向,行为类似 MGDA,追求帕累托最优,但可能偏离平均损失最优点。
四、理论:CAGrad 保证收敛到平均损失最小点
4.1 收敛定理
在以下假设下:
- 所有损失函数 L i L_i Li 可微,且梯度是 Lipschitz 连续的(即变化不会太快)。
- 平均损失 L 0 L_0 L0 有下界(通常成立)。
定理 :若步长 α \alpha α 足够小,且超参数满足 0 ≤ c < 1 0 \leq c < 1 0≤c<1,则 CAGrad 生成的序列 { θ t } \{\theta_t\} {θt} 满足:
∑ t = 0 T ∥ ∇ L 0 ( θ t ) ∥ 2 ≤ 2 ( L 0 ( θ 0 ) − L 0 ∗ ) α ( 1 − c 2 ) \sum_{t=0}^{T} \| \nabla L_0(\theta_t) \|^2 \leq \frac{2(L_0(\theta_0) - L_0^*)}{\alpha(1 - c^2)} t=0∑T∥∇L0(θt)∥2≤α(1−c2)2(L0(θ0)−L0∗)
其中 L 0 ∗ L_0^* L0∗ 是平均损失的全局下界。
4.2 定理含义
- 收敛性 :不等式右边是常数,因此当 T → ∞ T \to \infty T→∞ 时, ∥ ∇ L 0 ( θ t ) ∥ → 0 \|\nabla L_0(\theta_t)\| \to 0 ∥∇L0(θt)∥→0,即算法必然收敛到平均损失 L 0 L_0 L0 的一个驻点(最优解或鞍点)。
- 收敛速度:梯度范数平方的累积和有一个明确上界,这提供了收敛速度的保证。
- 与 MGDA / PCGrad 的关键区别 :CAGrad 在 c < 1 c<1 c<1 时保证收敛到平均损失的最优点,而 MGDA / PCGrad 只能收敛到任意的帕累托点,无法控制具体位置。
证明核心思想 :利用 Lipschitz 条件推导单步下降量,再通过约束 ∥ d − g 0 ∥ ≤ c ∥ g 0 ∥ \|d - g_0\| \leq c \|g_0\| ∥d−g0∥≤c∥g0∥ 将下降量与平均梯度范数关联,最后通过望远镜求和和损失下界得到不等式。
五、实验验证:全方位领先
论文在三大类任务上进行了全面实验:监督学习、强化学习、半监督学习。
5.1 监督学习:视觉多任务数据集
- 数据集:NYU-v2(语义分割、深度估计、表面法线预测)、CityScapes(语义分割、深度估计)。
- 基线方法:MTAN(多任务注意力网络)、PCGrad、MGDA、GradNorm、Cross-Stitch 等。
- 指标 :各任务专用指标(如 mIoU、深度误差)及平均任务性能下降率 Δ m % \Delta m\% Δm%(越低越好)。
- 结果:CAGrad 在保持主任务性能的同时,显著提升了其他方法通常忽视的任务(如表面法线预测),取得了最优的平均性能。
5.2 强化学习:Meta-World 机器人操作
- 环境:MT10(10任务)、MT50(50任务)。
- 基础算法:Soft Actor-Critic(SAC)。
- 挑战:任务数多,梯度冲突严重,计算开销大。
- CAGrad 技巧 :
- 任务子采样(CAGrad-Fast):每一步随机采样部分任务(MT10 采4个,MT50 采8个)计算梯度,大幅降低计算量,理论保证仍成立。
- 近似求解对偶问题:用梯度下降近似求解权重 w w w,而非调用凸优化库,进一步提高速度。
- 结果:CAGrad 在成功率上显著优于 PCGrad、Soft Modularization 等方法;CAGrad-Fast 在几乎不损失性能的情况下,速度比 PCGrad 快 2-5 倍。
5.3 半监督学习:CIFAR-10 辅助任务
- 设定:主任务为分类,两个辅助任务(旋转预测、一致性正则)用于利用无标签数据。
- 方法:将 CAGrad 与 ARML、GradNorm 等半监督学习方法结合。
- 结果:CAGrad 在不同标注数据量(500/1000/2000)下均取得最高平均测试准确率,且分析显示它有效平衡了主任务与辅助任务的优化。
六、总结与展望
CAGrad 的核心贡献在于:用一个简单而严谨的优化公式,解决了多任务学习中的梯度冲突问题,并首次提供了收敛到平均损失最优点的理论保证。其关键创新点包括:
- 冲突度量:以最差任务的损失下降量作为冲突程度的度量。
- 约束搜索:将更新方向限制在平均梯度邻域内,确保不偏离主目标。
- 高效求解:通过对偶变换将高维问题转化为低维凸优化,并支持任务子采样进一步加速。
当前 CAGrad 主要针对平均损失优化,未来可探索其他主要目标(如最差任务性能最大化)下的扩展。