CAGrad:保证收敛到平均损失最小的多任务梯度算法

Liu, Bo, et al. Conflict-averse gradient descent for multi-task learning. NeurIPS 2021.

在多任务学习(Multi-task Learning, MTL)中,一个模型同时学习多个任务,期待通过知识共享提升整体效率与性能。然而,现实情况下,不同任务的梯度方向可能南辕北辙,直接按平均梯度更新模型,反而会导致某些任务性能大幅下降。这篇题为《Conflict-Averse Gradient Descent for Multi-task Learning》的论文,提出了一个简洁而有效的解决方案 ------ 冲突避免梯度下降(CAGrad)


文章目录

    • 一、问题背景:多任务学习与梯度冲突
      • [1.1 多任务学习的目标](#1.1 多任务学习的目标)
      • [1.2 梯度冲突:为何平均损失优化会失败?](#1.2 梯度冲突:为何平均损失优化会失败?)
    • 二、现有方法及其局限
    • 三、CAGrad:冲突避免梯度下降
      • [3.1 核心直觉](#3.1 核心直觉)
      • [3.2 形式化定义](#3.2 形式化定义)
      • [3.3 如何求解?对偶变换降维](#3.3 如何求解?对偶变换降维)
      • [3.4 超参数 c 的角色:平衡的艺术](#3.4 超参数 c 的角色:平衡的艺术)
    • [四、理论:CAGrad 保证收敛到平均损失最小点](#四、理论:CAGrad 保证收敛到平均损失最小点)
      • [4.1 收敛定理](#4.1 收敛定理)
      • [4.2 定理含义](#4.2 定理含义)
    • 五、实验验证:全方位领先
      • [5.1 监督学习:视觉多任务数据集](#5.1 监督学习:视觉多任务数据集)
      • [5.2 强化学习:Meta-World 机器人操作](#5.2 强化学习:Meta-World 机器人操作)
      • [5.3 半监督学习:CIFAR-10 辅助任务](#5.3 半监督学习:CIFAR-10 辅助任务)
    • 六、总结与展望

一、问题背景:多任务学习与梯度冲突

1.1 多任务学习的目标

我们假设有 K 个任务(例如 K=3,分别是图像分类、目标检测、语义分割),它们共享一个模型参数 θ。每个任务有自己的损失函数 L i ( θ ) L_i(\theta) Li(θ)。常见的优化目标是最小化所有任务的平均损失

L 0 ( θ ) = 1 K ∑ i = 1 K L i ( θ ) L_0(\theta) = \frac{1}{K} \sum_{i=1}^{K} L_i(\theta) L0(θ)=K1i=1∑KLi(θ)

1.2 梯度冲突:为何平均损失优化会失败?

设任务 i 的梯度为 g i = ∇ L i ( θ ) g_i = \nabla L_i(\theta) gi=∇Li(θ),平均梯度为 g 0 = 1 K ∑ i g i g_0 = \frac{1}{K}\sum_i g_i g0=K1∑igi。直接按照平均梯度更新参数: θ ← θ − α g 0 \theta \leftarrow \theta - \alpha g_0 θ←θ−αg0,看似合理,但存在两大问题:

  1. 梯度尺度差异:某个任务的梯度可能特别大,主导更新方向,淹没其他任务。
  2. 梯度方向冲突 :可能存在任务 i,使得 ⟨ g i , g 0 ⟩ < 0 \langle g_i, g_0 \rangle < 0 ⟨gi,g0⟩<0,即该任务的梯度与平均梯度方向相反。此时,按平均梯度更新反而会恶化该任务的性能。

举例说明:假设任务 A(分类)需要模型关注整体轮廓,梯度方向指向右侧;任务 B(细节分割)需要模型关注局部纹理,梯度方向指向左侧。平均梯度可能指向正前方,对两个任务都没有明显好处,甚至可能让两者都变差。

二、现有方法及其局限

梯度操作的主流方法:

  • MGDA(Multiple Gradient Descent Algorithm) :将多任务学习视为多目标优化,寻找帕累托最优解。其更新方向是使得任何任务损失都不增加的最小范数梯度方向。问题:只能保证收敛到帕累托集上的某个点,无法控制具体是哪个点,可能远离平均损失最优点。
  • PCGrad(Projecting Conflicting Gradients) :将每个任务的梯度投影到其他梯度的正交方向上,消除冲突分量后再平均。问题:同样只能收敛到任意帕累托点,且计算效率低(需循环处理所有任务对)。
  • 梯度重加权(如 GradNorm) :根据任务难度或梯度幅值动态调整损失权重。问题:缺乏理论收敛保证,性能不稳定。

核心局限:现有方法大多 缺乏收敛到平均损失最小点的理论保证,且可能过度偏向某些任务,忽视整体优化目标。

三、CAGrad:冲突避免梯度下降

3.1 核心直觉

CAGrad 希望在平均梯度 g 0 g_0 g0 的附近 寻找一个新的更新方向 d d d,这个方向不仅要降低平均损失,还要确保每个任务的损失都能有所下降(至少不要恶化最差的任务)。换句话说,CAGrad 扮演一个"公平的教练",在确保团队整体进步的同时,特别关照进步最慢的队员。

3.2 形式化定义

每一步,CAGrad 求解如下优化问题:

max ⁡ d ∈ R m min ⁡ i ∈ [ K ] ⟨ g i , d ⟩ s.t. ∥ d − g 0 ∥ ≤ c ∥ g 0 ∥ \max_{d \in \mathbb{R}^m} \min_{i \in [K]} \langle g_i, d \rangle \quad \text{s.t.} \quad \|d - g_0\| \leq c \|g_0\| d∈Rmmaxi∈[K]min⟨gi,d⟩s.t.∥d−g0∥≤c∥g0∥

  • min ⁡ i ⟨ g i , d ⟩ \min_i \langle g_i, d \rangle mini⟨gi,d⟩ 表示所有任务中,损失下降最慢的那个任务的"下降量"。我们希望最大化这个最慢的下降量,从而保证所有任务都能受益。
  • 约束条件 ∥ d − g 0 ∥ ≤ c ∥ g 0 ∥ \|d - g_0\| \leq c \|g_0\| ∥d−g0∥≤c∥g0∥ 要求更新方向 d d d 不能偏离平均梯度 g 0 g_0 g0 太远,其中 c ∈ [ 0 , 1 ) c \in [0,1) c∈[0,1) 是一个超参数,控制着偏离的程度。这个约束是理论收敛的关键:它确保我们不会完全脱离优化平均损失的主航道。

3.3 如何求解?对偶变换降维

直接优化高维向量 d d d(参数数量可能数百万)计算量大。论文通过拉格朗日对偶 ,将原问题转化为一个仅需优化 K 维权重向量 w w w 的问题:

  1. 令 W \mathcal{W} W 为概率单纯形(所有权重非负且和为 1)。可以证明:
    min ⁡ i ⟨ g i , d ⟩ = min ⁡ w ∈ W ⟨ ∑ i w i g i , d ⟩ \min_i \langle g_i, d \rangle = \min_{w \in \mathcal{W}} \langle \sum_{i} w_i g_i, d \rangle imin⟨gi,d⟩=w∈Wmin⟨i∑wigi,d⟩

    因为线性函数在单纯形上的最小值必然在某个顶点(即某个任务梯度)处取得。

  2. 记 g w = ∑ i w i g i g_w = \sum_i w_i g_i gw=∑iwigi,原问题的对偶形式为:
    min ⁡ w ∈ W ( ⟨ g w , g 0 ⟩ + c ∥ g 0 ∥ ∥ g w ∥ ) \min_{w \in \mathcal{W}} \left( \langle g_w, g_0 \rangle + c \|g_0\| \|g_w\| \right) w∈Wmin(⟨gw,g0⟩+c∥g0∥∥gw∥)

    这是一个仅关于 w w w 的凸优化问题,维度为任务数 K(通常很小,如2~50),可用标准凸优化库快速求解。

  3. 求解得到最优权重 w ∗ w^* w∗ 后,原更新方向 d ∗ d^* d∗ 为:
    d ∗ = g 0 + c ∥ g 0 ∥ ∥ g w ∗ ∥ g w ∗ d^* = g_0 + \frac{c \|g_0\|}{\|g_{w^*}\|} g_{w^*} d∗=g0+∥gw∗∥c∥g0∥gw∗

    直观上看, d ∗ d^* d∗ 是在平均梯度 g 0 g_0 g0 的基础上,沿着加权梯度 g w ∗ g_{w^*} gw∗ 的方向进行了一定程度的修正。

3.4 超参数 c 的角色:平衡的艺术

  • c = 0: d ∗ = g 0 d^* = g_0 d∗=g0,CAGrad 退化为普通梯度下降(GD)。
  • 0 < c < 1:在"优化平均损失"与"照顾最差任务"之间取得平衡。理论保证成立(见下一节)。
  • c → ∞ \infty ∞(实践中取较大值如 10): d ∗ d^* d∗ 趋近于 g w g_w gw 方向,行为类似 MGDA,追求帕累托最优,但可能偏离平均损失最优点。

四、理论:CAGrad 保证收敛到平均损失最小点

4.1 收敛定理

在以下假设下:

  1. 所有损失函数 L i L_i Li 可微,且梯度是 Lipschitz 连续的(即变化不会太快)。
  2. 平均损失 L 0 L_0 L0 有下界(通常成立)。

定理 :若步长 α \alpha α 足够小,且超参数满足 0 ≤ c < 1 0 \leq c < 1 0≤c<1,则 CAGrad 生成的序列 { θ t } \{\theta_t\} {θt} 满足:
∑ t = 0 T ∥ ∇ L 0 ( θ t ) ∥ 2 ≤ 2 ( L 0 ( θ 0 ) − L 0 ∗ ) α ( 1 − c 2 ) \sum_{t=0}^{T} \| \nabla L_0(\theta_t) \|^2 \leq \frac{2(L_0(\theta_0) - L_0^*)}{\alpha(1 - c^2)} t=0∑T∥∇L0(θt)∥2≤α(1−c2)2(L0(θ0)−L0∗)

其中 L 0 ∗ L_0^* L0∗ 是平均损失的全局下界。

4.2 定理含义

  1. 收敛性 :不等式右边是常数,因此当 T → ∞ T \to \infty T→∞ 时, ∥ ∇ L 0 ( θ t ) ∥ → 0 \|\nabla L_0(\theta_t)\| \to 0 ∥∇L0(θt)∥→0,即算法必然收敛到平均损失 L 0 L_0 L0 的一个驻点(最优解或鞍点)。
  2. 收敛速度:梯度范数平方的累积和有一个明确上界,这提供了收敛速度的保证。
  3. 与 MGDA / PCGrad 的关键区别 :CAGrad 在 c < 1 c<1 c<1 时保证收敛到平均损失的最优点,而 MGDA / PCGrad 只能收敛到任意的帕累托点,无法控制具体位置。

证明核心思想 :利用 Lipschitz 条件推导单步下降量,再通过约束 ∥ d − g 0 ∥ ≤ c ∥ g 0 ∥ \|d - g_0\| \leq c \|g_0\| ∥d−g0∥≤c∥g0∥ 将下降量与平均梯度范数关联,最后通过望远镜求和和损失下界得到不等式。

五、实验验证:全方位领先

论文在三大类任务上进行了全面实验:监督学习、强化学习、半监督学习。

5.1 监督学习:视觉多任务数据集

  • 数据集:NYU-v2(语义分割、深度估计、表面法线预测)、CityScapes(语义分割、深度估计)。
  • 基线方法:MTAN(多任务注意力网络)、PCGrad、MGDA、GradNorm、Cross-Stitch 等。
  • 指标 :各任务专用指标(如 mIoU、深度误差)及平均任务性能下降率 Δ m % \Delta m\% Δm%(越低越好)。
  • 结果:CAGrad 在保持主任务性能的同时,显著提升了其他方法通常忽视的任务(如表面法线预测),取得了最优的平均性能。

5.2 强化学习:Meta-World 机器人操作

  • 环境:MT10(10任务)、MT50(50任务)。
  • 基础算法:Soft Actor-Critic(SAC)。
  • 挑战:任务数多,梯度冲突严重,计算开销大。
  • CAGrad 技巧
    1. 任务子采样(CAGrad-Fast):每一步随机采样部分任务(MT10 采4个,MT50 采8个)计算梯度,大幅降低计算量,理论保证仍成立。
    2. 近似求解对偶问题:用梯度下降近似求解权重 w w w,而非调用凸优化库,进一步提高速度。
  • 结果:CAGrad 在成功率上显著优于 PCGrad、Soft Modularization 等方法;CAGrad-Fast 在几乎不损失性能的情况下,速度比 PCGrad 快 2-5 倍。

5.3 半监督学习:CIFAR-10 辅助任务

  • 设定:主任务为分类,两个辅助任务(旋转预测、一致性正则)用于利用无标签数据。
  • 方法:将 CAGrad 与 ARML、GradNorm 等半监督学习方法结合。
  • 结果:CAGrad 在不同标注数据量(500/1000/2000)下均取得最高平均测试准确率,且分析显示它有效平衡了主任务与辅助任务的优化。

六、总结与展望

CAGrad 的核心贡献在于:用一个简单而严谨的优化公式,解决了多任务学习中的梯度冲突问题,并首次提供了收敛到平均损失最优点的理论保证。其关键创新点包括:

  1. 冲突度量:以最差任务的损失下降量作为冲突程度的度量。
  2. 约束搜索:将更新方向限制在平均梯度邻域内,确保不偏离主目标。
  3. 高效求解:通过对偶变换将高维问题转化为低维凸优化,并支持任务子采样进一步加速。

当前 CAGrad 主要针对平均损失优化,未来可探索其他主要目标(如最差任务性能最大化)下的扩展。

相关推荐
棒棒的皮皮2 小时前
【深度学习】YOLO模型评估之指标、可视化曲线分析
人工智能·深度学习·yolo·计算机视觉
驭白.3 小时前
不止于自动化:新能源汽车智造的数字基座如何搭建?
大数据·人工智能·自动化·汽车·数字化转型·制造业
企业智能研究3 小时前
什么是数据治理?数据治理对企业有什么用?
大数据·人工智能·数据分析·agent
阿里云大数据AI技术3 小时前
面向 Interleaved Thinking 的大模型 Agent 蒸馏实践
人工智能
AI Echoes4 小时前
LangChain 非分割类型的文档转换器使用技巧
人工智能·python·langchain·prompt·agent
哔哔龙4 小时前
LangChain核心组件可用工具
人工智能
全栈独立开发者4 小时前
点餐系统装上了“DeepSeek大脑”:基于 Spring AI + PgVector 的 RAG 落地指南
java·人工智能·spring
2501_941878744 小时前
在班加罗尔工程实践中构建可持续演进的机器学习平台体系与技术实现分享
人工智能·机器学习
guoketg4 小时前
BERT的技术细节和面试问题汇总
人工智能·深度学习·bert
永远在Debug的小殿下4 小时前
SLAM开发环境(虚拟机的安装)
人工智能